Unify serde transports.

This PR obsoletes the JSON and Bincode transports and instead introduces a unified transport that
is generic over any tokio-serde serialization format as well as AsyncRead + AsyncWrite medium.
This comes with a slight hit for usability (having to manually specify the underlying transport
and codec), but it can be alleviated by making custom freestanding connect and listen fns.
This commit is contained in:
Artem Vorotnikov
2019-11-29 03:30:19 +03:00
committed by Tim Kuehn
parent f945392b5a
commit bbbd43e282
12 changed files with 358 additions and 545 deletions

View File

@@ -18,6 +18,7 @@ futures = "0.3"
serde = { version = "1.0" }
tarpc = { version = "0.18", path = "../tarpc", features = ["full"] }
tokio = "0.2"
tokio-serde = { version = "0.6", features = ["json"] }
env_logger = "0.6"
[lib]

View File

@@ -7,6 +7,7 @@
use clap::{App, Arg};
use std::{io, net::SocketAddr};
use tarpc::{client, context};
use tokio_serde::formats::Json;
#[tokio::main]
async fn main() -> io::Result<()> {
@@ -40,7 +41,7 @@ async fn main() -> io::Result<()> {
let name = flags.value_of("name").unwrap().into();
let transport = tarpc::json_transport::connect(server_addr).await?;
let transport = tarpc::serde_transport::tcp::connect(server_addr, Json::default()).await?;
// WorldClient is generated by the service attribute. It has a constructor `new` that takes a
// config and any Transport as input.

View File

@@ -18,6 +18,7 @@ use tarpc::{
context,
server::{self, Channel, Handler},
};
use tokio_serde::formats::Json;
// This is the type that implements the generated World trait. It is the business logic
// and is used to start the server.
@@ -66,7 +67,7 @@ async fn main() -> io::Result<()> {
// JSON transport is provided by the json_transport tarpc module. It makes it easy
// to start up a serde-powered json serialization strategy over TCP.
tarpc::json_transport::listen(&server_addr)
tarpc::serde_transport::tcp::listen(&server_addr, Json::default)
.await?
// Ignore accept errors.
.filter_map(|r| future::ready(r.ok()))

View File

@@ -17,10 +17,10 @@ default = []
serde1 = ["tarpc-plugins/serde1", "serde", "serde/derive"]
tokio1 = ["tokio"]
bincode-transport = ["async-bincode", "futures-legacy", "futures-test", "futures/compat", "tokio-io", "tokio-tcp"]
json-transport = ["tokio/net", "tokio/stream", "tokio-serde/json", "tokio-util/codec"]
serde-transport = ["tokio-serde", "tokio-util/codec"]
tcp = ["tokio/net", "tokio/stream"]
full = ["serde1", "tokio1", "bincode-transport", "json-transport"]
full = ["serde1", "tokio1", "serde-transport", "tcp"]
[badges]
travis-ci = { repository = "google/tarpc" }
@@ -38,13 +38,7 @@ tokio = { optional = true, version = "0.2", features = ["time"] }
tokio-util = { optional = true, version = "0.2" }
tarpc-plugins = { path = "../plugins" }
async-bincode = { optional = true, version = "0.4" }
futures-legacy = { optional = true, version = "0.1", package = "futures" }
futures-test = { optional = true, version = "0.3" }
tokio-io = { optional = true, version = "0.1" }
tokio-tcp = { optional = true, version = "0.1" }
tokio-serde = { optional = true, version = "0.5" }
tokio-serde = { optional = true, version = "0.6" }
[dev-dependencies]
assert_matches = "1.0"
@@ -55,16 +49,17 @@ humantime = "1.0"
log = "0.4"
pin-utils = "0.1.0-alpha"
tokio = { version = "0.2", features = ["full"] }
tokio-serde = { version = "0.6", features = ["json"] }
[[example]]
name = "server_calling_server"
required-features = ["serde1"]
required-features = ["full"]
[[example]]
name = "readme"
required-features = ["serde1", "tokio1"]
required-features = ["full"]
[[example]]
name = "pubsub"
required-features = ["serde1"]
required-features = ["full"]

View File

@@ -24,6 +24,7 @@ use tarpc::{
client, context,
server::{self, Handler},
};
use tokio_serde::formats::Json;
pub mod subscriber {
#[tarpc::service]
@@ -59,7 +60,7 @@ impl subscriber::Subscriber for Subscriber {
impl Subscriber {
async fn listen(id: u32, config: server::Config) -> io::Result<SocketAddr> {
let incoming = tarpc::json_transport::listen("localhost:0")
let incoming = tarpc::serde_transport::tcp::listen("localhost:0", Json::default)
.await?
.filter_map(|r| future::ready(r.ok()));
let addr = incoming.get_ref().local_addr();
@@ -114,7 +115,7 @@ impl publisher::Publisher for Publisher {
id: u32,
addr: SocketAddr,
) -> io::Result<()> {
let conn = tarpc::json_transport::connect(addr).await?;
let conn = tarpc::serde_transport::tcp::connect(addr, Json::default()).await?;
let subscriber =
subscriber::SubscriberClient::new(client::Config::default(), conn).spawn()?;
eprintln!("Subscribing {}.", id);
@@ -146,7 +147,7 @@ impl publisher::Publisher for Publisher {
async fn main() -> io::Result<()> {
env_logger::init();
let transport = tarpc::json_transport::listen("localhost:0")
let transport = tarpc::serde_transport::tcp::listen("localhost:0", Json::default)
.await?
.filter_map(|r| future::ready(r.ok()));
let publisher_addr = transport.get_ref().local_addr();
@@ -160,7 +161,7 @@ async fn main() -> io::Result<()> {
let subscriber1 = Subscriber::listen(0, server::Config::default()).await?;
let subscriber2 = Subscriber::listen(1, server::Config::default()).await?;
let publisher_conn = tarpc::json_transport::connect(publisher_addr);
let publisher_conn = tarpc::serde_transport::tcp::connect(publisher_addr, Json::default());
let publisher_conn = publisher_conn.await?;
let mut publisher =
publisher::PublisherClient::new(client::Config::default(), publisher_conn).spawn()?;

View File

@@ -13,6 +13,7 @@ use tarpc::{
client, context,
server::{BaseChannel, Channel},
};
use tokio_serde::formats::Json;
/// This is the service definition. It looks a lot like a trait definition.
/// It defines one RPC, hello, which takes one arg, name, and returns a String.
@@ -41,7 +42,7 @@ impl World for HelloServer {
async fn main() -> io::Result<()> {
// tarpc_json_transport is provided by the associated crate json_transport. It makes it
// easy to start up a serde-powered JSON serialization strategy over TCP.
let mut transport = tarpc::json_transport::listen("localhost:0").await?;
let mut transport = tarpc::serde_transport::tcp::listen("localhost:0", Json::default).await?;
let addr = transport.local_addr();
let server = async move {
@@ -61,7 +62,7 @@ async fn main() -> io::Result<()> {
};
tokio::spawn(server);
let transport = tarpc::json_transport::connect(addr).await?;
let transport = tarpc::serde_transport::tcp::connect(addr, Json::default()).await?;
// WorldClient is generated by the tarpc::service attribute. It has a constructor `new` that
// takes a config and any Transport as input.

View File

@@ -14,6 +14,7 @@ use tarpc::{
client, context,
server::{Handler, Server},
};
use tokio_serde::formats::Json;
pub mod add {
#[tarpc::service]
@@ -66,7 +67,7 @@ impl DoubleService for DoubleServer {
async fn main() -> io::Result<()> {
env_logger::init();
let add_listener = tarpc::json_transport::listen("localhost:0")
let add_listener = tarpc::serde_transport::tcp::listen("localhost:0", Json::default)
.await?
.filter_map(|r| future::ready(r.ok()));
let addr = add_listener.get_ref().local_addr();
@@ -76,10 +77,10 @@ async fn main() -> io::Result<()> {
.respond_with(AddServer.serve());
tokio::spawn(add_server);
let to_add_server = tarpc::json_transport::connect(addr).await?;
let to_add_server = tarpc::serde_transport::tcp::connect(addr, Json::default()).await?;
let add_client = add::AddClient::new(client::Config::default(), to_add_server).spawn()?;
let double_listener = tarpc::json_transport::listen("localhost:0")
let double_listener = tarpc::serde_transport::tcp::listen("localhost:0", Json::default)
.await?
.filter_map(|r| future::ready(r.ok()));
let addr = double_listener.get_ref().local_addr();
@@ -89,7 +90,7 @@ async fn main() -> io::Result<()> {
.respond_with(DoubleServer { add_client }.serve());
tokio::spawn(double_server);
let to_double_server = tarpc::json_transport::connect(addr).await?;
let to_double_server = tarpc::serde_transport::tcp::connect(addr, Json::default()).await?;
let mut double_client =
double::DoubleClient::new(client::Config::default(), to_double_server).spawn()?;

View File

@@ -1,224 +0,0 @@
// Copyright 2018 Google LLC
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
//! A TCP [`Transport`] that serializes as bincode.
#![deny(missing_docs, missing_debug_implementations)]
use async_bincode::{AsyncBincodeStream, AsyncDestination};
use futures::{compat::*, prelude::*, ready, task::*};
use pin_project::pin_project;
use serde::{Deserialize, Serialize};
use std::{error::Error, io, marker::PhantomData, net::SocketAddr, pin::Pin};
use tokio_io::{AsyncRead, AsyncWrite};
use tokio_tcp::{TcpListener, TcpStream};
/// A transport that serializes to, and deserializes from, a [`TcpStream`].
#[pin_project]
#[derive(Debug)]
pub struct Transport<S, Item, SinkItem> {
#[pin]
inner: Compat01As03Sink<AsyncBincodeStream<S, Item, SinkItem, AsyncDestination>, SinkItem>,
}
impl<S, Item, SinkItem> Stream for Transport<S, Item, SinkItem>
where
S: AsyncRead,
Item: for<'a> Deserialize<'a>,
{
type Item = io::Result<Item>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<io::Result<Item>>> {
match self.project().inner.poll_next(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(None) => Poll::Ready(None),
Poll::Ready(Some(Ok(next))) => Poll::Ready(Some(Ok(next))),
Poll::Ready(Some(Err(e))) => {
Poll::Ready(Some(Err(io::Error::new(io::ErrorKind::Other, e))))
}
}
}
}
impl<S, Item, SinkItem> Sink<SinkItem> for Transport<S, Item, SinkItem>
where
S: AsyncWrite,
SinkItem: Serialize,
{
type Error = io::Error;
fn start_send(self: Pin<&mut Self>, item: SinkItem) -> io::Result<()> {
self.project()
.inner
.start_send(item)
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))
}
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
convert(self.project().inner.poll_ready(cx))
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
convert(self.project().inner.poll_flush(cx))
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
convert(self.project().inner.poll_close(cx))
}
}
fn convert<E: Into<Box<dyn Error + Send + Sync>>>(
poll: Poll<Result<(), E>>,
) -> Poll<io::Result<()>> {
match poll {
Poll::Pending => Poll::Pending,
Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
Poll::Ready(Err(e)) => Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, e))),
}
}
impl<Item, SinkItem> Transport<TcpStream, Item, SinkItem> {
/// Returns the address of the peer connected over the transport.
pub fn peer_addr(&self) -> io::Result<SocketAddr> {
self.inner.get_ref().get_ref().peer_addr()
}
/// Returns the address of this end of the transport.
pub fn local_addr(&self) -> io::Result<SocketAddr> {
self.inner.get_ref().get_ref().local_addr()
}
}
impl<T, Item, SinkItem> AsRef<T> for Transport<T, Item, SinkItem> {
fn as_ref(&self) -> &T {
self.inner.get_ref().get_ref()
}
}
/// Returns a new bincode transport that reads from and writes to `io`.
pub fn new<Item, SinkItem>(io: TcpStream) -> Transport<TcpStream, Item, SinkItem>
where
Item: for<'de> Deserialize<'de>,
SinkItem: Serialize,
{
Transport::from(io)
}
impl<S, Item, SinkItem> From<S> for Transport<S, Item, SinkItem> {
fn from(inner: S) -> Self {
Transport {
inner: Compat01As03Sink::new(AsyncBincodeStream::from(inner).for_async()),
}
}
}
/// Connects to `addr`, wrapping the connection in a bincode transport.
pub async fn connect<Item, SinkItem>(
addr: &SocketAddr,
) -> io::Result<Transport<TcpStream, Item, SinkItem>>
where
Item: for<'de> Deserialize<'de>,
SinkItem: Serialize,
{
Ok(new(TcpStream::connect(addr).compat().await?))
}
/// Listens on `addr`, wrapping accepted connections in bincode transports.
pub fn listen<Item, SinkItem>(addr: &SocketAddr) -> io::Result<Incoming<Item, SinkItem>>
where
Item: for<'de> Deserialize<'de>,
SinkItem: Serialize,
{
let listener = TcpListener::bind(addr)?;
let local_addr = listener.local_addr()?;
let incoming = listener.incoming().compat();
Ok(Incoming {
incoming,
local_addr,
ghost: PhantomData,
})
}
/// A [`TcpListener`] that wraps connections in bincode transports.
#[pin_project]
#[derive(Debug)]
pub struct Incoming<Item, SinkItem> {
#[pin]
incoming: Compat01As03<tokio_tcp::Incoming>,
local_addr: SocketAddr,
ghost: PhantomData<(Item, SinkItem)>,
}
impl<Item, SinkItem> Incoming<Item, SinkItem> {
/// Returns the address being listened on.
pub fn local_addr(&self) -> SocketAddr {
self.local_addr
}
}
impl<Item, SinkItem> Stream for Incoming<Item, SinkItem>
where
Item: for<'a> Deserialize<'a>,
SinkItem: Serialize,
{
type Item = io::Result<Transport<TcpStream, Item, SinkItem>>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let next = ready!(self.project().incoming.poll_next(cx)?);
Poll::Ready(next.map(|conn| Ok(new(conn))))
}
}
#[cfg(test)]
mod tests {
use super::Transport;
use assert_matches::assert_matches;
use futures::{task::*, Sink, Stream};
use pin_utils::pin_mut;
use std::io::Cursor;
fn ctx() -> Context<'static> {
Context::from_waker(&noop_waker_ref())
}
#[test]
fn test_stream() {
// Frame is big endian; bincode is little endian. A bit confusing!
let reader = *b"\x00\x00\x00\x1e\x16\x00\x00\x00\x00\x00\x00\x00Test one, check check.";
let reader: Box<[u8]> = Box::new(reader);
let transport = Transport::<_, String, String>::from(Cursor::new(reader));
pin_mut!(transport);
assert_matches!(
transport.poll_next(&mut ctx()),
Poll::Ready(Some(Ok(ref s))) if s == "Test one, check check.");
}
#[test]
fn test_sink() {
let writer: &mut [u8] = &mut [0; 34];
let transport = Transport::<_, String, String>::from(Cursor::new(&mut *writer));
pin_mut!(transport);
assert_matches!(
transport.as_mut().poll_ready(&mut ctx()),
Poll::Ready(Ok(()))
);
assert_matches!(
transport
.as_mut()
.start_send("Test one, check check.".into()),
Ok(())
);
assert_matches!(transport.poll_flush(&mut ctx()), Poll::Ready(Ok(())));
assert_eq!(
writer,
<&[u8]>::from(
b"\x00\x00\x00\x1e\x16\x00\x00\x00\x00\x00\x00\x00Test one, check check."
)
);
}
}

View File

@@ -1,289 +0,0 @@
// Copyright 2019 Google LLC
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
//! A TCP [`Transport`] that serializes as JSON.
#![deny(missing_docs)]
use futures::{prelude::*, ready, task::*};
use pin_project::pin_project;
use serde::{Deserialize, Serialize};
use std::{error::Error, io, marker::PhantomData, net::SocketAddr, pin::Pin};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::{TcpListener, TcpStream, ToSocketAddrs};
use tokio_serde::{formats::*, *};
use tokio_util::codec::{length_delimited::LengthDelimitedCodec, Framed};
/// A transport that serializes to, and deserializes from, a [`TcpStream`].
#[pin_project]
pub struct Transport<S, Item, SinkItem> {
#[pin]
inner: FramedRead<
FramedWrite<Framed<S, LengthDelimitedCodec>, SinkItem, Json<SinkItem>>,
Item,
Json<Item>,
>,
}
impl<S, Item, SinkItem> Stream for Transport<S, Item, SinkItem>
where
// TODO: Remove Unpin bound when tokio-rs/tokio#1272 is resolved.
S: AsyncWrite + AsyncRead + Unpin,
Item: for<'a> Deserialize<'a>,
{
type Item = io::Result<Item>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<io::Result<Item>>> {
match self.project().inner.poll_next(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(None) => Poll::Ready(None),
Poll::Ready(Some(Ok(next))) => Poll::Ready(Some(Ok(next))),
Poll::Ready(Some(Err(e))) => {
Poll::Ready(Some(Err(io::Error::new(io::ErrorKind::Other, e))))
}
}
}
}
impl<S, Item, SinkItem> Sink<SinkItem> for Transport<S, Item, SinkItem>
where
S: AsyncWrite + Unpin,
SinkItem: Serialize,
{
type Error = io::Error;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
convert(self.project().inner.poll_ready(cx))
}
fn start_send(self: Pin<&mut Self>, item: SinkItem) -> io::Result<()> {
self.project()
.inner
.start_send(item)
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
convert(self.project().inner.poll_flush(cx))
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
convert(self.project().inner.poll_close(cx))
}
}
fn convert<E: Into<Box<dyn Error + Send + Sync>>>(
poll: Poll<Result<(), E>>,
) -> Poll<io::Result<()>> {
match poll {
Poll::Pending => Poll::Pending,
Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
Poll::Ready(Err(e)) => Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, e))),
}
}
impl<Item, SinkItem> Transport<TcpStream, Item, SinkItem> {
/// Returns the peer address of the underlying TcpStream.
pub fn peer_addr(&self) -> io::Result<SocketAddr> {
self.inner.get_ref().get_ref().get_ref().peer_addr()
}
/// Returns the local address of the underlying TcpStream.
pub fn local_addr(&self) -> io::Result<SocketAddr> {
self.inner.get_ref().get_ref().get_ref().local_addr()
}
}
/// Returns a new JSON transport that reads from and writes to `io`.
pub fn new<Item, SinkItem>(io: TcpStream) -> Transport<TcpStream, Item, SinkItem>
where
Item: for<'de> Deserialize<'de>,
SinkItem: Serialize,
{
Transport::from(io)
}
impl<S: AsyncWrite + AsyncRead, Item: serde::de::DeserializeOwned, SinkItem: Serialize> From<S>
for Transport<S, Item, SinkItem>
{
fn from(inner: S) -> Self {
Transport {
inner: FramedRead::new(
FramedWrite::new(
Framed::new(inner, LengthDelimitedCodec::new()),
Json::default(),
),
Json::default(),
),
}
}
}
/// Connects to `addr`, wrapping the connection in a JSON transport.
pub async fn connect<A, Item, SinkItem>(addr: A) -> io::Result<Transport<TcpStream, Item, SinkItem>>
where
A: ToSocketAddrs,
Item: for<'de> Deserialize<'de>,
SinkItem: Serialize,
{
Ok(new(TcpStream::connect(addr).await?))
}
/// Listens on `addr`, wrapping accepted connections in JSON transports.
pub async fn listen<A, Item, SinkItem>(addr: A) -> io::Result<Incoming<Item, SinkItem>>
where
A: ToSocketAddrs,
Item: for<'de> Deserialize<'de>,
SinkItem: Serialize,
{
let listener = TcpListener::bind(addr).await?;
let local_addr = listener.local_addr()?;
Ok(Incoming {
listener,
local_addr,
ghost: PhantomData,
})
}
/// A [`TcpListener`] that wraps connections in JSON transports.
#[pin_project]
#[derive(Debug)]
pub struct Incoming<Item, SinkItem> {
listener: TcpListener,
local_addr: SocketAddr,
ghost: PhantomData<(Item, SinkItem)>,
}
impl<Item, SinkItem> Incoming<Item, SinkItem> {
/// Returns the address being listened on.
pub fn local_addr(&self) -> SocketAddr {
self.local_addr
}
}
impl<Item, SinkItem> Stream for Incoming<Item, SinkItem>
where
Item: for<'a> Deserialize<'a>,
SinkItem: Serialize,
{
type Item = io::Result<Transport<TcpStream, Item, SinkItem>>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let next = ready!(Pin::new(&mut self.project().listener.incoming()).poll_next(cx)?);
Poll::Ready(next.map(|conn| Ok(new(conn))))
}
}
#[cfg(test)]
mod tests {
use super::Transport;
use assert_matches::assert_matches;
use futures::{task::*, Sink, Stream};
use pin_utils::pin_mut;
use std::{
io::{self, Cursor},
pin::Pin,
};
use tokio::io::{AsyncRead, AsyncWrite};
fn ctx() -> Context<'static> {
Context::from_waker(&noop_waker_ref())
}
#[test]
fn test_stream() {
struct TestIo(Cursor<&'static [u8]>);
impl AsyncRead for TestIo {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
AsyncRead::poll_read(Pin::new(self.0.get_mut()), cx, buf)
}
}
impl AsyncWrite for TestIo {
fn poll_write(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
_buf: &[u8],
) -> Poll<io::Result<usize>> {
unreachable!()
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
unreachable!()
}
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
unreachable!()
}
}
let data = b"\x00\x00\x00\x18\"Test one, check check.\"";
let transport = Transport::<_, String, String>::from(TestIo(Cursor::new(data)));
pin_mut!(transport);
assert_matches!(
transport.poll_next(&mut ctx()),
Poll::Ready(Some(Ok(ref s))) if s == "Test one, check check.");
}
#[test]
fn test_sink() {
struct TestIo<'a>(&'a mut Vec<u8>);
impl<'a> AsyncRead for TestIo<'a> {
fn poll_read(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
_buf: &mut [u8],
) -> Poll<io::Result<usize>> {
unreachable!()
}
}
impl<'a> AsyncWrite for TestIo<'a> {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
AsyncWrite::poll_write(Pin::new(&mut *self.0), cx, buf)
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
AsyncWrite::poll_flush(Pin::new(&mut *self.0), cx)
}
fn poll_shutdown(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<io::Result<()>> {
AsyncWrite::poll_shutdown(Pin::new(&mut *self.0), cx)
}
}
let mut writer = vec![];
let transport = Transport::<_, String, String>::from(TestIo(&mut writer));
pin_mut!(transport);
assert_matches!(
transport.as_mut().poll_ready(&mut ctx()),
Poll::Ready(Ok(()))
);
assert_matches!(
transport
.as_mut()
.start_send("Test one, check check.".into()),
Ok(())
);
assert_matches!(transport.poll_flush(&mut ctx()), Poll::Ready(Ok(())));
assert_eq!(writer, b"\x00\x00\x00\x18\"Test one, check check.\"");
}
}

View File

@@ -209,10 +209,8 @@
pub mod rpc;
pub use rpc::*;
#[cfg(feature = "bincode-transport")]
pub mod bincode_transport;
#[cfg(feature = "json-transport")]
pub mod json_transport;
#[cfg(feature = "serde-transport")]
pub mod serde_transport;
pub mod trace;

View File

@@ -0,0 +1,326 @@
// Copyright 2019 Google LLC
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
//! A generic Serde-based `Transport` that can serialize anything supported by `tokio-serde` via any medium that implements `AsyncRead` and `AsyncWrite`.
#![deny(missing_docs)]
use futures::{prelude::*, task::*};
use pin_project::pin_project;
use serde::{Deserialize, Serialize};
use std::{error::Error, io, pin::Pin};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_serde::{Framed as SerdeFramed, *};
use tokio_util::codec::{length_delimited::LengthDelimitedCodec, Framed};
/// A transport that serializes to, and deserializes from, a [`TcpStream`].
#[pin_project]
pub struct Transport<S, Item, SinkItem, Codec> {
#[pin]
inner: SerdeFramed<Framed<S, LengthDelimitedCodec>, Item, SinkItem, Codec>,
}
impl<S, Item, SinkItem, Codec, CodecError> Stream for Transport<S, Item, SinkItem, Codec>
where
S: AsyncWrite + AsyncRead,
Item: for<'a> Deserialize<'a>,
Codec: Deserializer<Item>,
CodecError: Into<Box<dyn std::error::Error + Send + Sync>>,
SerdeFramed<Framed<S, LengthDelimitedCodec>, Item, SinkItem, Codec>:
Stream<Item = Result<Item, CodecError>>,
{
type Item = io::Result<Item>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<io::Result<Item>>> {
match self.project().inner.poll_next(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(None) => Poll::Ready(None),
Poll::Ready(Some(Ok::<_, CodecError>(next))) => Poll::Ready(Some(Ok(next))),
Poll::Ready(Some(Err::<_, CodecError>(e))) => {
Poll::Ready(Some(Err(io::Error::new(io::ErrorKind::Other, e))))
}
}
}
}
impl<S, Item, SinkItem, Codec, CodecError> Sink<SinkItem> for Transport<S, Item, SinkItem, Codec>
where
S: AsyncWrite,
SinkItem: Serialize,
Codec: Serializer<SinkItem>,
CodecError: Into<Box<dyn Error + Send + Sync>>,
SerdeFramed<Framed<S, LengthDelimitedCodec>, Item, SinkItem, Codec>:
Sink<SinkItem, Error = CodecError>,
{
type Error = io::Error;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
convert(self.project().inner.poll_ready(cx))
}
fn start_send(self: Pin<&mut Self>, item: SinkItem) -> io::Result<()> {
self.project()
.inner
.start_send(item)
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
convert(self.project().inner.poll_flush(cx))
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
convert(self.project().inner.poll_close(cx))
}
}
fn convert<E: Into<Box<dyn Error + Send + Sync>>>(
poll: Poll<Result<(), E>>,
) -> Poll<io::Result<()>> {
poll.map(|ready| ready.map_err(|e| io::Error::new(io::ErrorKind::Other, e)))
}
impl<S, Item, SinkItem, Codec> From<(S, Codec)> for Transport<S, Item, SinkItem, Codec>
where
S: AsyncWrite + AsyncRead,
Item: for<'de> Deserialize<'de>,
SinkItem: Serialize,
Codec: Serializer<SinkItem> + Deserializer<Item>,
{
fn from((inner, codec): (S, Codec)) -> Self {
Transport {
inner: SerdeFramed::new(Framed::new(inner, LengthDelimitedCodec::new()), codec),
}
}
}
#[cfg(feature = "tcp")]
#[cfg_attr(docsrs, doc(cfg(feature = "tcp")))]
/// TCP support for generic transport using Tokio.
pub mod tcp {
use {
super::*,
futures::ready,
std::{marker::PhantomData, net::SocketAddr},
tokio::net::{TcpListener, TcpStream, ToSocketAddrs},
};
mod private {
use super::*;
pub trait Sealed {}
impl<Item, SinkItem, Codec> Sealed for Transport<TcpStream, Item, SinkItem, Codec> {}
}
impl<Item, SinkItem, Codec> Transport<TcpStream, Item, SinkItem, Codec> {
/// Returns the peer address of the underlying TcpStream.
pub fn peer_addr(&self) -> io::Result<SocketAddr> {
self.inner.get_ref().get_ref().peer_addr()
}
/// Returns the local address of the underlying TcpStream.
pub fn local_addr(&self) -> io::Result<SocketAddr> {
self.inner.get_ref().get_ref().local_addr()
}
}
/// Returns a new JSON transport that reads from and writes to `io`.
pub fn new<Item, SinkItem, Codec>(
io: TcpStream,
codec: Codec,
) -> Transport<TcpStream, Item, SinkItem, Codec>
where
Item: for<'de> Deserialize<'de>,
SinkItem: Serialize,
Codec: Serializer<SinkItem> + Deserializer<Item>,
{
Transport::from((io, codec))
}
/// Connects to `addr`, wrapping the connection in a JSON transport.
pub async fn connect<A, Item, SinkItem, Codec>(
addr: A,
codec: Codec,
) -> io::Result<Transport<TcpStream, Item, SinkItem, Codec>>
where
A: ToSocketAddrs,
Item: for<'de> Deserialize<'de>,
SinkItem: Serialize,
Codec: Serializer<SinkItem> + Deserializer<Item>,
{
Ok(new(TcpStream::connect(addr).await?, codec))
}
/// Listens on `addr`, wrapping accepted connections in JSON transports.
pub async fn listen<A, Item, SinkItem, Codec, CodecFn>(
addr: A,
codec_fn: CodecFn,
) -> io::Result<Incoming<Item, SinkItem, Codec, CodecFn>>
where
A: ToSocketAddrs,
Item: for<'de> Deserialize<'de>,
Codec: Serializer<SinkItem> + Deserializer<Item>,
CodecFn: Fn() -> Codec,
{
let listener = TcpListener::bind(addr).await?;
let local_addr = listener.local_addr()?;
Ok(Incoming {
listener,
codec_fn,
local_addr,
ghost: PhantomData,
})
}
/// A [`TcpListener`] that wraps connections in JSON transports.
#[pin_project]
#[derive(Debug)]
pub struct Incoming<Item, SinkItem, Codec, CodecFn> {
listener: TcpListener,
local_addr: SocketAddr,
codec_fn: CodecFn,
ghost: PhantomData<(Item, SinkItem, Codec)>,
}
impl<Item, SinkItem, Codec, CodecFn> Incoming<Item, SinkItem, Codec, CodecFn> {
/// Returns the address being listened on.
pub fn local_addr(&self) -> SocketAddr {
self.local_addr
}
}
impl<Item, SinkItem, Codec, CodecFn> Stream for Incoming<Item, SinkItem, Codec, CodecFn>
where
Item: for<'de> Deserialize<'de>,
SinkItem: Serialize,
Codec: Serializer<SinkItem> + Deserializer<Item>,
CodecFn: Fn() -> Codec,
{
type Item = io::Result<Transport<TcpStream, Item, SinkItem, Codec>>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let next =
ready!(Pin::new(&mut self.as_mut().project().listener.incoming()).poll_next(cx)?);
Poll::Ready(next.map(|conn| Ok(new(conn, (self.codec_fn)()))))
}
}
}
#[cfg(test)]
mod tests {
use super::Transport;
use assert_matches::assert_matches;
use futures::{task::*, Sink, Stream};
use pin_utils::pin_mut;
use std::{
io::{self, Cursor},
pin::Pin,
};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_serde::formats::SymmetricalJson;
fn ctx() -> Context<'static> {
Context::from_waker(&noop_waker_ref())
}
#[test]
fn test_stream() {
struct TestIo(Cursor<&'static [u8]>);
impl AsyncRead for TestIo {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
AsyncRead::poll_read(Pin::new(self.0.get_mut()), cx, buf)
}
}
impl AsyncWrite for TestIo {
fn poll_write(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
_buf: &[u8],
) -> Poll<io::Result<usize>> {
unreachable!()
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
unreachable!()
}
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
unreachable!()
}
}
let data = b"\x00\x00\x00\x18\"Test one, check check.\"";
let transport = Transport::from((
TestIo(Cursor::new(data)),
SymmetricalJson::<String>::default(),
));
pin_mut!(transport);
assert_matches!(
transport.poll_next(&mut ctx()),
Poll::Ready(Some(Ok(ref s))) if s == "Test one, check check.");
}
#[test]
fn test_sink() {
struct TestIo<'a>(&'a mut Vec<u8>);
impl<'a> AsyncRead for TestIo<'a> {
fn poll_read(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
_buf: &mut [u8],
) -> Poll<io::Result<usize>> {
unreachable!()
}
}
impl<'a> AsyncWrite for TestIo<'a> {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
AsyncWrite::poll_write(Pin::new(&mut *self.0), cx, buf)
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
AsyncWrite::poll_flush(Pin::new(&mut *self.0), cx)
}
fn poll_shutdown(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<io::Result<()>> {
AsyncWrite::poll_shutdown(Pin::new(&mut *self.0), cx)
}
}
let mut writer = vec![];
let transport =
Transport::from((TestIo(&mut writer), SymmetricalJson::<String>::default()));
pin_mut!(transport);
assert_matches!(
transport.as_mut().poll_ready(&mut ctx()),
Poll::Ready(Ok(()))
);
assert_matches!(
transport
.as_mut()
.start_send("Test one, check check.".into()),
Ok(())
);
assert_matches!(transport.poll_flush(&mut ctx()), Poll::Ready(Ok(())));
assert_eq!(writer, b"\x00\x00\x00\x18\"Test one, check check.\"");
}
}

View File

@@ -6,10 +6,11 @@ use futures::{
use std::io;
use tarpc::{
client::{self},
context, json_transport,
context, serde_transport,
server::{self, BaseChannel, Channel, Handler},
transport::channel,
};
use tokio_serde::formats::Json;
#[tarpc_plugins::service]
trait Service {
@@ -61,7 +62,7 @@ async fn sequential() -> io::Result<()> {
async fn serde() -> io::Result<()> {
let _ = env_logger::try_init();
let transport = json_transport::listen("localhost:56789").await?;
let transport = serde_transport::tcp::listen("localhost:56789", Json::default).await?;
let addr = transport.local_addr();
tokio::spawn(
tarpc::Server::default()
@@ -69,7 +70,7 @@ async fn serde() -> io::Result<()> {
.respond_with(Server.serve()),
);
let transport = json_transport::connect(addr).await?;
let transport = serde_transport::tcp::connect(addr, Json::default()).await?;
let mut client = ServiceClient::new(client::Config::default(), transport).spawn()?;
assert_matches!(client.add(context::current(), 1, 2).await, Ok(3));