Change RPC error type from io::Error => RpcError.

Becaue tarpc is a library, not an application, it should strive to
use structured errors in its API so that users have maximal flexibility
in how they handle errors. io::Error makes that hard, because it is a
kitchen-sink error type.

RPCs in particular only have 3 classes of errors:

- The connection breaks.
- The request expires.
- The server decides not to process the request.

(Of course, RPCs themselves can have application-specific errors, but
from the perspective of the RPC library, those can be classified as
successful responsees).
This commit is contained in:
Tim Kuehn
2021-04-20 18:29:55 -07:00
parent 82c4da1743
commit d0c11a6efa
9 changed files with 68 additions and 60 deletions

View File

@@ -746,7 +746,7 @@ impl<'a> ServiceGenerator<'a> {
#[allow(unused)]
#( #method_attrs )*
#vis fn #method_idents(&self, ctx: tarpc::context::Context, #( #args ),*)
-> impl std::future::Future<Output = std::io::Result<#return_types>> + '_ {
-> impl std::future::Future<Output = Result<#return_types, tarpc::client::RpcError>> + '_ {
let request = #request_ident::#camel_case_idents { #( #arg_pats ),* };
let resp = self.0.call(ctx, #request_names, request);
async move {

View File

@@ -1,4 +1,3 @@
use futures::future;
use tarpc::context::Context;
use tarpc::serde_transport as transport;
use tarpc::server::{BaseChannel, Channel};
@@ -14,16 +13,13 @@ pub trait PingService {
#[derive(Clone)]
struct Service;
#[tarpc::server]
impl PingService for Service {
type PingFut = future::Ready<()>;
fn ping(self, _: Context) -> Self::PingFut {
future::ready(())
}
async fn ping(self, _: Context) {}
}
#[tokio::main]
async fn main() -> std::io::Result<()> {
async fn main() -> anyhow::Result<()> {
let bind_addr = "/tmp/tarpc_on_unix_example.sock";
let _ = std::fs::remove_file(bind_addr);
@@ -46,5 +42,7 @@ async fn main() -> std::io::Result<()> {
PingServiceClient::new(Default::default(), transport)
.spawn()
.ping(tarpc::context::current())
.await
.await?;
Ok(())
}

View File

@@ -5,7 +5,6 @@
// https://opensource.org/licenses/MIT.
use futures::future::{self, Ready};
use std::io;
use tarpc::{
client, context,
server::{self, Channel},
@@ -35,7 +34,7 @@ impl World for HelloServer {
}
#[tokio::main]
async fn main() -> io::Result<()> {
async fn main() -> anyhow::Result<()> {
let (client_transport, server_transport) = tarpc::transport::channel::unbounded();
let server = server::BaseChannel::with_defaults(server_transport);

View File

@@ -8,14 +8,14 @@
mod in_flight_requests;
use crate::{context, trace, ClientMessage, Request, Response, Transport};
use crate::{context, trace, ClientMessage, Request, Response, ServerError, Transport};
use futures::{prelude::*, ready, stream::Fuse, task::*};
use in_flight_requests::InFlightRequests;
use in_flight_requests::{DeadlineExceededError, InFlightRequests};
use pin_project::pin_project;
use std::{
convert::TryFrom,
error::Error,
fmt, io, mem,
fmt, mem,
pin::Pin,
sync::{
atomic::{AtomicUsize, Ordering},
@@ -125,7 +125,7 @@ impl<Req, Resp> Channel<Req, Resp> {
mut ctx: context::Context,
request_name: &str,
request: Req,
) -> io::Result<Resp> {
) -> Result<Resp, RpcError> {
let span = Span::current();
ctx.trace_context = trace::Context::try_from(&span).unwrap_or_else(|_| {
tracing::warn!(
@@ -156,7 +156,7 @@ impl<Req, Resp> Channel<Req, Resp> {
response_completion,
})
.await
.map_err(|mpsc::error::SendError(_)| io::Error::from(io::ErrorKind::ConnectionReset))?;
.map_err(|mpsc::error::SendError(_)| RpcError::Disconnected)?;
response_guard.response().await
}
}
@@ -164,23 +164,44 @@ impl<Req, Resp> Channel<Req, Resp> {
/// A server response that is completed by request dispatch when the corresponding response
/// arrives off the wire.
struct ResponseGuard<'a, Resp> {
response: &'a mut oneshot::Receiver<Response<Resp>>,
response: &'a mut oneshot::Receiver<Result<Response<Resp>, DeadlineExceededError>>,
cancellation: &'a RequestCancellation,
request_id: u64,
}
/// An error that can occur in the processing of an RPC. This is not request-specific errors but
/// rather cross-cutting errors that can always occur.
#[derive(thiserror::Error, Debug)]
pub enum RpcError {
/// The client disconnected from the server.
#[error("the client disconnected from the server")]
Disconnected,
/// The request exceeded its deadline.
#[error("the request exceeded its deadline")]
DeadlineExceeded,
/// The server aborted request processing.
#[error("the server aborted request processing")]
Server(#[from] ServerError),
}
impl From<DeadlineExceededError> for RpcError {
fn from(_: DeadlineExceededError) -> Self {
RpcError::DeadlineExceeded
}
}
impl<Resp> ResponseGuard<'_, Resp> {
async fn response(mut self) -> io::Result<Resp> {
async fn response(mut self) -> Result<Resp, RpcError> {
let response = (&mut self.response).await;
// Cancel drop logic once a response has been received.
mem::forget(self);
match response {
Ok(resp) => Ok(resp.message?),
Ok(resp) => Ok(resp?.message?),
Err(oneshot::error::RecvError { .. }) => {
// The oneshot is Canceled when the dispatch task ends. In that case,
// there's nothing listening on the other side, so there's no point in
// propagating cancellation.
Err(io::Error::from(io::ErrorKind::ConnectionReset))
Err(RpcError::Disconnected)
}
}
}
@@ -549,7 +570,7 @@ struct DispatchRequest<Req, Resp> {
pub span: Span,
pub request_id: u64,
pub request: Req,
pub response_completion: oneshot::Sender<Response<Resp>>,
pub response_completion: oneshot::Sender<Result<Response<Resp>, DeadlineExceededError>>,
}
/// Sends request cancellation signals.
@@ -596,7 +617,10 @@ mod tests {
RequestDispatch, ResponseGuard,
};
use crate::{
client::{in_flight_requests::InFlightRequests, Config},
client::{
in_flight_requests::{DeadlineExceededError, InFlightRequests},
Config,
},
context,
transport::{self, channel::UnboundedChannel},
ClientMessage, Response,
@@ -630,7 +654,7 @@ mod tests {
.await
.unwrap();
assert_matches!(dispatch.as_mut().poll(cx), Poll::Pending);
assert_matches!(rx.try_recv(), Ok(Response { request_id: 0, message: Ok(resp) }) if resp == "Resp");
assert_matches!(rx.try_recv(), Ok(Ok(Response { request_id: 0, message: Ok(resp) })) if resp == "Resp");
}
#[tokio::test]
@@ -651,10 +675,10 @@ mod tests {
async fn dispatch_response_doesnt_cancel_after_complete() {
let (cancellation, mut canceled_requests) = cancellations();
let (tx, mut response) = oneshot::channel();
tx.send(Response {
tx.send(Ok(Response {
request_id: 0,
message: Ok("well done"),
})
}))
.unwrap();
// resp's drop() is run, but should not send a cancel message.
ResponseGuard {
@@ -799,8 +823,8 @@ mod tests {
async fn send_request<'a>(
channel: &'a mut Channel<String, String>,
request: &str,
response_completion: oneshot::Sender<Response<String>>,
response: &'a mut oneshot::Receiver<Response<String>>,
response_completion: oneshot::Sender<Result<Response<String>, DeadlineExceededError>>,
response: &'a mut oneshot::Receiver<Result<Response<String>, DeadlineExceededError>>,
) -> ResponseGuard<'a, String> {
let request_id =
u64::try_from(channel.next_request_id.fetch_add(1, Ordering::Relaxed)).unwrap();

View File

@@ -1,12 +1,11 @@
use crate::{
context,
util::{Compact, TimeUntil},
Response, ServerError,
Response,
};
use fnv::FnvHashMap;
use std::{
collections::hash_map,
io,
task::{Context, Poll},
};
use tokio::sync::oneshot;
@@ -29,11 +28,17 @@ impl<Resp> Default for InFlightRequests<Resp> {
}
}
/// The request exceeded its deadline.
#[derive(thiserror::Error, Debug)]
#[non_exhaustive]
#[error("the request exceeded its deadline")]
pub struct DeadlineExceededError;
#[derive(Debug)]
struct RequestData<Resp> {
ctx: context::Context,
span: Span,
response_completion: oneshot::Sender<Response<Resp>>,
response_completion: oneshot::Sender<Result<Response<Resp>, DeadlineExceededError>>,
/// The key to remove the timer for the request's deadline.
deadline_key: delay_queue::Key,
}
@@ -60,7 +65,7 @@ impl<Resp> InFlightRequests<Resp> {
request_id: u64,
ctx: context::Context,
span: Span,
response_completion: oneshot::Sender<Response<Resp>>,
response_completion: oneshot::Sender<Result<Response<Resp>, DeadlineExceededError>>,
) -> Result<(), AlreadyExistsError> {
match self.request_data.entry(request_id) {
hash_map::Entry::Vacant(vacant) => {
@@ -85,7 +90,7 @@ impl<Resp> InFlightRequests<Resp> {
tracing::info!("ReceiveResponse");
self.request_data.compact(0.1);
self.deadlines.remove(&request_data.deadline_key);
let _ = request_data.response_completion.send(response);
let _ = request_data.response_completion.send(Ok(response));
return true;
}
@@ -124,19 +129,9 @@ impl<Resp> InFlightRequests<Resp> {
self.request_data.compact(0.1);
let _ = request_data
.response_completion
.send(Self::deadline_exceeded_error(request_id));
.send(Err(DeadlineExceededError));
}
request_id
})
}
fn deadline_exceeded_error(request_id: u64) -> Response<Resp> {
Response {
request_id,
message: Err(ServerError {
kind: io::ErrorKind::TimedOut,
detail: Some("Client dropped expired request.".to_string()),
}),
}
}
}

View File

@@ -67,6 +67,7 @@
//! your `Cargo.toml`:
//!
//! ```toml
//! anyhow = "1.0"
//! futures = "1.0"
//! tarpc = { version = "0.26", features = ["tokio1"] }
//! tokio = { version = "1.0", features = ["macros"] }
@@ -89,7 +90,6 @@
//! client, context,
//! server::{self, Incoming},
//! };
//! use std::io;
//!
//! // This is the service definition. It looks a lot like a trait definition.
//! // It defines one RPC, hello, which takes one arg, name, and returns a String.
@@ -113,7 +113,6 @@
//! # client, context,
//! # server::{self, Incoming},
//! # };
//! # use std::io;
//! # // This is the service definition. It looks a lot like a trait definition.
//! # // It defines one RPC, hello, which takes one arg, name, and returns a String.
//! # #[tarpc::service]
@@ -153,7 +152,6 @@
//! # client, context,
//! # server::{self, Channel},
//! # };
//! # use std::io;
//! # // This is the service definition. It looks a lot like a trait definition.
//! # // It defines one RPC, hello, which takes one arg, name, and returns a String.
//! # #[tarpc::service]
@@ -177,7 +175,7 @@
//! # fn main() {}
//! # #[cfg(feature = "tokio1")]
//! #[tokio::main]
//! async fn main() -> io::Result<()> {
//! async fn main() -> anyhow::Result<()> {
//! let (client_transport, server_transport) = tarpc::transport::channel::unbounded();
//!
//! let server = server::BaseChannel::with_defaults(server_transport);
@@ -364,8 +362,9 @@ pub struct Response<T> {
pub message: Result<T, ServerError>,
}
/// An error response from a server to a client.
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
/// An error indicating the server aborted the request early, e.g., due to request throttling.
#[derive(thiserror::Error, Clone, Debug, PartialEq, Eq, Hash)]
#[error("{kind:?}: {detail}")]
#[non_exhaustive]
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
pub struct ServerError {
@@ -380,13 +379,7 @@ pub struct ServerError {
/// The type of error that occurred to fail the request.
pub kind: io::ErrorKind,
/// A message describing more detail about the error that occurred.
pub detail: Option<String>,
}
impl From<ServerError> for io::Error {
fn from(e: ServerError) -> io::Error {
io::Error::new(e.kind, e.detail.unwrap_or_default())
}
pub detail: String,
}
impl<T> Request<T> {

View File

@@ -66,7 +66,7 @@ where
request_id: request.id,
message: Err(ServerError {
kind: io::ErrorKind::WouldBlock,
detail: Some("Server throttled the request.".into()),
detail: "Server throttled the request.".into(),
}),
})?;
}

View File

@@ -170,7 +170,7 @@ mod tests {
}
#[tokio::test]
async fn integration() -> io::Result<()> {
async fn integration() -> anyhow::Result<()> {
let _ = tracing_subscriber::fmt::try_init();
let (client_channel, server_channel) = transport::channel::unbounded();

View File

@@ -1,5 +1,4 @@
use futures::prelude::*;
use std::io;
use tarpc::serde_transport;
use tarpc::{
client, context,
@@ -33,7 +32,7 @@ impl ColorProtocol for ColorServer {
}
#[tokio::test]
async fn test_call() -> io::Result<()> {
async fn test_call() -> anyhow::Result<()> {
let transport = tarpc::serde_transport::tcp::listen("localhost:56797", Json::default).await?;
let addr = transport.local_addr();
tokio::spawn(