diff --git a/tarpc/src/serde_transport.rs b/tarpc/src/serde_transport.rs index 51ad9fb..a8eedd5 100644 --- a/tarpc/src/serde_transport.rs +++ b/tarpc/src/serde_transport.rs @@ -210,7 +210,19 @@ pub mod tcp { Codec: Serializer + Deserializer, CodecFn: Fn() -> Codec, { - let listener = TcpListener::bind(addr).await?; + listen_on(TcpListener::bind(addr).await?, codec_fn).await + } + + /// Wrap accepted connections from `listener` in TCP transports. + pub async fn listen_on( + listener: TcpListener, + codec_fn: CodecFn, + ) -> io::Result> + where + Item: for<'de> Deserialize<'de>, + Codec: Serializer + Deserializer, + CodecFn: Fn() -> Codec, + { let local_addr = listener.local_addr()?; Ok(Incoming { listener, @@ -662,6 +674,26 @@ mod tests { Ok(()) } + #[cfg(tcp)] + #[tokio::test] + async fn tcp_on_existing_transport() -> io::Result<()> { + use super::tcp; + + let transport = TcpListener::bind("0.0.0.0:0").await?; + let mut listener = tcp::listen_on(transport, SymmetricalJson::::default).await?; + let addr = listener.local_addr(); + tokio::spawn(async move { + let mut transport = listener.next().await.unwrap().unwrap(); + let message = transport.next().await.unwrap().unwrap(); + transport.send(message).await.unwrap(); + }); + let mut transport = tcp::connect(addr, SymmetricalJson::::default).await?; + transport.send(String::from("test")).await?; + assert_matches!(transport.next().await, Some(Ok(s)) if s == "test"); + assert_matches!(transport.next().await, None); + Ok(()) + } + #[cfg(all(unix, feature = "unix"))] #[tokio::test] async fn uds() -> io::Result<()> {