Ensure no rpc calls can be started once the reader thread returns.

This commit is contained in:
Tim Kuehn
2016-01-14 01:53:28 -08:00
parent 91053b96c0
commit 2644bf0d9b

View File

@@ -271,30 +271,43 @@ enum Packet<T> {
Shutdown,
}
fn reader<Reply>(mut stream: TcpStream, requests: Arc<Mutex<HashMap<u64, Sender<Reply>>>>)
where Reply: serde::Deserialize
{
loop {
let packet: bincode::serde::DeserializeResult<Packet<Reply>> =
bincode::serde::deserialize_from(&mut stream, bincode::SizeLimit::Infinite);
match packet {
Ok(Packet::Message(id, reply)) => {
debug!("Client: received message, id={}", id);
let mut requests = requests.lock().unwrap();
let reply_tx = requests.remove(&id).unwrap();
reply_tx.send(reply).unwrap();
struct Reader<Reply> {
requests: Arc<Mutex<Option<HashMap<u64, Sender<Reply>>>>>
}
impl<Reply> Reader<Reply> {
fn read(self, mut stream: TcpStream) where Reply: serde::Deserialize {
loop {
let packet: bincode::serde::DeserializeResult<Packet<Reply>> =
bincode::serde::deserialize_from(&mut stream, bincode::SizeLimit::Infinite);
match packet {
Ok(Packet::Message(id, reply)) => {
debug!("Client: received message, id={}", id);
let mut requests = self.requests.lock().unwrap();
let mut requests = requests.as_mut().unwrap();
let reply_tx = requests.remove(&id).unwrap();
reply_tx.send(reply).unwrap();
}
Ok(Packet::Shutdown) => {
info!("Client: got shutdown message.");
break;
}
// TODO: This shutdown logic is janky.. What's the right way to do this?
Err(err) => panic!("unexpected error while parsing!: {:?}", err),
}
Ok(Packet::Shutdown) => {
info!("Client: got shutdown message.");
requests.lock().unwrap().clear();
break;
}
// TODO: This shutdown logic is janky.. What's the right way to do this?
Err(err) => panic!("unexpected error while parsing!: {:?}", err),
}
}
}
impl<Reply> Drop for Reader<Reply> {
fn drop(&mut self) {
let mut guard = self.requests.lock().unwrap();
guard.as_mut().unwrap().clear();
// remove the hashmap so no one can put more senders and accidentally block
guard.take();
}
}
fn increment(cur_id: &mut u64) -> u64 {
let id = *cur_id;
*cur_id += 1;
@@ -311,7 +324,7 @@ pub struct Client<Request, Reply>
where Request: serde::ser::Serialize
{
synced_state: Mutex<SyncedClientState>,
requests: Arc<Mutex<HashMap<u64, Sender<Reply>>>>,
requests: Arc<Mutex<Option<HashMap<u64, Sender<Reply>>>>>,
reader_guard: Option<thread::JoinHandle<()>>,
timeout: Option<Duration>,
_request: PhantomData<Request>,
@@ -324,10 +337,10 @@ impl<Request, Reply> Client<Request, Reply>
/// Create a new client that connects to `addr`
pub fn new<A: ToSocketAddrs>(addr: A, timeout: Option<Duration>) -> io::Result<Self> {
let stream = try!(TcpStream::connect(addr));
let requests = Arc::new(Mutex::new(HashMap::new()));
let requests = Arc::new(Mutex::new(Some(HashMap::new())));
let reader_stream = try!(stream.try_clone());
let reader_requests = requests.clone();
let reader_guard = thread::spawn(move || reader(reader_stream, reader_requests));
let reader = Reader { requests: requests.clone() };
let reader_guard = thread::spawn(move || reader.read(reader_stream));
Ok(Client {
synced_state: Mutex::new(SyncedClientState {
next_id: 0,
@@ -348,8 +361,11 @@ impl<Request, Reply> Client<Request, Reply>
let mut state = self.synced_state.lock().unwrap();
let id = increment(&mut state.next_id);
{
let mut requests = self.requests.lock().unwrap();
requests.insert(id, tx);
if let Some(ref mut requests) = *self.requests.lock().unwrap() {
requests.insert(id, tx);
} else {
return Err(Error::ConnectionBroken);
}
}
let packet = Packet::Message(id, request);
try!(state.stream.set_write_timeout(self.timeout));
@@ -361,7 +377,11 @@ impl<Request, Reply> Client<Request, Reply>
warn!("Client: failed to write packet.\nPacket: {:?}\nError: {:?}",
packet,
err);
self.requests.lock().unwrap().remove(&id);
if let Some(requests) = self.requests.lock().unwrap().as_mut() {
requests.remove(&id);
} else {
warn!("Client: couldn't remove sender for request {} because reader thread returned.", id);
}
return Err(err.into());
}
drop(state);
@@ -495,7 +515,8 @@ mod test {
let client: Arc<Client<Request, Reply>> = Arc::new(Client::new(addr, test_timeout())
.unwrap());
let thread = thread::spawn(move || serve_handle.shutdown());
info!("force_shutdown::client: {:?}", client.rpc(&Request::Increment));
info!("force_shutdown:: rpc1: {:?}", client.rpc(&Request::Increment));
info!("force_shutdown:: rpc2: {:?}", client.rpc(&Request::Increment));
thread.join().unwrap();
}