diff --git a/tarpc/src/protocol.rs b/tarpc/src/protocol.rs index d83569c..0e1cab7 100644 --- a/tarpc/src/protocol.rs +++ b/tarpc/src/protocol.rs @@ -271,30 +271,43 @@ enum Packet { Shutdown, } -fn reader(mut stream: TcpStream, requests: Arc>>>) - where Reply: serde::Deserialize -{ - loop { - let packet: bincode::serde::DeserializeResult> = - bincode::serde::deserialize_from(&mut stream, bincode::SizeLimit::Infinite); - match packet { - Ok(Packet::Message(id, reply)) => { - debug!("Client: received message, id={}", id); - let mut requests = requests.lock().unwrap(); - let reply_tx = requests.remove(&id).unwrap(); - reply_tx.send(reply).unwrap(); +struct Reader { + requests: Arc>>>> +} + +impl Reader { + fn read(self, mut stream: TcpStream) where Reply: serde::Deserialize { + loop { + let packet: bincode::serde::DeserializeResult> = + bincode::serde::deserialize_from(&mut stream, bincode::SizeLimit::Infinite); + match packet { + Ok(Packet::Message(id, reply)) => { + debug!("Client: received message, id={}", id); + let mut requests = self.requests.lock().unwrap(); + let mut requests = requests.as_mut().unwrap(); + let reply_tx = requests.remove(&id).unwrap(); + reply_tx.send(reply).unwrap(); + } + Ok(Packet::Shutdown) => { + info!("Client: got shutdown message."); + break; + } + // TODO: This shutdown logic is janky.. What's the right way to do this? + Err(err) => panic!("unexpected error while parsing!: {:?}", err), } - Ok(Packet::Shutdown) => { - info!("Client: got shutdown message."); - requests.lock().unwrap().clear(); - break; - } - // TODO: This shutdown logic is janky.. What's the right way to do this? - Err(err) => panic!("unexpected error while parsing!: {:?}", err), } } } +impl Drop for Reader { + fn drop(&mut self) { + let mut guard = self.requests.lock().unwrap(); + guard.as_mut().unwrap().clear(); + // remove the hashmap so no one can put more senders and accidentally block + guard.take(); + } +} + fn increment(cur_id: &mut u64) -> u64 { let id = *cur_id; *cur_id += 1; @@ -311,7 +324,7 @@ pub struct Client where Request: serde::ser::Serialize { synced_state: Mutex, - requests: Arc>>>, + requests: Arc>>>>, reader_guard: Option>, timeout: Option, _request: PhantomData, @@ -324,10 +337,10 @@ impl Client /// Create a new client that connects to `addr` pub fn new(addr: A, timeout: Option) -> io::Result { let stream = try!(TcpStream::connect(addr)); - let requests = Arc::new(Mutex::new(HashMap::new())); + let requests = Arc::new(Mutex::new(Some(HashMap::new()))); let reader_stream = try!(stream.try_clone()); - let reader_requests = requests.clone(); - let reader_guard = thread::spawn(move || reader(reader_stream, reader_requests)); + let reader = Reader { requests: requests.clone() }; + let reader_guard = thread::spawn(move || reader.read(reader_stream)); Ok(Client { synced_state: Mutex::new(SyncedClientState { next_id: 0, @@ -348,8 +361,11 @@ impl Client let mut state = self.synced_state.lock().unwrap(); let id = increment(&mut state.next_id); { - let mut requests = self.requests.lock().unwrap(); - requests.insert(id, tx); + if let Some(ref mut requests) = *self.requests.lock().unwrap() { + requests.insert(id, tx); + } else { + return Err(Error::ConnectionBroken); + } } let packet = Packet::Message(id, request); try!(state.stream.set_write_timeout(self.timeout)); @@ -361,7 +377,11 @@ impl Client warn!("Client: failed to write packet.\nPacket: {:?}\nError: {:?}", packet, err); - self.requests.lock().unwrap().remove(&id); + if let Some(requests) = self.requests.lock().unwrap().as_mut() { + requests.remove(&id); + } else { + warn!("Client: couldn't remove sender for request {} because reader thread returned.", id); + } return Err(err.into()); } drop(state); @@ -495,7 +515,8 @@ mod test { let client: Arc> = Arc::new(Client::new(addr, test_timeout()) .unwrap()); let thread = thread::spawn(move || serve_handle.shutdown()); - info!("force_shutdown::client: {:?}", client.rpc(&Request::Increment)); + info!("force_shutdown:: rpc1: {:?}", client.rpc(&Request::Increment)); + info!("force_shutdown:: rpc2: {:?}", client.rpc(&Request::Increment)); thread.join().unwrap(); }