// Copyright 2019 Google LLC // // Use of this source code is governed by an MIT-style // license that can be found in the LICENSE file or at // https://opensource.org/licenses/MIT. //! A TCP [`Transport`] that serializes as JSON. #![feature(arbitrary_self_types, async_await)] #![deny(missing_docs)] use futures::{compat::*, prelude::*, ready}; use pin_utils::unsafe_pinned; use serde::{Deserialize, Serialize}; use std::{ error::Error, io, marker::PhantomData, net::SocketAddr, pin::Pin, task::{Context, Poll}, }; use tokio::codec::{length_delimited::LengthDelimitedCodec, Framed}; use tokio_io::{AsyncRead, AsyncWrite}; use tokio_serde_json::*; use tokio_tcp::{TcpListener, TcpStream}; /// A transport that serializes to, and deserializes from, a [`TcpStream`]. pub struct Transport { inner: Compat01As03Sink< ReadJson, SinkItem>, Item>, SinkItem, >, } impl Transport { unsafe_pinned!( inner: Compat01As03Sink< ReadJson, SinkItem>, Item>, SinkItem, > ); } impl Stream for Transport where S: AsyncWrite + AsyncRead, Item: for<'a> Deserialize<'a>, { type Item = io::Result; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll>> { match self.inner().poll_next(cx) { Poll::Pending => Poll::Pending, Poll::Ready(None) => Poll::Ready(None), Poll::Ready(Some(Ok(next))) => Poll::Ready(Some(Ok(next))), Poll::Ready(Some(Err(e))) => { Poll::Ready(Some(Err(io::Error::new(io::ErrorKind::Other, e)))) } } } } impl Sink for Transport where S: AsyncWrite, SinkItem: Serialize, { type Error = io::Error; fn start_send(self: Pin<&mut Self>, item: SinkItem) -> io::Result<()> { self.inner() .start_send(item) .map_err(|e| io::Error::new(io::ErrorKind::Other, e)) } fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { convert(self.inner().poll_ready(cx)) } fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { convert(self.inner().poll_flush(cx)) } fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { convert(self.inner().poll_close(cx)) } } fn convert>>( poll: Poll>, ) -> Poll> { match poll { Poll::Pending => Poll::Pending, Poll::Ready(Ok(())) => Poll::Ready(Ok(())), Poll::Ready(Err(e)) => Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, e))), } } impl Transport { /// Returns the peer address of the underlying TcpStream. pub fn peer_addr(&self) -> io::Result { self.inner .get_ref() .get_ref() .get_ref() .get_ref() .peer_addr() } /// Returns the local address of the underlying TcpStream. pub fn local_addr(&self) -> io::Result { self.inner .get_ref() .get_ref() .get_ref() .get_ref() .local_addr() } } /// Returns a new JSON transport that reads from and writes to `io`. pub fn new(io: TcpStream) -> Transport where Item: for<'de> Deserialize<'de>, SinkItem: Serialize, { Transport::from(io) } impl From for Transport { fn from(inner: S) -> Self { Transport { inner: Compat01As03Sink::new(ReadJson::new(WriteJson::new(Framed::new( inner, LengthDelimitedCodec::new(), )))), } } } /// Connects to `addr`, wrapping the connection in a JSON transport. pub async fn connect( addr: &SocketAddr, ) -> io::Result> where Item: for<'de> Deserialize<'de>, SinkItem: Serialize, { Ok(new(TcpStream::connect(addr).compat().await?)) } /// Listens on `addr`, wrapping accepted connections in JSON transports. pub fn listen(addr: &SocketAddr) -> io::Result> where Item: for<'de> Deserialize<'de>, SinkItem: Serialize, { let listener = TcpListener::bind(addr)?; let local_addr = listener.local_addr()?; let incoming = listener.incoming().compat(); Ok(Incoming { incoming, local_addr, ghost: PhantomData, }) } /// A [`TcpListener`] that wraps connections in JSON transports. #[derive(Debug)] pub struct Incoming { incoming: Compat01As03, local_addr: SocketAddr, ghost: PhantomData<(Item, SinkItem)>, } impl Incoming { unsafe_pinned!(incoming: Compat01As03); /// Returns the address being listened on. pub fn local_addr(&self) -> SocketAddr { self.local_addr } } impl Stream for Incoming where Item: for<'a> Deserialize<'a>, SinkItem: Serialize, { type Item = io::Result>; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let next = ready!(self.incoming().poll_next(cx)?); Poll::Ready(next.map(|conn| Ok(new(conn)))) } } #[cfg(test)] mod tests { use assert_matches::assert_matches; use futures::{Sink, Stream}; use futures_test::task::noop_waker_ref; use pin_utils::pin_mut; use std::{io::Cursor, task::{Context, Poll}}; use super::Transport; fn ctx() -> Context<'static> { Context::from_waker(&noop_waker_ref()) } #[test] fn test_stream() { let reader = *b"\x00\x00\x00\x18\"Test one, check check.\""; let reader: Box<[u8]> = Box::new(reader); let transport = Transport::<_, String, String>::from(Cursor::new(reader)); pin_mut!(transport); assert_matches!( transport.poll_next(&mut ctx()), Poll::Ready(Some(Ok(ref s))) if s == "Test one, check check."); } #[test] fn test_sink() { let writer: &mut [u8] = &mut [0; 28]; let transport = Transport::<_, String, String>::from(Cursor::new(&mut *writer)); pin_mut!(transport); assert_matches!(transport.as_mut().poll_ready(&mut ctx()), Poll::Ready(Ok(()))); assert_matches!(transport.as_mut().start_send("Test one, check check.".into()), Ok(())); assert_matches!(transport.poll_flush(&mut ctx()), Poll::Ready(Ok(()))); assert_eq!(writer, b"\x00\x00\x00\x18\"Test one, check check.\""); } }