mirror of
https://github.com/OMGeeky/tarpc.git
synced 2025-12-26 17:02:32 +01:00
feat: Unix domain sockets with serde transports (#380)
* adds support for Unix Domain Socket generic transports * adds a TempPathBuf that lives in temp and is removed on drop
This commit is contained in:
@@ -25,6 +25,7 @@ serde-transport = ["serde1", "tokio1", "tokio-serde", "tokio-util/codec"]
|
||||
serde-transport-json = ["tokio-serde/json"]
|
||||
serde-transport-bincode = ["tokio-serde/bincode"]
|
||||
tcp = ["tokio/net"]
|
||||
unix = ["tokio/net"]
|
||||
|
||||
full = [
|
||||
"serde1",
|
||||
@@ -33,6 +34,7 @@ full = [
|
||||
"serde-transport-json",
|
||||
"serde-transport-bincode",
|
||||
"tcp",
|
||||
"unix",
|
||||
]
|
||||
|
||||
[badges]
|
||||
|
||||
@@ -277,6 +277,270 @@ pub mod tcp {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(all(unix, feature = "unix"))]
|
||||
#[cfg_attr(docsrs, doc(cfg(all(unix, feature = "unix"))))]
|
||||
/// Unix Domain Socket support for generic transport using Tokio.
|
||||
pub mod unix {
|
||||
use {
|
||||
super::*,
|
||||
futures::ready,
|
||||
std::{marker::PhantomData, path::Path},
|
||||
tokio::net::{unix::SocketAddr, UnixListener, UnixStream},
|
||||
tokio_util::codec::length_delimited,
|
||||
};
|
||||
|
||||
impl<Item, SinkItem, Codec> Transport<UnixStream, Item, SinkItem, Codec> {
|
||||
/// Returns the socket address of the remote half of the underlying [`UnixStream`].
|
||||
pub fn peer_addr(&self) -> io::Result<SocketAddr> {
|
||||
self.inner.get_ref().get_ref().peer_addr()
|
||||
}
|
||||
/// Returns the socket address of the local half of the underlying [`UnixStream`].
|
||||
pub fn local_addr(&self) -> io::Result<SocketAddr> {
|
||||
self.inner.get_ref().get_ref().local_addr()
|
||||
}
|
||||
}
|
||||
|
||||
/// A connection Future that also exposes the length-delimited framing config.
|
||||
#[must_use]
|
||||
#[pin_project]
|
||||
pub struct Connect<T, Item, SinkItem, CodecFn> {
|
||||
#[pin]
|
||||
inner: T,
|
||||
codec_fn: CodecFn,
|
||||
config: length_delimited::Builder,
|
||||
ghost: PhantomData<(fn(SinkItem), fn() -> Item)>,
|
||||
}
|
||||
|
||||
impl<T, Item, SinkItem, Codec, CodecFn> Future for Connect<T, Item, SinkItem, CodecFn>
|
||||
where
|
||||
T: Future<Output = io::Result<UnixStream>>,
|
||||
Item: for<'de> Deserialize<'de>,
|
||||
SinkItem: Serialize,
|
||||
Codec: Serializer<SinkItem> + Deserializer<Item>,
|
||||
CodecFn: Fn() -> Codec,
|
||||
{
|
||||
type Output = io::Result<Transport<UnixStream, Item, SinkItem, Codec>>;
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
|
||||
let io = ready!(self.as_mut().project().inner.poll(cx))?;
|
||||
Poll::Ready(Ok(new(self.config.new_framed(io), (self.codec_fn)())))
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, Item, SinkItem, CodecFn> Connect<T, Item, SinkItem, CodecFn> {
|
||||
/// Returns an immutable reference to the length-delimited codec's config.
|
||||
pub fn config(&self) -> &length_delimited::Builder {
|
||||
&self.config
|
||||
}
|
||||
|
||||
/// Returns a mutable reference to the length-delimited codec's config.
|
||||
pub fn config_mut(&mut self) -> &mut length_delimited::Builder {
|
||||
&mut self.config
|
||||
}
|
||||
}
|
||||
|
||||
/// Connects to socket named by `path`, wrapping the connection in a Unix Domain Socket
|
||||
/// transport.
|
||||
pub fn connect<P, Item, SinkItem, Codec, CodecFn>(
|
||||
path: P,
|
||||
codec_fn: CodecFn,
|
||||
) -> Connect<impl Future<Output = io::Result<UnixStream>>, Item, SinkItem, CodecFn>
|
||||
where
|
||||
P: AsRef<Path>,
|
||||
Item: for<'de> Deserialize<'de>,
|
||||
SinkItem: Serialize,
|
||||
Codec: Serializer<SinkItem> + Deserializer<Item>,
|
||||
CodecFn: Fn() -> Codec,
|
||||
{
|
||||
Connect {
|
||||
inner: UnixStream::connect(path),
|
||||
codec_fn,
|
||||
config: LengthDelimitedCodec::builder(),
|
||||
ghost: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
/// Listens on the socket named by `path`, wrapping accepted connections in Unix Domain Socket
|
||||
/// transports.
|
||||
pub async fn listen<P, Item, SinkItem, Codec, CodecFn>(
|
||||
path: P,
|
||||
codec_fn: CodecFn,
|
||||
) -> io::Result<Incoming<Item, SinkItem, Codec, CodecFn>>
|
||||
where
|
||||
P: AsRef<Path>,
|
||||
Item: for<'de> Deserialize<'de>,
|
||||
Codec: Serializer<SinkItem> + Deserializer<Item>,
|
||||
CodecFn: Fn() -> Codec,
|
||||
{
|
||||
let listener = UnixListener::bind(path)?;
|
||||
let local_addr = listener.local_addr()?;
|
||||
Ok(Incoming {
|
||||
listener,
|
||||
codec_fn,
|
||||
local_addr,
|
||||
config: LengthDelimitedCodec::builder(),
|
||||
ghost: PhantomData,
|
||||
})
|
||||
}
|
||||
|
||||
/// A [`UnixListener`] that wraps connections in [transports](Transport).
|
||||
#[pin_project]
|
||||
#[derive(Debug)]
|
||||
pub struct Incoming<Item, SinkItem, Codec, CodecFn> {
|
||||
listener: UnixListener,
|
||||
local_addr: SocketAddr,
|
||||
codec_fn: CodecFn,
|
||||
config: length_delimited::Builder,
|
||||
ghost: PhantomData<(fn() -> Item, fn(SinkItem), Codec)>,
|
||||
}
|
||||
|
||||
impl<Item, SinkItem, Codec, CodecFn> Incoming<Item, SinkItem, Codec, CodecFn> {
|
||||
/// Returns the the socket address being listened on.
|
||||
pub fn local_addr(&self) -> &SocketAddr {
|
||||
&self.local_addr
|
||||
}
|
||||
|
||||
/// Returns an immutable reference to the length-delimited codec's config.
|
||||
pub fn config(&self) -> &length_delimited::Builder {
|
||||
&self.config
|
||||
}
|
||||
|
||||
/// Returns a mutable reference to the length-delimited codec's config.
|
||||
pub fn config_mut(&mut self) -> &mut length_delimited::Builder {
|
||||
&mut self.config
|
||||
}
|
||||
}
|
||||
|
||||
impl<Item, SinkItem, Codec, CodecFn> Stream for Incoming<Item, SinkItem, Codec, CodecFn>
|
||||
where
|
||||
Item: for<'de> Deserialize<'de>,
|
||||
SinkItem: Serialize,
|
||||
Codec: Serializer<SinkItem> + Deserializer<Item>,
|
||||
CodecFn: Fn() -> Codec,
|
||||
{
|
||||
type Item = io::Result<Transport<UnixStream, Item, SinkItem, Codec>>;
|
||||
|
||||
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
let conn: UnixStream = ready!(self.as_mut().project().listener.poll_accept(cx)?).0;
|
||||
Poll::Ready(Some(Ok(new(
|
||||
self.config.new_framed(conn),
|
||||
(self.codec_fn)(),
|
||||
))))
|
||||
}
|
||||
}
|
||||
|
||||
/// A temporary `PathBuf` that lives in `std::env::temp_dir` and is removed on drop.
|
||||
pub struct TempPathBuf(std::path::PathBuf);
|
||||
|
||||
impl TempPathBuf {
|
||||
/// A named socket that results in `<tempdir>/<name>`
|
||||
pub fn new<S: AsRef<str>>(name: S) -> Self {
|
||||
let mut sock = std::env::temp_dir();
|
||||
sock.push(name.as_ref());
|
||||
Self(sock)
|
||||
}
|
||||
|
||||
/// Appends a random hex string to the socket name resulting in
|
||||
/// `<tempdir>/<name>_<xxxxx>`
|
||||
pub fn with_random<S: AsRef<str>>(name: S) -> Self {
|
||||
Self::new(format!("{}_{:x}", name.as_ref(), rand::random::<u64>()))
|
||||
}
|
||||
}
|
||||
|
||||
impl AsRef<std::path::Path> for TempPathBuf {
|
||||
fn as_ref(&self) -> &std::path::Path {
|
||||
self.0.as_path()
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for TempPathBuf {
|
||||
fn drop(&mut self) {
|
||||
// This will remove the file pointed to by this PathBuf if it exists, however Err's can
|
||||
// be returned such as attempting to remove a non-existing file, or one which we don't
|
||||
// have permission to remove. In these cases the Err is swallowed
|
||||
let _ = std::fs::remove_file(&self.0);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tokio_serde::formats::SymmetricalJson;
|
||||
|
||||
#[test]
|
||||
fn temp_path_buf_non_random() {
|
||||
let sock = TempPathBuf::new("test");
|
||||
let mut good = std::env::temp_dir();
|
||||
good.push("test");
|
||||
assert_eq!(sock.as_ref(), good);
|
||||
assert_eq!(sock.as_ref().file_name().unwrap(), "test");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn temp_path_buf_random() {
|
||||
let sock = TempPathBuf::with_random("test");
|
||||
let good = std::env::temp_dir();
|
||||
assert!(sock.as_ref().starts_with(good));
|
||||
// Since there are 16 random characters we just assert the file_name has the right name
|
||||
// and starts with the correct string 'test_'
|
||||
// file name: test_xxxxxxxxxxxxxxxx
|
||||
// test = 4
|
||||
// _ = 1
|
||||
// <hex> = 16
|
||||
// total = 21
|
||||
let fname = sock.as_ref().file_name().unwrap().to_string_lossy();
|
||||
assert!(fname.starts_with("test_"));
|
||||
assert_eq!(fname.len(), 21);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn temp_path_buf_non_existing() {
|
||||
let sock = TempPathBuf::with_random("test");
|
||||
let sock_path = std::path::PathBuf::from(sock.as_ref());
|
||||
|
||||
// No actual file has been created yet
|
||||
assert!(!sock_path.exists());
|
||||
// Should not panic
|
||||
std::mem::drop(sock);
|
||||
assert!(!sock_path.exists());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn temp_path_buf_existing_file() {
|
||||
let sock = TempPathBuf::with_random("test");
|
||||
let sock_path = std::path::PathBuf::from(sock.as_ref());
|
||||
let _file = std::fs::File::create(&sock).unwrap();
|
||||
assert!(sock_path.exists());
|
||||
std::mem::drop(sock);
|
||||
assert!(!sock_path.exists());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn temp_path_buf_preexisting_file() {
|
||||
let mut pre_existing = std::env::temp_dir();
|
||||
pre_existing.push("test");
|
||||
let _file = std::fs::File::create(&pre_existing).unwrap();
|
||||
let sock = TempPathBuf::new("test");
|
||||
let sock_path = std::path::PathBuf::from(sock.as_ref());
|
||||
assert!(sock_path.exists());
|
||||
std::mem::drop(sock);
|
||||
assert!(!sock_path.exists());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn temp_path_buf_for_socket() {
|
||||
let sock = TempPathBuf::with_random("test");
|
||||
// Save path for testing after drop
|
||||
let sock_path = std::path::PathBuf::from(sock.as_ref());
|
||||
// create the actual socket
|
||||
let _ = listen(&sock, SymmetricalJson::<String>::default).await;
|
||||
assert!(sock_path.exists());
|
||||
std::mem::drop(sock);
|
||||
assert!(!sock_path.exists());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::Transport;
|
||||
@@ -393,4 +657,24 @@ mod tests {
|
||||
assert_matches!(transport.next().await, None);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(all(unix, feature = "unix"))]
|
||||
#[tokio::test]
|
||||
async fn uds() -> io::Result<()> {
|
||||
use super::unix;
|
||||
use super::*;
|
||||
|
||||
let sock = unix::TempPathBuf::with_random("uds");
|
||||
let mut listener = unix::listen(&sock, SymmetricalJson::<String>::default).await?;
|
||||
tokio::spawn(async move {
|
||||
let mut transport = listener.next().await.unwrap().unwrap();
|
||||
let message = transport.next().await.unwrap().unwrap();
|
||||
transport.send(message).await.unwrap();
|
||||
});
|
||||
let mut transport = unix::connect(&sock, SymmetricalJson::<String>::default).await?;
|
||||
transport.send(String::from("test")).await?;
|
||||
assert_matches!(transport.next().await, Some(Ok(s)) if s == "test");
|
||||
assert_matches!(transport.next().await, None);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
error: unused `Connect` that must be used
|
||||
error: unused `tarpc::serde_transport::tcp::Connect` that must be used
|
||||
--> tests/compile_fail/serde_transport/must_use_tcp_connect.rs:7:9
|
||||
|
|
||||
7 | serde_transport::tcp::connect::<_, (), (), _, _>("0.0.0.0:0", Json::default);
|
||||
|
||||
@@ -108,7 +108,7 @@ async fn dropped_channel_aborts_in_flight_requests() -> anyhow::Result<()> {
|
||||
|
||||
#[cfg(all(feature = "serde-transport", feature = "tcp"))]
|
||||
#[tokio::test]
|
||||
async fn serde() -> anyhow::Result<()> {
|
||||
async fn serde_tcp() -> anyhow::Result<()> {
|
||||
use tarpc::serde_transport;
|
||||
use tokio_serde::formats::Json;
|
||||
|
||||
@@ -136,6 +136,37 @@ async fn serde() -> anyhow::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(all(feature = "serde-transport", feature = "unix", unix))]
|
||||
#[tokio::test]
|
||||
async fn serde_uds() -> anyhow::Result<()> {
|
||||
use tarpc::serde_transport;
|
||||
use tokio_serde::formats::Json;
|
||||
|
||||
let _ = tracing_subscriber::fmt::try_init();
|
||||
|
||||
let sock = tarpc::serde_transport::unix::TempPathBuf::with_random("uds");
|
||||
let transport = tarpc::serde_transport::unix::listen(&sock, Json::default).await?;
|
||||
tokio::spawn(
|
||||
transport
|
||||
.take(1)
|
||||
.filter_map(|r| async { r.ok() })
|
||||
.map(BaseChannel::with_defaults)
|
||||
.execute(Server.serve()),
|
||||
);
|
||||
|
||||
let transport = serde_transport::unix::connect(&sock, Json::default).await?;
|
||||
let client = ServiceClient::new(client::Config::default(), transport).spawn();
|
||||
|
||||
// Save results using socket so we can clean the socket even if our test assertions fail
|
||||
let res1 = client.add(context::current(), 1, 2).await;
|
||||
let res2 = client.hey(context::current(), "Tim".to_string()).await;
|
||||
|
||||
assert_matches!(res1, Ok(3));
|
||||
assert_matches!(res2, Ok(ref s) if s == "Hey, Tim.");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn concurrent() -> anyhow::Result<()> {
|
||||
let _ = tracing_subscriber::fmt::try_init();
|
||||
|
||||
Reference in New Issue
Block a user