diff --git a/tarpc/Cargo.toml b/tarpc/Cargo.toml index 9a8af27..b2f037b 100644 --- a/tarpc/Cargo.toml +++ b/tarpc/Cargo.toml @@ -25,6 +25,7 @@ serde-transport = ["serde1", "tokio1", "tokio-serde", "tokio-util/codec"] serde-transport-json = ["tokio-serde/json"] serde-transport-bincode = ["tokio-serde/bincode"] tcp = ["tokio/net"] +unix = ["tokio/net"] full = [ "serde1", @@ -33,6 +34,7 @@ full = [ "serde-transport-json", "serde-transport-bincode", "tcp", + "unix", ] [badges] diff --git a/tarpc/src/serde_transport.rs b/tarpc/src/serde_transport.rs index 07d860a..55eb8d9 100644 --- a/tarpc/src/serde_transport.rs +++ b/tarpc/src/serde_transport.rs @@ -277,6 +277,270 @@ pub mod tcp { } } +#[cfg(all(unix, feature = "unix"))] +#[cfg_attr(docsrs, doc(cfg(all(unix, feature = "unix"))))] +/// Unix Domain Socket support for generic transport using Tokio. +pub mod unix { + use { + super::*, + futures::ready, + std::{marker::PhantomData, path::Path}, + tokio::net::{unix::SocketAddr, UnixListener, UnixStream}, + tokio_util::codec::length_delimited, + }; + + impl Transport { + /// Returns the socket address of the remote half of the underlying [`UnixStream`]. + pub fn peer_addr(&self) -> io::Result { + self.inner.get_ref().get_ref().peer_addr() + } + /// Returns the socket address of the local half of the underlying [`UnixStream`]. + pub fn local_addr(&self) -> io::Result { + self.inner.get_ref().get_ref().local_addr() + } + } + + /// A connection Future that also exposes the length-delimited framing config. + #[must_use] + #[pin_project] + pub struct Connect { + #[pin] + inner: T, + codec_fn: CodecFn, + config: length_delimited::Builder, + ghost: PhantomData<(fn(SinkItem), fn() -> Item)>, + } + + impl Future for Connect + where + T: Future>, + Item: for<'de> Deserialize<'de>, + SinkItem: Serialize, + Codec: Serializer + Deserializer, + CodecFn: Fn() -> Codec, + { + type Output = io::Result>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { + let io = ready!(self.as_mut().project().inner.poll(cx))?; + Poll::Ready(Ok(new(self.config.new_framed(io), (self.codec_fn)()))) + } + } + + impl Connect { + /// Returns an immutable reference to the length-delimited codec's config. + pub fn config(&self) -> &length_delimited::Builder { + &self.config + } + + /// Returns a mutable reference to the length-delimited codec's config. + pub fn config_mut(&mut self) -> &mut length_delimited::Builder { + &mut self.config + } + } + + /// Connects to socket named by `path`, wrapping the connection in a Unix Domain Socket + /// transport. + pub fn connect( + path: P, + codec_fn: CodecFn, + ) -> Connect>, Item, SinkItem, CodecFn> + where + P: AsRef, + Item: for<'de> Deserialize<'de>, + SinkItem: Serialize, + Codec: Serializer + Deserializer, + CodecFn: Fn() -> Codec, + { + Connect { + inner: UnixStream::connect(path), + codec_fn, + config: LengthDelimitedCodec::builder(), + ghost: PhantomData, + } + } + + /// Listens on the socket named by `path`, wrapping accepted connections in Unix Domain Socket + /// transports. + pub async fn listen( + path: P, + codec_fn: CodecFn, + ) -> io::Result> + where + P: AsRef, + Item: for<'de> Deserialize<'de>, + Codec: Serializer + Deserializer, + CodecFn: Fn() -> Codec, + { + let listener = UnixListener::bind(path)?; + let local_addr = listener.local_addr()?; + Ok(Incoming { + listener, + codec_fn, + local_addr, + config: LengthDelimitedCodec::builder(), + ghost: PhantomData, + }) + } + + /// A [`UnixListener`] that wraps connections in [transports](Transport). + #[pin_project] + #[derive(Debug)] + pub struct Incoming { + listener: UnixListener, + local_addr: SocketAddr, + codec_fn: CodecFn, + config: length_delimited::Builder, + ghost: PhantomData<(fn() -> Item, fn(SinkItem), Codec)>, + } + + impl Incoming { + /// Returns the the socket address being listened on. + pub fn local_addr(&self) -> &SocketAddr { + &self.local_addr + } + + /// Returns an immutable reference to the length-delimited codec's config. + pub fn config(&self) -> &length_delimited::Builder { + &self.config + } + + /// Returns a mutable reference to the length-delimited codec's config. + pub fn config_mut(&mut self) -> &mut length_delimited::Builder { + &mut self.config + } + } + + impl Stream for Incoming + where + Item: for<'de> Deserialize<'de>, + SinkItem: Serialize, + Codec: Serializer + Deserializer, + CodecFn: Fn() -> Codec, + { + type Item = io::Result>; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let conn: UnixStream = ready!(self.as_mut().project().listener.poll_accept(cx)?).0; + Poll::Ready(Some(Ok(new( + self.config.new_framed(conn), + (self.codec_fn)(), + )))) + } + } + + /// A temporary `PathBuf` that lives in `std::env::temp_dir` and is removed on drop. + pub struct TempPathBuf(std::path::PathBuf); + + impl TempPathBuf { + /// A named socket that results in `/` + pub fn new>(name: S) -> Self { + let mut sock = std::env::temp_dir(); + sock.push(name.as_ref()); + Self(sock) + } + + /// Appends a random hex string to the socket name resulting in + /// `/_` + pub fn with_random>(name: S) -> Self { + Self::new(format!("{}_{:x}", name.as_ref(), rand::random::())) + } + } + + impl AsRef for TempPathBuf { + fn as_ref(&self) -> &std::path::Path { + self.0.as_path() + } + } + + impl Drop for TempPathBuf { + fn drop(&mut self) { + // This will remove the file pointed to by this PathBuf if it exists, however Err's can + // be returned such as attempting to remove a non-existing file, or one which we don't + // have permission to remove. In these cases the Err is swallowed + let _ = std::fs::remove_file(&self.0); + } + } + + #[cfg(test)] + mod tests { + use super::*; + use tokio_serde::formats::SymmetricalJson; + + #[test] + fn temp_path_buf_non_random() { + let sock = TempPathBuf::new("test"); + let mut good = std::env::temp_dir(); + good.push("test"); + assert_eq!(sock.as_ref(), good); + assert_eq!(sock.as_ref().file_name().unwrap(), "test"); + } + + #[test] + fn temp_path_buf_random() { + let sock = TempPathBuf::with_random("test"); + let good = std::env::temp_dir(); + assert!(sock.as_ref().starts_with(good)); + // Since there are 16 random characters we just assert the file_name has the right name + // and starts with the correct string 'test_' + // file name: test_xxxxxxxxxxxxxxxx + // test = 4 + // _ = 1 + // = 16 + // total = 21 + let fname = sock.as_ref().file_name().unwrap().to_string_lossy(); + assert!(fname.starts_with("test_")); + assert_eq!(fname.len(), 21); + } + + #[test] + fn temp_path_buf_non_existing() { + let sock = TempPathBuf::with_random("test"); + let sock_path = std::path::PathBuf::from(sock.as_ref()); + + // No actual file has been created yet + assert!(!sock_path.exists()); + // Should not panic + std::mem::drop(sock); + assert!(!sock_path.exists()); + } + + #[test] + fn temp_path_buf_existing_file() { + let sock = TempPathBuf::with_random("test"); + let sock_path = std::path::PathBuf::from(sock.as_ref()); + let _file = std::fs::File::create(&sock).unwrap(); + assert!(sock_path.exists()); + std::mem::drop(sock); + assert!(!sock_path.exists()); + } + + #[test] + fn temp_path_buf_preexisting_file() { + let mut pre_existing = std::env::temp_dir(); + pre_existing.push("test"); + let _file = std::fs::File::create(&pre_existing).unwrap(); + let sock = TempPathBuf::new("test"); + let sock_path = std::path::PathBuf::from(sock.as_ref()); + assert!(sock_path.exists()); + std::mem::drop(sock); + assert!(!sock_path.exists()); + } + + #[tokio::test] + async fn temp_path_buf_for_socket() { + let sock = TempPathBuf::with_random("test"); + // Save path for testing after drop + let sock_path = std::path::PathBuf::from(sock.as_ref()); + // create the actual socket + let _ = listen(&sock, SymmetricalJson::::default).await; + assert!(sock_path.exists()); + std::mem::drop(sock); + assert!(!sock_path.exists()); + } + } +} + #[cfg(test)] mod tests { use super::Transport; @@ -393,4 +657,24 @@ mod tests { assert_matches!(transport.next().await, None); Ok(()) } + + #[cfg(all(unix, feature = "unix"))] + #[tokio::test] + async fn uds() -> io::Result<()> { + use super::unix; + use super::*; + + let sock = unix::TempPathBuf::with_random("uds"); + let mut listener = unix::listen(&sock, SymmetricalJson::::default).await?; + tokio::spawn(async move { + let mut transport = listener.next().await.unwrap().unwrap(); + let message = transport.next().await.unwrap().unwrap(); + transport.send(message).await.unwrap(); + }); + let mut transport = unix::connect(&sock, SymmetricalJson::::default).await?; + transport.send(String::from("test")).await?; + assert_matches!(transport.next().await, Some(Ok(s)) if s == "test"); + assert_matches!(transport.next().await, None); + Ok(()) + } } diff --git a/tarpc/tests/compile_fail/serde_transport/must_use_tcp_connect.stderr b/tarpc/tests/compile_fail/serde_transport/must_use_tcp_connect.stderr index e668e80..b1be874 100644 --- a/tarpc/tests/compile_fail/serde_transport/must_use_tcp_connect.stderr +++ b/tarpc/tests/compile_fail/serde_transport/must_use_tcp_connect.stderr @@ -1,4 +1,4 @@ -error: unused `Connect` that must be used +error: unused `tarpc::serde_transport::tcp::Connect` that must be used --> tests/compile_fail/serde_transport/must_use_tcp_connect.rs:7:9 | 7 | serde_transport::tcp::connect::<_, (), (), _, _>("0.0.0.0:0", Json::default); diff --git a/tarpc/tests/service_functional.rs b/tarpc/tests/service_functional.rs index a4939ee..50d19b0 100644 --- a/tarpc/tests/service_functional.rs +++ b/tarpc/tests/service_functional.rs @@ -108,7 +108,7 @@ async fn dropped_channel_aborts_in_flight_requests() -> anyhow::Result<()> { #[cfg(all(feature = "serde-transport", feature = "tcp"))] #[tokio::test] -async fn serde() -> anyhow::Result<()> { +async fn serde_tcp() -> anyhow::Result<()> { use tarpc::serde_transport; use tokio_serde::formats::Json; @@ -136,6 +136,37 @@ async fn serde() -> anyhow::Result<()> { Ok(()) } +#[cfg(all(feature = "serde-transport", feature = "unix", unix))] +#[tokio::test] +async fn serde_uds() -> anyhow::Result<()> { + use tarpc::serde_transport; + use tokio_serde::formats::Json; + + let _ = tracing_subscriber::fmt::try_init(); + + let sock = tarpc::serde_transport::unix::TempPathBuf::with_random("uds"); + let transport = tarpc::serde_transport::unix::listen(&sock, Json::default).await?; + tokio::spawn( + transport + .take(1) + .filter_map(|r| async { r.ok() }) + .map(BaseChannel::with_defaults) + .execute(Server.serve()), + ); + + let transport = serde_transport::unix::connect(&sock, Json::default).await?; + let client = ServiceClient::new(client::Config::default(), transport).spawn(); + + // Save results using socket so we can clean the socket even if our test assertions fail + let res1 = client.add(context::current(), 1, 2).await; + let res2 = client.hey(context::current(), "Tim".to_string()).await; + + assert_matches!(res1, Ok(3)); + assert_matches!(res2, Ok(ref s) if s == "Hey, Tim."); + + Ok(()) +} + #[tokio::test] async fn concurrent() -> anyhow::Result<()> { let _ = tracing_subscriber::fmt::try_init();