Tear out requirement that Transport's error type is io::Error.

This commit is contained in:
Tim Kuehn
2021-03-28 23:28:01 -07:00
parent 7b7c182411
commit 21e2f7ca62
19 changed files with 300 additions and 207 deletions

View File

@@ -30,7 +30,7 @@ async fn main() -> anyhow::Result<()> {
// WorldClient is generated by the service attribute. It has a constructor `new` that takes a
// config and any Transport as input.
let client = WorldClient::new(client::Config::default(), transport.await?).spawn()?;
let client = WorldClient::new(client::Config::default(), transport.await?).spawn();
let hello = async move {
// Send the request twice, just to be safe! ;)

View File

@@ -12,7 +12,7 @@ use rand::{
};
use service::{init_tracing, World};
use std::{
net::{IpAddr, SocketAddr},
net::{IpAddr, Ipv6Addr, SocketAddr},
time::Duration,
};
use tarpc::{
@@ -40,7 +40,7 @@ impl World for HelloServer {
let sleep_time =
Duration::from_millis(Uniform::new_inclusive(1, 10).sample(&mut thread_rng()));
time::sleep(sleep_time).await;
format!("Hello, {}! You are connected from {:?}.", name, self.0)
format!("Hello, {}! You are connected from {}", name, self.0)
}
}
@@ -49,7 +49,7 @@ async fn main() -> anyhow::Result<()> {
let flags = Flags::parse();
let _uninstall = init_tracing("Tarpc Example Server")?;
let server_addr = (IpAddr::from([0, 0, 0, 0]), flags.port);
let server_addr = (IpAddr::V6(Ipv6Addr::LOCALHOST), flags.port);
// JSON transport is provided by the json_transport tarpc module. It makes it easy
// to start up a serde-powered json serialization strategy over TCP.

View File

@@ -35,6 +35,7 @@ rand = "0.8"
serde = { optional = true, version = "1.0", features = ["derive"] }
static_assertions = "1.1.0"
tarpc-plugins = { path = "../plugins", version = "0.10" }
thiserror = "1.0"
tokio = { version = "1", features = ["time"] }
tokio-util = { version = "0.6.3", features = ["time"] }
tokio-serde = { optional = true, version = "0.8" }

View File

@@ -118,7 +118,7 @@ async fn main() -> anyhow::Result<()> {
});
let transport = tcp::connect(addr, Bincode::default).await?;
let client = WorldClient::new(client::Config::default(), add_compression(transport)).spawn()?;
let client = WorldClient::new(client::Config::default(), add_compression(transport)).spawn();
println!(
"{}",

View File

@@ -44,7 +44,7 @@ async fn main() -> std::io::Result<()> {
let conn = UnixStream::connect(bind_addr).await?;
let transport = transport::new(codec_builder.new_framed(conn), Bincode::default());
PingServiceClient::new(Default::default(), transport)
.spawn()?
.spawn()
.ping(tarpc::context::current())
.await
}

View File

@@ -41,7 +41,9 @@ use futures::{
use publisher::Publisher as _;
use std::{
collections::HashMap,
env, io,
env,
error::Error,
io,
net::SocketAddr,
sync::{Arc, Mutex, RwLock},
};
@@ -224,10 +226,10 @@ impl Publisher {
}
}
fn start_subscriber_gc(
fn start_subscriber_gc<E: Error>(
self,
subscriber_addr: SocketAddr,
client_dispatch: impl Future<Output = anyhow::Result<()>> + Send + 'static,
client_dispatch: impl Future<Output = Result<(), E>> + Send + 'static,
subscriber_ready: oneshot::Receiver<()>,
) {
tokio::spawn(async move {
@@ -325,7 +327,7 @@ async fn main() -> anyhow::Result<()> {
client::Config::default(),
tcp::connect(addrs.publisher, Json::default).await?,
)
.spawn()?;
.spawn();
publisher
.publish(context::current(), "calculus".into(), "sqrt(2)".into())

View File

@@ -43,7 +43,7 @@ async fn main() -> io::Result<()> {
// WorldClient is generated by the #[tarpc::service] attribute. It has a constructor `new`
// that takes a config and any Transport as input.
let client = WorldClient::new(client::Config::default(), client_transport).spawn()?;
let client = WorldClient::new(client::Config::default(), client_transport).spawn();
// The client has an RPC method for each RPC defined in the annotated trait. It takes the same
// args as defined, with the addition of a Context, which is always the first arg. The Context

View File

@@ -88,7 +88,7 @@ async fn main() -> anyhow::Result<()> {
tokio::spawn(add_server);
let to_add_server = tarpc::serde_transport::tcp::connect(addr, Json::default).await?;
let add_client = add::AddClient::new(client::Config::default(), to_add_server).spawn()?;
let add_client = add::AddClient::new(client::Config::default(), to_add_server).spawn();
let double_listener = tarpc::serde_transport::tcp::listen("localhost:0", Json::default)
.await?
@@ -102,7 +102,7 @@ async fn main() -> anyhow::Result<()> {
let to_double_server = tarpc::serde_transport::tcp::connect(addr, Json::default).await?;
let double_client =
double::DoubleClient::new(client::Config::default(), to_double_server).spawn()?;
double::DoubleClient::new(client::Config::default(), to_double_server).spawn();
let ctx = context::current();
for _ in 1..=5 {

View File

@@ -8,12 +8,13 @@
mod in_flight_requests;
use crate::{context, trace, ClientMessage, PollContext, PollIo, Request, Response, Transport};
use crate::{context, trace, ClientMessage, Request, Response, Transport};
use futures::{prelude::*, ready, stream::Fuse, task::*};
use in_flight_requests::InFlightRequests;
use pin_project::pin_project;
use std::{
convert::TryFrom,
error::Error,
fmt, io, mem,
pin::Pin,
sync::{
@@ -64,12 +65,12 @@ where
/// Helper method to spawn the dispatch on the default executor.
#[cfg(feature = "tokio1")]
#[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))]
pub fn spawn(self) -> io::Result<C> {
pub fn spawn(self) -> C {
let dispatch = self
.dispatch
.unwrap_or_else(move |e| tracing::warn!("Connection broken: {}", e));
tokio::spawn(dispatch);
Ok(self.client)
self.client
}
}
@@ -250,6 +251,20 @@ pub struct RequestDispatch<Req, Resp, C> {
config: Config,
}
/// Critical errors that result in a Channel disconnecting.
#[derive(thiserror::Error, Debug)]
pub enum ChannelError<E>
where
E: Error + Send + Sync + 'static,
{
/// An error occurred reading from, or writing to, the transport.
#[error("an error occurred in the transport: {0}")]
Transport(#[source] E),
/// An error occurred while polling expired requests.
#[error("an error occurred while polling expired requests: {0}")]
Timer(#[source] tokio::time::error::Error),
}
impl<Req, Resp, C> RequestDispatch<Req, Resp, C>
where
C: Transport<ClientMessage<Req>, Response<Resp>>,
@@ -272,29 +287,42 @@ where
self.as_mut().project().pending_requests
}
fn pump_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> PollIo<()> {
Poll::Ready(match ready!(self.transport_pin_mut().poll_next(cx)?) {
Some(response) => {
fn pump_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<(), ChannelError<C::Error>>>> {
self.transport_pin_mut()
.poll_next(cx)
.map_err(ChannelError::Transport)
.map_ok(|response| {
self.complete(response);
Some(Ok(()))
}
None => None,
})
})
}
fn pump_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> PollIo<()> {
fn pump_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<(), ChannelError<C::Error>>>> {
enum ReceiverStatus {
Pending,
Closed,
}
let pending_requests_status = match self.as_mut().poll_write_request(cx)? {
let pending_requests_status = match self
.as_mut()
.poll_write_request(cx)
.map_err(ChannelError::Transport)?
{
Poll::Ready(Some(())) => return Poll::Ready(Some(Ok(()))),
Poll::Ready(None) => ReceiverStatus::Closed,
Poll::Pending => ReceiverStatus::Pending,
};
let canceled_requests_status = match self.as_mut().poll_write_cancel(cx)? {
let canceled_requests_status = match self
.as_mut()
.poll_write_cancel(cx)
.map_err(ChannelError::Transport)?
{
Poll::Ready(Some(())) => return Poll::Ready(Some(Ok(()))),
Poll::Ready(None) => ReceiverStatus::Closed,
Poll::Pending => ReceiverStatus::Pending,
@@ -303,7 +331,11 @@ where
// Receiving Poll::Ready(None) when polling expired requests never indicates "Closed",
// because there can temporarily be zero in-flight rquests. Therefore, there is no need to
// track the status like is done with pending and cancelled requests.
if let Poll::Ready(Some(_)) = self.in_flight_requests().poll_expired(cx)? {
if let Poll::Ready(Some(_)) = self
.in_flight_requests()
.poll_expired(cx)
.map_err(ChannelError::Timer)?
{
// Expired requests are considered complete; there is no compelling reason to send a
// cancellation message to the server, since it will have already exhausted its
// allotted processing time.
@@ -312,12 +344,18 @@ where
match (pending_requests_status, canceled_requests_status) {
(ReceiverStatus::Closed, ReceiverStatus::Closed) => {
ready!(self.transport_pin_mut().poll_flush(cx)?);
ready!(self
.transport_pin_mut()
.poll_flush(cx)
.map_err(ChannelError::Transport)?);
Poll::Ready(None)
}
(ReceiverStatus::Pending, _) | (_, ReceiverStatus::Pending) => {
// No more messages to process, so flush any messages buffered in the transport.
ready!(self.transport_pin_mut().poll_flush(cx)?);
ready!(self
.transport_pin_mut()
.poll_flush(cx)
.map_err(ChannelError::Transport)?);
// Even if we fully-flush, we return Pending, because we have no more requests
// or cancellations right now.
@@ -333,7 +371,7 @@ where
fn poll_next_request(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> PollIo<DispatchRequest<Req, Resp>> {
) -> Poll<Option<Result<DispatchRequest<Req, Resp>, C::Error>>> {
if self.in_flight_requests().len() >= self.config.max_in_flight_requests {
tracing::info!(
"At in-flight request capacity ({}/{}).",
@@ -371,7 +409,7 @@ where
fn poll_next_cancellation(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> PollIo<(context::Context, Span, u64)> {
) -> Poll<Option<Result<(context::Context, Span, u64), C::Error>>> {
ready!(self.ensure_writeable(cx)?);
loop {
@@ -390,14 +428,20 @@ where
/// Returns Ready if writing a message to the transport (i.e. via write_request or
/// write_cancel) would not fail due to a full buffer. If the transport is not ready to be
/// written to, flushes it until it is ready.
fn ensure_writeable<'a>(self: &'a mut Pin<&mut Self>, cx: &mut Context<'_>) -> PollIo<()> {
fn ensure_writeable<'a>(
self: &'a mut Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<(), C::Error>>> {
while self.transport_pin_mut().poll_ready(cx)?.is_pending() {
ready!(self.transport_pin_mut().poll_flush(cx)?);
}
Poll::Ready(Some(Ok(())))
}
fn poll_write_request<'a>(self: &'a mut Pin<&mut Self>, cx: &mut Context<'_>) -> PollIo<()> {
fn poll_write_request<'a>(
self: &'a mut Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<(), C::Error>>> {
let DispatchRequest {
ctx,
span,
@@ -435,7 +479,10 @@ where
Poll::Ready(Some(Ok(())))
}
fn poll_write_cancel<'a>(self: &'a mut Pin<&mut Self>, cx: &mut Context<'_>) -> PollIo<()> {
fn poll_write_cancel<'a>(
self: &'a mut Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<(), C::Error>>> {
let (context, span, request_id) = match ready!(self.as_mut().poll_next_cancellation(cx)?) {
Some(triple) => triple,
None => return Poll::Ready(None),
@@ -461,18 +508,14 @@ impl<Req, Resp, C> Future for RequestDispatch<Req, Resp, C>
where
C: Transport<ClientMessage<Req>, Response<Resp>>,
{
type Output = anyhow::Result<()>;
type Output = Result<(), ChannelError<C::Error>>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<anyhow::Result<()>> {
fn poll(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), ChannelError<C::Error>>> {
loop {
match (
self.as_mut()
.pump_read(cx)
.context("failed to read from transport")?,
self.as_mut()
.pump_write(cx)
.context("failed to write to transport")?,
) {
match (self.as_mut().pump_read(cx)?, self.as_mut().pump_write(cx)?) {
(Poll::Ready(None), _) => {
tracing::info!("Shutdown: read half closed, so shutting down.");
return Poll::Ready(Ok(()));

View File

@@ -1,10 +1,9 @@
use crate::{
context,
util::{Compact, TimeUntil},
PollIo, Response, ServerError,
Response, ServerError,
};
use fnv::FnvHashMap;
use futures::ready;
use std::{
collections::hash_map,
io,
@@ -113,22 +112,21 @@ impl<Resp> InFlightRequests<Resp> {
/// Yields a request that has expired, completing it with a TimedOut error.
/// The caller should send cancellation messages for any yielded request ID.
pub fn poll_expired(&mut self, cx: &mut Context) -> PollIo<u64> {
Poll::Ready(match ready!(self.deadlines.poll_expired(cx)) {
Some(Ok(expired)) => {
let request_id = expired.into_inner();
if let Some(request_data) = self.request_data.remove(&request_id) {
let _entered = request_data.span.enter();
tracing::error!("DeadlineExceeded");
self.request_data.compact(0.1);
let _ = request_data
.response_completion
.send(Self::deadline_exceeded_error(request_id));
}
Some(Ok(request_id))
pub fn poll_expired(
&mut self,
cx: &mut Context,
) -> Poll<Option<Result<u64, tokio::time::error::Error>>> {
self.deadlines.poll_expired(cx).map_ok(|expired| {
let request_id = expired.into_inner();
if let Some(request_data) = self.request_data.remove(&request_id) {
let _entered = request_data.span.enter();
tracing::error!("DeadlineExceeded");
self.request_data.compact(0.1);
let _ = request_data
.response_completion
.send(Self::deadline_exceeded_error(request_id));
}
Some(Err(e)) => Some(Err(io::Error::new(io::ErrorKind::Other, e))),
None => None,
request_id
})
}

View File

@@ -177,7 +177,7 @@
//!
//! // WorldClient is generated by the #[tarpc::service] attribute. It has a constructor `new`
//! // that takes a config and any Transport as input.
//! let mut client = WorldClient::new(client::Config::default(), client_transport).spawn()?;
//! let mut client = WorldClient::new(client::Config::default(), client_transport).spawn();
//!
//! // The client has an RPC method for each RPC defined in the annotated trait. It takes the same
//! // args as defined, with the addition of a Context, which is always the first arg. The Context
@@ -304,7 +304,7 @@ pub use crate::transport::sealed::Transport;
use anyhow::Context as _;
use futures::task::*;
use std::{fmt::Display, io, time::SystemTime};
use std::{error::Error, fmt::Display, io, time::SystemTime};
/// A message from a client to a server.
#[derive(Debug)]
@@ -388,7 +388,6 @@ impl<T> Request<T> {
}
}
pub(crate) type PollIo<T> = Poll<Option<io::Result<T>>>;
pub(crate) trait PollContext<T> {
fn context<C>(self, context: C) -> Poll<Option<anyhow::Result<T>>>
where
@@ -400,7 +399,10 @@ pub(crate) trait PollContext<T> {
F: FnOnce() -> C;
}
impl<T> PollContext<T> for PollIo<T> {
impl<T, E> PollContext<T> for Poll<Option<Result<T, E>>>
where
E: Error + Send + Sync + 'static,
{
fn context<C>(self, context: C) -> Poll<Option<anyhow::Result<T>>>
where
C: Display + Send + Sync + 'static,

View File

@@ -8,7 +8,7 @@
#![deny(missing_docs)]
use futures::{prelude::*, ready, task::*};
use futures::{prelude::*, task::*};
use pin_project::pin_project;
use serde::{Deserialize, Serialize};
use std::{error::Error, io, pin::Pin};
@@ -42,15 +42,12 @@ where
type Item = io::Result<Item>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<io::Result<Item>>> {
let next = ready!(self.project().inner.poll_next(cx)).map(|next| {
next.map_err(|e| {
io::Error::new(
io::ErrorKind::Other,
format!("while reading from transport: {}", e.into()),
)
})
});
Poll::Ready(next)
self.project().inner.poll_next(cx).map_err(|e| {
io::Error::new(
io::ErrorKind::Other,
format!("while reading from transport: {}", e.into()),
)
})
}
}
@@ -66,13 +63,11 @@ where
type Error = io::Error;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.project().inner.poll_ready(cx).map(|ready| {
ready.map_err(|e| {
io::Error::new(
io::ErrorKind::Other,
format!("while readying write half of transport: {}", e.into()),
)
})
self.project().inner.poll_ready(cx).map_err(|e| {
io::Error::new(
io::ErrorKind::Other,
format!("while readying write half of transport: {}", e.into()),
)
})
}
@@ -86,24 +81,20 @@ where
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.project().inner.poll_flush(cx).map(|ready| {
ready.map_err(|e| {
io::Error::new(
io::ErrorKind::Other,
format!("while flushing transport: {}", e.into()),
)
})
self.project().inner.poll_flush(cx).map_err(|e| {
io::Error::new(
io::ErrorKind::Other,
format!("while flushing transport: {}", e.into()),
)
})
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.project().inner.poll_close(cx).map(|ready| {
ready.map_err(|e| {
io::Error::new(
io::ErrorKind::Other,
format!("while closing write half of transport: {}", e.into()),
)
})
self.project().inner.poll_close(cx).map_err(|e| {
io::Error::new(
io::ErrorKind::Other,
format!("while closing write half of transport: {}", e.into()),
)
})
}
}

View File

@@ -8,7 +8,7 @@
use crate::{
context::{self, SpanExt},
trace, ClientMessage, PollIo, Request, Response, Transport,
trace, ClientMessage, Request, Response, Transport,
};
use futures::{
future::{AbortRegistration, Abortable},
@@ -19,7 +19,10 @@ use futures::{
};
use in_flight_requests::{AlreadyExistsError, InFlightRequests};
use pin_project::pin_project;
use std::{convert::TryFrom, fmt, hash::Hash, io, marker::PhantomData, pin::Pin, time::SystemTime};
use std::{
convert::TryFrom, error::Error, fmt, hash::Hash, marker::PhantomData, pin::Pin,
time::SystemTime,
};
use tokio::sync::mpsc;
use tracing::{info_span, instrument::Instrument, Span};
@@ -248,10 +251,10 @@ where
/// Tells the Channel that request with ID `request_id` is being handled.
/// The request will be tracked until a response with the same ID is sent
/// to the Channel.
/// to the Channel or the deadline expires, whichever happens first.
fn start_request(
self: Pin<&mut Self>,
id: u64,
request_id: u64,
deadline: SystemTime,
span: Span,
) -> Result<AbortRegistration, AlreadyExistsError>;
@@ -292,11 +295,25 @@ where
}
}
/// Critical errors that result in a Channel disconnecting.
#[derive(thiserror::Error, Debug)]
pub enum ChannelError<E>
where
E: Error + Send + Sync + 'static,
{
/// An error occurred reading from, or writing to, the transport.
#[error("an error occurred in the transport: {0}")]
Transport(#[source] E),
/// An error occurred while polling expired requests.
#[error("an error occurred while polling expired requests: {0}")]
Timer(#[source] tokio::time::error::Error),
}
impl<Req, Resp, T> Stream for BaseChannel<Req, Resp, T>
where
T: Transport<Response<Resp>, ClientMessage<Req>>,
{
type Item = io::Result<Request<Req>>;
type Item = Result<Request<Req>, ChannelError<T::Error>>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
enum ReceiverStatus {
@@ -307,7 +324,11 @@ where
use ReceiverStatus::*;
loop {
let expiration_status = match self.in_flight_requests_mut().poll_expired(cx)? {
let expiration_status = match self
.in_flight_requests_mut()
.poll_expired(cx)
.map_err(ChannelError::Timer)?
{
// No need to send a response, since the client wouldn't be waiting for one
// anymore.
Poll::Ready(Some(_)) => Ready,
@@ -315,7 +336,11 @@ where
Poll::Pending => Pending,
};
let request_status = match self.transport_pin_mut().poll_next(cx)? {
let request_status = match self
.transport_pin_mut()
.poll_next(cx)
.map_err(ChannelError::Transport)?
{
Poll::Ready(Some(message)) => match message {
ClientMessage::Request(request) => {
return Poll::Ready(Some(Ok(request)));
@@ -349,11 +374,15 @@ where
impl<Req, Resp, T> Sink<Response<Resp>> for BaseChannel<Req, Resp, T>
where
T: Transport<Response<Resp>, ClientMessage<Req>>,
T::Error: Error,
{
type Error = io::Error;
type Error = ChannelError<T::Error>;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.project().transport.poll_ready(cx)
self.project()
.transport
.poll_ready(cx)
.map_err(ChannelError::Transport)
}
fn start_send(mut self: Pin<&mut Self>, response: Response<Resp>) -> Result<(), Self::Error> {
@@ -365,7 +394,10 @@ where
{
let _entered = span.enter();
tracing::info!("SendResponse");
self.project().transport.start_send(response)
self.project()
.transport
.start_send(response)
.map_err(ChannelError::Transport)
} else {
// If the request isn't tracked anymore, there's no need to send the response.
Ok(())
@@ -373,11 +405,17 @@ where
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.project().transport.poll_flush(cx)
self.project()
.transport
.poll_flush(cx)
.map_err(ChannelError::Transport)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.project().transport.poll_close(cx)
self.project()
.transport
.poll_close(cx)
.map_err(ChannelError::Transport)
}
}
@@ -409,13 +447,13 @@ where
fn start_request(
self: Pin<&mut Self>,
id: u64,
request_id: u64,
deadline: SystemTime,
span: Span,
) -> Result<AbortRegistration, AlreadyExistsError> {
self.project()
.in_flight_requests
.start_request(id, deadline, span)
.start_request(request_id, deadline, span)
}
}
@@ -453,7 +491,7 @@ where
fn pump_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> PollIo<InFlightRequest<C::Req, C::Resp>> {
) -> Poll<Option<Result<InFlightRequest<C::Req, C::Resp>, C::Error>>> {
loop {
match ready!(self.channel_pin_mut().poll_next(cx)?) {
Some(mut request) => {
@@ -508,7 +546,7 @@ where
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
read_half_closed: bool,
) -> PollIo<()> {
) -> Poll<Option<Result<(), C::Error>>> {
match self.as_mut().poll_next_response(cx)? {
Poll::Ready(Some(response)) => {
// A Ready result from poll_next_response means the Channel is ready to be written
@@ -544,7 +582,7 @@ where
fn poll_next_response(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> PollIo<Response<C::Resp>> {
) -> Poll<Option<Result<Response<C::Resp>, C::Error>>> {
ready!(self.ensure_writeable(cx)?);
match ready!(self.pending_responses_mut().poll_recv(cx)) {
@@ -558,7 +596,10 @@ where
/// Returns Ready if writing a message to the Channel would not fail due to a full buffer. If
/// the Channel is not ready to be written to, flushes it until it is ready.
fn ensure_writeable<'a>(self: &'a mut Pin<&mut Self>, cx: &mut Context<'_>) -> PollIo<()> {
fn ensure_writeable<'a>(
self: &'a mut Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<(), C::Error>>> {
while self.channel_pin_mut().poll_ready(cx)?.is_pending() {
ready!(self.channel_pin_mut().poll_flush(cx)?);
}
@@ -620,6 +661,7 @@ impl<Req, Res> InFlightRequest<Req, Res> {
span.record("otel.name", &method.unwrap_or(""));
let _ = Abortable::new(
async move {
tracing::info!("BeginRequest");
let response = serve.serve(context, message).await;
tracing::info!("CompleteRequest");
let response = Response {
@@ -640,7 +682,7 @@ impl<C> Stream for Requests<C>
where
C: Channel,
{
type Item = io::Result<InFlightRequest<C::Req, C::Resp>>;
type Item = Result<InFlightRequest<C::Req, C::Resp>, C::Error>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
loop {
@@ -764,7 +806,7 @@ where
});
}
Err(e) => {
tracing::info!("Requests stream errored out: {}", e);
tracing::warn!("Requests stream errored out: {}", e);
break;
}
}

View File

@@ -1,15 +1,8 @@
use crate::{
util::{Compact, TimeUntil},
PollIo,
};
use crate::util::{Compact, TimeUntil};
use fnv::FnvHashMap;
use futures::{
future::{AbortHandle, AbortRegistration},
ready,
};
use futures::future::{AbortHandle, AbortRegistration};
use std::{
collections::hash_map,
io,
task::{Context, Poll},
time::SystemTime,
};
@@ -101,22 +94,21 @@ impl InFlightRequests {
}
/// Yields a request that has expired, aborting any ongoing processing of that request.
pub fn poll_expired(&mut self, cx: &mut Context) -> PollIo<u64> {
Poll::Ready(match ready!(self.deadlines.poll_expired(cx)) {
Some(Ok(expired)) => {
if let Some(RequestData {
abort_handle, span, ..
}) = self.request_data.remove(expired.get_ref())
{
let _entered = span.enter();
self.request_data.compact(0.1);
abort_handle.abort();
tracing::error!("DeadlineExceeded");
}
Some(Ok(expired.into_inner()))
pub fn poll_expired(
&mut self,
cx: &mut Context,
) -> Poll<Option<Result<u64, tokio::time::error::Error>>> {
self.deadlines.poll_expired(cx).map_ok(|expired| {
if let Some(RequestData {
abort_handle, span, ..
}) = self.request_data.remove(expired.get_ref())
{
let _entered = span.enter();
self.request_data.compact(0.1);
abort_handle.abort();
tracing::error!("DeadlineExceeded");
}
Some(Err(e)) => Some(Err(io::Error::new(io::ErrorKind::Other, e))),
None => None,
expired.into_inner()
})
}
}

View File

@@ -81,21 +81,24 @@ impl<C> Sink<Response<<C as Channel>::Resp>> for Throttler<C>
where
C: Channel,
{
type Error = io::Error;
type Error = C::Error;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.project().inner.poll_ready(cx)
}
fn start_send(self: Pin<&mut Self>, item: Response<<C as Channel>::Resp>) -> io::Result<()> {
fn start_send(
self: Pin<&mut Self>,
item: Response<<C as Channel>::Resp>,
) -> Result<(), Self::Error> {
self.project().inner.start_send(item)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.project().inner.poll_flush(cx)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.project().inner.poll_close(cx)
}
}

View File

@@ -9,22 +9,32 @@
//! The rpc crate is transport- and protocol-agnostic. Any transport that impls [`Transport`](sealed::Transport)
//! can be plugged in, using whatever protocol it wants.
use futures::prelude::*;
use std::io;
pub mod channel;
pub(crate) mod sealed {
use super::*;
use futures::prelude::*;
use std::error::Error;
/// A bidirectional stream ([`Sink`] + [`Stream`]) of messages.
pub trait Transport<SinkItem, Item>:
Stream<Item = io::Result<Item>> + Sink<SinkItem, Error = io::Error>
pub trait Transport<SinkItem, Item>
where
Self: Stream<Item = Result<Item, <Self as Sink<SinkItem>>::Error>>,
Self: Sink<SinkItem, Error = <Self as Transport<SinkItem, Item>>::TransportError>,
<Self as Sink<SinkItem>>::Error: Error,
{
/// Associated type where clauses are not elaborated; this associated type allows users
/// bounding types by Transport to avoid having to explicitly add `T::Error: Error` to their
/// bounds.
type TransportError: Error + Send + Sync + 'static;
}
impl<T, SinkItem, Item> Transport<SinkItem, Item> for T where
T: Stream<Item = io::Result<Item>> + Sink<SinkItem, Error = io::Error> + ?Sized
impl<T, SinkItem, Item, E> Transport<SinkItem, Item> for T
where
T: ?Sized,
T: Stream<Item = Result<Item, E>>,
T: Sink<SinkItem, Error = E>,
T::Error: Error + Send + Sync + 'static,
{
type TransportError = E;
}
}

View File

@@ -6,13 +6,19 @@
//! Transports backed by in-memory channels.
use crate::PollIo;
use futures::{task::*, Sink, Stream};
use pin_project::pin_project;
use std::io;
use std::pin::Pin;
use std::{error::Error, pin::Pin};
use tokio::sync::mpsc;
/// Errors that occur in the sending or receiving of messages over a channel.
#[derive(thiserror::Error, Debug)]
pub enum ChannelError {
/// An error occurred sending over the channel.
#[error("an error occurred sending over the channel")]
Send(#[source] Box<dyn Error + Send + Sync + 'static>),
}
/// Returns two unbounded channel peers. Each [`Stream`] yields items sent through the other's
/// [`Sink`].
pub fn unbounded<SinkItem, Item>() -> (
@@ -36,28 +42,33 @@ pub struct UnboundedChannel<Item, SinkItem> {
}
impl<Item, SinkItem> Stream for UnboundedChannel<Item, SinkItem> {
type Item = Result<Item, io::Error>;
type Item = Result<Item, ChannelError>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> PollIo<Item> {
fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Item, ChannelError>>> {
self.rx.poll_recv(cx).map(|option| option.map(Ok))
}
}
impl<Item, SinkItem> Sink<SinkItem> for UnboundedChannel<Item, SinkItem> {
type Error = io::Error;
const CLOSED_MESSAGE: &str = "the channel is closed and cannot accept new items for sending";
fn poll_ready(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
impl<Item, SinkItem> Sink<SinkItem> for UnboundedChannel<Item, SinkItem> {
type Error = ChannelError;
fn poll_ready(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(if self.tx.is_closed() {
Err(io::Error::from(io::ErrorKind::NotConnected))
Err(ChannelError::Send(CLOSED_MESSAGE.into()))
} else {
Ok(())
})
}
fn start_send(self: Pin<&mut Self>, item: SinkItem) -> io::Result<()> {
fn start_send(self: Pin<&mut Self>, item: SinkItem) -> Result<(), Self::Error> {
self.tx
.send(item)
.map_err(|_| io::Error::from(io::ErrorKind::NotConnected))
.map_err(|_| ChannelError::Send(CLOSED_MESSAGE.into()))
}
fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
@@ -65,7 +76,7 @@ impl<Item, SinkItem> Sink<SinkItem> for UnboundedChannel<Item, SinkItem> {
Poll::Ready(Ok(()))
}
fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
// UnboundedSender can't initiate closure.
Poll::Ready(Ok(()))
}
@@ -93,52 +104,45 @@ pub struct Channel<Item, SinkItem> {
}
impl<Item, SinkItem> Stream for Channel<Item, SinkItem> {
type Item = Result<Item, io::Error>;
type Item = Result<Item, ChannelError>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> PollIo<Item> {
fn poll_next(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Item, ChannelError>>> {
self.project().rx.poll_next(cx).map(|option| option.map(Ok))
}
}
impl<Item, SinkItem> Sink<SinkItem> for Channel<Item, SinkItem> {
type Error = io::Error;
type Error = ChannelError;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.project()
.tx
.poll_ready(cx)
.map_err(convert_send_err_to_io)
.map_err(|e| ChannelError::Send(Box::new(e)))
}
fn start_send(self: Pin<&mut Self>, item: SinkItem) -> io::Result<()> {
fn start_send(self: Pin<&mut Self>, item: SinkItem) -> Result<(), Self::Error> {
self.project()
.tx
.start_send(item)
.map_err(convert_send_err_to_io)
.map_err(|e| ChannelError::Send(Box::new(e)))
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.project()
.tx
.poll_flush(cx)
.map_err(convert_send_err_to_io)
.map_err(|e| ChannelError::Send(Box::new(e)))
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.project()
.tx
.poll_close(cx)
.map_err(convert_send_err_to_io)
}
}
fn convert_send_err_to_io(e: futures::channel::mpsc::SendError) -> io::Error {
if e.is_disconnected() {
io::Error::from(io::ErrorKind::NotConnected)
} else if e.is_full() {
io::Error::from(io::ErrorKind::WouldBlock)
} else {
io::Error::new(io::ErrorKind::Other, e)
.map_err(|e| ChannelError::Send(Box::new(e)))
}
}
@@ -148,13 +152,23 @@ mod tests {
use crate::{
client, context,
server::{BaseChannel, Incoming},
transport,
transport::{
self,
channel::{Channel, UnboundedChannel},
},
};
use assert_matches::assert_matches;
use futures::{prelude::*, stream};
use std::io;
use tracing::trace;
#[test]
fn ensure_is_transport() {
fn is_transport<SinkItem, Item, T: crate::Transport<SinkItem, Item>>() {}
is_transport::<(), (), UnboundedChannel<(), ()>>();
is_transport::<(), (), Channel<(), ()>>();
}
#[tokio::test]
async fn integration() -> io::Result<()> {
let _ = tracing_subscriber::fmt::try_init();
@@ -173,7 +187,7 @@ mod tests {
}),
);
let client = client::new(client::Config::default(), client_channel).spawn()?;
let client = client::new(client::Config::default(), client_channel).spawn();
let response1 = client.call(context::current(), "", "123".into()).await?;
let response2 = client.call(context::current(), "", "abc".into()).await?;

View File

@@ -45,7 +45,7 @@ async fn test_call() -> io::Result<()> {
);
let transport = serde_transport::tcp::connect(addr, Json::default).await?;
let client = ColorProtocolClient::new(client::Config::default(), transport).spawn()?;
let client = ColorProtocolClient::new(client::Config::default(), transport).spawn();
let color = client
.get_opposite_color(context::current(), TestData::White)

View File

@@ -3,10 +3,7 @@ use futures::{
future::{join_all, ready, Ready},
prelude::*,
};
use std::{
io,
time::{Duration, SystemTime},
};
use std::time::{Duration, SystemTime};
use tarpc::{
client::{self},
context,
@@ -39,7 +36,7 @@ impl Service for Server {
}
#[tokio::test]
async fn sequential() -> io::Result<()> {
async fn sequential() -> anyhow::Result<()> {
let _ = tracing_subscriber::fmt::try_init();
let (tx, rx) = channel::unbounded();
@@ -50,7 +47,7 @@ async fn sequential() -> io::Result<()> {
.execute(Server.serve()),
);
let client = ServiceClient::new(client::Config::default(), tx).spawn()?;
let client = ServiceClient::new(client::Config::default(), tx).spawn();
assert_matches!(client.add(context::current(), 1, 2).await, Ok(3));
assert_matches!(
@@ -61,7 +58,7 @@ async fn sequential() -> io::Result<()> {
}
#[tokio::test]
async fn dropped_channel_aborts_in_flight_requests() -> io::Result<()> {
async fn dropped_channel_aborts_in_flight_requests() -> anyhow::Result<()> {
#[tarpc_plugins::service]
trait Loop {
async fn r#loop();
@@ -89,9 +86,7 @@ async fn dropped_channel_aborts_in_flight_requests() -> io::Result<()> {
// Set up a client that initiates a long-lived request.
// The request will complete in error when the server drops the connection.
tokio::spawn(async move {
let client = LoopClient::new(client::Config::default(), tx)
.spawn()
.unwrap();
let client = LoopClient::new(client::Config::default(), tx).spawn();
let mut ctx = context::current();
ctx.deadline = SystemTime::now() + Duration::from_secs(60 * 60);
@@ -113,7 +108,7 @@ async fn dropped_channel_aborts_in_flight_requests() -> io::Result<()> {
#[cfg(all(feature = "serde-transport", feature = "tcp"))]
#[tokio::test]
async fn serde() -> io::Result<()> {
async fn serde() -> anyhow::Result<()> {
use tarpc::serde_transport;
use tokio_serde::formats::Json;
@@ -130,7 +125,7 @@ async fn serde() -> io::Result<()> {
);
let transport = serde_transport::tcp::connect(addr, Json::default).await?;
let client = ServiceClient::new(client::Config::default(), transport).spawn()?;
let client = ServiceClient::new(client::Config::default(), transport).spawn();
assert_matches!(client.add(context::current(), 1, 2).await, Ok(3));
assert_matches!(
@@ -142,7 +137,7 @@ async fn serde() -> io::Result<()> {
}
#[tokio::test]
async fn concurrent() -> io::Result<()> {
async fn concurrent() -> anyhow::Result<()> {
let _ = tracing_subscriber::fmt::try_init();
let (tx, rx) = channel::unbounded();
@@ -152,7 +147,7 @@ async fn concurrent() -> io::Result<()> {
.execute(Server.serve()),
);
let client = ServiceClient::new(client::Config::default(), tx).spawn()?;
let client = ServiceClient::new(client::Config::default(), tx).spawn();
let req1 = client.add(context::current(), 1, 2);
let req2 = client.add(context::current(), 3, 4);
@@ -166,7 +161,7 @@ async fn concurrent() -> io::Result<()> {
}
#[tokio::test]
async fn concurrent_join() -> io::Result<()> {
async fn concurrent_join() -> anyhow::Result<()> {
let _ = tracing_subscriber::fmt::try_init();
let (tx, rx) = channel::unbounded();
@@ -176,7 +171,7 @@ async fn concurrent_join() -> io::Result<()> {
.execute(Server.serve()),
);
let client = ServiceClient::new(client::Config::default(), tx).spawn()?;
let client = ServiceClient::new(client::Config::default(), tx).spawn();
let req1 = client.add(context::current(), 1, 2);
let req2 = client.add(context::current(), 3, 4);
@@ -191,7 +186,7 @@ async fn concurrent_join() -> io::Result<()> {
}
#[tokio::test]
async fn concurrent_join_all() -> io::Result<()> {
async fn concurrent_join_all() -> anyhow::Result<()> {
let _ = tracing_subscriber::fmt::try_init();
let (tx, rx) = channel::unbounded();
@@ -201,7 +196,7 @@ async fn concurrent_join_all() -> io::Result<()> {
.execute(Server.serve()),
);
let client = ServiceClient::new(client::Config::default(), tx).spawn()?;
let client = ServiceClient::new(client::Config::default(), tx).spawn();
let req1 = client.add(context::current(), 1, 2);
let req2 = client.add(context::current(), 3, 4);
@@ -214,7 +209,7 @@ async fn concurrent_join_all() -> io::Result<()> {
}
#[tokio::test]
async fn counter() -> io::Result<()> {
async fn counter() -> anyhow::Result<()> {
#[tarpc::service]
trait Counter {
async fn count() -> u32;
@@ -241,7 +236,7 @@ async fn counter() -> io::Result<()> {
}
});
let client = CounterClient::new(client::Config::default(), tx).spawn()?;
let client = CounterClient::new(client::Config::default(), tx).spawn();
assert_matches!(client.count(context::current()).await, Ok(1));
assert_matches!(client.count(context::current()).await, Ok(2));