diff --git a/rpc/src/client/channel.rs b/rpc/src/client/channel.rs index b888efc..6b75a18 100644 --- a/rpc/src/client/channel.rs +++ b/rpc/src/client/channel.rs @@ -258,6 +258,7 @@ where { let (to_dispatch, pending_requests) = mpsc::channel(config.pending_request_buffer); let (cancellation, canceled_requests) = cancellations(); + let canceled_requests = canceled_requests.fuse(); crate::spawn( RequestDispatch { @@ -296,7 +297,7 @@ struct RequestDispatch { /// Requests waiting to be written to the wire. pending_requests: Fuse>>, /// Requests that were dropped. - canceled_requests: CanceledRequests, + canceled_requests: Fuse, /// Requests already written to the wire that haven't yet received responses. in_flight_requests: FnvHashMap>, /// Configures limits to prevent unlimited resource usage. @@ -313,7 +314,7 @@ where { unsafe_pinned!(server_addr: SocketAddr); unsafe_pinned!(in_flight_requests: FnvHashMap>); - unsafe_pinned!(canceled_requests: CanceledRequests); + unsafe_pinned!(canceled_requests: Fuse); unsafe_pinned!(pending_requests: Fuse>>); unsafe_pinned!(transport: Fuse); @@ -423,7 +424,8 @@ where } loop { - match ready!(self.as_mut().canceled_requests().poll_next_unpin(waker)) { + let cancellation = self.as_mut().canceled_requests().poll_next_unpin(waker); + match ready!(cancellation) { Some(request_id) => { if let Some(in_flight_data) = self.as_mut().in_flight_requests().remove(&request_id) @@ -807,7 +809,9 @@ where #[cfg(test)] mod tests { - use super::{CanceledRequests, Channel, RequestCancellation, RequestDispatch}; + use super::{ + CanceledRequests, Channel, DispatchResponse, RequestCancellation, RequestDispatch, + }; use crate::{ client::Config, context, @@ -828,19 +832,11 @@ mod tests { #[test] fn stage_request() { let (mut dispatch, mut channel, _server_channel) = set_up(); - - // Test that a request future dropped before it's processed by dispatch will cause the request - // to not be added to the in-flight request map. - let _resp = tokio::runtime::current_thread::block_on_all( - channel - .send(context::current(), "hi".to_string()) - .boxed() - .compat(), - ); - let mut dispatch = Pin::new(&mut dispatch); let waker = &noop_waker_ref(); + let _resp = send_request(&mut channel, "hi"); + let req = dispatch.poll_next_request(waker).ready(); assert!(req.is_some()); @@ -849,49 +845,77 @@ mod tests { assert_eq!(req.request, "hi".to_string()); } + // Regression test for https://github.com/google/tarpc/issues/220 #[test] - fn stage_request_response_future_dropped() { - let (mut dispatch, mut channel, _server_channel) = set_up(); - - // Test that a request future dropped before it's processed by dispatch will cause the request - // to not be added to the in-flight request map. - let resp = tokio::runtime::current_thread::block_on_all( - channel - .send(context::current(), "hi".into()) - .boxed() - .compat(), - ) - .unwrap(); - drop(resp); - drop(channel); - + fn stage_request_channel_dropped_doesnt_panic() { + let (mut dispatch, mut channel, mut server_channel) = set_up(); let mut dispatch = Pin::new(&mut dispatch); let waker = &noop_waker_ref(); - dispatch.poll_next_cancellation(waker).unwrap(); + let _ = send_request(&mut channel, "hi"); + drop(channel); + + assert!(dispatch.as_mut().poll(waker).is_ready()); + send_response( + &mut server_channel, + Response { + request_id: 0, + message: Ok("hello".into()), + }, + ); + tokio::runtime::current_thread::block_on_all(dispatch.boxed().compat()).unwrap(); + } + + #[test] + fn stage_request_response_future_dropped_is_canceled_before_sending() { + let (mut dispatch, mut channel, _server_channel) = set_up(); + let mut dispatch = Pin::new(&mut dispatch); + let waker = &noop_waker_ref(); + + let _ = send_request(&mut channel, "hi"); + + + // Drop the channel so polling returns none if no requests are currently ready. + drop(channel); + // Test that a request future dropped before it's processed by dispatch will cause the request + // to not be added to the in-flight request map. assert!(dispatch.poll_next_request(waker).ready().is_none()); } #[test] - fn stage_request_response_future_closed() { + fn stage_request_response_future_dropped_is_canceled_after_sending() { let (mut dispatch, mut channel, _server_channel) = set_up(); + let waker = &noop_waker_ref(); + let mut dispatch = Pin::new(&mut dispatch); + + let req = send_request(&mut channel, "hi"); + + assert!(dispatch.as_mut().pump_write(waker).ready().is_some()); + assert!(!dispatch.as_mut().in_flight_requests().is_empty()); + + + // Test that a request future dropped after it's processed by dispatch will cause the request + // to be removed from the in-flight request map. + drop(req); + if let Poll::Ready(Some(_)) = dispatch.as_mut().poll_next_cancellation(waker).unwrap() { + // ok + } else { panic!("Expected request to be cancelled")}; + assert!(dispatch.in_flight_requests().is_empty()); + } + + #[test] + fn stage_request_response_closed_skipped() { + let (mut dispatch, mut channel, _server_channel) = set_up(); + let mut dispatch = Pin::new(&mut dispatch); + let waker = &noop_waker_ref(); // Test that a request future that's closed its receiver but not yet canceled its request -- // i.e. still in `drop fn` -- will cause the request to not be added to the in-flight request // map. - let resp = tokio::runtime::current_thread::block_on_all( - channel - .send(context::current(), "hi".into()) - .boxed() - .compat(), - ) - .unwrap(); - drop(resp); - drop(channel); + let mut resp = send_request(&mut channel, "hi"); + resp.response.get_mut().close(); - let mut dispatch = Pin::new(&mut dispatch); - let waker = &noop_waker_ref(); - assert!(dispatch.poll_next_request(waker).ready().is_none()); + assert!(dispatch.poll_next_request(waker).is_pending()); } fn set_up() -> ( @@ -908,7 +932,7 @@ mod tests { let dispatch = RequestDispatch:: { transport: client_channel.fuse(), pending_requests: pending_requests.fuse(), - canceled_requests: CanceledRequests(canceled_requests), + canceled_requests: CanceledRequests(canceled_requests).fuse(), in_flight_requests: FnvHashMap::default(), config: Config::default(), server_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0), @@ -925,6 +949,26 @@ mod tests { (dispatch, channel, server_channel) } + fn send_request( + channel: &mut Channel, + request: &str, + ) -> DispatchResponse { + tokio::runtime::current_thread::block_on_all( + channel + .send(context::current(), request.to_string()) + .boxed() + .compat(), + ).unwrap() + } + + fn send_response( + channel: &mut UnboundedChannel, Response>, + response: Response, + ) { + tokio::runtime::current_thread::block_on_all(channel.send(response).boxed().compat()) + .unwrap(); + } + trait PollTest { type T; fn unwrap(self) -> Poll; @@ -955,5 +999,4 @@ mod tests { } } } - }