From d0c11a6efaa6c86b6735feb36f4979e050691b68 Mon Sep 17 00:00:00 2001 From: Tim Kuehn Date: Tue, 20 Apr 2021 18:29:55 -0700 Subject: [PATCH] Change RPC error type from io::Error => RpcError. Becaue tarpc is a library, not an application, it should strive to use structured errors in its API so that users have maximal flexibility in how they handle errors. io::Error makes that hard, because it is a kitchen-sink error type. RPCs in particular only have 3 classes of errors: - The connection breaks. - The request expires. - The server decides not to process the request. (Of course, RPCs themselves can have application-specific errors, but from the perspective of the RPC library, those can be classified as successful responsees). --- plugins/src/lib.rs | 2 +- tarpc/examples/custom_transport.rs | 14 +++---- tarpc/examples/readme.rs | 3 +- tarpc/src/client.rs | 56 ++++++++++++++++++-------- tarpc/src/client/in_flight_requests.rs | 27 +++++-------- tarpc/src/lib.rs | 19 +++------ tarpc/src/server/throttle.rs | 2 +- tarpc/src/transport/channel.rs | 2 +- tarpc/tests/dataservice.rs | 3 +- 9 files changed, 68 insertions(+), 60 deletions(-) diff --git a/plugins/src/lib.rs b/plugins/src/lib.rs index 373d575..2645398 100644 --- a/plugins/src/lib.rs +++ b/plugins/src/lib.rs @@ -746,7 +746,7 @@ impl<'a> ServiceGenerator<'a> { #[allow(unused)] #( #method_attrs )* #vis fn #method_idents(&self, ctx: tarpc::context::Context, #( #args ),*) - -> impl std::future::Future> + '_ { + -> impl std::future::Future> + '_ { let request = #request_ident::#camel_case_idents { #( #arg_pats ),* }; let resp = self.0.call(ctx, #request_names, request); async move { diff --git a/tarpc/examples/custom_transport.rs b/tarpc/examples/custom_transport.rs index 8ead447..118cae8 100644 --- a/tarpc/examples/custom_transport.rs +++ b/tarpc/examples/custom_transport.rs @@ -1,4 +1,3 @@ -use futures::future; use tarpc::context::Context; use tarpc::serde_transport as transport; use tarpc::server::{BaseChannel, Channel}; @@ -14,16 +13,13 @@ pub trait PingService { #[derive(Clone)] struct Service; +#[tarpc::server] impl PingService for Service { - type PingFut = future::Ready<()>; - - fn ping(self, _: Context) -> Self::PingFut { - future::ready(()) - } + async fn ping(self, _: Context) {} } #[tokio::main] -async fn main() -> std::io::Result<()> { +async fn main() -> anyhow::Result<()> { let bind_addr = "/tmp/tarpc_on_unix_example.sock"; let _ = std::fs::remove_file(bind_addr); @@ -46,5 +42,7 @@ async fn main() -> std::io::Result<()> { PingServiceClient::new(Default::default(), transport) .spawn() .ping(tarpc::context::current()) - .await + .await?; + + Ok(()) } diff --git a/tarpc/examples/readme.rs b/tarpc/examples/readme.rs index ab455af..b548ddf 100644 --- a/tarpc/examples/readme.rs +++ b/tarpc/examples/readme.rs @@ -5,7 +5,6 @@ // https://opensource.org/licenses/MIT. use futures::future::{self, Ready}; -use std::io; use tarpc::{ client, context, server::{self, Channel}, @@ -35,7 +34,7 @@ impl World for HelloServer { } #[tokio::main] -async fn main() -> io::Result<()> { +async fn main() -> anyhow::Result<()> { let (client_transport, server_transport) = tarpc::transport::channel::unbounded(); let server = server::BaseChannel::with_defaults(server_transport); diff --git a/tarpc/src/client.rs b/tarpc/src/client.rs index 6b169c2..2bb32a5 100644 --- a/tarpc/src/client.rs +++ b/tarpc/src/client.rs @@ -8,14 +8,14 @@ mod in_flight_requests; -use crate::{context, trace, ClientMessage, Request, Response, Transport}; +use crate::{context, trace, ClientMessage, Request, Response, ServerError, Transport}; use futures::{prelude::*, ready, stream::Fuse, task::*}; -use in_flight_requests::InFlightRequests; +use in_flight_requests::{DeadlineExceededError, InFlightRequests}; use pin_project::pin_project; use std::{ convert::TryFrom, error::Error, - fmt, io, mem, + fmt, mem, pin::Pin, sync::{ atomic::{AtomicUsize, Ordering}, @@ -125,7 +125,7 @@ impl Channel { mut ctx: context::Context, request_name: &str, request: Req, - ) -> io::Result { + ) -> Result { let span = Span::current(); ctx.trace_context = trace::Context::try_from(&span).unwrap_or_else(|_| { tracing::warn!( @@ -156,7 +156,7 @@ impl Channel { response_completion, }) .await - .map_err(|mpsc::error::SendError(_)| io::Error::from(io::ErrorKind::ConnectionReset))?; + .map_err(|mpsc::error::SendError(_)| RpcError::Disconnected)?; response_guard.response().await } } @@ -164,23 +164,44 @@ impl Channel { /// A server response that is completed by request dispatch when the corresponding response /// arrives off the wire. struct ResponseGuard<'a, Resp> { - response: &'a mut oneshot::Receiver>, + response: &'a mut oneshot::Receiver, DeadlineExceededError>>, cancellation: &'a RequestCancellation, request_id: u64, } +/// An error that can occur in the processing of an RPC. This is not request-specific errors but +/// rather cross-cutting errors that can always occur. +#[derive(thiserror::Error, Debug)] +pub enum RpcError { + /// The client disconnected from the server. + #[error("the client disconnected from the server")] + Disconnected, + /// The request exceeded its deadline. + #[error("the request exceeded its deadline")] + DeadlineExceeded, + /// The server aborted request processing. + #[error("the server aborted request processing")] + Server(#[from] ServerError), +} + +impl From for RpcError { + fn from(_: DeadlineExceededError) -> Self { + RpcError::DeadlineExceeded + } +} + impl ResponseGuard<'_, Resp> { - async fn response(mut self) -> io::Result { + async fn response(mut self) -> Result { let response = (&mut self.response).await; // Cancel drop logic once a response has been received. mem::forget(self); match response { - Ok(resp) => Ok(resp.message?), + Ok(resp) => Ok(resp?.message?), Err(oneshot::error::RecvError { .. }) => { // The oneshot is Canceled when the dispatch task ends. In that case, // there's nothing listening on the other side, so there's no point in // propagating cancellation. - Err(io::Error::from(io::ErrorKind::ConnectionReset)) + Err(RpcError::Disconnected) } } } @@ -549,7 +570,7 @@ struct DispatchRequest { pub span: Span, pub request_id: u64, pub request: Req, - pub response_completion: oneshot::Sender>, + pub response_completion: oneshot::Sender, DeadlineExceededError>>, } /// Sends request cancellation signals. @@ -596,7 +617,10 @@ mod tests { RequestDispatch, ResponseGuard, }; use crate::{ - client::{in_flight_requests::InFlightRequests, Config}, + client::{ + in_flight_requests::{DeadlineExceededError, InFlightRequests}, + Config, + }, context, transport::{self, channel::UnboundedChannel}, ClientMessage, Response, @@ -630,7 +654,7 @@ mod tests { .await .unwrap(); assert_matches!(dispatch.as_mut().poll(cx), Poll::Pending); - assert_matches!(rx.try_recv(), Ok(Response { request_id: 0, message: Ok(resp) }) if resp == "Resp"); + assert_matches!(rx.try_recv(), Ok(Ok(Response { request_id: 0, message: Ok(resp) })) if resp == "Resp"); } #[tokio::test] @@ -651,10 +675,10 @@ mod tests { async fn dispatch_response_doesnt_cancel_after_complete() { let (cancellation, mut canceled_requests) = cancellations(); let (tx, mut response) = oneshot::channel(); - tx.send(Response { + tx.send(Ok(Response { request_id: 0, message: Ok("well done"), - }) + })) .unwrap(); // resp's drop() is run, but should not send a cancel message. ResponseGuard { @@ -799,8 +823,8 @@ mod tests { async fn send_request<'a>( channel: &'a mut Channel, request: &str, - response_completion: oneshot::Sender>, - response: &'a mut oneshot::Receiver>, + response_completion: oneshot::Sender, DeadlineExceededError>>, + response: &'a mut oneshot::Receiver, DeadlineExceededError>>, ) -> ResponseGuard<'a, String> { let request_id = u64::try_from(channel.next_request_id.fetch_add(1, Ordering::Relaxed)).unwrap(); diff --git a/tarpc/src/client/in_flight_requests.rs b/tarpc/src/client/in_flight_requests.rs index ab9941f..0758691 100644 --- a/tarpc/src/client/in_flight_requests.rs +++ b/tarpc/src/client/in_flight_requests.rs @@ -1,12 +1,11 @@ use crate::{ context, util::{Compact, TimeUntil}, - Response, ServerError, + Response, }; use fnv::FnvHashMap; use std::{ collections::hash_map, - io, task::{Context, Poll}, }; use tokio::sync::oneshot; @@ -29,11 +28,17 @@ impl Default for InFlightRequests { } } +/// The request exceeded its deadline. +#[derive(thiserror::Error, Debug)] +#[non_exhaustive] +#[error("the request exceeded its deadline")] +pub struct DeadlineExceededError; + #[derive(Debug)] struct RequestData { ctx: context::Context, span: Span, - response_completion: oneshot::Sender>, + response_completion: oneshot::Sender, DeadlineExceededError>>, /// The key to remove the timer for the request's deadline. deadline_key: delay_queue::Key, } @@ -60,7 +65,7 @@ impl InFlightRequests { request_id: u64, ctx: context::Context, span: Span, - response_completion: oneshot::Sender>, + response_completion: oneshot::Sender, DeadlineExceededError>>, ) -> Result<(), AlreadyExistsError> { match self.request_data.entry(request_id) { hash_map::Entry::Vacant(vacant) => { @@ -85,7 +90,7 @@ impl InFlightRequests { tracing::info!("ReceiveResponse"); self.request_data.compact(0.1); self.deadlines.remove(&request_data.deadline_key); - let _ = request_data.response_completion.send(response); + let _ = request_data.response_completion.send(Ok(response)); return true; } @@ -124,19 +129,9 @@ impl InFlightRequests { self.request_data.compact(0.1); let _ = request_data .response_completion - .send(Self::deadline_exceeded_error(request_id)); + .send(Err(DeadlineExceededError)); } request_id }) } - - fn deadline_exceeded_error(request_id: u64) -> Response { - Response { - request_id, - message: Err(ServerError { - kind: io::ErrorKind::TimedOut, - detail: Some("Client dropped expired request.".to_string()), - }), - } - } } diff --git a/tarpc/src/lib.rs b/tarpc/src/lib.rs index 2bced52..ae562e2 100644 --- a/tarpc/src/lib.rs +++ b/tarpc/src/lib.rs @@ -67,6 +67,7 @@ //! your `Cargo.toml`: //! //! ```toml +//! anyhow = "1.0" //! futures = "1.0" //! tarpc = { version = "0.26", features = ["tokio1"] } //! tokio = { version = "1.0", features = ["macros"] } @@ -89,7 +90,6 @@ //! client, context, //! server::{self, Incoming}, //! }; -//! use std::io; //! //! // This is the service definition. It looks a lot like a trait definition. //! // It defines one RPC, hello, which takes one arg, name, and returns a String. @@ -113,7 +113,6 @@ //! # client, context, //! # server::{self, Incoming}, //! # }; -//! # use std::io; //! # // This is the service definition. It looks a lot like a trait definition. //! # // It defines one RPC, hello, which takes one arg, name, and returns a String. //! # #[tarpc::service] @@ -153,7 +152,6 @@ //! # client, context, //! # server::{self, Channel}, //! # }; -//! # use std::io; //! # // This is the service definition. It looks a lot like a trait definition. //! # // It defines one RPC, hello, which takes one arg, name, and returns a String. //! # #[tarpc::service] @@ -177,7 +175,7 @@ //! # fn main() {} //! # #[cfg(feature = "tokio1")] //! #[tokio::main] -//! async fn main() -> io::Result<()> { +//! async fn main() -> anyhow::Result<()> { //! let (client_transport, server_transport) = tarpc::transport::channel::unbounded(); //! //! let server = server::BaseChannel::with_defaults(server_transport); @@ -364,8 +362,9 @@ pub struct Response { pub message: Result, } -/// An error response from a server to a client. -#[derive(Clone, Debug, PartialEq, Eq, Hash)] +/// An error indicating the server aborted the request early, e.g., due to request throttling. +#[derive(thiserror::Error, Clone, Debug, PartialEq, Eq, Hash)] +#[error("{kind:?}: {detail}")] #[non_exhaustive] #[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] pub struct ServerError { @@ -380,13 +379,7 @@ pub struct ServerError { /// The type of error that occurred to fail the request. pub kind: io::ErrorKind, /// A message describing more detail about the error that occurred. - pub detail: Option, -} - -impl From for io::Error { - fn from(e: ServerError) -> io::Error { - io::Error::new(e.kind, e.detail.unwrap_or_default()) - } + pub detail: String, } impl Request { diff --git a/tarpc/src/server/throttle.rs b/tarpc/src/server/throttle.rs index 59a6c44..477a920 100644 --- a/tarpc/src/server/throttle.rs +++ b/tarpc/src/server/throttle.rs @@ -66,7 +66,7 @@ where request_id: request.id, message: Err(ServerError { kind: io::ErrorKind::WouldBlock, - detail: Some("Server throttled the request.".into()), + detail: "Server throttled the request.".into(), }), })?; } diff --git a/tarpc/src/transport/channel.rs b/tarpc/src/transport/channel.rs index d131886..f6dee17 100644 --- a/tarpc/src/transport/channel.rs +++ b/tarpc/src/transport/channel.rs @@ -170,7 +170,7 @@ mod tests { } #[tokio::test] - async fn integration() -> io::Result<()> { + async fn integration() -> anyhow::Result<()> { let _ = tracing_subscriber::fmt::try_init(); let (client_channel, server_channel) = transport::channel::unbounded(); diff --git a/tarpc/tests/dataservice.rs b/tarpc/tests/dataservice.rs index 7e37dae..4cc52d6 100644 --- a/tarpc/tests/dataservice.rs +++ b/tarpc/tests/dataservice.rs @@ -1,5 +1,4 @@ use futures::prelude::*; -use std::io; use tarpc::serde_transport; use tarpc::{ client, context, @@ -33,7 +32,7 @@ impl ColorProtocol for ColorServer { } #[tokio::test] -async fn test_call() -> io::Result<()> { +async fn test_call() -> anyhow::Result<()> { let transport = tarpc::serde_transport::tcp::listen("localhost:56797", Json::default).await?; let addr = transport.local_addr(); tokio::spawn(