diff --git a/gdriver-backend/src/service.rs b/gdriver-backend/src/service.rs index 5dae856..ca62803 100644 --- a/gdriver-backend/src/service.rs +++ b/gdriver-backend/src/service.rs @@ -9,6 +9,7 @@ use std::{path::PathBuf, sync::Arc, thread}; use tarpc::context::Context; use tokio::sync::Mutex; use tracing::instrument; +use futures::prelude::*; use super::*; use crate::drive::Drive; @@ -200,8 +201,15 @@ impl GDriverService for GdriverServer { } BackendActionRequest::RunLong => { println!("RunLong request received"); - long_running_task(&self.drive).await; - Ok(AsyncResponse::Immediate(String::from("OK"))) + let drive = self.drive.clone(); + let client = self.client.clone(); + let task_id = TaskId(uuid::Uuid::new_v4().to_string()); + let task_id_clone = task_id.clone(); + tokio::spawn(async move { + long_running_task(&drive).await; + let _ = client.report_task_result(tarpc::context::current(), task_id_clone, TaskResult::Success("OK".to_string())).await; + }); + Ok(AsyncResponse::Pending(task_id)) } BackendActionRequest::StartLong => { println!("StartLong request received"); @@ -211,13 +219,7 @@ impl GDriverService for GdriverServer { let task_id_clone = task_id.clone(); tokio::spawn(async move { long_running_task(&drive).await; - let _ = client - .report_task_result( - tarpc::context::current(), - task_id_clone, - TaskResult::Success("OK".to_string()), - ) - .await; + let _ = client.report_task_result(tarpc::context::current(), task_id_clone, TaskResult::Success("OK".to_string())).await; }); Ok(AsyncResponse::Pending(task_id)) } @@ -251,21 +253,25 @@ pub async fn start() -> Result<()> { listener // Ignore accept errors. .filter_map(|r| future::ready(r.ok())) - .map(|transport| { - let peer_addr = transport.peer_addr().unwrap(); - let (server_channel, client_channel) = spawn_twoway(transport); - let client = - GDriverClientClient::new(tarpc::client::Config::default(), client_channel).spawn(); - let server = GdriverServer { - socket_address: peer_addr, - drive: drive.clone(), - client, - }; - BaseChannel::with_defaults(server_channel).execute(server.serve()) + .for_each_concurrent(10, |transport| { + let drive = drive.clone(); + async move { + let peer_addr = transport.peer_addr().unwrap(); + let (server_channel, client_channel) = spawn_twoway(transport); + let client = GDriverClientClient::new(tarpc::client::Config::default(), client_channel).spawn(); + let server = GdriverServer { + socket_address: peer_addr, + drive, + client, + }; + BaseChannel::with_defaults(server_channel) + .execute(server.serve()) + .for_each(|f| async move { + tokio::spawn(f); + }) + .await + } }) - // Max 10 channels. - // .buffer_unordered(10) - .for_each(|_| async {}) .await; Ok(()) } @@ -273,11 +279,11 @@ pub async fn start() -> Result<()> { mod sample_sym { use futures::future::{AbortHandle, Abortable}; use futures::{Sink, TryFutureExt}; - use futures::{SinkExt, Stream, StreamExt, TryStreamExt}; - use serde::{Deserialize, Serialize}; + use futures::{ SinkExt, Stream, StreamExt, TryStreamExt}; use std::io; use tarpc::transport::channel::{ChannelError, UnboundedChannel}; use tracing::{debug, warn}; + use serde::{Serialize, Deserialize}; #[derive(Debug, Serialize, Deserialize)] pub enum TwoWayMessage { diff --git a/gdriver-client/src/service.rs b/gdriver-client/src/service.rs index 16a16c6..2772a56 100644 --- a/gdriver-client/src/service.rs +++ b/gdriver-client/src/service.rs @@ -21,8 +21,11 @@ pub(crate) struct ClientHandler { impl GDriverClient for ClientHandler { async fn report_task_result(self, _: Context, id: TaskId, result: TaskResult) { println!("Received task result for task {}: {:?}", id.0, result); - let mut tasks = self.pending_tasks.lock().unwrap(); - if let Some(sender) = tasks.remove(&id.0) { + let sender = { + let mut tasks = self.pending_tasks.lock().unwrap(); + tasks.remove(&id.0) + }; + if let Some(sender) = sender { let _ = sender.send(result).await; } } @@ -52,6 +55,8 @@ pub(crate) async fn start_with_client( } } let start = time::SystemTime::now(); + + // Use AsyncResponse for RunLong as well let hello = client .do_something2(tarpc::context::current(), BackendActionRequest::RunLong) .await; @@ -61,9 +66,32 @@ pub(crate) async fn start_with_client( .as_secs(); match hello { - Ok(hello) => println!("Run Long returned after {} seconds: {:?}", seconds, hello), - Err(e) => println!(":( {:?}", (e)), + Ok(Ok(AsyncResponse::Pending(task_id))) => { + println!("Run Long returned pending task: {:?}", task_id); + let (tx, mut rx) = mpsc::channel(1); + handler + .pending_tasks + .lock() + .unwrap() + .insert(task_id.0.clone(), tx); + + if let Some(result) = rx.recv().await { + println!( + "Received async result after {} seconds: {:?}", + (time::SystemTime::now().duration_since(start)) + .unwrap() + .as_secs(), + result + ); + } + } + Ok(Ok(AsyncResponse::Immediate(res))) => { + println!("Run Long returned immediate: {:?}", res) + } + Ok(Err(e)) => println!("Backend Error: {:?}", e), + Err(e) => println!("RPC Error: {:?}", e), } + let start = time::SystemTime::now(); let hello = client .do_something2(tarpc::context::current(), BackendActionRequest::StartLong) @@ -111,9 +139,13 @@ pub async fn create_client( e })?; - let (client_channel, server_channel) = spawn_twoway(transport); + // spawn_twoway returns (server_channel, client_channel) + // server_channel: UnboundedChannel, Response> -> For Server + // client_channel: UnboundedChannel, ClientMessage> -> For Client + let (server_channel, client_channel) = spawn_twoway(transport); - let client = GDriverServiceClient::new(client::Config::default(), server_channel).spawn(); + // GDriverServiceClient needs a channel to send requests (Req2) and receive responses (Resp2) + let client = GDriverServiceClient::new(client::Config::default(), client_channel).spawn(); let handler = ClientHandler { pending_tasks: Arc::new(Mutex::new(HashMap::new())), @@ -121,9 +153,12 @@ pub async fn create_client( let handler_clone = handler.clone(); tokio::spawn(async move { - BaseChannel::with_defaults(client_channel) + // BaseChannel needs a channel to receive requests (Req1) and send responses (Resp1) + BaseChannel::with_defaults(server_channel) .execute(handler_clone.serve()) - .for_each(|_| async {}) + .for_each(|f| async move { + tokio::spawn(f); + }) .await; });