diff --git a/src/lib.rs b/src/lib.rs index 0571d95..496ab40 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -11,9 +11,9 @@ use std::io::{self, Read}; use std::convert; use std::collections::HashMap; use std::net::{ - self, TcpListener, TcpStream, + SocketAddr, }; use std::sync::{ self, @@ -22,13 +22,13 @@ use std::sync::{ }; use std::sync::mpsc::{ channel, - sync_channel, Sender, - SyncSender, - Receiver, + TryRecvError, +}; +use std::thread::{ + self, + JoinHandle, }; -use std::time; -use std::thread; #[derive(Debug)] pub enum Error { @@ -62,102 +62,132 @@ impl convert::From> for Error { pub type Result = std::result::Result; -pub fn handle_conn(mut stream: TcpStream, f: Arc) -> Result<()> - where Request: fmt::Debug + serde::de::Deserialize, - Reply: fmt::Debug + serde::ser::Serialize, - F: Serve +pub fn handle_conn(stream: TcpStream, f: F) -> Result<()> + where Request: 'static + fmt::Debug + Send + serde::de::Deserialize + serde::ser::Serialize, + Reply: 'static + fmt::Debug + serde::ser::Serialize, + F: 'static + Clone + Serve { let read_stream = try!(stream.try_clone()); let mut de = serde_json::Deserializer::new(read_stream.bytes()); + let stream = Arc::new(Mutex::new(stream)); loop { - println!("read"); let request_packet: Packet = try!(Packet::deserialize(&mut de)); match request_packet { - Packet::Shutdown => break, + Packet::Shutdown => { + let stream = stream.clone(); + let mut my_stream = stream.lock().unwrap(); + try!(serde_json::to_writer(&mut *my_stream, &request_packet)); + break; + }, Packet::Message(id, message) => { - let reply = try!(f.serve(&message)); - let reply_packet = Packet::Message(id, reply); - println!("write"); - try!(serde_json::to_writer(&mut stream, &reply_packet)); + let f = f.clone(); + let arc_stream = stream.clone(); + thread::spawn(move || { + let reply = f.serve(&message).unwrap(); + let reply_packet = Packet::Message(id, reply); + let mut my_stream = arc_stream.lock().unwrap(); + serde_json::to_writer(&mut *my_stream, &reply_packet).unwrap(); + }); }, } } Ok(()) } -pub fn serve(listener: TcpListener, f: Arc) -> Error - where Request: fmt::Debug + serde::de::Deserialize, - Reply: fmt::Debug + serde::ser::Serialize, - F: 'static + Serve, -{ - for conn in listener.incoming() { - let conn = match conn { - Err(err) => return convert::From::from(err), - Ok(c) => c, - }; - let f = f.clone(); - thread::spawn(move || { - if let Err(err) = handle_conn(conn, f) { - println!("error handling connection: {:?}", err); - } - }); + +pub struct Shutdown { + tx: Sender<()>, + join_handle: JoinHandle<()>, + addr: SocketAddr, +} + + +impl Shutdown { + pub fn wait(self) { + self.join_handle.join().unwrap(); } - Error::Impossible + + pub fn shutdown(self) { + self.tx.send(()).expect(&line!().to_string()); + TcpStream::connect(&self.addr).unwrap(); + self.join_handle.join().expect(&line!().to_string()); + } +} + +pub fn serve_async(addr: &SocketAddr, f: F) -> io::Result + where Request: 'static + fmt::Debug + Send + serde::de::Deserialize + fmt::Debug + serde::ser::Serialize, + Reply: 'static + fmt::Debug + serde::ser::Serialize, + F: 'static + Clone + Serve, +{ + let listener = try!(TcpListener::bind(addr)); + let (die_tx, die_rx) = channel(); + let join_handle = thread::spawn(move || { + for conn in listener.incoming() { + match die_rx.try_recv() { + Ok(_) => break, + Err(TryRecvError::Disconnected) => { + println!("serve: sender disconnected "); + break; + }, + _ => (), + } + let conn = match conn { + Err(err) => { + println!("I couldn't unwrap the connection :( {:?}", err); + return; + }, + Ok(c) => c, + }; + let f = f.clone(); + thread::spawn(move || { + if let Err(err) = handle_conn(conn, f) { + println!("error handling connection: {:?}", err); + } + }); + } + }); + Ok(Shutdown{ + tx: die_tx, + join_handle: join_handle, + addr: addr.clone(), + }) } pub trait Serve: Send + Sync { fn serve(&self, request: &Request) -> io::Result; } +impl Serve for Arc + where S: Serve +{ + fn serve(&self, request: &Request) -> io::Result { + S::serve(self, request) + } +} + #[derive(Debug, Clone, Serialize, Deserialize)] enum Packet { Message(u64, T), Shutdown, } -struct Handle { - id: u64, - sender: Sender, -} - -enum ReceiverMessage { - Handle(Handle), - Packet(Packet), - Shutdown, -} - -fn receiver(messages: Receiver>) -> Result<()> { - let mut ready_handles: HashMap> = HashMap::new(); - for message in messages.into_iter() { - match message { - ReceiverMessage::Handle(handle) => { - ready_handles.insert(handle.id, handle); - }, - ReceiverMessage::Packet(Packet::Shutdown) => break, - ReceiverMessage::Packet(Packet::Message(id, message)) => { - let handle = ready_handles.remove(&id).unwrap(); - try!(handle.sender.send(message)); - } - ReceiverMessage::Shutdown => break, - } - } - Ok(()) -} - -fn reader(stream: TcpStream, tx: SyncSender>) +fn reader( + stream: TcpStream, + requests: Arc>>>) where Reply: serde::Deserialize { - use serde_json::Error::SyntaxError; - use serde_json::ErrorCode::EOFWhileParsingValue; let mut de = serde_json::Deserializer::new(stream.bytes()); loop { match Packet::deserialize(&mut de) { - Ok(packet) =>{ - println!("send!"); - tx.send(ReceiverMessage::Packet(packet)).unwrap(); + Ok(Packet::Message(id, reply)) => { + let mut requests = requests.lock().unwrap(); + let reply_tx = requests.remove(&id).unwrap(); + reply_tx.send(reply).unwrap(); }, + Ok(Packet::Shutdown) => { + break; + } // TODO: This shutdown logic is janky.. What's the right way to do this? - Err(SyntaxError(EOFWhileParsingValue, _, _)) => break, Err(err) => panic!("unexpected error while parsing!: {:?}", err), } } @@ -169,14 +199,14 @@ fn increment(cur_id: &mut u64) -> u64 { id } -struct SyncedClientState { +struct SyncedClientState { next_id: u64, stream: TcpStream, - handles_tx: SyncSender>, } pub struct Client { - synced_state: Mutex>, + synced_state: Mutex, + requests: Arc>>>, reader_guard: thread::JoinHandle<()>, } @@ -184,19 +214,18 @@ impl Client where Reply: serde::de::Deserialize + Send + 'static { pub fn new(stream: TcpStream) -> Result { - let (handles_tx, receiver_rx) = sync_channel(0); - let read_stream = try!(stream.try_clone()); - try!(read_stream.set_read_timeout(Some(time::Duration::from_millis(50)))); - let reader_handles_tx = handles_tx.clone(); - let guard = thread::spawn(move || reader(read_stream, reader_handles_tx)); - thread::spawn(move || receiver(receiver_rx)); + let requests = Arc::new(Mutex::new(HashMap::new())); + let reader_stream = try!(stream.try_clone()); + let reader_requests = requests.clone(); + let reader_guard = + thread::spawn(move || reader(reader_stream, reader_requests)); Ok(Client{ synced_state: Mutex::new(SyncedClientState{ next_id: 0, stream: stream, - handles_tx: handles_tx, }), - reader_guard: guard, + requests: requests, + reader_guard: reader_guard, }) } @@ -206,20 +235,22 @@ impl Client let (tx, rx) = channel(); let mut state = self.synced_state.lock().unwrap(); let id = increment(&mut state.next_id); - try!(state.handles_tx.send(ReceiverMessage::Handle(Handle{ - id: id, - sender: tx, - }))); + { + let mut requests = self.requests.lock().unwrap(); + requests.insert(id, tx); + } let packet = Packet::Message(id, request.clone()); try!(serde_json::to_writer(&mut state.stream, &packet)); + drop(state); Ok(rx.recv().unwrap()) } - pub fn join(self) -> Result<()> { - let mut state = self.synced_state.lock().unwrap(); - let packet: Packet = Packet::Shutdown; - try!(serde_json::to_writer(&mut state.stream, &packet)); - try!(state.stream.shutdown(net::Shutdown::Both)); + pub fn disconnect(self) -> Result<()> { + { + let mut state = self.synced_state.lock().unwrap(); + let packet: Packet = Packet::Shutdown; + try!(serde_json::to_writer(&mut state.stream, &packet)); + } self.reader_guard.join().unwrap(); Ok(()) } @@ -227,22 +258,23 @@ impl Client #[cfg(test)] mod test { + use serde; use super::*; + use std::fmt; use std::io; - use std::net::{TcpStream, TcpListener, SocketAddr}; + use std::net::{TcpStream, TcpListener, SocketAddr, ToSocketAddrs}; use std::str::FromStr; use std::sync::{Arc, Mutex, Barrier}; + use std::sync::mpsc::channel; use std::sync::atomic::{AtomicUsize, Ordering}; use std::thread; const port: AtomicUsize = AtomicUsize::new(10000); - fn pair() -> (TcpStream, TcpListener) { + fn next_addr() -> SocketAddr { let addr = format!("127.0.0.1:{}", port.fetch_add(1, Ordering::SeqCst)); - println!("what the fuck {}", &addr); - // Do this one first so that we don't get connection refused :) - let listener = TcpListener::bind(&*addr).unwrap(); - (TcpStream::connect(&*addr).unwrap(), listener) + addr.to_socket_addrs().unwrap().next().unwrap() + //ToSocketAddrs::to_socket_addrs(addr.as_ref()).unwrap().next().unwrap() } #[derive(Debug, PartialEq, Serialize, Deserialize, Clone)] @@ -278,19 +310,39 @@ mod test { } } + fn wtf(server: F) -> (SocketAddr, Shutdown) + where Request: 'static + fmt::Debug + Send + serde::de::Deserialize + fmt::Debug + serde::ser::Serialize, + Reply: 'static + fmt::Debug + Send + serde::ser::Serialize, + F: 'static + Clone + Serve + { + let mut addr; + let mut shutdown; + while let &Err(_) = {shutdown = serve_async({addr = next_addr(); &addr}, server.clone()); &shutdown} { } + (addr, shutdown.unwrap()) + } + + #[test] + fn test_handle() { + let server = Arc::new(Server::new()); + let (addr, shutdown) = wtf(server.clone()); + let client_stream = TcpStream::connect(&addr).unwrap(); + let client: Client = Client::new(client_stream).expect(&line!().to_string()); + client.disconnect::(); + shutdown.shutdown(); + } + #[test] fn test() { - let (client_stream, server_streams) = pair(); let server = Arc::new(Server::new()); - let thread_server = server.clone(); - let guard = thread::spawn(move || serve(server_streams, thread_server)); + let (addr, shutdown) = wtf(server.clone()); + let client_stream = TcpStream::connect(&addr).unwrap(); let client = Client::new(client_stream).unwrap(); assert_eq!(Reply::Increment(0), client.rpc(&Request::Increment).unwrap()); assert_eq!(1, server.count()); assert_eq!(Reply::Increment(1), client.rpc(&Request::Increment).unwrap()); assert_eq!(2, server.count()); - client.join::().unwrap(); - guard.join(); + client.disconnect::().unwrap(); + shutdown.shutdown(); } struct BarrierServer { @@ -318,10 +370,9 @@ mod test { #[test] fn test_concurrent() { - let (client_stream, server_streams) = pair(); let server = Arc::new(BarrierServer::new(10)); - let thread_server = server.clone(); - let guard = thread::spawn(move || serve(server_streams, thread_server)); + let (addr, shutdown) = wtf(server.clone()); + let client_stream = TcpStream::connect(&addr).unwrap(); let client: Arc> = Arc::new(Client::new(client_stream).unwrap()); let mut join_handles = vec![]; for _ in 0..10 { @@ -336,7 +387,7 @@ mod test { Err(_) => panic!("couldn't unwrap arc"), Ok(c) => c, }; - client.join::().unwrap(); - guard.join(); + client.disconnect::().unwrap(); + shutdown.shutdown(); } }