From 7c5afa97bb3d3d964eab857e8a9ef3ec113bf6e7 Mon Sep 17 00:00:00 2001 From: Tim Kuehn Date: Mon, 7 Nov 2022 18:37:58 -0800 Subject: [PATCH] Add request hooks to the Serve trait. This allows plugging in horizontal functionality, such as authorization, throttling, or latency recording, that should run before and/or after execution of every request, regardless of the request type. The tracing example is updated to show off both client stubs as well as server hooks. As part of this change, there were some changes to the Serve trait: 1. Serve's output type is now a Result.. Serve previously did not allow returning ServerErrors, which prevented using Serve for horizontal functionality like throttling or auth. Now, Serve's output type is Result, making Serve a more natural integration point for horizontal capabilities. 2. Serve's generic Request type changed to an associated type. The primary benefit of the generic type is that it allows one type to impl a trait multiple times (for example, u64 impls TryFrom, TryFrom, etc.). In the case of Serve impls, while it is theoretically possible to contrive a type that could serve multiple request types, in practice I don't expect that to be needed. Most users will use the Serve impl generated by #[tarpc::service], which only ever serves one type of request. --- plugins/src/lib.rs | 10 +- tarpc/examples/tracing.rs | 116 ++++- tarpc/src/server.rs | 421 ++++++++++++++++-- tarpc/src/server/incoming.rs | 2 +- tarpc/src/server/request_hook.rs | 22 + tarpc/src/server/request_hook/after.rs | 89 ++++ tarpc/src/server/request_hook/before.rs | 84 ++++ .../server/request_hook/before_and_after.rs | 70 +++ tarpc/src/server/tokio.rs | 22 +- tarpc/src/transport/channel.rs | 22 +- 10 files changed, 781 insertions(+), 77 deletions(-) create mode 100644 tarpc/src/server/request_hook.rs create mode 100644 tarpc/src/server/request_hook/after.rs create mode 100644 tarpc/src/server/request_hook/before.rs create mode 100644 tarpc/src/server/request_hook/before_and_after.rs diff --git a/plugins/src/lib.rs b/plugins/src/lib.rs index d30363e..efab161 100644 --- a/plugins/src/lib.rs +++ b/plugins/src/lib.rs @@ -548,9 +548,10 @@ impl<'a> ServiceGenerator<'a> { } = self; quote! { - impl tarpc::server::Serve<#request_ident> for #server_ident + impl tarpc::server::Serve for #server_ident where S: #service_ident { + type Req = #request_ident; type Resp = #response_ident; type Fut = #response_fut_ident; @@ -670,10 +671,10 @@ impl<'a> ServiceGenerator<'a> { quote! { impl std::future::Future for #response_fut_ident { - type Output = #response_ident; + type Output = Result<#response_ident, tarpc::ServerError>; fn poll(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) - -> std::task::Poll<#response_ident> + -> std::task::Poll> { unsafe { match std::pin::Pin::get_unchecked_mut(self) { @@ -681,7 +682,8 @@ impl<'a> ServiceGenerator<'a> { #response_fut_ident::#camel_case_idents(resp) => std::pin::Pin::new_unchecked(resp) .poll(cx) - .map(#response_ident::#camel_case_idents), + .map(#response_ident::#camel_case_idents) + .map(Ok), )* } } diff --git a/tarpc/examples/tracing.rs b/tarpc/examples/tracing.rs index 2756146..589c16f 100644 --- a/tarpc/examples/tracing.rs +++ b/tarpc/examples/tracing.rs @@ -4,13 +4,32 @@ // license that can be found in the LICENSE file or at // https://opensource.org/licenses/MIT. -use crate::{add::Add as AddService, double::Double as DoubleService}; -use futures::{future, prelude::*}; -use tarpc::{ - client, context, - server::{incoming::Incoming, BaseChannel}, - tokio_serde::formats::Json, +#![feature(type_alias_impl_trait)] + +use crate::{ + add::{Add as AddService, AddStub}, + double::Double as DoubleService, }; +use futures::{future, prelude::*}; +use std::{ + io, + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, + }, +}; +use tarpc::{ + client::{ + self, + stub::{load_balance, retry}, + RpcError, + }, + context, serde_transport, + server::{incoming::Incoming, BaseChannel, Serve}, + tokio_serde::formats::Json, + ClientMessage, Response, ServerError, Transport, +}; +use tokio::net::TcpStream; use tracing_subscriber::prelude::*; pub mod add { @@ -40,12 +59,16 @@ impl AddService for AddServer { } #[derive(Clone)] -struct DoubleServer { - add_client: add::AddClient, +struct DoubleServer { + add_client: add::AddClient, } #[tarpc::server] -impl DoubleService for DoubleServer { +impl DoubleService for DoubleServer +where + Stub: AddStub + Clone + Send + Sync + 'static, + for<'a> Stub::RespFut<'a>: Send, +{ async fn double(self, _: context::Context, x: i32) -> Result { self.add_client .add(context::current(), x, x) @@ -70,22 +93,79 @@ fn init_tracing(service_name: &str) -> anyhow::Result<()> { Ok(()) } +async fn listen_on_random_port() -> anyhow::Result<( + impl Stream>>, + std::net::SocketAddr, +)> +where + Item: for<'de> serde::Deserialize<'de>, + SinkItem: serde::Serialize, +{ + let listener = tarpc::serde_transport::tcp::listen("localhost:0", Json::default) + .await? + .filter_map(|r| future::ready(r.ok())) + .take(1); + let addr = listener.get_ref().get_ref().local_addr(); + Ok((listener, addr)) +} + +fn make_stub( + backends: [impl Transport>, Response> + Send + Sync + 'static; N], +) -> retry::Retry< + impl Fn(&Result, u32) -> bool + Clone, + load_balance::RoundRobin, Resp>>, +> +where + Req: Send + Sync + 'static, + Resp: Send + Sync + 'static, +{ + let stub = load_balance::RoundRobin::new( + backends + .into_iter() + .map(|transport| tarpc::client::new(client::Config::default(), transport).spawn()) + .collect(), + ); + let stub = retry::Retry::new(stub, |resp, attempts| { + if let Err(e) = resp { + tracing::warn!("Got an error: {e:?}"); + attempts < 3 + } else { + false + } + }); + stub +} + #[tokio::main] async fn main() -> anyhow::Result<()> { init_tracing("tarpc_tracing_example")?; - let add_listener = tarpc::serde_transport::tcp::listen("localhost:0", Json::default) - .await? - .filter_map(|r| future::ready(r.ok())); - let addr = add_listener.get_ref().local_addr(); - let add_server = add_listener + let (add_listener1, addr1) = listen_on_random_port().await?; + let (add_listener2, addr2) = listen_on_random_port().await?; + let something_bad_happened = Arc::new(AtomicBool::new(false)); + let server = AddServer.serve().before(move |_: &mut _, _: &_| { + let something_bad_happened = something_bad_happened.clone(); + async move { + if something_bad_happened.fetch_xor(true, Ordering::Relaxed) { + Err(ServerError::new( + io::ErrorKind::NotFound, + "Gamma Ray!".into(), + )) + } else { + Ok(()) + } + } + }); + let add_server = add_listener1 + .chain(add_listener2) .map(BaseChannel::with_defaults) - .take(1) - .execute(AddServer.serve()); + .execute(server); tokio::spawn(add_server); - let to_add_server = tarpc::serde_transport::tcp::connect(addr, Json::default).await?; - let add_client = add::AddClient::new(client::Config::default(), to_add_server).spawn(); + let add_client = add::AddClient::from(make_stub([ + tarpc::serde_transport::tcp::connect(addr1, Json::default).await?, + tarpc::serde_transport::tcp::connect(addr2, Json::default).await?, + ])); let double_listener = tarpc::serde_transport::tcp::listen("localhost:0", Json::default) .await? diff --git a/tarpc/src/server.rs b/tarpc/src/server.rs index b44724d..70f28d9 100644 --- a/tarpc/src/server.rs +++ b/tarpc/src/server.rs @@ -21,10 +21,11 @@ use futures::{ }; use in_flight_requests::{AlreadyExistsError, InFlightRequests}; use pin_project::pin_project; -use std::{convert::TryFrom, error::Error, fmt, marker::PhantomData, mem, pin::Pin, sync::Arc}; +use std::{convert::TryFrom, error::Error, fmt, marker::PhantomData, pin::Pin, sync::Arc}; use tracing::{info_span, instrument::Instrument, Span}; mod in_flight_requests; +pub mod request_hook; #[cfg(test)] mod testing; @@ -39,6 +40,10 @@ pub mod incoming; #[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))] pub mod tokio; +use request_hook::{ + AfterRequest, AfterRequestHook, BeforeAndAfterRequestHook, BeforeRequest, BeforeRequestHook, +}; + /// Settings that control the behavior of [channels](Channel). #[derive(Clone, Debug)] pub struct Config { @@ -67,32 +72,212 @@ impl Config { } /// Equivalent to a `FnOnce(Req) -> impl Future`. -pub trait Serve { +pub trait Serve { + /// Type of request. + type Req; + /// Type of response. type Resp; /// Type of response future. - type Fut: Future; + type Fut: Future>; + + /// Responds to a single request. + fn serve(self, ctx: context::Context, req: Self::Req) -> Self::Fut; /// Extracts a method name from the request. - fn method(&self, _request: &Req) -> Option<&'static str> { + fn method(&self, _request: &Self::Req) -> Option<&'static str> { None } - /// Responds to a single request. - fn serve(self, ctx: context::Context, req: Req) -> Self::Fut; + /// Runs a hook before execution of the request. + /// + /// If the hook returns an error, the request will not be executed and the error will be + /// returned instead. + /// + /// The hook can also modify the request context. This could be used, for example, to enforce a + /// maximum deadline on all requests. + /// + /// Any type that implements [`BeforeRequest`] can be used as the hook. Types that implement + /// `FnMut(&mut Context, &RequestType) -> impl Future>` can + /// also be used. + /// + /// # Example + /// + /// ```rust + /// use futures::{executor::block_on, future}; + /// use tarpc::{context, ServerError, server::{Serve, serve}}; + /// use std::io; + /// + /// let serve = serve(|_ctx, i| async move { Ok(i + 1) }) + /// .before(|_ctx: &mut context::Context, req: &i32| { + /// future::ready( + /// if *req == 1 { + /// Err(ServerError::new( + /// io::ErrorKind::Other, + /// format!("I don't like {req}"))) + /// } else { + /// Ok(()) + /// }) + /// }); + /// let response = serve.serve(context::current(), 1); + /// assert!(block_on(response).is_err()); + /// ``` + fn before(self, hook: Hook) -> BeforeRequestHook + where + Hook: BeforeRequest, + Self: Sized, + { + BeforeRequestHook::new(self, hook) + } + + /// Runs a hook after completion of a request. + /// + /// The hook can modify the request context and the response. + /// + /// Any type that implements [`AfterRequest`] can be used as the hook. Types that implement + /// `FnMut(&mut Context, &mut Result) -> impl Future` + /// can also be used. + /// + /// # Example + /// + /// ```rust + /// use futures::{executor::block_on, future}; + /// use tarpc::{context, ServerError, server::{Serve, serve}}; + /// use std::io; + /// + /// let serve = serve( + /// |_ctx, i| async move { + /// if i == 1 { + /// Err(ServerError::new( + /// io::ErrorKind::Other, + /// format!("{i} is the loneliest number"))) + /// } else { + /// Ok(i + 1) + /// } + /// }) + /// .after(|_ctx: &mut context::Context, resp: &mut Result| { + /// if let Err(e) = resp { + /// eprintln!("server error: {e:?}"); + /// } + /// future::ready(()) + /// }); + /// + /// let response = serve.serve(context::current(), 1); + /// assert!(block_on(response).is_err()); + /// ``` + fn after(self, hook: Hook) -> AfterRequestHook + where + Hook: AfterRequest, + Self: Sized, + { + AfterRequestHook::new(self, hook) + } + + /// Runs a hook before and after execution of the request. + /// + /// If the hook returns an error, the request will not be executed and the error will be + /// returned instead. + /// + /// The hook can also modify the request context and the response. This could be used, for + /// example, to enforce a maximum deadline on all requests. + /// + /// # Example + /// + /// ```rust + /// use futures::{executor::block_on, future}; + /// use tarpc::{ + /// context, ServerError, server::{Serve, serve, request_hook::{BeforeRequest, AfterRequest}} + /// }; + /// use std::{io, time::Instant}; + /// + /// struct PrintLatency(Instant); + /// + /// impl BeforeRequest for PrintLatency { + /// type Fut<'a> = future::Ready> where Self: 'a, Req: 'a; + /// + /// fn before<'a>(&'a mut self, _: &'a mut context::Context, _: &'a Req) -> Self::Fut<'a> { + /// self.0 = Instant::now(); + /// future::ready(Ok(())) + /// } + /// } + /// + /// impl AfterRequest for PrintLatency { + /// type Fut<'a> = future::Ready<()> where Self:'a, Resp:'a; + /// + /// fn after<'a>( + /// &'a mut self, + /// _: &'a mut context::Context, + /// _: &'a mut Result, + /// ) -> Self::Fut<'a> { + /// tracing::info!("Elapsed: {:?}", self.0.elapsed()); + /// future::ready(()) + /// } + /// } + /// + /// let serve = serve(|_ctx, i| async move { + /// Ok(i + 1) + /// }).before_and_after(PrintLatency(Instant::now())); + /// let response = serve.serve(context::current(), 1); + /// assert!(block_on(response).is_ok()); + /// ``` + fn before_and_after( + self, + hook: Hook, + ) -> BeforeAndAfterRequestHook + where + Hook: BeforeRequest + AfterRequest, + Self: Sized, + { + BeforeAndAfterRequestHook::new(self, hook) + } } -impl Serve for F +/// A Serve wrapper around a Fn. +#[derive(Debug)] +pub struct ServeFn { + f: F, + data: PhantomData Resp>, +} + +impl Clone for ServeFn +where + F: Clone, +{ + fn clone(&self) -> Self { + Self { + f: self.f.clone(), + data: PhantomData, + } + } +} + +impl Copy for ServeFn where F: Copy {} + +/// Creates a [`Serve`] wrapper around a `FnOnce(context::Context, Req) -> impl Future>`. +pub fn serve(f: F) -> ServeFn where F: FnOnce(context::Context, Req) -> Fut, - Fut: Future, + Fut: Future>, { + ServeFn { + f, + data: PhantomData, + } +} + +impl Serve for ServeFn +where + F: FnOnce(context::Context, Req) -> Fut, + Fut: Future>, +{ + type Req = Req; type Resp = Resp; type Fut = Fut; fn serve(self, ctx: context::Context, req: Req) -> Self::Fut { - self(ctx, req) + (self.f)(ctx, req) } } @@ -120,7 +305,7 @@ pub struct BaseChannel { /// Holds data necessary to clean up in-flight requests. in_flight_requests: InFlightRequests, /// Types the request and response. - ghost: PhantomData<(Req, Resp)>, + ghost: PhantomData<(fn() -> Req, fn(Resp))>, } impl BaseChannel @@ -307,6 +492,34 @@ where /// This is a terminal operation. After calling `requests`, the channel cannot be retrieved, /// and the only way to complete requests is via [`Requests::execute`] or /// [`InFlightRequest::execute`]. + /// + /// # Example + /// + /// ```rust + /// use tarpc::{ + /// context, + /// client::{self, NewClient}, + /// server::{self, BaseChannel, Channel, serve}, + /// transport, + /// }; + /// use futures::prelude::*; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, rx) = transport::channel::unbounded(); + /// let server = BaseChannel::new(server::Config::default(), rx); + /// let NewClient { client, dispatch } = client::new(client::Config::default(), tx); + /// tokio::spawn(dispatch); + /// + /// let mut requests = server.requests(); + /// tokio::spawn(async move { + /// while let Some(Ok(request)) = requests.next().await { + /// tokio::spawn(request.execute(serve(|_, i| async move { Ok(i + 1) }))); + /// } + /// }); + /// assert_eq!(client.call(context::current(), "AddOne", 1).await.unwrap(), 2); + /// } + /// ``` fn requests(self) -> Requests where Self: Sized, @@ -323,12 +536,28 @@ where /// Runs the channel until completion by executing all requests using the given service /// function. Request handlers are run concurrently by [spawning](::tokio::spawn) on tokio's /// default executor. + /// + /// # Example + /// + /// ```rust + /// use tarpc::{context, client, server::{self, BaseChannel, Channel, serve}, transport}; + /// use futures::prelude::*; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, rx) = transport::channel::unbounded(); + /// let client = client::new(client::Config::default(), tx).spawn(); + /// let channel = BaseChannel::new(server::Config::default(), rx); + /// tokio::spawn(channel.execute(serve(|_, i| async move { Ok(i + 1) }))); + /// assert_eq!(client.call(context::current(), "AddOne", 1).await.unwrap(), 2); + /// } + /// ``` #[cfg(feature = "tokio1")] #[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))] fn execute(self, serve: S) -> self::tokio::TokioChannelExecutor, S> where Self: Sized, - S: Serve + Send + 'static, + S: Serve + Send + 'static, S::Fut: Send, Self::Req: Send + 'static, Self::Resp: Send + 'static, @@ -690,29 +919,6 @@ impl InFlightRequest { &self.request } - /// Respond without executing a service function. Useful for early aborts (e.g. for throttling). - pub async fn respond(self, response: Result) { - let Self { - response_tx, - response_guard, - request: Request { id: request_id, .. }, - span, - .. - } = self; - let _entered = span.enter(); - tracing::info!("CompleteRequest"); - let response = Response { - request_id, - message: response, - }; - let _ = response_tx.send(response).await; - tracing::info!("BufferResponse"); - // Request processing has completed, meaning either the channel canceled the request or - // a request was sent back to the channel. Either way, the channel will clean up the - // request data, so the request does not need to be canceled. - mem::forget(response_guard); - } - /// Returns a [future](Future) that executes the request using the given [service /// function](Serve). The service function's output is automatically sent back to the [Channel] /// that yielded this request. The request will be executed in the scope of this request's @@ -727,9 +933,39 @@ impl InFlightRequest { /// /// If the returned Future is dropped before completion, a cancellation message will be sent to /// the Channel to clean up associated request state. + /// + /// # Example + /// + /// ```rust + /// use tarpc::{ + /// context, + /// client::{self, NewClient}, + /// server::{self, BaseChannel, Channel, serve}, + /// transport, + /// }; + /// use futures::prelude::*; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, rx) = transport::channel::unbounded(); + /// let server = BaseChannel::new(server::Config::default(), rx); + /// let NewClient { client, dispatch } = client::new(client::Config::default(), tx); + /// tokio::spawn(dispatch); + /// + /// tokio::spawn(async move { + /// let mut requests = server.requests(); + /// while let Some(Ok(in_flight_request)) = requests.next().await { + /// in_flight_request.execute(serve(|_, i| async move { Ok(i + 1) })).await; + /// } + /// + /// }); + /// assert_eq!(client.call(context::current(), "AddOne", 1).await.unwrap(), 2); + /// } + /// ``` + /// pub async fn execute(self, serve: S) where - S: Serve, + S: Serve, { let Self { response_tx, @@ -747,11 +983,11 @@ impl InFlightRequest { span.record("otel.name", method.unwrap_or("")); let _ = Abortable::new( async move { - let response = serve.serve(context, message).await; + let message = serve.serve(context, message).await; tracing::info!("CompleteRequest"); let response = Response { request_id, - message: Ok(response), + message, }; let _ = response_tx.send(response).await; tracing::info!("BufferResponse"); @@ -795,11 +1031,14 @@ where #[cfg(test)] mod tests { - use super::{in_flight_requests::AlreadyExistsError, BaseChannel, Channel, Config, Requests}; + use super::{ + in_flight_requests::AlreadyExistsError, serve, AfterRequest, BaseChannel, BeforeRequest, + Channel, Config, Requests, Serve, + }; use crate::{ context, trace, transport::channel::{self, UnboundedChannel}, - ClientMessage, Request, Response, + ClientMessage, Request, Response, ServerError, }; use assert_matches::assert_matches; use futures::{ @@ -808,7 +1047,12 @@ mod tests { Future, }; use futures_test::task::noop_context; - use std::{pin::Pin, task::Poll}; + use std::{ + io, + pin::Pin, + task::Poll, + time::{Duration, Instant, SystemTime}, + }; fn test_channel() -> ( Pin, Response>>>>, @@ -869,6 +1113,101 @@ mod tests { Abortable::new(pending(), abort_registration) } + #[tokio::test] + async fn test_serve() { + let serve = serve(|_, i| async move { Ok(i) }); + assert_matches!(serve.serve(context::current(), 7).await, Ok(7)); + } + + #[tokio::test] + async fn serve_before_mutates_context() -> anyhow::Result<()> { + struct SetDeadline(SystemTime); + type SetDeadlineFut<'a, Req: 'a> = impl Future> + 'a; + impl BeforeRequest for SetDeadline { + type Fut<'a> = SetDeadlineFut<'a, Req> where Self: 'a, Req: 'a; + fn before<'a>( + &'a mut self, + ctx: &'a mut context::Context, + _: &'a Req, + ) -> Self::Fut<'a> { + async move { + ctx.deadline = self.0; + Ok(()) + } + } + } + + let some_time = SystemTime::UNIX_EPOCH + Duration::from_secs(37); + let some_other_time = SystemTime::UNIX_EPOCH + Duration::from_secs(83); + + let serve = serve(move |ctx: context::Context, i| async move { + assert_eq!(ctx.deadline, some_time); + Ok(i) + }); + let deadline_hook = serve.before(SetDeadline(some_time)); + let mut ctx = context::current(); + ctx.deadline = some_other_time; + deadline_hook.serve(ctx, 7).await?; + Ok(()) + } + + #[tokio::test] + async fn serve_before_and_after() -> anyhow::Result<()> { + let _ = tracing_subscriber::fmt::try_init(); + + struct PrintLatency { + start: Instant, + } + impl PrintLatency { + fn new() -> Self { + Self { + start: Instant::now(), + } + } + } + type StartFut<'a, Req: 'a> = impl Future> + 'a; + type EndFut<'a, Resp: 'a> = impl Future + 'a; + impl BeforeRequest for PrintLatency { + type Fut<'a> = StartFut<'a, Req> where Self: 'a, Req: 'a; + fn before<'a>(&'a mut self, _: &'a mut context::Context, _: &'a Req) -> Self::Fut<'a> { + async move { + self.start = Instant::now(); + Ok(()) + } + } + } + impl AfterRequest for PrintLatency { + type Fut<'a> = EndFut<'a, Resp> where Self: 'a, Resp: 'a; + fn after<'a>( + &'a mut self, + _: &'a mut context::Context, + _: &'a mut Result, + ) -> Self::Fut<'a> { + async move { + tracing::info!("Elapsed: {:?}", self.start.elapsed()); + } + } + } + + let serve = serve(move |_: context::Context, i| async move { Ok(i) }); + serve + .before_and_after(PrintLatency::new()) + .serve(context::current(), 7) + .await?; + Ok(()) + } + + #[tokio::test] + async fn serve_before_error_aborts_request() -> anyhow::Result<()> { + let serve = serve(|_, _| async { panic!("Shouldn't get here") }); + let deadline_hook = serve.before(|_: &mut context::Context, _: &i32| async { + Err(ServerError::new(io::ErrorKind::Other, "oops".into())) + }); + let resp: Result = deadline_hook.serve(context::current(), 7).await; + assert_matches!(resp, Err(_)); + Ok(()) + } + #[tokio::test] async fn base_channel_start_send_duplicate_request_returns_error() { let (mut channel, _tx) = test_channel::<(), ()>(); @@ -1069,7 +1408,7 @@ mod tests { Poll::Ready(Some(Ok(request))) => request, result => panic!("Unexpected result: {:?}", result), }; - request.execute(|_, _| async {}).await; + request.execute(serve(|_, _| async { Ok(()) })).await; assert!(requests .as_mut() .channel_pin_mut() diff --git a/tarpc/src/server/incoming.rs b/tarpc/src/server/incoming.rs index 445fc3e..931e876 100644 --- a/tarpc/src/server/incoming.rs +++ b/tarpc/src/server/incoming.rs @@ -35,7 +35,7 @@ where #[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))] fn execute(self, serve: S) -> TokioServerExecutor where - S: Serve, + S: Serve, { TokioServerExecutor::new(self, serve) } diff --git a/tarpc/src/server/request_hook.rs b/tarpc/src/server/request_hook.rs new file mode 100644 index 0000000..ef23d73 --- /dev/null +++ b/tarpc/src/server/request_hook.rs @@ -0,0 +1,22 @@ +// Copyright 2022 Google LLC +// +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file or at +// https://opensource.org/licenses/MIT. + +//! Hooks for horizontal functionality that can run either before or after a request is executed. + +/// A request hook that runs before a request is executed. +mod before; + +/// A request hook that runs after a request is completed. +mod after; + +/// A request hook that runs both before a request is executed and after it is completed. +mod before_and_after; + +pub use { + after::{AfterRequest, AfterRequestHook}, + before::{BeforeRequest, BeforeRequestHook}, + before_and_after::BeforeAndAfterRequestHook, +}; diff --git a/tarpc/src/server/request_hook/after.rs b/tarpc/src/server/request_hook/after.rs new file mode 100644 index 0000000..a3803ba --- /dev/null +++ b/tarpc/src/server/request_hook/after.rs @@ -0,0 +1,89 @@ +// Copyright 2022 Google LLC +// +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file or at +// https://opensource.org/licenses/MIT. + +//! Provides a hook that runs after request execution. + +use crate::{context, server::Serve, ServerError}; +use futures::prelude::*; + +/// A hook that runs after request execution. +pub trait AfterRequest { + /// The type of future returned by the hook. + type Fut<'a>: Future + where + Self: 'a, + Resp: 'a; + + /// The function that is called after request execution. + /// + /// The hook can modify the request context and the response. + fn after<'a>( + &'a mut self, + ctx: &'a mut context::Context, + resp: &'a mut Result, + ) -> Self::Fut<'a>; +} + +impl AfterRequest for F +where + F: FnMut(&mut context::Context, &mut Result) -> Fut, + Fut: Future, +{ + type Fut<'a> = Fut where Self: 'a, Resp: 'a; + + fn after<'a>( + &'a mut self, + ctx: &'a mut context::Context, + resp: &'a mut Result, + ) -> Self::Fut<'a> { + self(ctx, resp) + } +} + +/// A Service function that runs a hook after request execution. +pub struct AfterRequestHook { + serve: Serv, + hook: Hook, +} + +impl AfterRequestHook { + pub(crate) fn new(serve: Serv, hook: Hook) -> Self { + Self { serve, hook } + } +} + +impl Clone for AfterRequestHook { + fn clone(&self) -> Self { + Self { + serve: self.serve.clone(), + hook: self.hook.clone(), + } + } +} + +impl Serve for AfterRequestHook +where + Serv: Serve, + Hook: AfterRequest, +{ + type Req = Serv::Req; + type Resp = Serv::Resp; + type Fut = AfterRequestHookFut; + + fn serve(self, mut ctx: context::Context, req: Serv::Req) -> Self::Fut { + async move { + let AfterRequestHook { + serve, mut hook, .. + } = self; + let mut resp = serve.serve(ctx, req).await; + hook.after(&mut ctx, &mut resp).await; + resp + } + } +} + +type AfterRequestHookFut> = + impl Future>; diff --git a/tarpc/src/server/request_hook/before.rs b/tarpc/src/server/request_hook/before.rs new file mode 100644 index 0000000..38ad54d --- /dev/null +++ b/tarpc/src/server/request_hook/before.rs @@ -0,0 +1,84 @@ +// Copyright 2022 Google LLC +// +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file or at +// https://opensource.org/licenses/MIT. + +//! Provides a hook that runs before request execution. + +use crate::{context, server::Serve, ServerError}; +use futures::prelude::*; + +/// A hook that runs before request execution. +pub trait BeforeRequest { + /// The type of future returned by the hook. + type Fut<'a>: Future> + where + Self: 'a, + Req: 'a; + + /// The function that is called before request execution. + /// + /// If this function returns an error, the request will not be executed and the error will be + /// returned instead. + /// + /// This function can also modify the request context. This could be used, for example, to + /// enforce a maximum deadline on all requests. + fn before<'a>(&'a mut self, ctx: &'a mut context::Context, req: &'a Req) -> Self::Fut<'a>; +} + +impl BeforeRequest for F +where + F: FnMut(&mut context::Context, &Req) -> Fut, + Fut: Future>, +{ + type Fut<'a> = Fut where Self: 'a, Req: 'a; + + fn before<'a>(&'a mut self, ctx: &'a mut context::Context, req: &'a Req) -> Self::Fut<'a> { + self(ctx, req) + } +} + +/// A Service function that runs a hook before request execution. +pub struct BeforeRequestHook { + serve: Serv, + hook: Hook, +} + +impl BeforeRequestHook { + pub(crate) fn new(serve: Serv, hook: Hook) -> Self { + Self { serve, hook } + } +} + +impl Clone for BeforeRequestHook { + fn clone(&self) -> Self { + Self { + serve: self.serve.clone(), + hook: self.hook.clone(), + } + } +} + +impl Serve for BeforeRequestHook +where + Serv: Serve, + Hook: BeforeRequest, +{ + type Req = Serv::Req; + type Resp = Serv::Resp; + type Fut = BeforeRequestHookFut; + + fn serve(self, mut ctx: context::Context, req: Self::Req) -> Self::Fut { + let BeforeRequestHook { + serve, mut hook, .. + } = self; + async move { + hook.before(&mut ctx, &req).await?; + serve.serve(ctx, req).await + } + } +} + +type BeforeRequestHookFut> = + impl Future>; diff --git a/tarpc/src/server/request_hook/before_and_after.rs b/tarpc/src/server/request_hook/before_and_after.rs new file mode 100644 index 0000000..ca42460 --- /dev/null +++ b/tarpc/src/server/request_hook/before_and_after.rs @@ -0,0 +1,70 @@ +// Copyright 2022 Google LLC +// +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file or at +// https://opensource.org/licenses/MIT. + +//! Provides a hook that runs both before and after request execution. + +use super::{after::AfterRequest, before::BeforeRequest}; +use crate::{context, server::Serve, ServerError}; +use futures::prelude::*; +use std::marker::PhantomData; + +/// A Service function that runs a hook both before and after request execution. +pub struct BeforeAndAfterRequestHook { + serve: Serv, + hook: Hook, + fns: PhantomData<(fn(Req), fn(Resp))>, +} + +impl BeforeAndAfterRequestHook { + pub(crate) fn new(serve: Serv, hook: Hook) -> Self { + Self { + serve, + hook, + fns: PhantomData, + } + } +} + +impl Clone + for BeforeAndAfterRequestHook +{ + fn clone(&self) -> Self { + Self { + serve: self.serve.clone(), + hook: self.hook.clone(), + fns: PhantomData, + } + } +} + +impl Serve for BeforeAndAfterRequestHook +where + Serv: Serve, + Hook: BeforeRequest + AfterRequest, +{ + type Req = Req; + type Resp = Resp; + type Fut = BeforeAndAfterRequestHookFut; + + fn serve(self, mut ctx: context::Context, req: Req) -> Self::Fut { + async move { + let BeforeAndAfterRequestHook { + serve, mut hook, .. + } = self; + hook.before(&mut ctx, &req).await?; + let mut resp = serve.serve(ctx, req).await; + hook.after(&mut ctx, &mut resp).await; + resp + } + } +} + +type BeforeAndAfterRequestHookFut< + Req, + Resp, + Serv: Serve, + Hook: BeforeRequest + AfterRequest, +> = impl Future>; diff --git a/tarpc/src/server/tokio.rs b/tarpc/src/server/tokio.rs index a44e846..e9ad842 100644 --- a/tarpc/src/server/tokio.rs +++ b/tarpc/src/server/tokio.rs @@ -55,9 +55,25 @@ where { /// Executes all requests using the given service function. Requests are handled concurrently /// by [spawning](::tokio::spawn) each handler on tokio's default executor. + /// + /// # Example + /// + /// ```rust + /// use tarpc::{context, client, server::{self, BaseChannel, Channel, serve}, transport}; + /// use futures::prelude::*; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, rx) = transport::channel::unbounded(); + /// let requests = BaseChannel::new(server::Config::default(), rx).requests(); + /// let client = client::new(client::Config::default(), tx).spawn(); + /// tokio::spawn(requests.execute(serve(|_, i| async move { Ok(i + 1) }))); + /// assert_eq!(client.call(context::current(), "AddOne", 1).await.unwrap(), 2); + /// } + /// ``` pub fn execute(self, serve: S) -> TokioChannelExecutor where - S: Serve + Send + 'static, + S: Serve + Send + 'static, { TokioChannelExecutor { inner: self, serve } } @@ -69,7 +85,7 @@ where C: Channel + Send + 'static, C::Req: Send + 'static, C::Resp: Send + 'static, - Se: Serve + Send + 'static + Clone, + Se: Serve + Send + 'static + Clone, Se::Fut: Send, { type Output = (); @@ -88,7 +104,7 @@ where C: Channel + 'static, C::Req: Send + 'static, C::Resp: Send + 'static, - S: Serve + Send + 'static + Clone, + S: Serve + Send + 'static + Clone, S::Fut: Send, { type Output = (); diff --git a/tarpc/src/transport/channel.rs b/tarpc/src/transport/channel.rs index 529ae8f..7f3035d 100644 --- a/tarpc/src/transport/channel.rs +++ b/tarpc/src/transport/channel.rs @@ -150,12 +150,14 @@ impl Sink for Channel { #[cfg(feature = "tokio1")] mod tests { use crate::{ - client, context, - server::{incoming::Incoming, BaseChannel}, + client::{self, RpcError}, + context, + server::{incoming::Incoming, serve, BaseChannel}, transport::{ self, channel::{Channel, UnboundedChannel}, }, + ServerError, }; use assert_matches::assert_matches; use futures::{prelude::*, stream}; @@ -177,25 +179,25 @@ mod tests { tokio::spawn( stream::once(future::ready(server_channel)) .map(BaseChannel::with_defaults) - .execute(|_ctx, request: String| { - future::ready(request.parse::().map_err(|_| { - io::Error::new( + .execute(serve(|_ctx, request: String| async move { + request.parse::().map_err(|_| { + ServerError::new( io::ErrorKind::InvalidInput, format!("{request:?} is not an int"), ) - })) - }), + }) + })), ); let client = client::new(client::Config::default(), client_channel).spawn(); - let response1 = client.call(context::current(), "", "123".into()).await?; - let response2 = client.call(context::current(), "", "abc".into()).await?; + let response1 = client.call(context::current(), "", "123".into()).await; + let response2 = client.call(context::current(), "", "abc".into()).await; trace!("response1: {:?}, response2: {:?}", response1, response2); assert_matches!(response1, Ok(123)); - assert_matches!(response2, Err(ref e) if e.kind() == io::ErrorKind::InvalidInput); + assert_matches!(response2, Err(RpcError::Server(e)) if e.kind == io::ErrorKind::InvalidInput); Ok(()) }