diff --git a/src/lib.rs b/src/lib.rs index 222e609..b81ed00 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -10,12 +10,14 @@ use std::io::{self, Read}; use std::convert; use std::collections::HashMap; use std::net::{ + self, TcpListener, TcpStream, }; use std::sync::{ self, Mutex, + Arc, }; use std::sync::mpsc::{ channel, @@ -24,6 +26,7 @@ use std::sync::mpsc::{ SyncSender, Receiver, }; +use std::time; use std::thread; #[derive(Debug)] @@ -58,26 +61,30 @@ impl convert::From> for Error { pub type Result = std::result::Result; -pub fn handle_conn( - mut stream: TcpStream, - f: F) -> Result<()> +pub fn handle_conn(mut stream: TcpStream, f: Arc) -> Result<()> where Request: fmt::Debug + serde::de::Deserialize, Reply: fmt::Debug + serde::ser::Serialize, - F: 'static + Serve + F: Serve { let read_stream = try!(stream.try_clone()); let mut de = serde_json::Deserializer::new(read_stream.bytes()); - let request_packet: Packet = try!(Packet::deserialize(&mut de)); - let reply = try!(f.serve(&request_packet.message)); - let reply_packet = Packet{ - id: request_packet.id, - message: reply, - }; - try!(serde_json::to_writer(&mut stream, &reply_packet)); + loop { + println!("read"); + let request_packet: Packet = try!(Packet::deserialize(&mut de)); + match request_packet { + Packet::Shutdown => break, + Packet::Message(id, message) => { + let reply = try!(f.serve(&message)); + let reply_packet = Packet::Message(id, reply); + println!("write"); + try!(serde_json::to_writer(&mut stream, &reply_packet)); + }, + } + } Ok(()) } -pub fn serve(listener: TcpListener, f: F) -> Error +pub fn serve(listener: TcpListener, f: Arc) -> Error where Request: fmt::Debug + serde::de::Deserialize, Reply: fmt::Debug + serde::ser::Serialize, F: 'static + Serve, @@ -97,14 +104,14 @@ pub fn serve(listener: TcpListener, f: F) -> Error Error::Impossible } -pub trait Serve : Sync + Send + Clone { +pub trait Serve: Send + Sync { fn serve(&self, request: &Request) -> io::Result; } #[derive(Debug, Clone, Serialize, Deserialize)] -struct Packet { - id: u64, - message: T, +enum Packet { + Message(u64, T), + Shutdown, } struct Handle { @@ -115,6 +122,7 @@ struct Handle { enum ReceiverMessage { Handle(Handle), Packet(Packet), + Shutdown, } fn receiver(messages: Receiver>) -> Result<()> { @@ -124,25 +132,32 @@ fn receiver(messages: Receiver>) -> Result<()> { ReceiverMessage::Handle(handle) => { ready_handles.insert(handle.id, handle); }, - ReceiverMessage::Packet(packet) => { - let handle = ready_handles.remove(&packet.id).unwrap(); - try!(handle.sender.send(packet.message)); + ReceiverMessage::Packet(Packet::Shutdown) => break, + ReceiverMessage::Packet(Packet::Message(id, message)) => { + let handle = ready_handles.remove(&id).unwrap(); + try!(handle.sender.send(message)); } + ReceiverMessage::Shutdown => break, } } Ok(()) } -fn reader(mut stream: TcpStream, decode: F, tx: SyncSender) - where F: Send + 'static + Fn(&mut TcpStream) -> Result, - T: Send + 'static +fn reader(stream: TcpStream, tx: SyncSender>) + where Reply: serde::Deserialize { use serde_json::Error::SyntaxError; use serde_json::ErrorCode::EOFWhileParsingValue; + let mut de = serde_json::Deserializer::new(stream.bytes()); loop { - match decode(&mut stream) { - Ok(t) => tx.send(t).unwrap(), - Err(Error::Json(SyntaxError(EOFWhileParsingValue, _, _))) => break, + match Packet::deserialize(&mut de) { + Ok(packet) =>{ + println!("send!"); + tx.send(ReceiverMessage::Packet(packet)).unwrap(); + }, + // TODO: This shutdown logic is janky.. What's the right way to do this? + Err(SyntaxError(EOFWhileParsingValue, _, _)) => break, + Err(SyntaxError(ExpectedValue, _, _)) => break, Err(err) => panic!("unexpected error while parsing!: {:?}", err), } } @@ -170,13 +185,10 @@ impl Client { pub fn new(stream: TcpStream) -> Result { let (handles_tx, receiver_rx) = sync_channel(0); - let decode = |mut stream: &mut TcpStream| { - let packet = try!(serde_json::from_reader(&mut stream)); - Ok(ReceiverMessage::Packet(packet)) - }; let read_stream = try!(stream.try_clone()); + try!(read_stream.set_read_timeout(Some(time::Duration::from_millis(50)))); let reader_handles_tx = handles_tx.clone(); - let guard = thread::spawn(move || reader(read_stream, decode, reader_handles_tx)); + let guard = thread::spawn(move || reader(read_stream, reader_handles_tx)); thread::spawn(move || receiver(receiver_rx)); Ok(Client{ synced_state: Mutex::new(SyncedClientState{ @@ -198,24 +210,28 @@ impl Client id: id, sender: tx, }))); - try!(serde_json::to_writer(&mut state.stream, &Packet{ - id: id, - message: request.clone(), - })); + let packet = Packet::Message(id, request.clone()); + try!(serde_json::to_writer(&mut state.stream, &packet)); Ok(rx.recv().unwrap()) } - pub fn join(self) { + pub fn join(self) -> Result<()> { + let mut state = self.synced_state.lock().unwrap(); + let packet: Packet = Packet::Shutdown; + try!(serde_json::to_writer(&mut state.stream, &packet)); + try!(state.stream.shutdown(net::Shutdown::Both)); self.reader_guard.join().unwrap(); + Ok(()) } } #[cfg(test)] mod test { use super::*; - use std::thread; - use std::net::{TcpStream, TcpListener}; use std::io; + use std::net::{TcpStream, TcpListener}; + use std::sync::{Arc, Mutex}; + use std::thread; fn pair() -> (TcpStream, TcpListener) { let addr = "127.0.0.1:9000"; @@ -231,24 +247,39 @@ mod test { #[derive(Debug, PartialEq, Serialize, Deserialize)] enum Reply { - Increment + Increment(u64) } - #[derive(Clone)] - struct Server; + struct Server { + counter: Mutex, + } impl Serve for Server { fn serve(&self, _: &Request) -> io::Result { - Ok(Reply::Increment) + let mut counter = self.counter.lock().unwrap(); + let reply = Reply::Increment(*counter); + *counter += 1; + Ok(reply) + } + } + + impl Server { + fn count(&self) -> u64 { + *self.counter.lock().unwrap() } } #[test] fn test() { let (client_stream, server_streams) = pair(); - thread::spawn(|| serve(server_streams, Server)); + let server = Arc::new(Server{counter: Mutex::new(0)}); + let thread_server = server.clone(); + thread::spawn(move || serve(server_streams, thread_server)); let client = Client::new(client_stream).unwrap(); - assert_eq!(Reply::Increment, client.rpc(&Request::Increment).unwrap()); - client.join(); + assert_eq!(Reply::Increment(0), client.rpc(&Request::Increment).unwrap()); + assert_eq!(1, server.count()); + assert_eq!(Reply::Increment(1), client.rpc(&Request::Increment).unwrap()); + assert_eq!(2, server.count()); + client.join::().unwrap(); } }