mirror of
https://github.com/OMGeeky/tarpc.git
synced 2025-12-30 16:18:56 +01:00
Ensure no rpc calls can be started once the reader thread returns.
This commit is contained in:
@@ -271,30 +271,43 @@ enum Packet<T> {
|
||||
Shutdown,
|
||||
}
|
||||
|
||||
fn reader<Reply>(mut stream: TcpStream, requests: Arc<Mutex<HashMap<u64, Sender<Reply>>>>)
|
||||
where Reply: serde::Deserialize
|
||||
{
|
||||
loop {
|
||||
let packet: bincode::serde::DeserializeResult<Packet<Reply>> =
|
||||
bincode::serde::deserialize_from(&mut stream, bincode::SizeLimit::Infinite);
|
||||
match packet {
|
||||
Ok(Packet::Message(id, reply)) => {
|
||||
debug!("Client: received message, id={}", id);
|
||||
let mut requests = requests.lock().unwrap();
|
||||
let reply_tx = requests.remove(&id).unwrap();
|
||||
reply_tx.send(reply).unwrap();
|
||||
struct Reader<Reply> {
|
||||
requests: Arc<Mutex<Option<HashMap<u64, Sender<Reply>>>>>
|
||||
}
|
||||
|
||||
impl<Reply> Reader<Reply> {
|
||||
fn read(self, mut stream: TcpStream) where Reply: serde::Deserialize {
|
||||
loop {
|
||||
let packet: bincode::serde::DeserializeResult<Packet<Reply>> =
|
||||
bincode::serde::deserialize_from(&mut stream, bincode::SizeLimit::Infinite);
|
||||
match packet {
|
||||
Ok(Packet::Message(id, reply)) => {
|
||||
debug!("Client: received message, id={}", id);
|
||||
let mut requests = self.requests.lock().unwrap();
|
||||
let mut requests = requests.as_mut().unwrap();
|
||||
let reply_tx = requests.remove(&id).unwrap();
|
||||
reply_tx.send(reply).unwrap();
|
||||
}
|
||||
Ok(Packet::Shutdown) => {
|
||||
info!("Client: got shutdown message.");
|
||||
break;
|
||||
}
|
||||
// TODO: This shutdown logic is janky.. What's the right way to do this?
|
||||
Err(err) => panic!("unexpected error while parsing!: {:?}", err),
|
||||
}
|
||||
Ok(Packet::Shutdown) => {
|
||||
info!("Client: got shutdown message.");
|
||||
requests.lock().unwrap().clear();
|
||||
break;
|
||||
}
|
||||
// TODO: This shutdown logic is janky.. What's the right way to do this?
|
||||
Err(err) => panic!("unexpected error while parsing!: {:?}", err),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<Reply> Drop for Reader<Reply> {
|
||||
fn drop(&mut self) {
|
||||
let mut guard = self.requests.lock().unwrap();
|
||||
guard.as_mut().unwrap().clear();
|
||||
// remove the hashmap so no one can put more senders and accidentally block
|
||||
guard.take();
|
||||
}
|
||||
}
|
||||
|
||||
fn increment(cur_id: &mut u64) -> u64 {
|
||||
let id = *cur_id;
|
||||
*cur_id += 1;
|
||||
@@ -311,7 +324,7 @@ pub struct Client<Request, Reply>
|
||||
where Request: serde::ser::Serialize
|
||||
{
|
||||
synced_state: Mutex<SyncedClientState>,
|
||||
requests: Arc<Mutex<HashMap<u64, Sender<Reply>>>>,
|
||||
requests: Arc<Mutex<Option<HashMap<u64, Sender<Reply>>>>>,
|
||||
reader_guard: Option<thread::JoinHandle<()>>,
|
||||
timeout: Option<Duration>,
|
||||
_request: PhantomData<Request>,
|
||||
@@ -324,10 +337,10 @@ impl<Request, Reply> Client<Request, Reply>
|
||||
/// Create a new client that connects to `addr`
|
||||
pub fn new<A: ToSocketAddrs>(addr: A, timeout: Option<Duration>) -> io::Result<Self> {
|
||||
let stream = try!(TcpStream::connect(addr));
|
||||
let requests = Arc::new(Mutex::new(HashMap::new()));
|
||||
let requests = Arc::new(Mutex::new(Some(HashMap::new())));
|
||||
let reader_stream = try!(stream.try_clone());
|
||||
let reader_requests = requests.clone();
|
||||
let reader_guard = thread::spawn(move || reader(reader_stream, reader_requests));
|
||||
let reader = Reader { requests: requests.clone() };
|
||||
let reader_guard = thread::spawn(move || reader.read(reader_stream));
|
||||
Ok(Client {
|
||||
synced_state: Mutex::new(SyncedClientState {
|
||||
next_id: 0,
|
||||
@@ -348,8 +361,11 @@ impl<Request, Reply> Client<Request, Reply>
|
||||
let mut state = self.synced_state.lock().unwrap();
|
||||
let id = increment(&mut state.next_id);
|
||||
{
|
||||
let mut requests = self.requests.lock().unwrap();
|
||||
requests.insert(id, tx);
|
||||
if let Some(ref mut requests) = *self.requests.lock().unwrap() {
|
||||
requests.insert(id, tx);
|
||||
} else {
|
||||
return Err(Error::ConnectionBroken);
|
||||
}
|
||||
}
|
||||
let packet = Packet::Message(id, request);
|
||||
try!(state.stream.set_write_timeout(self.timeout));
|
||||
@@ -361,7 +377,11 @@ impl<Request, Reply> Client<Request, Reply>
|
||||
warn!("Client: failed to write packet.\nPacket: {:?}\nError: {:?}",
|
||||
packet,
|
||||
err);
|
||||
self.requests.lock().unwrap().remove(&id);
|
||||
if let Some(requests) = self.requests.lock().unwrap().as_mut() {
|
||||
requests.remove(&id);
|
||||
} else {
|
||||
warn!("Client: couldn't remove sender for request {} because reader thread returned.", id);
|
||||
}
|
||||
return Err(err.into());
|
||||
}
|
||||
drop(state);
|
||||
@@ -495,7 +515,8 @@ mod test {
|
||||
let client: Arc<Client<Request, Reply>> = Arc::new(Client::new(addr, test_timeout())
|
||||
.unwrap());
|
||||
let thread = thread::spawn(move || serve_handle.shutdown());
|
||||
info!("force_shutdown::client: {:?}", client.rpc(&Request::Increment));
|
||||
info!("force_shutdown:: rpc1: {:?}", client.rpc(&Request::Increment));
|
||||
info!("force_shutdown:: rpc2: {:?}", client.rpc(&Request::Increment));
|
||||
thread.join().unwrap();
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user