This commit is contained in:
2026-01-26 20:23:39 +01:00
parent 2ed6cd3c06
commit 2b2f8622bf
10 changed files with 374 additions and 63 deletions

View File

@@ -14,8 +14,10 @@ tarpc.workspace = true
futures.workspace = true
chrono.workspace = true
google-drive3.workspace = true
uuid = { version = "1.7.0", features = ["v4", "serde"] }
yup-oauth2 = "12.1.2"
thiserror = "2.0.18"
[dependencies.gdriver-common]
path = "../gdriver-common"

View File

@@ -1,10 +1,10 @@
use futures::{future, prelude::*};
use std::net::SocketAddr;
use tarpc::{
context,
server::{self, incoming::Incoming, Channel},
tokio_serde::formats::Json,
};
use tokio::net::unix::SocketAddr;
mod prelude;
use crate::prelude::*;

View File

@@ -13,8 +13,8 @@ impl World for HelloServer {
pub(super) async fn main() -> Result<()> {
println!("Hello, world!");
let config = &CONFIGURATION;
let server_addr = (config.ip, config.port);
let mut listener = tarpc::serde_transport::tcp::listen(&server_addr, Json::default).await?;
let server_addr = (&config.socket_path);
let mut listener = tarpc::serde_transport::unix::listen(&server_addr, Json::default).await?;
println!("Listening");
listener.config_mut().max_frame_length(usize::MAX);
@@ -23,11 +23,21 @@ pub(super) async fn main() -> Result<()> {
.filter_map(|r| future::ready(r.ok()))
.map(server::BaseChannel::with_defaults)
// Limit channels to 1 per IP.
.max_channels_per_key(1, |t| t.transport().peer_addr().unwrap().ip())
.max_channels_per_key(1, |t| {
t.transport()
.peer_addr()
.unwrap()
.as_pathname()
.unwrap()
.to_str()
.unwrap()
.to_string()
})
// serve is generated by the service attribute. It takes as input any type implementing
// the generated World trait.
.map(|channel| {
let server = HelloServer(channel.transport().peer_addr().unwrap());
let c = channel.transport().peer_addr().unwrap();
let server = HelloServer(c);
channel.execute(server.serve()).for_each(spawn)
})
// Max 10 channels.

View File

@@ -10,14 +10,17 @@ use tarpc::context::Context;
use tokio::sync::Mutex;
use tracing::instrument;
use super::*;
use crate::drive::Drive;
use crate::prelude::*;
use crate::service::sample_sym::spawn_twoway;
use tarpc::server::{BaseChannel, Channel};
use super::*;
#[derive(Clone, Debug)]
struct GdriverServer {
socket_address: SocketAddr,
drive: Arc<Mutex<Drive>>,
client: GDriverClientClient,
}
impl GDriverService for GdriverServer {
async fn is_online(self, _: Context) -> bool {
@@ -168,8 +171,15 @@ impl GDriverService for GdriverServer {
self,
_: ::tarpc::context::Context,
req: BackendActionRequest,
) -> std::result::Result<String, BackendActionError> {
println!("You are connected from {}", self.socket_address);
) -> std::result::Result<AsyncResponse<String>, BackendActionError> {
println!(
"You are connected from {}",
self.socket_address
.as_pathname()
.map(|p| p.to_str())
.flatten()
.unwrap_or("unknown")
);
match req {
BackendActionRequest::ShutdownGracefully => {
@@ -182,21 +192,34 @@ impl GDriverService for GdriverServer {
let drive = &self.drive;
print_sample_tracking_state(drive).await;
Ok(String::from("OK"))
Ok(AsyncResponse::Immediate(String::from("OK")))
}
BackendActionRequest::Ping => {
println!("Ping request received");
Ok(String::from("Pong"))
Ok(AsyncResponse::Immediate(String::from("Pong")))
}
BackendActionRequest::RunLong => {
println!("RunLong request received");
long_running_task(&self.drive).await;
Ok(String::from("OK"))
Ok(AsyncResponse::Immediate(String::from("OK")))
}
BackendActionRequest::StartLong => {
println!("StartLong request received");
tokio::spawn(async move { long_running_task(&self.drive).await });
Ok(String::from("OK"))
let drive = self.drive.clone();
let client = self.client.clone();
let task_id = TaskId(uuid::Uuid::new_v4().to_string());
let task_id_clone = task_id.clone();
tokio::spawn(async move {
long_running_task(&drive).await;
let _ = client
.report_task_result(
tarpc::context::current(),
task_id_clone,
TaskResult::Success("OK".to_string()),
)
.await;
});
Ok(AsyncResponse::Pending(task_id))
}
}
}
@@ -219,30 +242,121 @@ pub async fn start() -> Result<()> {
let drive = Drive::new().await?;
let drive = Arc::new(Mutex::new(drive));
let server_addr = (config.ip, config.port);
let mut listener = tarpc::serde_transport::tcp::listen(&server_addr, Json::default).await?;
let mut listener =
tarpc::serde_transport::unix::listen(&config.socket_path, Json::default).await?;
listener.config_mut().max_frame_length(usize::MAX);
println!("Listening");
listener
// Ignore accept errors.
.filter_map(|r| future::ready(r.ok()))
.map(server::BaseChannel::with_defaults)
// // Limit channels to 1 per IP.
.max_channels_per_key(1, |t| t.transport().peer_addr().unwrap().ip())
// serve is generated by the service attribute. It takes as input any type implementing
// the generated World trait.
.map(|channel| {
let c = channel.transport().peer_addr().unwrap();
.map(|transport| {
let peer_addr = transport.peer_addr().unwrap();
let (server_channel, client_channel) = spawn_twoway(transport);
let client =
GDriverClientClient::new(tarpc::client::Config::default(), client_channel).spawn();
let server = GdriverServer {
socket_address: c,
socket_address: peer_addr,
drive: drive.clone(),
client,
};
channel.execute(server.serve()).for_each(spawn)
BaseChannel::with_defaults(server_channel).execute(server.serve())
})
// Max 10 channels.
.buffer_unordered(10)
// .buffer_unordered(10)
.for_each(|_| async {})
.await;
Ok(())
}
mod sample_sym {
use futures::future::{AbortHandle, Abortable};
use futures::{Sink, TryFutureExt};
use futures::{SinkExt, Stream, StreamExt, TryStreamExt};
use serde::{Deserialize, Serialize};
use std::io;
use tarpc::transport::channel::{ChannelError, UnboundedChannel};
use tracing::{debug, warn};
#[derive(Debug, Serialize, Deserialize)]
pub enum TwoWayMessage<Req, Resp> {
Request(tarpc::ClientMessage<Req>),
Response(tarpc::Response<Resp>),
}
#[derive(thiserror::Error, Debug)]
pub enum ChannelOrIoError {
#[error("{0}")]
ChannelError(#[from] ChannelError),
#[error("{0}")]
IoError(#[from] io::Error),
}
/// Returns two transports that multiplex over the given transport.
/// The first transport can be used by a server: it receives requests and sends back responses.
/// The second transport can be used by a client: it sends requests and receives back responses.
pub fn spawn_twoway<Req1, Resp1, Req2, Resp2, T>(
transport: T,
) -> (
UnboundedChannel<tarpc::ClientMessage<Req1>, tarpc::Response<Resp1>>,
UnboundedChannel<tarpc::Response<Resp2>, tarpc::ClientMessage<Req2>>,
)
where
T: Stream<Item = Result<TwoWayMessage<Req1, Resp2>, io::Error>>,
T: Sink<TwoWayMessage<Req2, Resp1>, Error = io::Error>,
T: Unpin + Send + 'static,
Req1: Send + 'static + Serialize + for<'de> Deserialize<'de>,
Resp1: Send + 'static + Serialize + for<'de> Deserialize<'de>,
Req2: Send + 'static + Serialize + for<'de> Deserialize<'de>,
Resp2: Send + 'static + Serialize + for<'de> Deserialize<'de>,
{
let (server, server_ret) = tarpc::transport::channel::unbounded();
let (client, client_ret) = tarpc::transport::channel::unbounded();
let (mut server_sink, server_stream) = server.split();
let (mut client_sink, client_stream) = client.split();
let (transport_sink, mut transport_stream) = transport.split();
let (abort_handle, abort_registration) = AbortHandle::new_pair();
// Task for inbound message handling
tokio::spawn(async move {
let e: Result<(), ChannelOrIoError> = async move {
while let Some(msg) = transport_stream.next().await {
match msg? {
TwoWayMessage::Request(req) => server_sink.send(req).await?,
TwoWayMessage::Response(resp) => client_sink.send(resp).await?,
}
}
Ok(())
}
.await;
match e {
Ok(()) => debug!("transport_stream done"),
Err(e) => warn!("Error in inbound multiplexing: {}", e),
}
abort_handle.abort();
});
let abortable_sink_channel = Abortable::new(
futures::stream::select(
server_stream.map_ok(TwoWayMessage::Response),
client_stream.map_ok(TwoWayMessage::Request),
)
.map_err(ChannelOrIoError::ChannelError),
abort_registration,
);
// Task for outbound message handling
tokio::spawn(
abortable_sink_channel
.forward(transport_sink.sink_map_err(ChannelOrIoError::IoError))
.inspect_ok(|_| debug!("transport_sink done"))
.inspect_err(|e| warn!("Error in outbound multiplexing: {}", e)),
);
(server_ret, client_ret)
}
}