mirror of
https://github.com/OMGeeky/tarpc.git
synced 2026-02-13 21:18:18 +01:00
feat: Unix domain sockets with serde transports (#380)
* adds support for Unix Domain Socket generic transports * adds a TempPathBuf that lives in temp and is removed on drop
This commit is contained in:
@@ -25,6 +25,7 @@ serde-transport = ["serde1", "tokio1", "tokio-serde", "tokio-util/codec"]
|
|||||||
serde-transport-json = ["tokio-serde/json"]
|
serde-transport-json = ["tokio-serde/json"]
|
||||||
serde-transport-bincode = ["tokio-serde/bincode"]
|
serde-transport-bincode = ["tokio-serde/bincode"]
|
||||||
tcp = ["tokio/net"]
|
tcp = ["tokio/net"]
|
||||||
|
unix = ["tokio/net"]
|
||||||
|
|
||||||
full = [
|
full = [
|
||||||
"serde1",
|
"serde1",
|
||||||
@@ -33,6 +34,7 @@ full = [
|
|||||||
"serde-transport-json",
|
"serde-transport-json",
|
||||||
"serde-transport-bincode",
|
"serde-transport-bincode",
|
||||||
"tcp",
|
"tcp",
|
||||||
|
"unix",
|
||||||
]
|
]
|
||||||
|
|
||||||
[badges]
|
[badges]
|
||||||
|
|||||||
@@ -277,6 +277,270 @@ pub mod tcp {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(all(unix, feature = "unix"))]
|
||||||
|
#[cfg_attr(docsrs, doc(cfg(all(unix, feature = "unix"))))]
|
||||||
|
/// Unix Domain Socket support for generic transport using Tokio.
|
||||||
|
pub mod unix {
|
||||||
|
use {
|
||||||
|
super::*,
|
||||||
|
futures::ready,
|
||||||
|
std::{marker::PhantomData, path::Path},
|
||||||
|
tokio::net::{unix::SocketAddr, UnixListener, UnixStream},
|
||||||
|
tokio_util::codec::length_delimited,
|
||||||
|
};
|
||||||
|
|
||||||
|
impl<Item, SinkItem, Codec> Transport<UnixStream, Item, SinkItem, Codec> {
|
||||||
|
/// Returns the socket address of the remote half of the underlying [`UnixStream`].
|
||||||
|
pub fn peer_addr(&self) -> io::Result<SocketAddr> {
|
||||||
|
self.inner.get_ref().get_ref().peer_addr()
|
||||||
|
}
|
||||||
|
/// Returns the socket address of the local half of the underlying [`UnixStream`].
|
||||||
|
pub fn local_addr(&self) -> io::Result<SocketAddr> {
|
||||||
|
self.inner.get_ref().get_ref().local_addr()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A connection Future that also exposes the length-delimited framing config.
|
||||||
|
#[must_use]
|
||||||
|
#[pin_project]
|
||||||
|
pub struct Connect<T, Item, SinkItem, CodecFn> {
|
||||||
|
#[pin]
|
||||||
|
inner: T,
|
||||||
|
codec_fn: CodecFn,
|
||||||
|
config: length_delimited::Builder,
|
||||||
|
ghost: PhantomData<(fn(SinkItem), fn() -> Item)>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T, Item, SinkItem, Codec, CodecFn> Future for Connect<T, Item, SinkItem, CodecFn>
|
||||||
|
where
|
||||||
|
T: Future<Output = io::Result<UnixStream>>,
|
||||||
|
Item: for<'de> Deserialize<'de>,
|
||||||
|
SinkItem: Serialize,
|
||||||
|
Codec: Serializer<SinkItem> + Deserializer<Item>,
|
||||||
|
CodecFn: Fn() -> Codec,
|
||||||
|
{
|
||||||
|
type Output = io::Result<Transport<UnixStream, Item, SinkItem, Codec>>;
|
||||||
|
|
||||||
|
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
|
||||||
|
let io = ready!(self.as_mut().project().inner.poll(cx))?;
|
||||||
|
Poll::Ready(Ok(new(self.config.new_framed(io), (self.codec_fn)())))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T, Item, SinkItem, CodecFn> Connect<T, Item, SinkItem, CodecFn> {
|
||||||
|
/// Returns an immutable reference to the length-delimited codec's config.
|
||||||
|
pub fn config(&self) -> &length_delimited::Builder {
|
||||||
|
&self.config
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns a mutable reference to the length-delimited codec's config.
|
||||||
|
pub fn config_mut(&mut self) -> &mut length_delimited::Builder {
|
||||||
|
&mut self.config
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Connects to socket named by `path`, wrapping the connection in a Unix Domain Socket
|
||||||
|
/// transport.
|
||||||
|
pub fn connect<P, Item, SinkItem, Codec, CodecFn>(
|
||||||
|
path: P,
|
||||||
|
codec_fn: CodecFn,
|
||||||
|
) -> Connect<impl Future<Output = io::Result<UnixStream>>, Item, SinkItem, CodecFn>
|
||||||
|
where
|
||||||
|
P: AsRef<Path>,
|
||||||
|
Item: for<'de> Deserialize<'de>,
|
||||||
|
SinkItem: Serialize,
|
||||||
|
Codec: Serializer<SinkItem> + Deserializer<Item>,
|
||||||
|
CodecFn: Fn() -> Codec,
|
||||||
|
{
|
||||||
|
Connect {
|
||||||
|
inner: UnixStream::connect(path),
|
||||||
|
codec_fn,
|
||||||
|
config: LengthDelimitedCodec::builder(),
|
||||||
|
ghost: PhantomData,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Listens on the socket named by `path`, wrapping accepted connections in Unix Domain Socket
|
||||||
|
/// transports.
|
||||||
|
pub async fn listen<P, Item, SinkItem, Codec, CodecFn>(
|
||||||
|
path: P,
|
||||||
|
codec_fn: CodecFn,
|
||||||
|
) -> io::Result<Incoming<Item, SinkItem, Codec, CodecFn>>
|
||||||
|
where
|
||||||
|
P: AsRef<Path>,
|
||||||
|
Item: for<'de> Deserialize<'de>,
|
||||||
|
Codec: Serializer<SinkItem> + Deserializer<Item>,
|
||||||
|
CodecFn: Fn() -> Codec,
|
||||||
|
{
|
||||||
|
let listener = UnixListener::bind(path)?;
|
||||||
|
let local_addr = listener.local_addr()?;
|
||||||
|
Ok(Incoming {
|
||||||
|
listener,
|
||||||
|
codec_fn,
|
||||||
|
local_addr,
|
||||||
|
config: LengthDelimitedCodec::builder(),
|
||||||
|
ghost: PhantomData,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A [`UnixListener`] that wraps connections in [transports](Transport).
|
||||||
|
#[pin_project]
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct Incoming<Item, SinkItem, Codec, CodecFn> {
|
||||||
|
listener: UnixListener,
|
||||||
|
local_addr: SocketAddr,
|
||||||
|
codec_fn: CodecFn,
|
||||||
|
config: length_delimited::Builder,
|
||||||
|
ghost: PhantomData<(fn() -> Item, fn(SinkItem), Codec)>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<Item, SinkItem, Codec, CodecFn> Incoming<Item, SinkItem, Codec, CodecFn> {
|
||||||
|
/// Returns the the socket address being listened on.
|
||||||
|
pub fn local_addr(&self) -> &SocketAddr {
|
||||||
|
&self.local_addr
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns an immutable reference to the length-delimited codec's config.
|
||||||
|
pub fn config(&self) -> &length_delimited::Builder {
|
||||||
|
&self.config
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns a mutable reference to the length-delimited codec's config.
|
||||||
|
pub fn config_mut(&mut self) -> &mut length_delimited::Builder {
|
||||||
|
&mut self.config
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<Item, SinkItem, Codec, CodecFn> Stream for Incoming<Item, SinkItem, Codec, CodecFn>
|
||||||
|
where
|
||||||
|
Item: for<'de> Deserialize<'de>,
|
||||||
|
SinkItem: Serialize,
|
||||||
|
Codec: Serializer<SinkItem> + Deserializer<Item>,
|
||||||
|
CodecFn: Fn() -> Codec,
|
||||||
|
{
|
||||||
|
type Item = io::Result<Transport<UnixStream, Item, SinkItem, Codec>>;
|
||||||
|
|
||||||
|
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||||
|
let conn: UnixStream = ready!(self.as_mut().project().listener.poll_accept(cx)?).0;
|
||||||
|
Poll::Ready(Some(Ok(new(
|
||||||
|
self.config.new_framed(conn),
|
||||||
|
(self.codec_fn)(),
|
||||||
|
))))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A temporary `PathBuf` that lives in `std::env::temp_dir` and is removed on drop.
|
||||||
|
pub struct TempPathBuf(std::path::PathBuf);
|
||||||
|
|
||||||
|
impl TempPathBuf {
|
||||||
|
/// A named socket that results in `<tempdir>/<name>`
|
||||||
|
pub fn new<S: AsRef<str>>(name: S) -> Self {
|
||||||
|
let mut sock = std::env::temp_dir();
|
||||||
|
sock.push(name.as_ref());
|
||||||
|
Self(sock)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Appends a random hex string to the socket name resulting in
|
||||||
|
/// `<tempdir>/<name>_<xxxxx>`
|
||||||
|
pub fn with_random<S: AsRef<str>>(name: S) -> Self {
|
||||||
|
Self::new(format!("{}_{:x}", name.as_ref(), rand::random::<u64>()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AsRef<std::path::Path> for TempPathBuf {
|
||||||
|
fn as_ref(&self) -> &std::path::Path {
|
||||||
|
self.0.as_path()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Drop for TempPathBuf {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
// This will remove the file pointed to by this PathBuf if it exists, however Err's can
|
||||||
|
// be returned such as attempting to remove a non-existing file, or one which we don't
|
||||||
|
// have permission to remove. In these cases the Err is swallowed
|
||||||
|
let _ = std::fs::remove_file(&self.0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use tokio_serde::formats::SymmetricalJson;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn temp_path_buf_non_random() {
|
||||||
|
let sock = TempPathBuf::new("test");
|
||||||
|
let mut good = std::env::temp_dir();
|
||||||
|
good.push("test");
|
||||||
|
assert_eq!(sock.as_ref(), good);
|
||||||
|
assert_eq!(sock.as_ref().file_name().unwrap(), "test");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn temp_path_buf_random() {
|
||||||
|
let sock = TempPathBuf::with_random("test");
|
||||||
|
let good = std::env::temp_dir();
|
||||||
|
assert!(sock.as_ref().starts_with(good));
|
||||||
|
// Since there are 16 random characters we just assert the file_name has the right name
|
||||||
|
// and starts with the correct string 'test_'
|
||||||
|
// file name: test_xxxxxxxxxxxxxxxx
|
||||||
|
// test = 4
|
||||||
|
// _ = 1
|
||||||
|
// <hex> = 16
|
||||||
|
// total = 21
|
||||||
|
let fname = sock.as_ref().file_name().unwrap().to_string_lossy();
|
||||||
|
assert!(fname.starts_with("test_"));
|
||||||
|
assert_eq!(fname.len(), 21);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn temp_path_buf_non_existing() {
|
||||||
|
let sock = TempPathBuf::with_random("test");
|
||||||
|
let sock_path = std::path::PathBuf::from(sock.as_ref());
|
||||||
|
|
||||||
|
// No actual file has been created yet
|
||||||
|
assert!(!sock_path.exists());
|
||||||
|
// Should not panic
|
||||||
|
std::mem::drop(sock);
|
||||||
|
assert!(!sock_path.exists());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn temp_path_buf_existing_file() {
|
||||||
|
let sock = TempPathBuf::with_random("test");
|
||||||
|
let sock_path = std::path::PathBuf::from(sock.as_ref());
|
||||||
|
let _file = std::fs::File::create(&sock).unwrap();
|
||||||
|
assert!(sock_path.exists());
|
||||||
|
std::mem::drop(sock);
|
||||||
|
assert!(!sock_path.exists());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn temp_path_buf_preexisting_file() {
|
||||||
|
let mut pre_existing = std::env::temp_dir();
|
||||||
|
pre_existing.push("test");
|
||||||
|
let _file = std::fs::File::create(&pre_existing).unwrap();
|
||||||
|
let sock = TempPathBuf::new("test");
|
||||||
|
let sock_path = std::path::PathBuf::from(sock.as_ref());
|
||||||
|
assert!(sock_path.exists());
|
||||||
|
std::mem::drop(sock);
|
||||||
|
assert!(!sock_path.exists());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn temp_path_buf_for_socket() {
|
||||||
|
let sock = TempPathBuf::with_random("test");
|
||||||
|
// Save path for testing after drop
|
||||||
|
let sock_path = std::path::PathBuf::from(sock.as_ref());
|
||||||
|
// create the actual socket
|
||||||
|
let _ = listen(&sock, SymmetricalJson::<String>::default).await;
|
||||||
|
assert!(sock_path.exists());
|
||||||
|
std::mem::drop(sock);
|
||||||
|
assert!(!sock_path.exists());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::Transport;
|
use super::Transport;
|
||||||
@@ -393,4 +657,24 @@ mod tests {
|
|||||||
assert_matches!(transport.next().await, None);
|
assert_matches!(transport.next().await, None);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(all(unix, feature = "unix"))]
|
||||||
|
#[tokio::test]
|
||||||
|
async fn uds() -> io::Result<()> {
|
||||||
|
use super::unix;
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
let sock = unix::TempPathBuf::with_random("uds");
|
||||||
|
let mut listener = unix::listen(&sock, SymmetricalJson::<String>::default).await?;
|
||||||
|
tokio::spawn(async move {
|
||||||
|
let mut transport = listener.next().await.unwrap().unwrap();
|
||||||
|
let message = transport.next().await.unwrap().unwrap();
|
||||||
|
transport.send(message).await.unwrap();
|
||||||
|
});
|
||||||
|
let mut transport = unix::connect(&sock, SymmetricalJson::<String>::default).await?;
|
||||||
|
transport.send(String::from("test")).await?;
|
||||||
|
assert_matches!(transport.next().await, Some(Ok(s)) if s == "test");
|
||||||
|
assert_matches!(transport.next().await, None);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
error: unused `Connect` that must be used
|
error: unused `tarpc::serde_transport::tcp::Connect` that must be used
|
||||||
--> tests/compile_fail/serde_transport/must_use_tcp_connect.rs:7:9
|
--> tests/compile_fail/serde_transport/must_use_tcp_connect.rs:7:9
|
||||||
|
|
|
|
||||||
7 | serde_transport::tcp::connect::<_, (), (), _, _>("0.0.0.0:0", Json::default);
|
7 | serde_transport::tcp::connect::<_, (), (), _, _>("0.0.0.0:0", Json::default);
|
||||||
|
|||||||
@@ -108,7 +108,7 @@ async fn dropped_channel_aborts_in_flight_requests() -> anyhow::Result<()> {
|
|||||||
|
|
||||||
#[cfg(all(feature = "serde-transport", feature = "tcp"))]
|
#[cfg(all(feature = "serde-transport", feature = "tcp"))]
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn serde() -> anyhow::Result<()> {
|
async fn serde_tcp() -> anyhow::Result<()> {
|
||||||
use tarpc::serde_transport;
|
use tarpc::serde_transport;
|
||||||
use tokio_serde::formats::Json;
|
use tokio_serde::formats::Json;
|
||||||
|
|
||||||
@@ -136,6 +136,37 @@ async fn serde() -> anyhow::Result<()> {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(all(feature = "serde-transport", feature = "unix", unix))]
|
||||||
|
#[tokio::test]
|
||||||
|
async fn serde_uds() -> anyhow::Result<()> {
|
||||||
|
use tarpc::serde_transport;
|
||||||
|
use tokio_serde::formats::Json;
|
||||||
|
|
||||||
|
let _ = tracing_subscriber::fmt::try_init();
|
||||||
|
|
||||||
|
let sock = tarpc::serde_transport::unix::TempPathBuf::with_random("uds");
|
||||||
|
let transport = tarpc::serde_transport::unix::listen(&sock, Json::default).await?;
|
||||||
|
tokio::spawn(
|
||||||
|
transport
|
||||||
|
.take(1)
|
||||||
|
.filter_map(|r| async { r.ok() })
|
||||||
|
.map(BaseChannel::with_defaults)
|
||||||
|
.execute(Server.serve()),
|
||||||
|
);
|
||||||
|
|
||||||
|
let transport = serde_transport::unix::connect(&sock, Json::default).await?;
|
||||||
|
let client = ServiceClient::new(client::Config::default(), transport).spawn();
|
||||||
|
|
||||||
|
// Save results using socket so we can clean the socket even if our test assertions fail
|
||||||
|
let res1 = client.add(context::current(), 1, 2).await;
|
||||||
|
let res2 = client.hey(context::current(), "Tim".to_string()).await;
|
||||||
|
|
||||||
|
assert_matches!(res1, Ok(3));
|
||||||
|
assert_matches!(res2, Ok(ref s) if s == "Hey, Tim.");
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn concurrent() -> anyhow::Result<()> {
|
async fn concurrent() -> anyhow::Result<()> {
|
||||||
let _ = tracing_subscriber::fmt::try_init();
|
let _ = tracing_subscriber::fmt::try_init();
|
||||||
|
|||||||
Reference in New Issue
Block a user