mirror of
https://github.com/OMGeeky/tarpc.git
synced 2026-01-01 09:03:48 +01:00
LOL Shutdown works LOL
This commit is contained in:
97
src/lib.rs
97
src/lib.rs
@@ -13,6 +13,7 @@ use std::collections::HashMap;
|
||||
use std::net::{
|
||||
TcpListener,
|
||||
TcpStream,
|
||||
SocketAddr,
|
||||
};
|
||||
use std::sync::{
|
||||
self,
|
||||
@@ -22,8 +23,13 @@ use std::sync::{
|
||||
use std::sync::mpsc::{
|
||||
channel,
|
||||
Sender,
|
||||
Receiver,
|
||||
TryRecvError,
|
||||
};
|
||||
use std::thread::{
|
||||
self,
|
||||
JoinHandle,
|
||||
};
|
||||
use std::thread;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum Error {
|
||||
@@ -83,24 +89,63 @@ pub fn handle_conn<F, Request, Reply>(mut stream: TcpStream, f: Arc<F>) -> Resul
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn serve<F, Request, Reply>(listener: TcpListener, f: Arc<F>) -> Error
|
||||
|
||||
pub struct Shutdown {
|
||||
tx: Sender<()>,
|
||||
join_handle: JoinHandle<()>,
|
||||
addr: SocketAddr,
|
||||
}
|
||||
|
||||
|
||||
impl Shutdown {
|
||||
fn wait(self) {
|
||||
self.join_handle.join().unwrap();
|
||||
}
|
||||
|
||||
fn shutdown(self) {
|
||||
self.tx.send(()).expect(&line!().to_string());
|
||||
TcpStream::connect(&self.addr).unwrap();
|
||||
self.join_handle.join().expect(&line!().to_string());
|
||||
}
|
||||
}
|
||||
|
||||
pub fn serve_async<F, Request, Reply>(addr: &SocketAddr, f: Arc<F>) -> io::Result<Shutdown>
|
||||
where Request: fmt::Debug + serde::de::Deserialize + fmt::Debug + serde::ser::Serialize,
|
||||
Reply: fmt::Debug + serde::ser::Serialize,
|
||||
F: 'static + Serve<Request, Reply>,
|
||||
{
|
||||
for conn in listener.incoming() {
|
||||
let conn = match conn {
|
||||
Err(err) => return convert::From::from(err),
|
||||
Ok(c) => c,
|
||||
};
|
||||
let f = f.clone();
|
||||
thread::spawn(move || {
|
||||
if let Err(err) = handle_conn(conn, f) {
|
||||
println!("error handling connection: {:?}", err);
|
||||
let listener = try!(TcpListener::bind(addr));
|
||||
let (die_tx, die_rx) = channel();
|
||||
let join_handle = thread::spawn(move || {
|
||||
for conn in listener.incoming() {
|
||||
match die_rx.try_recv() {
|
||||
Ok(_) => break,
|
||||
Err(TryRecvError::Disconnected) => {
|
||||
println!("serve: sender disconnected ");
|
||||
break;
|
||||
},
|
||||
_ => (),
|
||||
}
|
||||
});
|
||||
}
|
||||
Error::Impossible
|
||||
let conn = match conn {
|
||||
Err(err) => {
|
||||
println!("I couldn't unwrap the connection :( {:?}", err);
|
||||
return;
|
||||
},
|
||||
Ok(c) => c,
|
||||
};
|
||||
let f = f.clone();
|
||||
thread::spawn(move || {
|
||||
if let Err(err) = handle_conn(conn, f) {
|
||||
println!("error handling connection: {:?}", err);
|
||||
}
|
||||
});
|
||||
}
|
||||
});
|
||||
Ok(Shutdown{
|
||||
tx: die_tx,
|
||||
join_handle: join_handle,
|
||||
addr: addr.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
pub trait Serve<Request, Reply>: Send + Sync {
|
||||
@@ -187,7 +232,7 @@ impl<Reply> Client<Reply>
|
||||
Ok(rx.recv().unwrap())
|
||||
}
|
||||
|
||||
pub fn join<Request: serde::Serialize>(self) -> Result<()> {
|
||||
pub fn disconnect<Request: serde::Serialize>(self) -> Result<()> {
|
||||
{
|
||||
let mut state = self.synced_state.lock().unwrap();
|
||||
let packet: Packet<Request> = Packet::Shutdown;
|
||||
@@ -202,14 +247,21 @@ impl<Reply> Client<Reply>
|
||||
mod test {
|
||||
use super::*;
|
||||
use std::io;
|
||||
use std::net::{TcpStream, TcpListener, SocketAddr};
|
||||
use std::net::{TcpStream, TcpListener, SocketAddr, ToSocketAddrs};
|
||||
use std::str::FromStr;
|
||||
use std::sync::{Arc, Mutex, Barrier};
|
||||
use std::sync::mpsc::channel;
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
use std::thread;
|
||||
|
||||
const port: AtomicUsize = AtomicUsize::new(10000);
|
||||
|
||||
fn next_addr() -> SocketAddr {
|
||||
let addr = format!("127.0.0.1:{}", port.fetch_add(1, Ordering::SeqCst));
|
||||
addr.to_socket_addrs().unwrap().next().unwrap()
|
||||
//ToSocketAddrs::to_socket_addrs(addr.as_ref()).unwrap().next().unwrap()
|
||||
}
|
||||
|
||||
fn pair() -> (TcpStream, TcpListener) {
|
||||
let addr = format!("127.0.0.1:{}", port.fetch_add(1, Ordering::SeqCst));
|
||||
println!("what the fuck {}", &addr);
|
||||
@@ -251,6 +303,18 @@ mod test {
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_handle() {
|
||||
let addr = next_addr();
|
||||
let server = Arc::new(Server::new());
|
||||
let srv_shutdown = serve_async(&addr, server).unwrap();
|
||||
let client_stream = TcpStream::connect(&addr).unwrap();
|
||||
let client: Client<Reply> = Client::new(client_stream).expect(&line!().to_string());
|
||||
client.disconnect::<Request>();
|
||||
srv_shutdown.shutdown();
|
||||
}
|
||||
|
||||
/*
|
||||
#[test]
|
||||
fn test() {
|
||||
let (client_stream, server_streams) = pair();
|
||||
@@ -268,7 +332,6 @@ mod test {
|
||||
guard.join();
|
||||
}
|
||||
|
||||
/*
|
||||
struct BarrierServer {
|
||||
barrier: Barrier,
|
||||
inner: Server,
|
||||
|
||||
Reference in New Issue
Block a user