diff --git a/Cargo.lock b/Cargo.lock index 25b44b3..71b2494 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -421,8 +421,10 @@ dependencies = [ "google-drive3", "serde", "tarpc", + "thiserror", "tokio", "tracing", + "uuid", "yup-oauth2", ] @@ -1801,6 +1803,18 @@ version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be" +[[package]] +name = "uuid" +version = "1.20.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee48d38b119b0cd71fe4141b30f5ba9c7c5d9f4e7a3a8b4a674e4b6ef789976f" +dependencies = [ + "getrandom 0.3.4", + "js-sys", + "serde_core", + "wasm-bindgen", +] + [[package]] name = "valuable" version = "0.1.1" diff --git a/gdriver-backend/Cargo.toml b/gdriver-backend/Cargo.toml index 6f5dc87..308c9f2 100644 --- a/gdriver-backend/Cargo.toml +++ b/gdriver-backend/Cargo.toml @@ -14,8 +14,10 @@ tarpc.workspace = true futures.workspace = true chrono.workspace = true google-drive3.workspace = true +uuid = { version = "1.7.0", features = ["v4", "serde"] } yup-oauth2 = "12.1.2" +thiserror = "2.0.18" [dependencies.gdriver-common] path = "../gdriver-common" diff --git a/gdriver-backend/src/main.rs b/gdriver-backend/src/main.rs index 195c7fe..28f2dc6 100644 --- a/gdriver-backend/src/main.rs +++ b/gdriver-backend/src/main.rs @@ -1,10 +1,10 @@ use futures::{future, prelude::*}; -use std::net::SocketAddr; use tarpc::{ context, server::{self, incoming::Incoming, Channel}, tokio_serde::formats::Json, }; +use tokio::net::unix::SocketAddr; mod prelude; use crate::prelude::*; diff --git a/gdriver-backend/src/sample.rs b/gdriver-backend/src/sample.rs index 955b62d..2656df4 100644 --- a/gdriver-backend/src/sample.rs +++ b/gdriver-backend/src/sample.rs @@ -13,8 +13,8 @@ impl World for HelloServer { pub(super) async fn main() -> Result<()> { println!("Hello, world!"); let config = &CONFIGURATION; - let server_addr = (config.ip, config.port); - let mut listener = tarpc::serde_transport::tcp::listen(&server_addr, Json::default).await?; + let server_addr = (&config.socket_path); + let mut listener = tarpc::serde_transport::unix::listen(&server_addr, Json::default).await?; println!("Listening"); listener.config_mut().max_frame_length(usize::MAX); @@ -23,11 +23,21 @@ pub(super) async fn main() -> Result<()> { .filter_map(|r| future::ready(r.ok())) .map(server::BaseChannel::with_defaults) // Limit channels to 1 per IP. - .max_channels_per_key(1, |t| t.transport().peer_addr().unwrap().ip()) + .max_channels_per_key(1, |t| { + t.transport() + .peer_addr() + .unwrap() + .as_pathname() + .unwrap() + .to_str() + .unwrap() + .to_string() + }) // serve is generated by the service attribute. It takes as input any type implementing // the generated World trait. .map(|channel| { - let server = HelloServer(channel.transport().peer_addr().unwrap()); + let c = channel.transport().peer_addr().unwrap(); + let server = HelloServer(c); channel.execute(server.serve()).for_each(spawn) }) // Max 10 channels. diff --git a/gdriver-backend/src/service.rs b/gdriver-backend/src/service.rs index 9193f72..5dae856 100644 --- a/gdriver-backend/src/service.rs +++ b/gdriver-backend/src/service.rs @@ -10,14 +10,17 @@ use tarpc::context::Context; use tokio::sync::Mutex; use tracing::instrument; +use super::*; use crate::drive::Drive; use crate::prelude::*; +use crate::service::sample_sym::spawn_twoway; +use tarpc::server::{BaseChannel, Channel}; -use super::*; #[derive(Clone, Debug)] struct GdriverServer { socket_address: SocketAddr, drive: Arc>, + client: GDriverClientClient, } impl GDriverService for GdriverServer { async fn is_online(self, _: Context) -> bool { @@ -168,8 +171,15 @@ impl GDriverService for GdriverServer { self, _: ::tarpc::context::Context, req: BackendActionRequest, - ) -> std::result::Result { - println!("You are connected from {}", self.socket_address); + ) -> std::result::Result, BackendActionError> { + println!( + "You are connected from {}", + self.socket_address + .as_pathname() + .map(|p| p.to_str()) + .flatten() + .unwrap_or("unknown") + ); match req { BackendActionRequest::ShutdownGracefully => { @@ -182,21 +192,34 @@ impl GDriverService for GdriverServer { let drive = &self.drive; print_sample_tracking_state(drive).await; - Ok(String::from("OK")) + Ok(AsyncResponse::Immediate(String::from("OK"))) } BackendActionRequest::Ping => { println!("Ping request received"); - Ok(String::from("Pong")) + Ok(AsyncResponse::Immediate(String::from("Pong"))) } BackendActionRequest::RunLong => { println!("RunLong request received"); long_running_task(&self.drive).await; - Ok(String::from("OK")) + Ok(AsyncResponse::Immediate(String::from("OK"))) } BackendActionRequest::StartLong => { println!("StartLong request received"); - tokio::spawn(async move { long_running_task(&self.drive).await }); - Ok(String::from("OK")) + let drive = self.drive.clone(); + let client = self.client.clone(); + let task_id = TaskId(uuid::Uuid::new_v4().to_string()); + let task_id_clone = task_id.clone(); + tokio::spawn(async move { + long_running_task(&drive).await; + let _ = client + .report_task_result( + tarpc::context::current(), + task_id_clone, + TaskResult::Success("OK".to_string()), + ) + .await; + }); + Ok(AsyncResponse::Pending(task_id)) } } } @@ -219,30 +242,121 @@ pub async fn start() -> Result<()> { let drive = Drive::new().await?; let drive = Arc::new(Mutex::new(drive)); - let server_addr = (config.ip, config.port); - let mut listener = tarpc::serde_transport::tcp::listen(&server_addr, Json::default).await?; + let mut listener = + tarpc::serde_transport::unix::listen(&config.socket_path, Json::default).await?; listener.config_mut().max_frame_length(usize::MAX); println!("Listening"); + listener // Ignore accept errors. .filter_map(|r| future::ready(r.ok())) - .map(server::BaseChannel::with_defaults) - // // Limit channels to 1 per IP. - .max_channels_per_key(1, |t| t.transport().peer_addr().unwrap().ip()) - // serve is generated by the service attribute. It takes as input any type implementing - // the generated World trait. - .map(|channel| { - let c = channel.transport().peer_addr().unwrap(); + .map(|transport| { + let peer_addr = transport.peer_addr().unwrap(); + let (server_channel, client_channel) = spawn_twoway(transport); + let client = + GDriverClientClient::new(tarpc::client::Config::default(), client_channel).spawn(); let server = GdriverServer { - socket_address: c, + socket_address: peer_addr, drive: drive.clone(), + client, }; - channel.execute(server.serve()).for_each(spawn) + BaseChannel::with_defaults(server_channel).execute(server.serve()) }) // Max 10 channels. - .buffer_unordered(10) + // .buffer_unordered(10) .for_each(|_| async {}) .await; Ok(()) } + +mod sample_sym { + use futures::future::{AbortHandle, Abortable}; + use futures::{Sink, TryFutureExt}; + use futures::{SinkExt, Stream, StreamExt, TryStreamExt}; + use serde::{Deserialize, Serialize}; + use std::io; + use tarpc::transport::channel::{ChannelError, UnboundedChannel}; + use tracing::{debug, warn}; + + #[derive(Debug, Serialize, Deserialize)] + pub enum TwoWayMessage { + Request(tarpc::ClientMessage), + Response(tarpc::Response), + } + + #[derive(thiserror::Error, Debug)] + pub enum ChannelOrIoError { + #[error("{0}")] + ChannelError(#[from] ChannelError), + #[error("{0}")] + IoError(#[from] io::Error), + } + + /// Returns two transports that multiplex over the given transport. + /// The first transport can be used by a server: it receives requests and sends back responses. + /// The second transport can be used by a client: it sends requests and receives back responses. + pub fn spawn_twoway( + transport: T, + ) -> ( + UnboundedChannel, tarpc::Response>, + UnboundedChannel, tarpc::ClientMessage>, + ) + where + T: Stream, io::Error>>, + T: Sink, Error = io::Error>, + T: Unpin + Send + 'static, + Req1: Send + 'static + Serialize + for<'de> Deserialize<'de>, + Resp1: Send + 'static + Serialize + for<'de> Deserialize<'de>, + Req2: Send + 'static + Serialize + for<'de> Deserialize<'de>, + Resp2: Send + 'static + Serialize + for<'de> Deserialize<'de>, + { + let (server, server_ret) = tarpc::transport::channel::unbounded(); + let (client, client_ret) = tarpc::transport::channel::unbounded(); + let (mut server_sink, server_stream) = server.split(); + let (mut client_sink, client_stream) = client.split(); + let (transport_sink, mut transport_stream) = transport.split(); + + let (abort_handle, abort_registration) = AbortHandle::new_pair(); + + // Task for inbound message handling + tokio::spawn(async move { + let e: Result<(), ChannelOrIoError> = async move { + while let Some(msg) = transport_stream.next().await { + match msg? { + TwoWayMessage::Request(req) => server_sink.send(req).await?, + TwoWayMessage::Response(resp) => client_sink.send(resp).await?, + } + } + Ok(()) + } + .await; + + match e { + Ok(()) => debug!("transport_stream done"), + Err(e) => warn!("Error in inbound multiplexing: {}", e), + } + + abort_handle.abort(); + }); + + let abortable_sink_channel = Abortable::new( + futures::stream::select( + server_stream.map_ok(TwoWayMessage::Response), + client_stream.map_ok(TwoWayMessage::Request), + ) + .map_err(ChannelOrIoError::ChannelError), + abort_registration, + ); + + // Task for outbound message handling + tokio::spawn( + abortable_sink_channel + .forward(transport_sink.sink_map_err(ChannelOrIoError::IoError)) + .inspect_ok(|_| debug!("transport_sink done")) + .inspect_err(|e| warn!("Error in outbound multiplexing: {}", e)), + ); + + (server_ret, client_ret) + } +} diff --git a/gdriver-client/src/main.rs b/gdriver-client/src/main.rs index 70f5420..a790211 100644 --- a/gdriver-client/src/main.rs +++ b/gdriver-client/src/main.rs @@ -1,4 +1,4 @@ -use fuser::{mount2, MountOption}; +use fuser::{mount2, spawn_mount2, MountOption}; use std::path::Path; use std::{error::Error, net::IpAddr, result::Result as StdResult}; @@ -18,13 +18,15 @@ async fn main() -> Result<()> { println!("Hello, world!"); let config = &CONFIGURATION; println!("Config: {:?}", **config); - let client: GDriverServiceClient = create_client(config.ip, config.port).await?; - let fs = DriveFilesystem::new(client); - let mountpoint = Path::new("/tmp/gdriver"); - std::fs::create_dir_all(mountpoint)?; - mount2(fs, mountpoint, &[MountOption::RW, MountOption::AutoUnmount])?; + let (client, _handler) = create_client(&config.socket_path).await?; - // service::start().await?; + // let fs = DriveFilesystem::new(client); + // let mountpoint = Path::new("/tmp/gdriver"); + // std::fs::create_dir_all(mountpoint)?; + // // let mount_handle = + // mount2(fs, mountpoint, &[MountOption::RW, MountOption::AutoUnmount])?; + // mount_handle.guard.join().unwrap()?; + service::start_with_client(client, _handler).await?; Ok(()) } pub mod prelude; diff --git a/gdriver-client/src/sample.rs b/gdriver-client/src/sample.rs index e0d0840..5df641a 100644 --- a/gdriver-client/src/sample.rs +++ b/gdriver-client/src/sample.rs @@ -6,7 +6,7 @@ pub async fn start() -> Result<()> { let name = "test1".to_string(); let config = &CONFIGURATION; - let client: WorldClient = create_client(config.ip, config.port).await?; + let client: WorldClient = create_client(&config.socket_path).await?; let hello = client .hello(tarpc::context::current(), name.to_string()) @@ -18,14 +18,10 @@ pub async fn start() -> Result<()> { } Ok(()) } -pub async fn create_client(ip: IpAddr, port: u16) -> Result { - let server_addr = (ip, port); - let transport = tarpc::serde_transport::tcp::connect(&server_addr, Json::default) +pub async fn create_client(socket_path: impl AsRef) -> Result { + let transport = tarpc::serde_transport::unix::connect(socket_path, Json::default) .await - .map_err(|e| { - println!("Could not connect"); - e - })?; + .inspect_err(|_| println!("Could not connect"))?; let var_name = WorldClient::new(client::Config::default(), transport); let client = var_name.spawn(); Ok(client) diff --git a/gdriver-client/src/service.rs b/gdriver-client/src/service.rs index 528b489..16a16c6 100644 --- a/gdriver-client/src/service.rs +++ b/gdriver-client/src/service.rs @@ -1,15 +1,46 @@ use std::time; -use gdriver_common::ipc::gdriver_service::{BackendActionRequest, GDriverServiceClient}; +use futures::prelude::*; +use gdriver_common::ipc::gdriver_service::{ + AsyncResponse, BackendActionRequest, GDriverClient, GDriverServiceClient, TaskId, TaskResult, +}; +use sample_sym::spawn_twoway; +use std::collections::HashMap; +use std::sync::{Arc, Mutex}; +use tarpc::context::Context; +use tarpc::server::{BaseChannel, Channel}; +use tokio::sync::mpsc; use super::*; +#[derive(Clone)] +pub(crate) struct ClientHandler { + pending_tasks: Arc>>>, +} + +impl GDriverClient for ClientHandler { + async fn report_task_result(self, _: Context, id: TaskId, result: TaskResult) { + println!("Received task result for task {}: {:?}", id.0, result); + let mut tasks = self.pending_tasks.lock().unwrap(); + if let Some(sender) = tasks.remove(&id.0) { + let _ = sender.send(result).await; + } + } +} + pub async fn start() -> Result<()> { println!("Hello, world!"); let config = &CONFIGURATION; println!("Config: {:?}", **config); - let client: GDriverServiceClient = create_client(config.ip, config.port).await?; + let (client, handler) = create_client(&config.socket_path).await?; + start_with_client(client, handler).await +} + +pub(crate) async fn start_with_client( + client: GDriverServiceClient, + handler: ClientHandler, +) -> Result<()> { let hello = client .do_something2(tarpc::context::current(), BackendActionRequest::Ping) .await; @@ -42,20 +73,150 @@ pub async fn start() -> Result<()> { .as_secs(); match hello { - Ok(hello) => println!("Start Long returned after {} seconds: {:?}", seconds, hello), - Err(e) => println!(":( {:?}", (e)), + Ok(Ok(AsyncResponse::Pending(task_id))) => { + println!("Start Long returned pending task: {:?}", task_id); + let (tx, mut rx) = mpsc::channel(1); + handler + .pending_tasks + .lock() + .unwrap() + .insert(task_id.0.clone(), tx); + + if let Some(result) = rx.recv().await { + println!( + "Received async result after {} seconds: {:?}", + (time::SystemTime::now().duration_since(start)) + .unwrap() + .as_secs(), + result + ); + } + } + Ok(Ok(AsyncResponse::Immediate(res))) => { + println!("Start Long returned immediate: {:?}", res) + } + Ok(Err(e)) => println!("Backend Error: {:?}", e), + Err(e) => println!("RPC Error: {:?}", e), } Ok(()) } -pub async fn create_client(ip: IpAddr, port: u16) -> Result { - let server_addr = (ip, port); - let transport = tarpc::serde_transport::tcp::connect(&server_addr, Json::default) + +pub async fn create_client( + socket_path: impl AsRef, +) -> Result<(GDriverServiceClient, ClientHandler)> { + let transport = tarpc::serde_transport::unix::connect(socket_path, Json::default) .await .map_err(|e| { println!("Could not connect"); e })?; - let service = GDriverServiceClient::new(client::Config::default(), transport); - let client = service.spawn(); - Ok(client) + + let (client_channel, server_channel) = spawn_twoway(transport); + + let client = GDriverServiceClient::new(client::Config::default(), server_channel).spawn(); + + let handler = ClientHandler { + pending_tasks: Arc::new(Mutex::new(HashMap::new())), + }; + + let handler_clone = handler.clone(); + tokio::spawn(async move { + BaseChannel::with_defaults(client_channel) + .execute(handler_clone.serve()) + .for_each(|_| async {}) + .await; + }); + + Ok((client, handler)) +} + +mod sample_sym { + use futures::future::{AbortHandle, Abortable}; + use futures::{Sink, TryFutureExt}; + use futures::{SinkExt, Stream, StreamExt, TryStreamExt}; + use serde::{Deserialize, Serialize}; + use std::io; + use tarpc::transport::channel::{ChannelError, UnboundedChannel}; + use tracing::{debug, warn}; + + #[derive(Debug, Serialize, Deserialize)] + pub enum TwoWayMessage { + Request(tarpc::ClientMessage), + Response(tarpc::Response), + } + + #[derive(thiserror::Error, Debug)] + pub enum ChannelOrIoError { + #[error("{0}")] + ChannelError(#[from] ChannelError), + #[error("{0}")] + IoError(#[from] io::Error), + } + + /// Returns two transports that multiplex over the given transport. + /// The first transport can be used by a server: it receives requests and sends back responses. + /// The second transport can be used by a client: it sends requests and receives back responses. + pub fn spawn_twoway( + transport: T, + ) -> ( + UnboundedChannel, tarpc::Response>, + UnboundedChannel, tarpc::ClientMessage>, + ) + where + T: Stream, io::Error>>, + T: Sink, Error = io::Error>, + T: Unpin + Send + 'static, + Req1: Send + 'static + Serialize + for<'de> Deserialize<'de>, + Resp1: Send + 'static + Serialize + for<'de> Deserialize<'de>, + Req2: Send + 'static + Serialize + for<'de> Deserialize<'de>, + Resp2: Send + 'static + Serialize + for<'de> Deserialize<'de>, + { + let (server, server_ret) = tarpc::transport::channel::unbounded(); + let (client, client_ret) = tarpc::transport::channel::unbounded(); + let (mut server_sink, server_stream) = server.split(); + let (mut client_sink, client_stream) = client.split(); + let (transport_sink, mut transport_stream) = transport.split(); + + let (abort_handle, abort_registration) = AbortHandle::new_pair(); + + // Task for inbound message handling + tokio::spawn(async move { + let e: Result<(), ChannelOrIoError> = async move { + while let Some(msg) = transport_stream.next().await { + match msg? { + TwoWayMessage::Request(req) => server_sink.send(req).await?, + TwoWayMessage::Response(resp) => client_sink.send(resp).await?, + } + } + Ok(()) + } + .await; + + match e { + Ok(()) => debug!("transport_stream done"), + Err(e) => warn!("Error in inbound multiplexing: {}", e), + } + + abort_handle.abort(); + }); + + let abortable_sink_channel = Abortable::new( + futures::stream::select( + server_stream.map_ok(TwoWayMessage::Response), + client_stream.map_ok(TwoWayMessage::Request), + ) + .map_err(ChannelOrIoError::ChannelError), + abort_registration, + ); + + // Task for outbound message handling + tokio::spawn( + abortable_sink_channel + .forward(transport_sink.sink_map_err(ChannelOrIoError::IoError)) + .inspect_ok(|_| debug!("transport_sink done")) + .inspect_err(|e| warn!("Error in outbound multiplexing: {}", e)), + ); + + (server_ret, client_ret) + } } diff --git a/gdriver-common/src/config.rs b/gdriver-common/src/config.rs index f7fd590..450bc1b 100644 --- a/gdriver-common/src/config.rs +++ b/gdriver-common/src/config.rs @@ -1,14 +1,10 @@ use super::*; use crate::prelude::*; use confique::{Config, Layer}; -use std::net::{IpAddr, Ipv6Addr}; -const IP_DEFAULT: IpAddr = IpAddr::V6(Ipv6Addr::LOCALHOST); #[derive(Debug, Serialize, Deserialize, Config, Clone)] pub struct Configuration { - #[config(default = 33333)] - pub port: u16, - // #[config(default = Test)] - pub ip: std::net::IpAddr, + #[config(default = "\0gdriver2_v3.sock")] + pub socket_path: String, } pub fn load_config() -> Result { Ok(add_default_locations(Config::builder()).load()?) @@ -19,12 +15,7 @@ pub fn load_config_with_path(path: &Path) -> Result { fn add_default_locations( builder: confique::Builder, ) -> confique::Builder { - type P = ::Layer; - let prebuilt = P { - ip: Some(IP_DEFAULT), - ..P::empty() - }; - builder.env().file("config.toml").preloaded(prebuilt) + builder.env().file("config.toml") } use lazy_static::lazy_static; diff --git a/gdriver-common/src/ipc/gdriver_service.rs b/gdriver-common/src/ipc/gdriver_service.rs index 7bae237..e079cf5 100644 --- a/gdriver-common/src/ipc/gdriver_service.rs +++ b/gdriver-common/src/ipc/gdriver_service.rs @@ -29,8 +29,29 @@ pub trait GDriverService { /// Returns true if the file was had remote changes and was updated async fn update_changes_for_file(id: DriveId) -> StdResult; async fn update_changes() -> StdResult<(), UpdateChangesError>; - async fn do_something2(req: BackendActionRequest) -> StdResult; + async fn do_something2(req: BackendActionRequest) -> StdResult, BackendActionError>; } + +#[tarpc::service] +pub trait GDriverClient { + async fn report_task_result(id: TaskId, result: TaskResult); +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct TaskId(pub String); + +#[derive(Debug, Serialize, Deserialize)] +pub enum AsyncResponse { + Immediate(T), + Pending(TaskId), +} + +#[derive(Debug, Serialize, Deserialize)] +pub enum TaskResult { + Success(String), + Error(String), +} + #[derive(Debug, Serialize, Deserialize)] pub enum BackendActionRequest { ShutdownGracefully,