diff --git a/example-service/Cargo.toml b/example-service/Cargo.toml index 47bfa16..707d88a 100644 --- a/example-service/Cargo.toml +++ b/example-service/Cargo.toml @@ -18,6 +18,7 @@ futures = "0.3" serde = { version = "1.0" } tarpc = { version = "0.18", path = "../tarpc", features = ["full"] } tokio = "0.2" +tokio-serde = { version = "0.6", features = ["json"] } env_logger = "0.6" [lib] diff --git a/example-service/src/client.rs b/example-service/src/client.rs index a6b8b3c..f9c37bb 100644 --- a/example-service/src/client.rs +++ b/example-service/src/client.rs @@ -7,6 +7,7 @@ use clap::{App, Arg}; use std::{io, net::SocketAddr}; use tarpc::{client, context}; +use tokio_serde::formats::Json; #[tokio::main] async fn main() -> io::Result<()> { @@ -40,7 +41,7 @@ async fn main() -> io::Result<()> { let name = flags.value_of("name").unwrap().into(); - let transport = tarpc::json_transport::connect(server_addr).await?; + let transport = tarpc::serde_transport::tcp::connect(server_addr, Json::default()).await?; // WorldClient is generated by the service attribute. It has a constructor `new` that takes a // config and any Transport as input. diff --git a/example-service/src/server.rs b/example-service/src/server.rs index 6543b1e..ac6d358 100644 --- a/example-service/src/server.rs +++ b/example-service/src/server.rs @@ -18,6 +18,7 @@ use tarpc::{ context, server::{self, Channel, Handler}, }; +use tokio_serde::formats::Json; // This is the type that implements the generated World trait. It is the business logic // and is used to start the server. @@ -66,7 +67,7 @@ async fn main() -> io::Result<()> { // JSON transport is provided by the json_transport tarpc module. It makes it easy // to start up a serde-powered json serialization strategy over TCP. - tarpc::json_transport::listen(&server_addr) + tarpc::serde_transport::tcp::listen(&server_addr, Json::default) .await? // Ignore accept errors. .filter_map(|r| future::ready(r.ok())) diff --git a/tarpc/Cargo.toml b/tarpc/Cargo.toml index 526db37..6c52863 100644 --- a/tarpc/Cargo.toml +++ b/tarpc/Cargo.toml @@ -17,10 +17,10 @@ default = [] serde1 = ["tarpc-plugins/serde1", "serde", "serde/derive"] tokio1 = ["tokio"] -bincode-transport = ["async-bincode", "futures-legacy", "futures-test", "futures/compat", "tokio-io", "tokio-tcp"] -json-transport = ["tokio/net", "tokio/stream", "tokio-serde/json", "tokio-util/codec"] +serde-transport = ["tokio-serde", "tokio-util/codec"] +tcp = ["tokio/net", "tokio/stream"] -full = ["serde1", "tokio1", "bincode-transport", "json-transport"] +full = ["serde1", "tokio1", "serde-transport", "tcp"] [badges] travis-ci = { repository = "google/tarpc" } @@ -38,13 +38,7 @@ tokio = { optional = true, version = "0.2", features = ["time"] } tokio-util = { optional = true, version = "0.2" } tarpc-plugins = { path = "../plugins" } -async-bincode = { optional = true, version = "0.4" } -futures-legacy = { optional = true, version = "0.1", package = "futures" } -futures-test = { optional = true, version = "0.3" } -tokio-io = { optional = true, version = "0.1" } -tokio-tcp = { optional = true, version = "0.1" } - -tokio-serde = { optional = true, version = "0.5" } +tokio-serde = { optional = true, version = "0.6" } [dev-dependencies] assert_matches = "1.0" @@ -55,16 +49,17 @@ humantime = "1.0" log = "0.4" pin-utils = "0.1.0-alpha" tokio = { version = "0.2", features = ["full"] } +tokio-serde = { version = "0.6", features = ["json"] } [[example]] name = "server_calling_server" -required-features = ["serde1"] +required-features = ["full"] [[example]] name = "readme" -required-features = ["serde1", "tokio1"] +required-features = ["full"] [[example]] name = "pubsub" -required-features = ["serde1"] +required-features = ["full"] diff --git a/tarpc/examples/pubsub.rs b/tarpc/examples/pubsub.rs index c61908f..8546a1a 100644 --- a/tarpc/examples/pubsub.rs +++ b/tarpc/examples/pubsub.rs @@ -24,6 +24,7 @@ use tarpc::{ client, context, server::{self, Handler}, }; +use tokio_serde::formats::Json; pub mod subscriber { #[tarpc::service] @@ -59,7 +60,7 @@ impl subscriber::Subscriber for Subscriber { impl Subscriber { async fn listen(id: u32, config: server::Config) -> io::Result { - let incoming = tarpc::json_transport::listen("localhost:0") + let incoming = tarpc::serde_transport::tcp::listen("localhost:0", Json::default) .await? .filter_map(|r| future::ready(r.ok())); let addr = incoming.get_ref().local_addr(); @@ -114,7 +115,7 @@ impl publisher::Publisher for Publisher { id: u32, addr: SocketAddr, ) -> io::Result<()> { - let conn = tarpc::json_transport::connect(addr).await?; + let conn = tarpc::serde_transport::tcp::connect(addr, Json::default()).await?; let subscriber = subscriber::SubscriberClient::new(client::Config::default(), conn).spawn()?; eprintln!("Subscribing {}.", id); @@ -146,7 +147,7 @@ impl publisher::Publisher for Publisher { async fn main() -> io::Result<()> { env_logger::init(); - let transport = tarpc::json_transport::listen("localhost:0") + let transport = tarpc::serde_transport::tcp::listen("localhost:0", Json::default) .await? .filter_map(|r| future::ready(r.ok())); let publisher_addr = transport.get_ref().local_addr(); @@ -160,7 +161,7 @@ async fn main() -> io::Result<()> { let subscriber1 = Subscriber::listen(0, server::Config::default()).await?; let subscriber2 = Subscriber::listen(1, server::Config::default()).await?; - let publisher_conn = tarpc::json_transport::connect(publisher_addr); + let publisher_conn = tarpc::serde_transport::tcp::connect(publisher_addr, Json::default()); let publisher_conn = publisher_conn.await?; let mut publisher = publisher::PublisherClient::new(client::Config::default(), publisher_conn).spawn()?; diff --git a/tarpc/examples/readme.rs b/tarpc/examples/readme.rs index 527ab91..b08245c 100644 --- a/tarpc/examples/readme.rs +++ b/tarpc/examples/readme.rs @@ -13,6 +13,7 @@ use tarpc::{ client, context, server::{BaseChannel, Channel}, }; +use tokio_serde::formats::Json; /// This is the service definition. It looks a lot like a trait definition. /// It defines one RPC, hello, which takes one arg, name, and returns a String. @@ -41,7 +42,7 @@ impl World for HelloServer { async fn main() -> io::Result<()> { // tarpc_json_transport is provided by the associated crate json_transport. It makes it // easy to start up a serde-powered JSON serialization strategy over TCP. - let mut transport = tarpc::json_transport::listen("localhost:0").await?; + let mut transport = tarpc::serde_transport::tcp::listen("localhost:0", Json::default).await?; let addr = transport.local_addr(); let server = async move { @@ -61,7 +62,7 @@ async fn main() -> io::Result<()> { }; tokio::spawn(server); - let transport = tarpc::json_transport::connect(addr).await?; + let transport = tarpc::serde_transport::tcp::connect(addr, Json::default()).await?; // WorldClient is generated by the tarpc::service attribute. It has a constructor `new` that // takes a config and any Transport as input. diff --git a/tarpc/examples/server_calling_server.rs b/tarpc/examples/server_calling_server.rs index 19130ed..bd10430 100644 --- a/tarpc/examples/server_calling_server.rs +++ b/tarpc/examples/server_calling_server.rs @@ -14,6 +14,7 @@ use tarpc::{ client, context, server::{Handler, Server}, }; +use tokio_serde::formats::Json; pub mod add { #[tarpc::service] @@ -66,7 +67,7 @@ impl DoubleService for DoubleServer { async fn main() -> io::Result<()> { env_logger::init(); - let add_listener = tarpc::json_transport::listen("localhost:0") + let add_listener = tarpc::serde_transport::tcp::listen("localhost:0", Json::default) .await? .filter_map(|r| future::ready(r.ok())); let addr = add_listener.get_ref().local_addr(); @@ -76,10 +77,10 @@ async fn main() -> io::Result<()> { .respond_with(AddServer.serve()); tokio::spawn(add_server); - let to_add_server = tarpc::json_transport::connect(addr).await?; + let to_add_server = tarpc::serde_transport::tcp::connect(addr, Json::default()).await?; let add_client = add::AddClient::new(client::Config::default(), to_add_server).spawn()?; - let double_listener = tarpc::json_transport::listen("localhost:0") + let double_listener = tarpc::serde_transport::tcp::listen("localhost:0", Json::default) .await? .filter_map(|r| future::ready(r.ok())); let addr = double_listener.get_ref().local_addr(); @@ -89,7 +90,7 @@ async fn main() -> io::Result<()> { .respond_with(DoubleServer { add_client }.serve()); tokio::spawn(double_server); - let to_double_server = tarpc::json_transport::connect(addr).await?; + let to_double_server = tarpc::serde_transport::tcp::connect(addr, Json::default()).await?; let mut double_client = double::DoubleClient::new(client::Config::default(), to_double_server).spawn()?; diff --git a/tarpc/src/bincode_transport/mod.rs b/tarpc/src/bincode_transport/mod.rs deleted file mode 100644 index d81e94e..0000000 --- a/tarpc/src/bincode_transport/mod.rs +++ /dev/null @@ -1,224 +0,0 @@ -// Copyright 2018 Google LLC -// -// Use of this source code is governed by an MIT-style -// license that can be found in the LICENSE file or at -// https://opensource.org/licenses/MIT. - -//! A TCP [`Transport`] that serializes as bincode. - -#![deny(missing_docs, missing_debug_implementations)] - -use async_bincode::{AsyncBincodeStream, AsyncDestination}; -use futures::{compat::*, prelude::*, ready, task::*}; -use pin_project::pin_project; -use serde::{Deserialize, Serialize}; -use std::{error::Error, io, marker::PhantomData, net::SocketAddr, pin::Pin}; -use tokio_io::{AsyncRead, AsyncWrite}; -use tokio_tcp::{TcpListener, TcpStream}; - -/// A transport that serializes to, and deserializes from, a [`TcpStream`]. -#[pin_project] -#[derive(Debug)] -pub struct Transport { - #[pin] - inner: Compat01As03Sink, SinkItem>, -} - -impl Stream for Transport -where - S: AsyncRead, - Item: for<'a> Deserialize<'a>, -{ - type Item = io::Result; - - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll>> { - match self.project().inner.poll_next(cx) { - Poll::Pending => Poll::Pending, - Poll::Ready(None) => Poll::Ready(None), - Poll::Ready(Some(Ok(next))) => Poll::Ready(Some(Ok(next))), - Poll::Ready(Some(Err(e))) => { - Poll::Ready(Some(Err(io::Error::new(io::ErrorKind::Other, e)))) - } - } - } -} - -impl Sink for Transport -where - S: AsyncWrite, - SinkItem: Serialize, -{ - type Error = io::Error; - - fn start_send(self: Pin<&mut Self>, item: SinkItem) -> io::Result<()> { - self.project() - .inner - .start_send(item) - .map_err(|e| io::Error::new(io::ErrorKind::Other, e)) - } - - fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - convert(self.project().inner.poll_ready(cx)) - } - - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - convert(self.project().inner.poll_flush(cx)) - } - - fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - convert(self.project().inner.poll_close(cx)) - } -} - -fn convert>>( - poll: Poll>, -) -> Poll> { - match poll { - Poll::Pending => Poll::Pending, - Poll::Ready(Ok(())) => Poll::Ready(Ok(())), - Poll::Ready(Err(e)) => Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, e))), - } -} - -impl Transport { - /// Returns the address of the peer connected over the transport. - pub fn peer_addr(&self) -> io::Result { - self.inner.get_ref().get_ref().peer_addr() - } - - /// Returns the address of this end of the transport. - pub fn local_addr(&self) -> io::Result { - self.inner.get_ref().get_ref().local_addr() - } -} - -impl AsRef for Transport { - fn as_ref(&self) -> &T { - self.inner.get_ref().get_ref() - } -} - -/// Returns a new bincode transport that reads from and writes to `io`. -pub fn new(io: TcpStream) -> Transport -where - Item: for<'de> Deserialize<'de>, - SinkItem: Serialize, -{ - Transport::from(io) -} - -impl From for Transport { - fn from(inner: S) -> Self { - Transport { - inner: Compat01As03Sink::new(AsyncBincodeStream::from(inner).for_async()), - } - } -} - -/// Connects to `addr`, wrapping the connection in a bincode transport. -pub async fn connect( - addr: &SocketAddr, -) -> io::Result> -where - Item: for<'de> Deserialize<'de>, - SinkItem: Serialize, -{ - Ok(new(TcpStream::connect(addr).compat().await?)) -} - -/// Listens on `addr`, wrapping accepted connections in bincode transports. -pub fn listen(addr: &SocketAddr) -> io::Result> -where - Item: for<'de> Deserialize<'de>, - SinkItem: Serialize, -{ - let listener = TcpListener::bind(addr)?; - let local_addr = listener.local_addr()?; - let incoming = listener.incoming().compat(); - Ok(Incoming { - incoming, - local_addr, - ghost: PhantomData, - }) -} - -/// A [`TcpListener`] that wraps connections in bincode transports. -#[pin_project] -#[derive(Debug)] -pub struct Incoming { - #[pin] - incoming: Compat01As03, - local_addr: SocketAddr, - ghost: PhantomData<(Item, SinkItem)>, -} - -impl Incoming { - /// Returns the address being listened on. - pub fn local_addr(&self) -> SocketAddr { - self.local_addr - } -} - -impl Stream for Incoming -where - Item: for<'a> Deserialize<'a>, - SinkItem: Serialize, -{ - type Item = io::Result>; - - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let next = ready!(self.project().incoming.poll_next(cx)?); - Poll::Ready(next.map(|conn| Ok(new(conn)))) - } -} - -#[cfg(test)] -mod tests { - use super::Transport; - use assert_matches::assert_matches; - use futures::{task::*, Sink, Stream}; - use pin_utils::pin_mut; - use std::io::Cursor; - - fn ctx() -> Context<'static> { - Context::from_waker(&noop_waker_ref()) - } - - #[test] - fn test_stream() { - // Frame is big endian; bincode is little endian. A bit confusing! - let reader = *b"\x00\x00\x00\x1e\x16\x00\x00\x00\x00\x00\x00\x00Test one, check check."; - let reader: Box<[u8]> = Box::new(reader); - let transport = Transport::<_, String, String>::from(Cursor::new(reader)); - pin_mut!(transport); - - assert_matches!( - transport.poll_next(&mut ctx()), - Poll::Ready(Some(Ok(ref s))) if s == "Test one, check check."); - } - - #[test] - fn test_sink() { - let writer: &mut [u8] = &mut [0; 34]; - let transport = Transport::<_, String, String>::from(Cursor::new(&mut *writer)); - pin_mut!(transport); - - assert_matches!( - transport.as_mut().poll_ready(&mut ctx()), - Poll::Ready(Ok(())) - ); - assert_matches!( - transport - .as_mut() - .start_send("Test one, check check.".into()), - Ok(()) - ); - assert_matches!(transport.poll_flush(&mut ctx()), Poll::Ready(Ok(()))); - assert_eq!( - writer, - <&[u8]>::from( - b"\x00\x00\x00\x1e\x16\x00\x00\x00\x00\x00\x00\x00Test one, check check." - ) - ); - } -} diff --git a/tarpc/src/json_transport/mod.rs b/tarpc/src/json_transport/mod.rs deleted file mode 100644 index 210b513..0000000 --- a/tarpc/src/json_transport/mod.rs +++ /dev/null @@ -1,289 +0,0 @@ -// Copyright 2019 Google LLC -// -// Use of this source code is governed by an MIT-style -// license that can be found in the LICENSE file or at -// https://opensource.org/licenses/MIT. - -//! A TCP [`Transport`] that serializes as JSON. - -#![deny(missing_docs)] - -use futures::{prelude::*, ready, task::*}; -use pin_project::pin_project; -use serde::{Deserialize, Serialize}; -use std::{error::Error, io, marker::PhantomData, net::SocketAddr, pin::Pin}; -use tokio::io::{AsyncRead, AsyncWrite}; -use tokio::net::{TcpListener, TcpStream, ToSocketAddrs}; -use tokio_serde::{formats::*, *}; -use tokio_util::codec::{length_delimited::LengthDelimitedCodec, Framed}; - -/// A transport that serializes to, and deserializes from, a [`TcpStream`]. -#[pin_project] -pub struct Transport { - #[pin] - inner: FramedRead< - FramedWrite, SinkItem, Json>, - Item, - Json, - >, -} - -impl Stream for Transport -where - // TODO: Remove Unpin bound when tokio-rs/tokio#1272 is resolved. - S: AsyncWrite + AsyncRead + Unpin, - Item: for<'a> Deserialize<'a>, -{ - type Item = io::Result; - - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll>> { - match self.project().inner.poll_next(cx) { - Poll::Pending => Poll::Pending, - Poll::Ready(None) => Poll::Ready(None), - Poll::Ready(Some(Ok(next))) => Poll::Ready(Some(Ok(next))), - Poll::Ready(Some(Err(e))) => { - Poll::Ready(Some(Err(io::Error::new(io::ErrorKind::Other, e)))) - } - } - } -} - -impl Sink for Transport -where - S: AsyncWrite + Unpin, - SinkItem: Serialize, -{ - type Error = io::Error; - - fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - convert(self.project().inner.poll_ready(cx)) - } - - fn start_send(self: Pin<&mut Self>, item: SinkItem) -> io::Result<()> { - self.project() - .inner - .start_send(item) - .map_err(|e| io::Error::new(io::ErrorKind::Other, e)) - } - - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - convert(self.project().inner.poll_flush(cx)) - } - - fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - convert(self.project().inner.poll_close(cx)) - } -} - -fn convert>>( - poll: Poll>, -) -> Poll> { - match poll { - Poll::Pending => Poll::Pending, - Poll::Ready(Ok(())) => Poll::Ready(Ok(())), - Poll::Ready(Err(e)) => Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, e))), - } -} - -impl Transport { - /// Returns the peer address of the underlying TcpStream. - pub fn peer_addr(&self) -> io::Result { - self.inner.get_ref().get_ref().get_ref().peer_addr() - } - - /// Returns the local address of the underlying TcpStream. - pub fn local_addr(&self) -> io::Result { - self.inner.get_ref().get_ref().get_ref().local_addr() - } -} - -/// Returns a new JSON transport that reads from and writes to `io`. -pub fn new(io: TcpStream) -> Transport -where - Item: for<'de> Deserialize<'de>, - SinkItem: Serialize, -{ - Transport::from(io) -} - -impl From - for Transport -{ - fn from(inner: S) -> Self { - Transport { - inner: FramedRead::new( - FramedWrite::new( - Framed::new(inner, LengthDelimitedCodec::new()), - Json::default(), - ), - Json::default(), - ), - } - } -} - -/// Connects to `addr`, wrapping the connection in a JSON transport. -pub async fn connect(addr: A) -> io::Result> -where - A: ToSocketAddrs, - Item: for<'de> Deserialize<'de>, - SinkItem: Serialize, -{ - Ok(new(TcpStream::connect(addr).await?)) -} - -/// Listens on `addr`, wrapping accepted connections in JSON transports. -pub async fn listen(addr: A) -> io::Result> -where - A: ToSocketAddrs, - Item: for<'de> Deserialize<'de>, - SinkItem: Serialize, -{ - let listener = TcpListener::bind(addr).await?; - let local_addr = listener.local_addr()?; - Ok(Incoming { - listener, - local_addr, - ghost: PhantomData, - }) -} - -/// A [`TcpListener`] that wraps connections in JSON transports. -#[pin_project] -#[derive(Debug)] -pub struct Incoming { - listener: TcpListener, - local_addr: SocketAddr, - ghost: PhantomData<(Item, SinkItem)>, -} - -impl Incoming { - /// Returns the address being listened on. - pub fn local_addr(&self) -> SocketAddr { - self.local_addr - } -} - -impl Stream for Incoming -where - Item: for<'a> Deserialize<'a>, - SinkItem: Serialize, -{ - type Item = io::Result>; - - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let next = ready!(Pin::new(&mut self.project().listener.incoming()).poll_next(cx)?); - Poll::Ready(next.map(|conn| Ok(new(conn)))) - } -} - -#[cfg(test)] -mod tests { - use super::Transport; - use assert_matches::assert_matches; - use futures::{task::*, Sink, Stream}; - use pin_utils::pin_mut; - use std::{ - io::{self, Cursor}, - pin::Pin, - }; - use tokio::io::{AsyncRead, AsyncWrite}; - - fn ctx() -> Context<'static> { - Context::from_waker(&noop_waker_ref()) - } - - #[test] - fn test_stream() { - struct TestIo(Cursor<&'static [u8]>); - - impl AsyncRead for TestIo { - fn poll_read( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll> { - AsyncRead::poll_read(Pin::new(self.0.get_mut()), cx, buf) - } - } - - impl AsyncWrite for TestIo { - fn poll_write( - self: Pin<&mut Self>, - _cx: &mut Context<'_>, - _buf: &[u8], - ) -> Poll> { - unreachable!() - } - - fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { - unreachable!() - } - - fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { - unreachable!() - } - } - - let data = b"\x00\x00\x00\x18\"Test one, check check.\""; - let transport = Transport::<_, String, String>::from(TestIo(Cursor::new(data))); - pin_mut!(transport); - - assert_matches!( - transport.poll_next(&mut ctx()), - Poll::Ready(Some(Ok(ref s))) if s == "Test one, check check."); - } - - #[test] - fn test_sink() { - struct TestIo<'a>(&'a mut Vec); - - impl<'a> AsyncRead for TestIo<'a> { - fn poll_read( - self: Pin<&mut Self>, - _cx: &mut Context<'_>, - _buf: &mut [u8], - ) -> Poll> { - unreachable!() - } - } - - impl<'a> AsyncWrite for TestIo<'a> { - fn poll_write( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - AsyncWrite::poll_write(Pin::new(&mut *self.0), cx, buf) - } - - fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - AsyncWrite::poll_flush(Pin::new(&mut *self.0), cx) - } - - fn poll_shutdown( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { - AsyncWrite::poll_shutdown(Pin::new(&mut *self.0), cx) - } - } - - let mut writer = vec![]; - let transport = Transport::<_, String, String>::from(TestIo(&mut writer)); - pin_mut!(transport); - - assert_matches!( - transport.as_mut().poll_ready(&mut ctx()), - Poll::Ready(Ok(())) - ); - assert_matches!( - transport - .as_mut() - .start_send("Test one, check check.".into()), - Ok(()) - ); - assert_matches!(transport.poll_flush(&mut ctx()), Poll::Ready(Ok(()))); - assert_eq!(writer, b"\x00\x00\x00\x18\"Test one, check check.\""); - } -} diff --git a/tarpc/src/lib.rs b/tarpc/src/lib.rs index 383e8e3..ae23496 100644 --- a/tarpc/src/lib.rs +++ b/tarpc/src/lib.rs @@ -209,10 +209,8 @@ pub mod rpc; pub use rpc::*; -#[cfg(feature = "bincode-transport")] -pub mod bincode_transport; -#[cfg(feature = "json-transport")] -pub mod json_transport; +#[cfg(feature = "serde-transport")] +pub mod serde_transport; pub mod trace; diff --git a/tarpc/src/serde_transport/mod.rs b/tarpc/src/serde_transport/mod.rs new file mode 100644 index 0000000..7c1b7ff --- /dev/null +++ b/tarpc/src/serde_transport/mod.rs @@ -0,0 +1,326 @@ +// Copyright 2019 Google LLC +// +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file or at +// https://opensource.org/licenses/MIT. + +//! A generic Serde-based `Transport` that can serialize anything supported by `tokio-serde` via any medium that implements `AsyncRead` and `AsyncWrite`. + +#![deny(missing_docs)] + +use futures::{prelude::*, task::*}; +use pin_project::pin_project; +use serde::{Deserialize, Serialize}; +use std::{error::Error, io, pin::Pin}; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio_serde::{Framed as SerdeFramed, *}; +use tokio_util::codec::{length_delimited::LengthDelimitedCodec, Framed}; + +/// A transport that serializes to, and deserializes from, a [`TcpStream`]. +#[pin_project] +pub struct Transport { + #[pin] + inner: SerdeFramed, Item, SinkItem, Codec>, +} + +impl Stream for Transport +where + S: AsyncWrite + AsyncRead, + Item: for<'a> Deserialize<'a>, + Codec: Deserializer, + CodecError: Into>, + SerdeFramed, Item, SinkItem, Codec>: + Stream>, +{ + type Item = io::Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll>> { + match self.project().inner.poll_next(cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(None) => Poll::Ready(None), + Poll::Ready(Some(Ok::<_, CodecError>(next))) => Poll::Ready(Some(Ok(next))), + Poll::Ready(Some(Err::<_, CodecError>(e))) => { + Poll::Ready(Some(Err(io::Error::new(io::ErrorKind::Other, e)))) + } + } + } +} + +impl Sink for Transport +where + S: AsyncWrite, + SinkItem: Serialize, + Codec: Serializer, + CodecError: Into>, + SerdeFramed, Item, SinkItem, Codec>: + Sink, +{ + type Error = io::Error; + + fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + convert(self.project().inner.poll_ready(cx)) + } + + fn start_send(self: Pin<&mut Self>, item: SinkItem) -> io::Result<()> { + self.project() + .inner + .start_send(item) + .map_err(|e| io::Error::new(io::ErrorKind::Other, e)) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + convert(self.project().inner.poll_flush(cx)) + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + convert(self.project().inner.poll_close(cx)) + } +} + +fn convert>>( + poll: Poll>, +) -> Poll> { + poll.map(|ready| ready.map_err(|e| io::Error::new(io::ErrorKind::Other, e))) +} + +impl From<(S, Codec)> for Transport +where + S: AsyncWrite + AsyncRead, + Item: for<'de> Deserialize<'de>, + SinkItem: Serialize, + Codec: Serializer + Deserializer, +{ + fn from((inner, codec): (S, Codec)) -> Self { + Transport { + inner: SerdeFramed::new(Framed::new(inner, LengthDelimitedCodec::new()), codec), + } + } +} + +#[cfg(feature = "tcp")] +#[cfg_attr(docsrs, doc(cfg(feature = "tcp")))] +/// TCP support for generic transport using Tokio. +pub mod tcp { + use { + super::*, + futures::ready, + std::{marker::PhantomData, net::SocketAddr}, + tokio::net::{TcpListener, TcpStream, ToSocketAddrs}, + }; + + mod private { + use super::*; + + pub trait Sealed {} + + impl Sealed for Transport {} + } + + impl Transport { + /// Returns the peer address of the underlying TcpStream. + pub fn peer_addr(&self) -> io::Result { + self.inner.get_ref().get_ref().peer_addr() + } + /// Returns the local address of the underlying TcpStream. + pub fn local_addr(&self) -> io::Result { + self.inner.get_ref().get_ref().local_addr() + } + } + + /// Returns a new JSON transport that reads from and writes to `io`. + pub fn new( + io: TcpStream, + codec: Codec, + ) -> Transport + where + Item: for<'de> Deserialize<'de>, + SinkItem: Serialize, + Codec: Serializer + Deserializer, + { + Transport::from((io, codec)) + } + + /// Connects to `addr`, wrapping the connection in a JSON transport. + pub async fn connect( + addr: A, + codec: Codec, + ) -> io::Result> + where + A: ToSocketAddrs, + Item: for<'de> Deserialize<'de>, + SinkItem: Serialize, + Codec: Serializer + Deserializer, + { + Ok(new(TcpStream::connect(addr).await?, codec)) + } + + /// Listens on `addr`, wrapping accepted connections in JSON transports. + pub async fn listen( + addr: A, + codec_fn: CodecFn, + ) -> io::Result> + where + A: ToSocketAddrs, + Item: for<'de> Deserialize<'de>, + Codec: Serializer + Deserializer, + CodecFn: Fn() -> Codec, + { + let listener = TcpListener::bind(addr).await?; + let local_addr = listener.local_addr()?; + Ok(Incoming { + listener, + codec_fn, + local_addr, + ghost: PhantomData, + }) + } + + /// A [`TcpListener`] that wraps connections in JSON transports. + #[pin_project] + #[derive(Debug)] + pub struct Incoming { + listener: TcpListener, + local_addr: SocketAddr, + codec_fn: CodecFn, + ghost: PhantomData<(Item, SinkItem, Codec)>, + } + + impl Incoming { + /// Returns the address being listened on. + pub fn local_addr(&self) -> SocketAddr { + self.local_addr + } + } + + impl Stream for Incoming + where + Item: for<'de> Deserialize<'de>, + SinkItem: Serialize, + Codec: Serializer + Deserializer, + CodecFn: Fn() -> Codec, + { + type Item = io::Result>; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let next = + ready!(Pin::new(&mut self.as_mut().project().listener.incoming()).poll_next(cx)?); + Poll::Ready(next.map(|conn| Ok(new(conn, (self.codec_fn)())))) + } + } +} + +#[cfg(test)] +mod tests { + use super::Transport; + use assert_matches::assert_matches; + use futures::{task::*, Sink, Stream}; + use pin_utils::pin_mut; + use std::{ + io::{self, Cursor}, + pin::Pin, + }; + use tokio::io::{AsyncRead, AsyncWrite}; + use tokio_serde::formats::SymmetricalJson; + + fn ctx() -> Context<'static> { + Context::from_waker(&noop_waker_ref()) + } + + #[test] + fn test_stream() { + struct TestIo(Cursor<&'static [u8]>); + + impl AsyncRead for TestIo { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + AsyncRead::poll_read(Pin::new(self.0.get_mut()), cx, buf) + } + } + + impl AsyncWrite for TestIo { + fn poll_write( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + _buf: &[u8], + ) -> Poll> { + unreachable!() + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + unreachable!() + } + + fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + unreachable!() + } + } + + let data = b"\x00\x00\x00\x18\"Test one, check check.\""; + let transport = Transport::from(( + TestIo(Cursor::new(data)), + SymmetricalJson::::default(), + )); + pin_mut!(transport); + + assert_matches!( + transport.poll_next(&mut ctx()), + Poll::Ready(Some(Ok(ref s))) if s == "Test one, check check."); + } + + #[test] + fn test_sink() { + struct TestIo<'a>(&'a mut Vec); + + impl<'a> AsyncRead for TestIo<'a> { + fn poll_read( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + _buf: &mut [u8], + ) -> Poll> { + unreachable!() + } + } + + impl<'a> AsyncWrite for TestIo<'a> { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + AsyncWrite::poll_write(Pin::new(&mut *self.0), cx, buf) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + AsyncWrite::poll_flush(Pin::new(&mut *self.0), cx) + } + + fn poll_shutdown( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + AsyncWrite::poll_shutdown(Pin::new(&mut *self.0), cx) + } + } + + let mut writer = vec![]; + let transport = + Transport::from((TestIo(&mut writer), SymmetricalJson::::default())); + pin_mut!(transport); + + assert_matches!( + transport.as_mut().poll_ready(&mut ctx()), + Poll::Ready(Ok(())) + ); + assert_matches!( + transport + .as_mut() + .start_send("Test one, check check.".into()), + Ok(()) + ); + assert_matches!(transport.poll_flush(&mut ctx()), Poll::Ready(Ok(()))); + assert_eq!(writer, b"\x00\x00\x00\x18\"Test one, check check.\""); + } +} diff --git a/tarpc/tests/service_functional.rs b/tarpc/tests/service_functional.rs index 9974d6e..241383c 100644 --- a/tarpc/tests/service_functional.rs +++ b/tarpc/tests/service_functional.rs @@ -6,10 +6,11 @@ use futures::{ use std::io; use tarpc::{ client::{self}, - context, json_transport, + context, serde_transport, server::{self, BaseChannel, Channel, Handler}, transport::channel, }; +use tokio_serde::formats::Json; #[tarpc_plugins::service] trait Service { @@ -61,7 +62,7 @@ async fn sequential() -> io::Result<()> { async fn serde() -> io::Result<()> { let _ = env_logger::try_init(); - let transport = json_transport::listen("localhost:56789").await?; + let transport = serde_transport::tcp::listen("localhost:56789", Json::default).await?; let addr = transport.local_addr(); tokio::spawn( tarpc::Server::default() @@ -69,7 +70,7 @@ async fn serde() -> io::Result<()> { .respond_with(Server.serve()), ); - let transport = json_transport::connect(addr).await?; + let transport = serde_transport::tcp::connect(addr, Json::default()).await?; let mut client = ServiceClient::new(client::Config::default(), transport).spawn()?; assert_matches!(client.add(context::current(), 1, 2).await, Ok(3));