diff --git a/example-service/src/server.rs b/example-service/src/server.rs index 6b17887..a6b687a 100644 --- a/example-service/src/server.rs +++ b/example-service/src/server.rs @@ -10,7 +10,10 @@ use futures::{ prelude::*, }; use service::World; -use std::{io, net::SocketAddr}; +use std::{ + io, + net::{IpAddr, SocketAddr}, +}; use tarpc::{ context, server::{self, Channel, Handler}, @@ -59,11 +62,12 @@ async fn main() -> io::Result<()> { .parse() .unwrap_or_else(|e| panic!(r#"--port value "{}" invalid: {}"#, port, e)); - let server_addr = ([0, 0, 0, 0], port).into(); + let server_addr = (IpAddr::from([0, 0, 0, 0]), port); // tarpc_json_transport is provided by the associated crate tarpc-json-transport. It makes it easy // to start up a serde-powered json serialization strategy over TCP. - tarpc_json_transport::listen(&server_addr)? + tarpc_json_transport::listen(&server_addr) + .await? // Ignore accept errors. .filter_map(|r| future::ready(r.ok())) .map(server::BaseChannel::with_defaults) diff --git a/json-transport/Cargo.toml b/json-transport/Cargo.toml index 3f95c35..127dbc2 100644 --- a/json-transport/Cargo.toml +++ b/json-transport/Cargo.toml @@ -13,16 +13,14 @@ readme = "../README.md" description = "A JSON-based transport for tarpc services." [dependencies] -futures-preview = { version = "0.3.0-alpha.18", features = ["compat"] } -futures_legacy = { version = "0.1", package = "futures" } -pin-utils = "0.1.0-alpha.4" -serde = "1.0" -serde_json = "1.0" -tokio = { version = "0.1", default-features = false, features = ["codec"] } -tokio-io = "0.1" -tokio-serde-json = "0.2" -tokio-tcp = "0.1" +futures-preview = "0.3.0-alpha" +pin-project = "0.4" +serde = "1" +serde_json = "1" +tokio = { version = "0.2.0-alpha", default-features = false, features = ["codec", "io", "net"] } +tokio-net = "0.2.0-alpha" +tokio-serde-json = "0.3" [dev-dependencies] -futures-test-preview = { version = "0.3.0-alpha.18" } +pin-utils = "0.1.0-alpha" assert_matches = "1.0" diff --git a/json-transport/src/lib.rs b/json-transport/src/lib.rs index 4c7c3b6..5be7ffe 100644 --- a/json-transport/src/lib.rs +++ b/json-transport/src/lib.rs @@ -8,8 +8,8 @@ #![deny(missing_docs)] -use futures::{compat::*, prelude::*, ready}; -use pin_utils::unsafe_pinned; +use futures::{prelude::*, ready}; +use pin_project::pin_project; use serde::{Deserialize, Serialize}; use std::{ error::Error, @@ -20,37 +20,27 @@ use std::{ task::{Context, Poll}, }; use tokio::codec::{length_delimited::LengthDelimitedCodec, Framed}; -use tokio_io::{AsyncRead, AsyncWrite}; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio::net::{TcpListener, TcpStream}; +use tokio_net::ToSocketAddrs; use tokio_serde_json::*; -use tokio_tcp::{TcpListener, TcpStream}; /// A transport that serializes to, and deserializes from, a [`TcpStream`]. -pub struct Transport { - inner: Compat01As03Sink< - ReadJson, SinkItem>, Item>, - SinkItem, - >, -} - -impl Transport { - unsafe_pinned!( - inner: - Compat01As03Sink< - ReadJson, SinkItem>, Item>, - SinkItem, - > - ); +#[pin_project] +pub struct Transport { + #[pin] + inner: ReadJson, SinkItem>, Item>, } impl Stream for Transport where - S: AsyncWrite + AsyncRead, + S: AsyncWrite + AsyncRead + Unpin, Item: for<'a> Deserialize<'a>, { type Item = io::Result; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll>> { - match self.inner().poll_next(cx) { + match self.project().inner.poll_next(cx) { Poll::Pending => Poll::Pending, Poll::Ready(None) => Poll::Ready(None), Poll::Ready(Some(Ok(next))) => Poll::Ready(Some(Ok(next))), @@ -63,27 +53,28 @@ where impl Sink for Transport where - S: AsyncWrite, + S: AsyncWrite + Unpin, SinkItem: Serialize, { type Error = io::Error; + fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + convert(self.project().inner.poll_ready(cx)) + } + fn start_send(self: Pin<&mut Self>, item: SinkItem) -> io::Result<()> { - self.inner() + self.project() + .inner .start_send(item) .map_err(|e| io::Error::new(io::ErrorKind::Other, e)) } - fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - convert(self.inner().poll_ready(cx)) - } - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - convert(self.inner().poll_flush(cx)) + convert(self.project().inner.poll_flush(cx)) } fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - convert(self.inner().poll_close(cx)) + convert(self.project().inner.poll_close(cx)) } } @@ -100,22 +91,12 @@ fn convert>>( impl Transport { /// Returns the peer address of the underlying TcpStream. pub fn peer_addr(&self) -> io::Result { - self.inner - .get_ref() - .get_ref() - .get_ref() - .get_ref() - .peer_addr() + self.inner.get_ref().get_ref().get_ref().peer_addr() } /// Returns the local address of the underlying TcpStream. pub fn local_addr(&self) -> io::Result { - self.inner - .get_ref() - .get_ref() - .get_ref() - .get_ref() - .local_addr() + self.inner.get_ref().get_ref().get_ref().local_addr() } } @@ -133,10 +114,10 @@ impl Self { Transport { - inner: Compat01As03Sink::new(ReadJson::new(WriteJson::new(Framed::new( + inner: ReadJson::new(WriteJson::new(Framed::new( inner, LengthDelimitedCodec::new(), - )))), + ))), } } } @@ -149,18 +130,19 @@ where Item: for<'de> Deserialize<'de>, SinkItem: Serialize, { - Ok(new(TcpStream::connect(addr).compat().await?)) + Ok(new(TcpStream::connect(addr).await?)) } /// Listens on `addr`, wrapping accepted connections in JSON transports. -pub fn listen(addr: &SocketAddr) -> io::Result> +pub async fn listen(addr: A) -> io::Result> where + A: ToSocketAddrs, Item: for<'de> Deserialize<'de>, SinkItem: Serialize, { - let listener = TcpListener::bind(addr)?; + let listener = TcpListener::bind(addr).await?; let local_addr = listener.local_addr()?; - let incoming = listener.incoming().compat(); + let incoming = Box::pin(listener.incoming()); Ok(Incoming { incoming, local_addr, @@ -168,17 +150,20 @@ where }) } +trait IncomingTrait: Stream> + std::fmt::Debug + Send {} +impl> + std::fmt::Debug + Send> IncomingTrait for T {} + /// A [`TcpListener`] that wraps connections in JSON transports. +#[pin_project] #[derive(Debug)] pub struct Incoming { - incoming: Compat01As03, + #[pin] + incoming: Pin>, local_addr: SocketAddr, ghost: PhantomData<(Item, SinkItem)>, } impl Incoming { - unsafe_pinned!(incoming: Compat01As03); - /// Returns the address being listened on. pub fn local_addr(&self) -> SocketAddr { self.local_addr @@ -193,7 +178,7 @@ where type Item = io::Result>; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let next = ready!(self.incoming().poll_next(cx)?); + let next = ready!(self.project().incoming.poll_next(cx)?); Poll::Ready(next.map(|conn| Ok(new(conn)))) } } @@ -202,13 +187,15 @@ where mod tests { use super::Transport; use assert_matches::assert_matches; + use futures::task::noop_waker_ref; use futures::{Sink, Stream}; - use futures_test::task::noop_waker_ref; use pin_utils::pin_mut; use std::{ - io::Cursor, + io::{self, Cursor}, + pin::Pin, task::{Context, Poll}, }; + use tokio::io::{AsyncRead, AsyncWrite}; fn ctx() -> Context<'static> { Context::from_waker(&noop_waker_ref()) @@ -216,9 +203,38 @@ mod tests { #[test] fn test_stream() { - let reader = *b"\x00\x00\x00\x18\"Test one, check check.\""; - let reader: Box<[u8]> = Box::new(reader); - let transport = Transport::<_, String, String>::from(Cursor::new(reader)); + struct TestIo(Cursor<&'static [u8]>); + + impl AsyncRead for TestIo { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + AsyncRead::poll_read(Pin::new(self.0.get_mut()), cx, buf) + } + } + + impl AsyncWrite for TestIo { + fn poll_write( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + _buf: &[u8], + ) -> Poll> { + unreachable!() + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + unreachable!() + } + + fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + unreachable!() + } + } + + let data = b"\x00\x00\x00\x18\"Test one, check check.\""; + let transport = Transport::<_, String, String>::from(TestIo(Cursor::new(data))); pin_mut!(transport); assert_matches!( @@ -228,8 +244,41 @@ mod tests { #[test] fn test_sink() { - let writer: &mut [u8] = &mut [0; 28]; - let transport = Transport::<_, String, String>::from(Cursor::new(&mut *writer)); + struct TestIo<'a>(&'a mut Vec); + + impl<'a> AsyncRead for TestIo<'a> { + fn poll_read( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + _buf: &mut [u8], + ) -> Poll> { + unreachable!() + } + } + + impl<'a> AsyncWrite for TestIo<'a> { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + AsyncWrite::poll_write(Pin::new(&mut *self.0), cx, buf) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + AsyncWrite::poll_flush(Pin::new(&mut *self.0), cx) + } + + fn poll_shutdown( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + AsyncWrite::poll_shutdown(Pin::new(&mut *self.0), cx) + } + } + + let mut writer = vec![]; + let transport = Transport::<_, String, String>::from(TestIo(&mut writer)); pin_mut!(transport); assert_matches!(