mirror of
https://github.com/OMGeeky/tarpc.git
synced 2025-12-26 17:02:32 +01:00
Merge branch 'master' of ssh://git.adam-wright.net:10022/shaladdle/adamrpc-rs
This commit is contained in:
261
src/lib.rs
261
src/lib.rs
@@ -11,9 +11,9 @@ use std::io::{self, Read};
|
||||
use std::convert;
|
||||
use std::collections::HashMap;
|
||||
use std::net::{
|
||||
self,
|
||||
TcpListener,
|
||||
TcpStream,
|
||||
SocketAddr,
|
||||
};
|
||||
use std::sync::{
|
||||
self,
|
||||
@@ -22,13 +22,13 @@ use std::sync::{
|
||||
};
|
||||
use std::sync::mpsc::{
|
||||
channel,
|
||||
sync_channel,
|
||||
Sender,
|
||||
SyncSender,
|
||||
Receiver,
|
||||
TryRecvError,
|
||||
};
|
||||
use std::thread::{
|
||||
self,
|
||||
JoinHandle,
|
||||
};
|
||||
use std::time;
|
||||
use std::thread;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum Error {
|
||||
@@ -62,102 +62,132 @@ impl<T> convert::From<sync::mpsc::SendError<T>> for Error {
|
||||
|
||||
pub type Result<T> = std::result::Result<T, Error>;
|
||||
|
||||
pub fn handle_conn<F, Request, Reply>(mut stream: TcpStream, f: Arc<F>) -> Result<()>
|
||||
where Request: fmt::Debug + serde::de::Deserialize,
|
||||
Reply: fmt::Debug + serde::ser::Serialize,
|
||||
F: Serve<Request, Reply>
|
||||
pub fn handle_conn<F, Request, Reply>(stream: TcpStream, f: F) -> Result<()>
|
||||
where Request: 'static + fmt::Debug + Send + serde::de::Deserialize + serde::ser::Serialize,
|
||||
Reply: 'static + fmt::Debug + serde::ser::Serialize,
|
||||
F: 'static + Clone + Serve<Request, Reply>
|
||||
{
|
||||
let read_stream = try!(stream.try_clone());
|
||||
let mut de = serde_json::Deserializer::new(read_stream.bytes());
|
||||
let stream = Arc::new(Mutex::new(stream));
|
||||
loop {
|
||||
println!("read");
|
||||
let request_packet: Packet<Request> = try!(Packet::deserialize(&mut de));
|
||||
match request_packet {
|
||||
Packet::Shutdown => break,
|
||||
Packet::Shutdown => {
|
||||
let stream = stream.clone();
|
||||
let mut my_stream = stream.lock().unwrap();
|
||||
try!(serde_json::to_writer(&mut *my_stream, &request_packet));
|
||||
break;
|
||||
},
|
||||
Packet::Message(id, message) => {
|
||||
let reply = try!(f.serve(&message));
|
||||
let reply_packet = Packet::Message(id, reply);
|
||||
println!("write");
|
||||
try!(serde_json::to_writer(&mut stream, &reply_packet));
|
||||
let f = f.clone();
|
||||
let arc_stream = stream.clone();
|
||||
thread::spawn(move || {
|
||||
let reply = f.serve(&message).unwrap();
|
||||
let reply_packet = Packet::Message(id, reply);
|
||||
let mut my_stream = arc_stream.lock().unwrap();
|
||||
serde_json::to_writer(&mut *my_stream, &reply_packet).unwrap();
|
||||
});
|
||||
},
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn serve<F, Request, Reply>(listener: TcpListener, f: Arc<F>) -> Error
|
||||
where Request: fmt::Debug + serde::de::Deserialize,
|
||||
Reply: fmt::Debug + serde::ser::Serialize,
|
||||
F: 'static + Serve<Request, Reply>,
|
||||
{
|
||||
for conn in listener.incoming() {
|
||||
let conn = match conn {
|
||||
Err(err) => return convert::From::from(err),
|
||||
Ok(c) => c,
|
||||
};
|
||||
let f = f.clone();
|
||||
thread::spawn(move || {
|
||||
if let Err(err) = handle_conn(conn, f) {
|
||||
println!("error handling connection: {:?}", err);
|
||||
}
|
||||
});
|
||||
|
||||
pub struct Shutdown {
|
||||
tx: Sender<()>,
|
||||
join_handle: JoinHandle<()>,
|
||||
addr: SocketAddr,
|
||||
}
|
||||
|
||||
|
||||
impl Shutdown {
|
||||
pub fn wait(self) {
|
||||
self.join_handle.join().unwrap();
|
||||
}
|
||||
Error::Impossible
|
||||
|
||||
pub fn shutdown(self) {
|
||||
self.tx.send(()).expect(&line!().to_string());
|
||||
TcpStream::connect(&self.addr).unwrap();
|
||||
self.join_handle.join().expect(&line!().to_string());
|
||||
}
|
||||
}
|
||||
|
||||
pub fn serve_async<F, Request, Reply>(addr: &SocketAddr, f: F) -> io::Result<Shutdown>
|
||||
where Request: 'static + fmt::Debug + Send + serde::de::Deserialize + fmt::Debug + serde::ser::Serialize,
|
||||
Reply: 'static + fmt::Debug + serde::ser::Serialize,
|
||||
F: 'static + Clone + Serve<Request, Reply>,
|
||||
{
|
||||
let listener = try!(TcpListener::bind(addr));
|
||||
let (die_tx, die_rx) = channel();
|
||||
let join_handle = thread::spawn(move || {
|
||||
for conn in listener.incoming() {
|
||||
match die_rx.try_recv() {
|
||||
Ok(_) => break,
|
||||
Err(TryRecvError::Disconnected) => {
|
||||
println!("serve: sender disconnected ");
|
||||
break;
|
||||
},
|
||||
_ => (),
|
||||
}
|
||||
let conn = match conn {
|
||||
Err(err) => {
|
||||
println!("I couldn't unwrap the connection :( {:?}", err);
|
||||
return;
|
||||
},
|
||||
Ok(c) => c,
|
||||
};
|
||||
let f = f.clone();
|
||||
thread::spawn(move || {
|
||||
if let Err(err) = handle_conn(conn, f) {
|
||||
println!("error handling connection: {:?}", err);
|
||||
}
|
||||
});
|
||||
}
|
||||
});
|
||||
Ok(Shutdown{
|
||||
tx: die_tx,
|
||||
join_handle: join_handle,
|
||||
addr: addr.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
pub trait Serve<Request, Reply>: Send + Sync {
|
||||
fn serve(&self, request: &Request) -> io::Result<Reply>;
|
||||
}
|
||||
|
||||
impl<Request, Reply, S> Serve<Request, Reply> for Arc<S>
|
||||
where S: Serve<Request, Reply>
|
||||
{
|
||||
fn serve(&self, request: &Request) -> io::Result<Reply> {
|
||||
S::serve(self, request)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
enum Packet<T> {
|
||||
Message(u64, T),
|
||||
Shutdown,
|
||||
}
|
||||
|
||||
struct Handle<T> {
|
||||
id: u64,
|
||||
sender: Sender<T>,
|
||||
}
|
||||
|
||||
enum ReceiverMessage<Reply> {
|
||||
Handle(Handle<Reply>),
|
||||
Packet(Packet<Reply>),
|
||||
Shutdown,
|
||||
}
|
||||
|
||||
fn receiver<Reply>(messages: Receiver<ReceiverMessage<Reply>>) -> Result<()> {
|
||||
let mut ready_handles: HashMap<u64, Handle<Reply>> = HashMap::new();
|
||||
for message in messages.into_iter() {
|
||||
match message {
|
||||
ReceiverMessage::Handle(handle) => {
|
||||
ready_handles.insert(handle.id, handle);
|
||||
},
|
||||
ReceiverMessage::Packet(Packet::Shutdown) => break,
|
||||
ReceiverMessage::Packet(Packet::Message(id, message)) => {
|
||||
let handle = ready_handles.remove(&id).unwrap();
|
||||
try!(handle.sender.send(message));
|
||||
}
|
||||
ReceiverMessage::Shutdown => break,
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn reader<Reply>(stream: TcpStream, tx: SyncSender<ReceiverMessage<Reply>>)
|
||||
fn reader<Reply>(
|
||||
stream: TcpStream,
|
||||
requests: Arc<Mutex<HashMap<u64, Sender<Reply>>>>)
|
||||
where Reply: serde::Deserialize
|
||||
{
|
||||
use serde_json::Error::SyntaxError;
|
||||
use serde_json::ErrorCode::EOFWhileParsingValue;
|
||||
let mut de = serde_json::Deserializer::new(stream.bytes());
|
||||
loop {
|
||||
match Packet::deserialize(&mut de) {
|
||||
Ok(packet) =>{
|
||||
println!("send!");
|
||||
tx.send(ReceiverMessage::Packet(packet)).unwrap();
|
||||
Ok(Packet::Message(id, reply)) => {
|
||||
let mut requests = requests.lock().unwrap();
|
||||
let reply_tx = requests.remove(&id).unwrap();
|
||||
reply_tx.send(reply).unwrap();
|
||||
},
|
||||
Ok(Packet::Shutdown) => {
|
||||
break;
|
||||
}
|
||||
// TODO: This shutdown logic is janky.. What's the right way to do this?
|
||||
Err(SyntaxError(EOFWhileParsingValue, _, _)) => break,
|
||||
Err(err) => panic!("unexpected error while parsing!: {:?}", err),
|
||||
}
|
||||
}
|
||||
@@ -169,14 +199,14 @@ fn increment(cur_id: &mut u64) -> u64 {
|
||||
id
|
||||
}
|
||||
|
||||
struct SyncedClientState<Reply> {
|
||||
struct SyncedClientState {
|
||||
next_id: u64,
|
||||
stream: TcpStream,
|
||||
handles_tx: SyncSender<ReceiverMessage<Reply>>,
|
||||
}
|
||||
|
||||
pub struct Client<Reply> {
|
||||
synced_state: Mutex<SyncedClientState<Reply>>,
|
||||
synced_state: Mutex<SyncedClientState>,
|
||||
requests: Arc<Mutex<HashMap<u64, Sender<Reply>>>>,
|
||||
reader_guard: thread::JoinHandle<()>,
|
||||
}
|
||||
|
||||
@@ -184,19 +214,18 @@ impl<Reply> Client<Reply>
|
||||
where Reply: serde::de::Deserialize + Send + 'static
|
||||
{
|
||||
pub fn new(stream: TcpStream) -> Result<Self> {
|
||||
let (handles_tx, receiver_rx) = sync_channel(0);
|
||||
let read_stream = try!(stream.try_clone());
|
||||
try!(read_stream.set_read_timeout(Some(time::Duration::from_millis(50))));
|
||||
let reader_handles_tx = handles_tx.clone();
|
||||
let guard = thread::spawn(move || reader(read_stream, reader_handles_tx));
|
||||
thread::spawn(move || receiver(receiver_rx));
|
||||
let requests = Arc::new(Mutex::new(HashMap::new()));
|
||||
let reader_stream = try!(stream.try_clone());
|
||||
let reader_requests = requests.clone();
|
||||
let reader_guard =
|
||||
thread::spawn(move || reader(reader_stream, reader_requests));
|
||||
Ok(Client{
|
||||
synced_state: Mutex::new(SyncedClientState{
|
||||
next_id: 0,
|
||||
stream: stream,
|
||||
handles_tx: handles_tx,
|
||||
}),
|
||||
reader_guard: guard,
|
||||
requests: requests,
|
||||
reader_guard: reader_guard,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -206,20 +235,22 @@ impl<Reply> Client<Reply>
|
||||
let (tx, rx) = channel();
|
||||
let mut state = self.synced_state.lock().unwrap();
|
||||
let id = increment(&mut state.next_id);
|
||||
try!(state.handles_tx.send(ReceiverMessage::Handle(Handle{
|
||||
id: id,
|
||||
sender: tx,
|
||||
})));
|
||||
{
|
||||
let mut requests = self.requests.lock().unwrap();
|
||||
requests.insert(id, tx);
|
||||
}
|
||||
let packet = Packet::Message(id, request.clone());
|
||||
try!(serde_json::to_writer(&mut state.stream, &packet));
|
||||
drop(state);
|
||||
Ok(rx.recv().unwrap())
|
||||
}
|
||||
|
||||
pub fn join<Request: serde::Serialize>(self) -> Result<()> {
|
||||
let mut state = self.synced_state.lock().unwrap();
|
||||
let packet: Packet<Request> = Packet::Shutdown;
|
||||
try!(serde_json::to_writer(&mut state.stream, &packet));
|
||||
try!(state.stream.shutdown(net::Shutdown::Both));
|
||||
pub fn disconnect<Request: serde::Serialize>(self) -> Result<()> {
|
||||
{
|
||||
let mut state = self.synced_state.lock().unwrap();
|
||||
let packet: Packet<Request> = Packet::Shutdown;
|
||||
try!(serde_json::to_writer(&mut state.stream, &packet));
|
||||
}
|
||||
self.reader_guard.join().unwrap();
|
||||
Ok(())
|
||||
}
|
||||
@@ -227,22 +258,23 @@ impl<Reply> Client<Reply>
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use serde;
|
||||
use super::*;
|
||||
use std::fmt;
|
||||
use std::io;
|
||||
use std::net::{TcpStream, TcpListener, SocketAddr};
|
||||
use std::net::{TcpStream, TcpListener, SocketAddr, ToSocketAddrs};
|
||||
use std::str::FromStr;
|
||||
use std::sync::{Arc, Mutex, Barrier};
|
||||
use std::sync::mpsc::channel;
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
use std::thread;
|
||||
|
||||
const port: AtomicUsize = AtomicUsize::new(10000);
|
||||
|
||||
fn pair() -> (TcpStream, TcpListener) {
|
||||
fn next_addr() -> SocketAddr {
|
||||
let addr = format!("127.0.0.1:{}", port.fetch_add(1, Ordering::SeqCst));
|
||||
println!("what the fuck {}", &addr);
|
||||
// Do this one first so that we don't get connection refused :)
|
||||
let listener = TcpListener::bind(&*addr).unwrap();
|
||||
(TcpStream::connect(&*addr).unwrap(), listener)
|
||||
addr.to_socket_addrs().unwrap().next().unwrap()
|
||||
//ToSocketAddrs::to_socket_addrs(addr.as_ref()).unwrap().next().unwrap()
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Serialize, Deserialize, Clone)]
|
||||
@@ -278,19 +310,39 @@ mod test {
|
||||
}
|
||||
}
|
||||
|
||||
fn wtf<F, Request, Reply>(server: F) -> (SocketAddr, Shutdown)
|
||||
where Request: 'static + fmt::Debug + Send + serde::de::Deserialize + fmt::Debug + serde::ser::Serialize,
|
||||
Reply: 'static + fmt::Debug + Send + serde::ser::Serialize,
|
||||
F: 'static + Clone + Serve<Request, Reply>
|
||||
{
|
||||
let mut addr;
|
||||
let mut shutdown;
|
||||
while let &Err(_) = {shutdown = serve_async({addr = next_addr(); &addr}, server.clone()); &shutdown} { }
|
||||
(addr, shutdown.unwrap())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_handle() {
|
||||
let server = Arc::new(Server::new());
|
||||
let (addr, shutdown) = wtf(server.clone());
|
||||
let client_stream = TcpStream::connect(&addr).unwrap();
|
||||
let client: Client<Reply> = Client::new(client_stream).expect(&line!().to_string());
|
||||
client.disconnect::<Request>();
|
||||
shutdown.shutdown();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test() {
|
||||
let (client_stream, server_streams) = pair();
|
||||
let server = Arc::new(Server::new());
|
||||
let thread_server = server.clone();
|
||||
let guard = thread::spawn(move || serve(server_streams, thread_server));
|
||||
let (addr, shutdown) = wtf(server.clone());
|
||||
let client_stream = TcpStream::connect(&addr).unwrap();
|
||||
let client = Client::new(client_stream).unwrap();
|
||||
assert_eq!(Reply::Increment(0), client.rpc(&Request::Increment).unwrap());
|
||||
assert_eq!(1, server.count());
|
||||
assert_eq!(Reply::Increment(1), client.rpc(&Request::Increment).unwrap());
|
||||
assert_eq!(2, server.count());
|
||||
client.join::<Request>().unwrap();
|
||||
guard.join();
|
||||
client.disconnect::<Request>().unwrap();
|
||||
shutdown.shutdown();
|
||||
}
|
||||
|
||||
struct BarrierServer {
|
||||
@@ -318,10 +370,9 @@ mod test {
|
||||
|
||||
#[test]
|
||||
fn test_concurrent() {
|
||||
let (client_stream, server_streams) = pair();
|
||||
let server = Arc::new(BarrierServer::new(10));
|
||||
let thread_server = server.clone();
|
||||
let guard = thread::spawn(move || serve(server_streams, thread_server));
|
||||
let (addr, shutdown) = wtf(server.clone());
|
||||
let client_stream = TcpStream::connect(&addr).unwrap();
|
||||
let client: Arc<Client<Reply>> = Arc::new(Client::new(client_stream).unwrap());
|
||||
let mut join_handles = vec![];
|
||||
for _ in 0..10 {
|
||||
@@ -336,7 +387,7 @@ mod test {
|
||||
Err(_) => panic!("couldn't unwrap arc"),
|
||||
Ok(c) => c,
|
||||
};
|
||||
client.join::<Request>().unwrap();
|
||||
guard.join();
|
||||
client.disconnect::<Request>().unwrap();
|
||||
shutdown.shutdown();
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user