diff --git a/src/lib.rs b/src/lib.rs index 0571d95..409e6ad 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -22,12 +22,8 @@ use std::sync::{ }; use std::sync::mpsc::{ channel, - sync_channel, Sender, - SyncSender, - Receiver, }; -use std::time; use std::thread; #[derive(Debug)] @@ -63,7 +59,7 @@ impl convert::From> for Error { pub type Result = std::result::Result; pub fn handle_conn(mut stream: TcpStream, f: Arc) -> Result<()> - where Request: fmt::Debug + serde::de::Deserialize, + where Request: fmt::Debug + serde::de::Deserialize + serde::ser::Serialize, Reply: fmt::Debug + serde::ser::Serialize, F: Serve { @@ -73,7 +69,10 @@ pub fn handle_conn(mut stream: TcpStream, f: Arc) -> Resul println!("read"); let request_packet: Packet = try!(Packet::deserialize(&mut de)); match request_packet { - Packet::Shutdown => break, + Packet::Shutdown => { + try!(serde_json::to_writer(&mut stream, &request_packet)); + break; + }, Packet::Message(id, message) => { let reply = try!(f.serve(&message)); let reply_packet = Packet::Message(id, reply); @@ -86,7 +85,7 @@ pub fn handle_conn(mut stream: TcpStream, f: Arc) -> Resul } pub fn serve(listener: TcpListener, f: Arc) -> Error - where Request: fmt::Debug + serde::de::Deserialize, + where Request: fmt::Debug + serde::de::Deserialize + fmt::Debug + serde::ser::Serialize, Reply: fmt::Debug + serde::ser::Serialize, F: 'static + Serve, { @@ -115,49 +114,21 @@ enum Packet { Shutdown, } -struct Handle { - id: u64, - sender: Sender, -} - -enum ReceiverMessage { - Handle(Handle), - Packet(Packet), - Shutdown, -} - -fn receiver(messages: Receiver>) -> Result<()> { - let mut ready_handles: HashMap> = HashMap::new(); - for message in messages.into_iter() { - match message { - ReceiverMessage::Handle(handle) => { - ready_handles.insert(handle.id, handle); - }, - ReceiverMessage::Packet(Packet::Shutdown) => break, - ReceiverMessage::Packet(Packet::Message(id, message)) => { - let handle = ready_handles.remove(&id).unwrap(); - try!(handle.sender.send(message)); - } - ReceiverMessage::Shutdown => break, - } - } - Ok(()) -} - -fn reader(stream: TcpStream, tx: SyncSender>) +fn reader( + stream: TcpStream, + requests: Arc>>>) where Reply: serde::Deserialize { - use serde_json::Error::SyntaxError; - use serde_json::ErrorCode::EOFWhileParsingValue; let mut de = serde_json::Deserializer::new(stream.bytes()); loop { match Packet::deserialize(&mut de) { - Ok(packet) =>{ - println!("send!"); - tx.send(ReceiverMessage::Packet(packet)).unwrap(); + Ok(Packet::Message(id, reply)) => { + let mut requests = requests.lock().unwrap(); + let reply_tx = requests.remove(&id).unwrap(); + reply_tx.send(reply).unwrap(); }, + Ok(Packet::Shutdown) => break, // TODO: This shutdown logic is janky.. What's the right way to do this? - Err(SyntaxError(EOFWhileParsingValue, _, _)) => break, Err(err) => panic!("unexpected error while parsing!: {:?}", err), } } @@ -169,14 +140,14 @@ fn increment(cur_id: &mut u64) -> u64 { id } -struct SyncedClientState { +struct SyncedClientState { next_id: u64, stream: TcpStream, - handles_tx: SyncSender>, } pub struct Client { - synced_state: Mutex>, + synced_state: Mutex, + requests: Arc>>>, reader_guard: thread::JoinHandle<()>, } @@ -184,19 +155,18 @@ impl Client where Reply: serde::de::Deserialize + Send + 'static { pub fn new(stream: TcpStream) -> Result { - let (handles_tx, receiver_rx) = sync_channel(0); - let read_stream = try!(stream.try_clone()); - try!(read_stream.set_read_timeout(Some(time::Duration::from_millis(50)))); - let reader_handles_tx = handles_tx.clone(); - let guard = thread::spawn(move || reader(read_stream, reader_handles_tx)); - thread::spawn(move || receiver(receiver_rx)); + let requests = Arc::new(Mutex::new(HashMap::new())); + let reader_stream = try!(stream.try_clone()); + let reader_requests = requests.clone(); + let reader_guard = + thread::spawn(move || reader(reader_stream, reader_requests)); Ok(Client{ synced_state: Mutex::new(SyncedClientState{ next_id: 0, stream: stream, - handles_tx: handles_tx, }), - reader_guard: guard, + requests: requests, + reader_guard: reader_guard, }) } @@ -206,10 +176,10 @@ impl Client let (tx, rx) = channel(); let mut state = self.synced_state.lock().unwrap(); let id = increment(&mut state.next_id); - try!(state.handles_tx.send(ReceiverMessage::Handle(Handle{ - id: id, - sender: tx, - }))); + { + let mut requests = self.requests.lock().unwrap(); + requests.insert(id, tx); + } let packet = Packet::Message(id, request.clone()); try!(serde_json::to_writer(&mut state.stream, &packet)); Ok(rx.recv().unwrap())