diff --git a/tarpc/src/protocol.rs b/tarpc/src/protocol.rs index 6eeaa47..9288a1b 100644 --- a/tarpc/src/protocol.rs +++ b/tarpc/src/protocol.rs @@ -304,8 +304,55 @@ struct Packet { message: T, } +struct RpcFutures(Result>>); + +impl RpcFutures { + fn new() -> RpcFutures { + RpcFutures(Ok(HashMap::new())) + } + + fn insert_tx(&mut self, id: u64, tx: Sender) -> Result<()> { + match self.0 { + Ok(ref mut requests) => { + requests.insert(id, tx); + Ok(()) + } + Err(ref e) => Err(e.clone()), + } + } + + fn remove_tx(&mut self, id: u64) -> Result<()> { + match self.0 { + Ok(ref mut requests) => { + requests.remove(&id); + Ok(()) + } + Err(ref e) => Err(e.clone()), + } + } + + fn complete_reply(&mut self, id: u64, reply: Reply) { + self.0 + .as_mut() + .unwrap() + .remove(&id) + .unwrap() + .send(reply) + .unwrap(); + } + + fn set_error(&mut self, err: bincode::serde::DeserializeError) { + let map = mem::replace(&mut self.0, Err(err.into())); + map.unwrap().clear(); + } + + fn get_error(&self) -> Error { + self.0.as_ref().err().unwrap().clone() + } +} + struct Reader { - requests: Arc>>>>, + requests: Arc>>, } impl Reader { @@ -321,18 +368,13 @@ impl Reader { message: reply }) => { debug!("Client: received message, id={}", id); - let mut requests = self.requests.lock().unwrap(); - let mut requests = requests.as_mut().unwrap(); - let reply_tx = requests.remove(&id).unwrap(); - reply_tx.send(reply).unwrap(); + self.requests.lock().unwrap().complete_reply(id, reply); } Err(err) => { warn!("Client: reader thread encountered an unexpected error while parsing; \ returning now. Error: {:?}", err); - let mut guard = self.requests.lock().unwrap(); - let map = mem::replace(&mut *guard, Err(err.into())); - map.unwrap().clear(); + self.requests.lock().unwrap().set_error(err); break; } } @@ -356,7 +398,7 @@ pub struct Client where Request: serde::ser::Serialize { synced_state: Mutex, - requests: Arc>>>>, + requests: Arc>>, reader_guard: Option>, timeout: Option, _request: PhantomData, @@ -370,7 +412,7 @@ impl Client /// for both reads and writes. pub fn new(addr: A, timeout: Option) -> io::Result { let stream = try!(TcpStream::connect(addr)); - let requests = Arc::new(Mutex::new(Ok(HashMap::new()))); + let requests = Arc::new(Mutex::new(RpcFutures::new())); let reader_stream = try!(stream.try_clone()); let reader = Reader { requests: requests.clone() }; let reader_guard = thread::spawn(move || reader.read(reader_stream)); @@ -393,12 +435,9 @@ impl Client let (tx, rx) = channel(); let mut state = self.synced_state.lock().unwrap(); let id = increment(&mut state.next_id); - { - match *self.requests.lock().unwrap() { - Ok(ref mut requests) => { - requests.insert(id, tx); - } - Err(ref e) => return Err(e.clone()), + { // block required to drop lock asap + if let Err(e) = self.requests.lock().unwrap().insert_tx(id, tx) { + return Err(e); } } let packet = Packet { @@ -414,21 +453,14 @@ impl Client warn!("Client: failed to write packet.\nPacket: {:?}\nError: {:?}", packet, err); - match *self.requests.lock().unwrap() { - Ok(ref mut requests) => { - requests.remove(&id); - return Err(err.into()); - } - Err(ref e) => return Err(e.clone()), + if let Err(e) = self.requests.lock().unwrap().remove_tx(id) { + return Err(e); } } drop(state); match rx.recv() { Ok(msg) => Ok(msg), - Err(_) => { - let guard = self.requests.lock().unwrap(); - Err(guard.as_ref().err().unwrap().clone()) - } + Err(_) => Err(self.requests.lock().unwrap().get_error()), } } }