diff --git a/tarpc/src/server.rs b/tarpc/src/server.rs index f395447..d700141 100644 --- a/tarpc/src/server.rs +++ b/tarpc/src/server.rs @@ -196,8 +196,8 @@ impl fmt::Debug for BaseChannel { /// do not have specific scheduling needs and whose services are `Send + 'static`. /// 2. [`Channel::requests`] - This method is best for those who need direct access to individual /// requests, or are not using `tokio`, or want control over [futures](Future) scheduling. -/// [`Requests`] is a stream of [`InFlightRequests`](InFlightRequest), which each have an -/// [`execute`](InFightRequest::execute) method. If using `execute`, request processing will +/// [`Requests`] is a stream of [`InFlightRequests`](InFlightRequest), each which has an +/// [`execute`](InFlightRequest::execute) method. If using `execute`, request processing will /// automatically cease when either the request deadline is reached or when a corresponding /// cancellation message is received by the Channel. /// 3. [`Sink::start_send`] - A user is free to manually send responses to requests produced by a diff --git a/tarpc/src/transport/channel.rs b/tarpc/src/transport/channel.rs index 59760a2..8456b88 100644 --- a/tarpc/src/transport/channel.rs +++ b/tarpc/src/transport/channel.rs @@ -29,20 +29,17 @@ pub fn unbounded() -> ( /// A bi-directional channel backed by an [`UnboundedSender`](mpsc::UnboundedSender) /// and [`UnboundedReceiver`](mpsc::UnboundedReceiver). -#[pin_project] #[derive(Debug)] pub struct UnboundedChannel { - #[pin] rx: mpsc::UnboundedReceiver, - #[pin] tx: mpsc::UnboundedSender, } impl Stream for UnboundedChannel { type Item = Result; - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> PollIo { - self.project().rx.poll_recv(cx).map(|option| option.map(Ok)) + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> PollIo { + self.rx.poll_recv(cx).map(|option| option.map(Ok)) } } @@ -50,7 +47,7 @@ impl Sink for UnboundedChannel { type Error = io::Error; fn poll_ready(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { - Poll::Ready(if self.project().tx.is_closed() { + Poll::Ready(if self.tx.is_closed() { Err(io::Error::from(io::ErrorKind::NotConnected)) } else { Ok(()) @@ -58,8 +55,7 @@ impl Sink for UnboundedChannel { } fn start_send(self: Pin<&mut Self>, item: SinkItem) -> io::Result<()> { - self.project() - .tx + self.tx .send(item) .map_err(|_| io::Error::from(io::ErrorKind::NotConnected)) } @@ -75,6 +71,77 @@ impl Sink for UnboundedChannel { } } +/// Returns two channel peers with buffer equal to `capacity`. Each [`Stream`] yields items sent +/// through the other's [`Sink`]. +pub fn bounded( + capacity: usize, +) -> (Channel, Channel) { + let (tx1, rx2) = futures::channel::mpsc::channel(capacity); + let (tx2, rx1) = futures::channel::mpsc::channel(capacity); + (Channel { tx: tx1, rx: rx1 }, Channel { tx: tx2, rx: rx2 }) +} + +/// A bi-directional channel backed by a [`Sender`](futures::channel::mpsc::Sender) +/// and [`Receiver`](futures::channel::mpsc::Receiver). +#[pin_project] +#[derive(Debug)] +pub struct Channel { + #[pin] + rx: futures::channel::mpsc::Receiver, + #[pin] + tx: futures::channel::mpsc::Sender, +} + +impl Stream for Channel { + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> PollIo { + self.project().rx.poll_next(cx).map(|option| option.map(Ok)) + } +} + +impl Sink for Channel { + type Error = io::Error; + + fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project() + .tx + .poll_ready(cx) + .map_err(convert_send_err_to_io) + } + + fn start_send(self: Pin<&mut Self>, item: SinkItem) -> io::Result<()> { + self.project() + .tx + .start_send(item) + .map_err(convert_send_err_to_io) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project() + .tx + .poll_flush(cx) + .map_err(convert_send_err_to_io) + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project() + .tx + .poll_close(cx) + .map_err(convert_send_err_to_io) + } +} + +fn convert_send_err_to_io(e: futures::channel::mpsc::SendError) -> io::Error { + if e.is_disconnected() { + io::Error::from(io::ErrorKind::NotConnected) + } else if e.is_full() { + io::Error::from(io::ErrorKind::WouldBlock) + } else { + io::Error::new(io::ErrorKind::Other, e) + } +} + #[cfg(test)] #[cfg(feature = "tokio1")] mod tests {