new communication structure

This commit is contained in:
2026-01-26 21:10:26 +01:00
parent 2b2f8622bf
commit 09c314463b
2 changed files with 74 additions and 33 deletions

View File

@@ -9,6 +9,7 @@ use std::{path::PathBuf, sync::Arc, thread};
use tarpc::context::Context; use tarpc::context::Context;
use tokio::sync::Mutex; use tokio::sync::Mutex;
use tracing::instrument; use tracing::instrument;
use futures::prelude::*;
use super::*; use super::*;
use crate::drive::Drive; use crate::drive::Drive;
@@ -200,8 +201,15 @@ impl GDriverService for GdriverServer {
} }
BackendActionRequest::RunLong => { BackendActionRequest::RunLong => {
println!("RunLong request received"); println!("RunLong request received");
long_running_task(&self.drive).await; let drive = self.drive.clone();
Ok(AsyncResponse::Immediate(String::from("OK"))) let client = self.client.clone();
let task_id = TaskId(uuid::Uuid::new_v4().to_string());
let task_id_clone = task_id.clone();
tokio::spawn(async move {
long_running_task(&drive).await;
let _ = client.report_task_result(tarpc::context::current(), task_id_clone, TaskResult::Success("OK".to_string())).await;
});
Ok(AsyncResponse::Pending(task_id))
} }
BackendActionRequest::StartLong => { BackendActionRequest::StartLong => {
println!("StartLong request received"); println!("StartLong request received");
@@ -211,13 +219,7 @@ impl GDriverService for GdriverServer {
let task_id_clone = task_id.clone(); let task_id_clone = task_id.clone();
tokio::spawn(async move { tokio::spawn(async move {
long_running_task(&drive).await; long_running_task(&drive).await;
let _ = client let _ = client.report_task_result(tarpc::context::current(), task_id_clone, TaskResult::Success("OK".to_string())).await;
.report_task_result(
tarpc::context::current(),
task_id_clone,
TaskResult::Success("OK".to_string()),
)
.await;
}); });
Ok(AsyncResponse::Pending(task_id)) Ok(AsyncResponse::Pending(task_id))
} }
@@ -251,21 +253,25 @@ pub async fn start() -> Result<()> {
listener listener
// Ignore accept errors. // Ignore accept errors.
.filter_map(|r| future::ready(r.ok())) .filter_map(|r| future::ready(r.ok()))
.map(|transport| { .for_each_concurrent(10, |transport| {
let peer_addr = transport.peer_addr().unwrap(); let drive = drive.clone();
let (server_channel, client_channel) = spawn_twoway(transport); async move {
let client = let peer_addr = transport.peer_addr().unwrap();
GDriverClientClient::new(tarpc::client::Config::default(), client_channel).spawn(); let (server_channel, client_channel) = spawn_twoway(transport);
let server = GdriverServer { let client = GDriverClientClient::new(tarpc::client::Config::default(), client_channel).spawn();
socket_address: peer_addr, let server = GdriverServer {
drive: drive.clone(), socket_address: peer_addr,
client, drive,
}; client,
BaseChannel::with_defaults(server_channel).execute(server.serve()) };
BaseChannel::with_defaults(server_channel)
.execute(server.serve())
.for_each(|f| async move {
tokio::spawn(f);
})
.await
}
}) })
// Max 10 channels.
// .buffer_unordered(10)
.for_each(|_| async {})
.await; .await;
Ok(()) Ok(())
} }
@@ -273,11 +279,11 @@ pub async fn start() -> Result<()> {
mod sample_sym { mod sample_sym {
use futures::future::{AbortHandle, Abortable}; use futures::future::{AbortHandle, Abortable};
use futures::{Sink, TryFutureExt}; use futures::{Sink, TryFutureExt};
use futures::{SinkExt, Stream, StreamExt, TryStreamExt}; use futures::{ SinkExt, Stream, StreamExt, TryStreamExt};
use serde::{Deserialize, Serialize};
use std::io; use std::io;
use tarpc::transport::channel::{ChannelError, UnboundedChannel}; use tarpc::transport::channel::{ChannelError, UnboundedChannel};
use tracing::{debug, warn}; use tracing::{debug, warn};
use serde::{Serialize, Deserialize};
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
pub enum TwoWayMessage<Req, Resp> { pub enum TwoWayMessage<Req, Resp> {

View File

@@ -21,8 +21,11 @@ pub(crate) struct ClientHandler {
impl GDriverClient for ClientHandler { impl GDriverClient for ClientHandler {
async fn report_task_result(self, _: Context, id: TaskId, result: TaskResult) { async fn report_task_result(self, _: Context, id: TaskId, result: TaskResult) {
println!("Received task result for task {}: {:?}", id.0, result); println!("Received task result for task {}: {:?}", id.0, result);
let mut tasks = self.pending_tasks.lock().unwrap(); let sender = {
if let Some(sender) = tasks.remove(&id.0) { let mut tasks = self.pending_tasks.lock().unwrap();
tasks.remove(&id.0)
};
if let Some(sender) = sender {
let _ = sender.send(result).await; let _ = sender.send(result).await;
} }
} }
@@ -52,6 +55,8 @@ pub(crate) async fn start_with_client(
} }
} }
let start = time::SystemTime::now(); let start = time::SystemTime::now();
// Use AsyncResponse for RunLong as well
let hello = client let hello = client
.do_something2(tarpc::context::current(), BackendActionRequest::RunLong) .do_something2(tarpc::context::current(), BackendActionRequest::RunLong)
.await; .await;
@@ -61,9 +66,32 @@ pub(crate) async fn start_with_client(
.as_secs(); .as_secs();
match hello { match hello {
Ok(hello) => println!("Run Long returned after {} seconds: {:?}", seconds, hello), Ok(Ok(AsyncResponse::Pending(task_id))) => {
Err(e) => println!(":( {:?}", (e)), println!("Run Long returned pending task: {:?}", task_id);
let (tx, mut rx) = mpsc::channel(1);
handler
.pending_tasks
.lock()
.unwrap()
.insert(task_id.0.clone(), tx);
if let Some(result) = rx.recv().await {
println!(
"Received async result after {} seconds: {:?}",
(time::SystemTime::now().duration_since(start))
.unwrap()
.as_secs(),
result
);
}
}
Ok(Ok(AsyncResponse::Immediate(res))) => {
println!("Run Long returned immediate: {:?}", res)
}
Ok(Err(e)) => println!("Backend Error: {:?}", e),
Err(e) => println!("RPC Error: {:?}", e),
} }
let start = time::SystemTime::now(); let start = time::SystemTime::now();
let hello = client let hello = client
.do_something2(tarpc::context::current(), BackendActionRequest::StartLong) .do_something2(tarpc::context::current(), BackendActionRequest::StartLong)
@@ -111,9 +139,13 @@ pub async fn create_client(
e e
})?; })?;
let (client_channel, server_channel) = spawn_twoway(transport); // spawn_twoway returns (server_channel, client_channel)
// server_channel: UnboundedChannel<ClientMessage<Req1>, Response<Resp1>> -> For Server
// client_channel: UnboundedChannel<Response<Resp2>, ClientMessage<Req2>> -> For Client
let (server_channel, client_channel) = spawn_twoway(transport);
let client = GDriverServiceClient::new(client::Config::default(), server_channel).spawn(); // GDriverServiceClient needs a channel to send requests (Req2) and receive responses (Resp2)
let client = GDriverServiceClient::new(client::Config::default(), client_channel).spawn();
let handler = ClientHandler { let handler = ClientHandler {
pending_tasks: Arc::new(Mutex::new(HashMap::new())), pending_tasks: Arc::new(Mutex::new(HashMap::new())),
@@ -121,9 +153,12 @@ pub async fn create_client(
let handler_clone = handler.clone(); let handler_clone = handler.clone();
tokio::spawn(async move { tokio::spawn(async move {
BaseChannel::with_defaults(client_channel) // BaseChannel needs a channel to receive requests (Req1) and send responses (Resp1)
BaseChannel::with_defaults(server_channel)
.execute(handler_clone.serve()) .execute(handler_clone.serve())
.for_each(|_| async {}) .for_each(|f| async move {
tokio::spawn(f);
})
.await; .await;
}); });