mirror of
https://github.com/OMGeeky/tarpc.git
synced 2026-01-01 09:03:48 +01:00
Change RPC error type from io::Error => RpcError.
Becaue tarpc is a library, not an application, it should strive to use structured errors in its API so that users have maximal flexibility in how they handle errors. io::Error makes that hard, because it is a kitchen-sink error type. RPCs in particular only have 3 classes of errors: - The connection breaks. - The request expires. - The server decides not to process the request. (Of course, RPCs themselves can have application-specific errors, but from the perspective of the RPC library, those can be classified as successful responsees).
This commit is contained in:
@@ -746,7 +746,7 @@ impl<'a> ServiceGenerator<'a> {
|
||||
#[allow(unused)]
|
||||
#( #method_attrs )*
|
||||
#vis fn #method_idents(&self, ctx: tarpc::context::Context, #( #args ),*)
|
||||
-> impl std::future::Future<Output = std::io::Result<#return_types>> + '_ {
|
||||
-> impl std::future::Future<Output = Result<#return_types, tarpc::client::RpcError>> + '_ {
|
||||
let request = #request_ident::#camel_case_idents { #( #arg_pats ),* };
|
||||
let resp = self.0.call(ctx, #request_names, request);
|
||||
async move {
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
use futures::future;
|
||||
use tarpc::context::Context;
|
||||
use tarpc::serde_transport as transport;
|
||||
use tarpc::server::{BaseChannel, Channel};
|
||||
@@ -14,16 +13,13 @@ pub trait PingService {
|
||||
#[derive(Clone)]
|
||||
struct Service;
|
||||
|
||||
#[tarpc::server]
|
||||
impl PingService for Service {
|
||||
type PingFut = future::Ready<()>;
|
||||
|
||||
fn ping(self, _: Context) -> Self::PingFut {
|
||||
future::ready(())
|
||||
}
|
||||
async fn ping(self, _: Context) {}
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> std::io::Result<()> {
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
let bind_addr = "/tmp/tarpc_on_unix_example.sock";
|
||||
|
||||
let _ = std::fs::remove_file(bind_addr);
|
||||
@@ -46,5 +42,7 @@ async fn main() -> std::io::Result<()> {
|
||||
PingServiceClient::new(Default::default(), transport)
|
||||
.spawn()
|
||||
.ping(tarpc::context::current())
|
||||
.await
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -5,7 +5,6 @@
|
||||
// https://opensource.org/licenses/MIT.
|
||||
|
||||
use futures::future::{self, Ready};
|
||||
use std::io;
|
||||
use tarpc::{
|
||||
client, context,
|
||||
server::{self, Channel},
|
||||
@@ -35,7 +34,7 @@ impl World for HelloServer {
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> io::Result<()> {
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
let (client_transport, server_transport) = tarpc::transport::channel::unbounded();
|
||||
|
||||
let server = server::BaseChannel::with_defaults(server_transport);
|
||||
|
||||
@@ -8,14 +8,14 @@
|
||||
|
||||
mod in_flight_requests;
|
||||
|
||||
use crate::{context, trace, ClientMessage, Request, Response, Transport};
|
||||
use crate::{context, trace, ClientMessage, Request, Response, ServerError, Transport};
|
||||
use futures::{prelude::*, ready, stream::Fuse, task::*};
|
||||
use in_flight_requests::InFlightRequests;
|
||||
use in_flight_requests::{DeadlineExceededError, InFlightRequests};
|
||||
use pin_project::pin_project;
|
||||
use std::{
|
||||
convert::TryFrom,
|
||||
error::Error,
|
||||
fmt, io, mem,
|
||||
fmt, mem,
|
||||
pin::Pin,
|
||||
sync::{
|
||||
atomic::{AtomicUsize, Ordering},
|
||||
@@ -125,7 +125,7 @@ impl<Req, Resp> Channel<Req, Resp> {
|
||||
mut ctx: context::Context,
|
||||
request_name: &str,
|
||||
request: Req,
|
||||
) -> io::Result<Resp> {
|
||||
) -> Result<Resp, RpcError> {
|
||||
let span = Span::current();
|
||||
ctx.trace_context = trace::Context::try_from(&span).unwrap_or_else(|_| {
|
||||
tracing::warn!(
|
||||
@@ -156,7 +156,7 @@ impl<Req, Resp> Channel<Req, Resp> {
|
||||
response_completion,
|
||||
})
|
||||
.await
|
||||
.map_err(|mpsc::error::SendError(_)| io::Error::from(io::ErrorKind::ConnectionReset))?;
|
||||
.map_err(|mpsc::error::SendError(_)| RpcError::Disconnected)?;
|
||||
response_guard.response().await
|
||||
}
|
||||
}
|
||||
@@ -164,23 +164,44 @@ impl<Req, Resp> Channel<Req, Resp> {
|
||||
/// A server response that is completed by request dispatch when the corresponding response
|
||||
/// arrives off the wire.
|
||||
struct ResponseGuard<'a, Resp> {
|
||||
response: &'a mut oneshot::Receiver<Response<Resp>>,
|
||||
response: &'a mut oneshot::Receiver<Result<Response<Resp>, DeadlineExceededError>>,
|
||||
cancellation: &'a RequestCancellation,
|
||||
request_id: u64,
|
||||
}
|
||||
|
||||
/// An error that can occur in the processing of an RPC. This is not request-specific errors but
|
||||
/// rather cross-cutting errors that can always occur.
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum RpcError {
|
||||
/// The client disconnected from the server.
|
||||
#[error("the client disconnected from the server")]
|
||||
Disconnected,
|
||||
/// The request exceeded its deadline.
|
||||
#[error("the request exceeded its deadline")]
|
||||
DeadlineExceeded,
|
||||
/// The server aborted request processing.
|
||||
#[error("the server aborted request processing")]
|
||||
Server(#[from] ServerError),
|
||||
}
|
||||
|
||||
impl From<DeadlineExceededError> for RpcError {
|
||||
fn from(_: DeadlineExceededError) -> Self {
|
||||
RpcError::DeadlineExceeded
|
||||
}
|
||||
}
|
||||
|
||||
impl<Resp> ResponseGuard<'_, Resp> {
|
||||
async fn response(mut self) -> io::Result<Resp> {
|
||||
async fn response(mut self) -> Result<Resp, RpcError> {
|
||||
let response = (&mut self.response).await;
|
||||
// Cancel drop logic once a response has been received.
|
||||
mem::forget(self);
|
||||
match response {
|
||||
Ok(resp) => Ok(resp.message?),
|
||||
Ok(resp) => Ok(resp?.message?),
|
||||
Err(oneshot::error::RecvError { .. }) => {
|
||||
// The oneshot is Canceled when the dispatch task ends. In that case,
|
||||
// there's nothing listening on the other side, so there's no point in
|
||||
// propagating cancellation.
|
||||
Err(io::Error::from(io::ErrorKind::ConnectionReset))
|
||||
Err(RpcError::Disconnected)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -549,7 +570,7 @@ struct DispatchRequest<Req, Resp> {
|
||||
pub span: Span,
|
||||
pub request_id: u64,
|
||||
pub request: Req,
|
||||
pub response_completion: oneshot::Sender<Response<Resp>>,
|
||||
pub response_completion: oneshot::Sender<Result<Response<Resp>, DeadlineExceededError>>,
|
||||
}
|
||||
|
||||
/// Sends request cancellation signals.
|
||||
@@ -596,7 +617,10 @@ mod tests {
|
||||
RequestDispatch, ResponseGuard,
|
||||
};
|
||||
use crate::{
|
||||
client::{in_flight_requests::InFlightRequests, Config},
|
||||
client::{
|
||||
in_flight_requests::{DeadlineExceededError, InFlightRequests},
|
||||
Config,
|
||||
},
|
||||
context,
|
||||
transport::{self, channel::UnboundedChannel},
|
||||
ClientMessage, Response,
|
||||
@@ -630,7 +654,7 @@ mod tests {
|
||||
.await
|
||||
.unwrap();
|
||||
assert_matches!(dispatch.as_mut().poll(cx), Poll::Pending);
|
||||
assert_matches!(rx.try_recv(), Ok(Response { request_id: 0, message: Ok(resp) }) if resp == "Resp");
|
||||
assert_matches!(rx.try_recv(), Ok(Ok(Response { request_id: 0, message: Ok(resp) })) if resp == "Resp");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -651,10 +675,10 @@ mod tests {
|
||||
async fn dispatch_response_doesnt_cancel_after_complete() {
|
||||
let (cancellation, mut canceled_requests) = cancellations();
|
||||
let (tx, mut response) = oneshot::channel();
|
||||
tx.send(Response {
|
||||
tx.send(Ok(Response {
|
||||
request_id: 0,
|
||||
message: Ok("well done"),
|
||||
})
|
||||
}))
|
||||
.unwrap();
|
||||
// resp's drop() is run, but should not send a cancel message.
|
||||
ResponseGuard {
|
||||
@@ -799,8 +823,8 @@ mod tests {
|
||||
async fn send_request<'a>(
|
||||
channel: &'a mut Channel<String, String>,
|
||||
request: &str,
|
||||
response_completion: oneshot::Sender<Response<String>>,
|
||||
response: &'a mut oneshot::Receiver<Response<String>>,
|
||||
response_completion: oneshot::Sender<Result<Response<String>, DeadlineExceededError>>,
|
||||
response: &'a mut oneshot::Receiver<Result<Response<String>, DeadlineExceededError>>,
|
||||
) -> ResponseGuard<'a, String> {
|
||||
let request_id =
|
||||
u64::try_from(channel.next_request_id.fetch_add(1, Ordering::Relaxed)).unwrap();
|
||||
|
||||
@@ -1,12 +1,11 @@
|
||||
use crate::{
|
||||
context,
|
||||
util::{Compact, TimeUntil},
|
||||
Response, ServerError,
|
||||
Response,
|
||||
};
|
||||
use fnv::FnvHashMap;
|
||||
use std::{
|
||||
collections::hash_map,
|
||||
io,
|
||||
task::{Context, Poll},
|
||||
};
|
||||
use tokio::sync::oneshot;
|
||||
@@ -29,11 +28,17 @@ impl<Resp> Default for InFlightRequests<Resp> {
|
||||
}
|
||||
}
|
||||
|
||||
/// The request exceeded its deadline.
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
#[non_exhaustive]
|
||||
#[error("the request exceeded its deadline")]
|
||||
pub struct DeadlineExceededError;
|
||||
|
||||
#[derive(Debug)]
|
||||
struct RequestData<Resp> {
|
||||
ctx: context::Context,
|
||||
span: Span,
|
||||
response_completion: oneshot::Sender<Response<Resp>>,
|
||||
response_completion: oneshot::Sender<Result<Response<Resp>, DeadlineExceededError>>,
|
||||
/// The key to remove the timer for the request's deadline.
|
||||
deadline_key: delay_queue::Key,
|
||||
}
|
||||
@@ -60,7 +65,7 @@ impl<Resp> InFlightRequests<Resp> {
|
||||
request_id: u64,
|
||||
ctx: context::Context,
|
||||
span: Span,
|
||||
response_completion: oneshot::Sender<Response<Resp>>,
|
||||
response_completion: oneshot::Sender<Result<Response<Resp>, DeadlineExceededError>>,
|
||||
) -> Result<(), AlreadyExistsError> {
|
||||
match self.request_data.entry(request_id) {
|
||||
hash_map::Entry::Vacant(vacant) => {
|
||||
@@ -85,7 +90,7 @@ impl<Resp> InFlightRequests<Resp> {
|
||||
tracing::info!("ReceiveResponse");
|
||||
self.request_data.compact(0.1);
|
||||
self.deadlines.remove(&request_data.deadline_key);
|
||||
let _ = request_data.response_completion.send(response);
|
||||
let _ = request_data.response_completion.send(Ok(response));
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -124,19 +129,9 @@ impl<Resp> InFlightRequests<Resp> {
|
||||
self.request_data.compact(0.1);
|
||||
let _ = request_data
|
||||
.response_completion
|
||||
.send(Self::deadline_exceeded_error(request_id));
|
||||
.send(Err(DeadlineExceededError));
|
||||
}
|
||||
request_id
|
||||
})
|
||||
}
|
||||
|
||||
fn deadline_exceeded_error(request_id: u64) -> Response<Resp> {
|
||||
Response {
|
||||
request_id,
|
||||
message: Err(ServerError {
|
||||
kind: io::ErrorKind::TimedOut,
|
||||
detail: Some("Client dropped expired request.".to_string()),
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -67,6 +67,7 @@
|
||||
//! your `Cargo.toml`:
|
||||
//!
|
||||
//! ```toml
|
||||
//! anyhow = "1.0"
|
||||
//! futures = "1.0"
|
||||
//! tarpc = { version = "0.26", features = ["tokio1"] }
|
||||
//! tokio = { version = "1.0", features = ["macros"] }
|
||||
@@ -89,7 +90,6 @@
|
||||
//! client, context,
|
||||
//! server::{self, Incoming},
|
||||
//! };
|
||||
//! use std::io;
|
||||
//!
|
||||
//! // This is the service definition. It looks a lot like a trait definition.
|
||||
//! // It defines one RPC, hello, which takes one arg, name, and returns a String.
|
||||
@@ -113,7 +113,6 @@
|
||||
//! # client, context,
|
||||
//! # server::{self, Incoming},
|
||||
//! # };
|
||||
//! # use std::io;
|
||||
//! # // This is the service definition. It looks a lot like a trait definition.
|
||||
//! # // It defines one RPC, hello, which takes one arg, name, and returns a String.
|
||||
//! # #[tarpc::service]
|
||||
@@ -153,7 +152,6 @@
|
||||
//! # client, context,
|
||||
//! # server::{self, Channel},
|
||||
//! # };
|
||||
//! # use std::io;
|
||||
//! # // This is the service definition. It looks a lot like a trait definition.
|
||||
//! # // It defines one RPC, hello, which takes one arg, name, and returns a String.
|
||||
//! # #[tarpc::service]
|
||||
@@ -177,7 +175,7 @@
|
||||
//! # fn main() {}
|
||||
//! # #[cfg(feature = "tokio1")]
|
||||
//! #[tokio::main]
|
||||
//! async fn main() -> io::Result<()> {
|
||||
//! async fn main() -> anyhow::Result<()> {
|
||||
//! let (client_transport, server_transport) = tarpc::transport::channel::unbounded();
|
||||
//!
|
||||
//! let server = server::BaseChannel::with_defaults(server_transport);
|
||||
@@ -364,8 +362,9 @@ pub struct Response<T> {
|
||||
pub message: Result<T, ServerError>,
|
||||
}
|
||||
|
||||
/// An error response from a server to a client.
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
|
||||
/// An error indicating the server aborted the request early, e.g., due to request throttling.
|
||||
#[derive(thiserror::Error, Clone, Debug, PartialEq, Eq, Hash)]
|
||||
#[error("{kind:?}: {detail}")]
|
||||
#[non_exhaustive]
|
||||
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
|
||||
pub struct ServerError {
|
||||
@@ -380,13 +379,7 @@ pub struct ServerError {
|
||||
/// The type of error that occurred to fail the request.
|
||||
pub kind: io::ErrorKind,
|
||||
/// A message describing more detail about the error that occurred.
|
||||
pub detail: Option<String>,
|
||||
}
|
||||
|
||||
impl From<ServerError> for io::Error {
|
||||
fn from(e: ServerError) -> io::Error {
|
||||
io::Error::new(e.kind, e.detail.unwrap_or_default())
|
||||
}
|
||||
pub detail: String,
|
||||
}
|
||||
|
||||
impl<T> Request<T> {
|
||||
|
||||
@@ -66,7 +66,7 @@ where
|
||||
request_id: request.id,
|
||||
message: Err(ServerError {
|
||||
kind: io::ErrorKind::WouldBlock,
|
||||
detail: Some("Server throttled the request.".into()),
|
||||
detail: "Server throttled the request.".into(),
|
||||
}),
|
||||
})?;
|
||||
}
|
||||
|
||||
@@ -170,7 +170,7 @@ mod tests {
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn integration() -> io::Result<()> {
|
||||
async fn integration() -> anyhow::Result<()> {
|
||||
let _ = tracing_subscriber::fmt::try_init();
|
||||
|
||||
let (client_channel, server_channel) = transport::channel::unbounded();
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
use futures::prelude::*;
|
||||
use std::io;
|
||||
use tarpc::serde_transport;
|
||||
use tarpc::{
|
||||
client, context,
|
||||
@@ -33,7 +32,7 @@ impl ColorProtocol for ColorServer {
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_call() -> io::Result<()> {
|
||||
async fn test_call() -> anyhow::Result<()> {
|
||||
let transport = tarpc::serde_transport::tcp::listen("localhost:56797", Json::default).await?;
|
||||
let addr = transport.local_addr();
|
||||
tokio::spawn(
|
||||
|
||||
Reference in New Issue
Block a user