diff --git a/tarpc/src/serde_transport.rs b/tarpc/src/serde_transport.rs index 3aac5d5..94050d6 100644 --- a/tarpc/src/serde_transport.rs +++ b/tarpc/src/serde_transport.rs @@ -294,89 +294,66 @@ mod tests { Context::from_waker(&noop_waker_ref()) } + struct TestIo(Cursor>); + + impl AsyncRead for TestIo { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + AsyncRead::poll_read(Pin::new(&mut self.0), cx, buf) + } + } + + impl AsyncWrite for TestIo { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + AsyncWrite::poll_write(Pin::new(&mut self.0), cx, buf) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + AsyncWrite::poll_flush(Pin::new(&mut self.0), cx) + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + AsyncWrite::poll_shutdown(Pin::new(&mut self.0), cx) + } + } + + #[test] + fn close() { + let (tx, _rx) = crate::transport::channel::bounded::<(), ()>(0); + pin_mut!(tx); + assert_matches!(tx.as_mut().poll_close(&mut ctx()), Poll::Ready(Ok(()))); + assert_matches!(tx.as_mut().start_send(()), Err(_)); + } + #[test] fn test_stream() { - struct TestIo(Cursor<&'static [u8]>); - - impl AsyncRead for TestIo { - fn poll_read( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut ReadBuf<'_>, - ) -> Poll> { - AsyncRead::poll_read(Pin::new(self.0.get_mut()), cx, buf) - } - } - - impl AsyncWrite for TestIo { - fn poll_write( - self: Pin<&mut Self>, - _cx: &mut Context<'_>, - _buf: &[u8], - ) -> Poll> { - unreachable!() - } - - fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { - unreachable!() - } - - fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { - unreachable!() - } - } - - let data = b"\x00\x00\x00\x18\"Test one, check check.\""; + let data: &[u8] = b"\x00\x00\x00\x18\"Test one, check check.\""; let transport = Transport::from(( - TestIo(Cursor::new(data)), + TestIo(Cursor::new(Vec::from(data))), SymmetricalJson::::default(), )); pin_mut!(transport); assert_matches!( - transport.poll_next(&mut ctx()), + transport.as_mut().poll_next(&mut ctx()), Poll::Ready(Some(Ok(ref s))) if s == "Test one, check check."); + assert_matches!(transport.as_mut().poll_next(&mut ctx()), Poll::Ready(None)); } #[test] fn test_sink() { - struct TestIo<'a>(&'a mut Vec); - - impl<'a> AsyncRead for TestIo<'a> { - fn poll_read( - self: Pin<&mut Self>, - _cx: &mut Context<'_>, - _buf: &mut ReadBuf<'_>, - ) -> Poll> { - unreachable!() - } - } - - impl<'a> AsyncWrite for TestIo<'a> { - fn poll_write( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - AsyncWrite::poll_write(Pin::new(&mut *self.0), cx, buf) - } - - fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - AsyncWrite::poll_flush(Pin::new(&mut *self.0), cx) - } - - fn poll_shutdown( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { - AsyncWrite::poll_shutdown(Pin::new(&mut *self.0), cx) - } - } - - let mut writer = vec![]; - let transport = - Transport::from((TestIo(&mut writer), SymmetricalJson::::default())); - pin_mut!(transport); + let writer = Cursor::new(vec![]); + let mut transport = Box::pin(Transport::from(( + TestIo(writer), + SymmetricalJson::::default(), + ))); assert_matches!( transport.as_mut().poll_ready(&mut ctx()), @@ -388,7 +365,32 @@ mod tests { .start_send("Test one, check check.".into()), Ok(()) ); - assert_matches!(transport.poll_flush(&mut ctx()), Poll::Ready(Ok(()))); - assert_eq!(writer, b"\x00\x00\x00\x18\"Test one, check check.\""); + assert_matches!( + transport.as_mut().poll_flush(&mut ctx()), + Poll::Ready(Ok(())) + ); + assert_eq!( + transport.get_ref().0.get_ref(), + b"\x00\x00\x00\x18\"Test one, check check.\"" + ); + } + + #[cfg(tcp)] + #[tokio::test] + async fn tcp() -> io::Result<()> { + use super::tcp; + + let mut listener = tcp::listen("0.0.0.0:0", SymmetricalJson::::default).await?; + let addr = listener.local_addr(); + tokio::spawn(async move { + let mut transport = listener.next().await.unwrap().unwrap(); + let message = transport.next().await.unwrap().unwrap(); + transport.send(message).await.unwrap(); + }); + let mut transport = tcp::connect(addr, SymmetricalJson::::default).await?; + transport.send(String::from("test")).await?; + assert_matches!(transport.next().await, Some(Ok(s)) if s == "test"); + assert_matches!(transport.next().await, None); + Ok(()) } }