From 324df5cd15505692827e69ef4722a7a06662a27d Mon Sep 17 00:00:00 2001 From: Tim Kuehn Date: Sat, 12 Nov 2022 17:24:40 -0800 Subject: [PATCH] Add back the Client trait, renamed Stub. Also adds a Client stub trait alias for each generated service. Now that generic associated types are stable, it's almost possible to define a trait for Channel that works with async fns on stable. `impl trait in type aliases` is still necessary (and unstable), but we're getting closer. As a proof of concept, three more implementations of Stub are implemented; 1. A load balancer that round-robins requests between different stubs. 2. A load balancer that selects a stub based on a request hash, so that the same requests go to the same stubs. 3. A stub that retries requests based on a configurable policy. The "serde/rc" feature is added to the "full" feature because the Retry stub wraps the request in an Arc, so that the request is reusable for multiple calls. Server implementors commonly need to operate generically across all services or request types. For example, a server throttler may want to return errors telling clients to back off, which is not specific to any one service. --- plugins/src/lib.rs | 36 ++- tarpc/Cargo.toml | 2 +- tarpc/src/client.rs | 1 + tarpc/src/client/stub.rs | 56 ++++ tarpc/src/client/stub/load_balance.rs | 305 ++++++++++++++++++ tarpc/src/client/stub/mock.rs | 54 ++++ tarpc/src/client/stub/retry.rs | 75 +++++ tarpc/src/lib.rs | 16 + tarpc/src/server.rs | 32 +- .../tarpc_server_missing_async.stderr | 16 +- 10 files changed, 581 insertions(+), 12 deletions(-) create mode 100644 tarpc/src/client/stub.rs create mode 100644 tarpc/src/client/stub/load_balance.rs create mode 100644 tarpc/src/client/stub/mock.rs create mode 100644 tarpc/src/client/stub/retry.rs diff --git a/plugins/src/lib.rs b/plugins/src/lib.rs index 1b83c32..d30363e 100644 --- a/plugins/src/lib.rs +++ b/plugins/src/lib.rs @@ -276,6 +276,7 @@ pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream { ServiceGenerator { response_fut_name, service_ident: ident, + client_stub_ident: &format_ident!("{}Stub", ident), server_ident: &format_ident!("Serve{}", ident), response_fut_ident: &Ident::new(response_fut_name, ident.span()), client_ident: &format_ident!("{}Client", ident), @@ -432,6 +433,7 @@ fn verify_types_were_provided( // the client stub. struct ServiceGenerator<'a> { service_ident: &'a Ident, + client_stub_ident: &'a Ident, server_ident: &'a Ident, response_fut_ident: &'a Ident, response_fut_name: &'a str, @@ -461,6 +463,9 @@ impl<'a> ServiceGenerator<'a> { future_types, return_types, service_ident, + client_stub_ident, + request_ident, + response_ident, server_ident, .. } = self; @@ -490,6 +495,7 @@ impl<'a> ServiceGenerator<'a> { }, ); + let stub_doc = format!("The stub trait for service [`{service_ident}`]."); quote! { #( #attrs )* #vis trait #service_ident: Sized { @@ -501,6 +507,15 @@ impl<'a> ServiceGenerator<'a> { #server_ident { service: self } } } + + #[doc = #stub_doc] + #vis trait #client_stub_ident: tarpc::client::stub::Stub { + } + + impl #client_stub_ident for S + where S: tarpc::client::stub::Stub + { + } } } @@ -689,7 +704,9 @@ impl<'a> ServiceGenerator<'a> { #[derive(Clone, Debug)] /// The client stub that makes RPC calls to the server. All request methods return /// [Futures](std::future::Future). - #vis struct #client_ident(tarpc::client::Channel<#request_ident, #response_ident>); + #vis struct #client_ident< + Stub = tarpc::client::Channel<#request_ident, #response_ident> + >(Stub); } } @@ -719,6 +736,17 @@ impl<'a> ServiceGenerator<'a> { dispatch: new_client.dispatch, } } + } + + impl From for #client_ident + where Stub: tarpc::client::stub::Stub< + Req = #request_ident, + Resp = #response_ident> + { + /// Returns a new client stub that sends requests over the given transport. + fn from(stub: Stub) -> Self { + #client_ident(stub) + } } } @@ -741,7 +769,11 @@ impl<'a> ServiceGenerator<'a> { } = self; quote! { - impl #client_ident { + impl #client_ident + where Stub: tarpc::client::stub::Stub< + Req = #request_ident, + Resp = #response_ident> + { #( #[allow(unused)] #( #method_attrs )* diff --git a/tarpc/Cargo.toml b/tarpc/Cargo.toml index 97ac952..c6f8064 100644 --- a/tarpc/Cargo.toml +++ b/tarpc/Cargo.toml @@ -19,7 +19,7 @@ description = "An RPC framework for Rust with a focus on ease of use." [features] default = [] -serde1 = ["tarpc-plugins/serde1", "serde", "serde/derive"] +serde1 = ["tarpc-plugins/serde1", "serde", "serde/derive", "serde/rc"] tokio1 = ["tokio/rt"] serde-transport = ["serde1", "tokio1", "tokio-serde", "tokio-util/codec"] serde-transport-json = ["tokio-serde/json"] diff --git a/tarpc/src/client.rs b/tarpc/src/client.rs index 0180ba4..9b54334 100644 --- a/tarpc/src/client.rs +++ b/tarpc/src/client.rs @@ -7,6 +7,7 @@ //! Provides a client that connects to a server and sends multiplexed requests. mod in_flight_requests; +pub mod stub; use crate::{ cancellations::{cancellations, CanceledRequests, RequestCancellation}, diff --git a/tarpc/src/client/stub.rs b/tarpc/src/client/stub.rs new file mode 100644 index 0000000..a8b72a2 --- /dev/null +++ b/tarpc/src/client/stub.rs @@ -0,0 +1,56 @@ +//! Provides a Stub trait, implemented by types that can call remote services. + +use crate::{ + client::{Channel, RpcError}, + context, +}; +use futures::prelude::*; + +pub mod load_balance; +pub mod retry; + +#[cfg(test)] +mod mock; + +/// A connection to a remote service. +/// Calls the service with requests of type `Req` and receives responses of type `Resp`. +pub trait Stub { + /// The service request type. + type Req; + + /// The service response type. + type Resp; + + /// The type of the future returned by `Stub::call`. + type RespFut<'a>: Future> + where + Self: 'a, + Self::Req: 'a, + Self::Resp: 'a; + + /// Calls a remote service. + fn call<'a>( + &'a self, + ctx: context::Context, + request_name: &'static str, + request: Self::Req, + ) -> Self::RespFut<'a>; +} + +impl Stub for Channel { + type Req = Req; + type Resp = Resp; + type RespFut<'a> = RespFut<'a, Req, Resp> + where Self: 'a; + + fn call<'a>( + &'a self, + ctx: context::Context, + request_name: &'static str, + request: Req, + ) -> Self::RespFut<'a> { + Self::call(self, ctx, request_name, request) + } +} + +type RespFut<'a, Req: 'a, Resp: 'a> = impl Future> + 'a; diff --git a/tarpc/src/client/stub/load_balance.rs b/tarpc/src/client/stub/load_balance.rs new file mode 100644 index 0000000..c9005a4 --- /dev/null +++ b/tarpc/src/client/stub/load_balance.rs @@ -0,0 +1,305 @@ +//! Provides load-balancing [Stubs](crate::client::stub::Stub). + +pub use consistent_hash::ConsistentHash; +pub use round_robin::RoundRobin; + +/// Provides a stub that load-balances with a simple round-robin strategy. +mod round_robin { + use crate::{ + client::{stub, RpcError}, + context, + }; + use cycle::AtomicCycle; + use futures::prelude::*; + + impl stub::Stub for RoundRobin + where + Stub: stub::Stub, + { + type Req = Stub::Req; + type Resp = Stub::Resp; + type RespFut<'a> = RespFut<'a, Stub> + where Self: 'a; + + fn call<'a>( + &'a self, + ctx: context::Context, + request_name: &'static str, + request: Self::Req, + ) -> Self::RespFut<'a> { + Self::call(self, ctx, request_name, request) + } + } + + type RespFut<'a, Stub: stub::Stub + 'a> = + impl Future> + 'a; + + /// A Stub that load-balances across backing stubs by round robin. + #[derive(Clone, Debug)] + pub struct RoundRobin { + stubs: AtomicCycle, + } + + impl RoundRobin + where + Stub: stub::Stub, + { + /// Returns a new RoundRobin stub. + pub fn new(stubs: Vec) -> Self { + Self { + stubs: AtomicCycle::new(stubs), + } + } + + async fn call( + &self, + ctx: context::Context, + request_name: &'static str, + request: Stub::Req, + ) -> Result { + let next = self.stubs.next(); + next.call(ctx, request_name, request).await + } + } + + mod cycle { + use std::sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, + }; + + /// Cycles endlessly and atomically over a collection of elements of type T. + #[derive(Clone, Debug)] + pub struct AtomicCycle(Arc>); + + #[derive(Debug)] + struct State { + elements: Vec, + next: AtomicUsize, + } + + impl AtomicCycle { + pub fn new(elements: Vec) -> Self { + Self(Arc::new(State { + elements, + next: Default::default(), + })) + } + + pub fn next(&self) -> &T { + self.0.next() + } + } + + impl State { + pub fn next(&self) -> &T { + let next = self.next.fetch_add(1, Ordering::Relaxed); + &self.elements[next % self.elements.len()] + } + } + + #[test] + fn test_cycle() { + let cycle = AtomicCycle::new(vec![1, 2, 3]); + assert_eq!(cycle.next(), &1); + assert_eq!(cycle.next(), &2); + assert_eq!(cycle.next(), &3); + assert_eq!(cycle.next(), &1); + } + } +} + +/// Provides a stub that load-balances with a consistent hashing strategy. +/// +/// Each request is hashed, then mapped to a stub based on the hash. Equivalent requests will use +/// the same stub. +mod consistent_hash { + use crate::{ + client::{stub, RpcError}, + context, + }; + use futures::prelude::*; + use std::{ + collections::hash_map::RandomState, + hash::{BuildHasher, Hash, Hasher}, + num::TryFromIntError, + }; + + impl stub::Stub for ConsistentHash + where + Stub: stub::Stub, + Stub::Req: Hash, + { + type Req = Stub::Req; + type Resp = Stub::Resp; + type RespFut<'a> = RespFut<'a, Stub> + where Self: 'a; + + fn call<'a>( + &'a self, + ctx: context::Context, + request_name: &'static str, + request: Self::Req, + ) -> Self::RespFut<'a> { + Self::call(self, ctx, request_name, request) + } + } + + type RespFut<'a, Stub: stub::Stub + 'a> = + impl Future> + 'a; + + /// A Stub that load-balances across backing stubs by round robin. + #[derive(Clone, Debug)] + pub struct ConsistentHash { + stubs: Vec, + stubs_len: u64, + hasher: S, + } + + impl ConsistentHash + where + Stub: stub::Stub, + Stub::Req: Hash, + { + /// Returns a new RoundRobin stub. + /// Returns an err if the length of `stubs` overflows a u64. + pub fn new(stubs: Vec) -> Result { + Ok(Self { + stubs_len: stubs.len().try_into()?, + stubs, + hasher: RandomState::new(), + }) + } + } + + impl ConsistentHash + where + Stub: stub::Stub, + Stub::Req: Hash, + S: BuildHasher, + { + /// Returns a new RoundRobin stub. + /// Returns an err if the length of `stubs` overflows a u64. + pub fn with_hasher(stubs: Vec, hasher: S) -> Result { + Ok(Self { + stubs_len: stubs.len().try_into()?, + stubs, + hasher, + }) + } + + async fn call( + &self, + ctx: context::Context, + request_name: &'static str, + request: Stub::Req, + ) -> Result { + let index = usize::try_from(self.hash_request(&request) % self.stubs_len).expect( + "invariant broken: stubs_len is not larger than a usize, \ + so the hash modulo stubs_len should always fit in a usize", + ); + let next = &self.stubs[index]; + next.call(ctx, request_name, request).await + } + + fn hash_request(&self, req: &Stub::Req) -> u64 { + let mut hasher = self.hasher.build_hasher(); + req.hash(&mut hasher); + hasher.finish() + } + } + + #[cfg(test)] + mod tests { + use super::ConsistentHash; + use crate::{client::stub::mock::Mock, context}; + use std::{ + collections::HashMap, + hash::{BuildHasher, Hash, Hasher}, + rc::Rc, + }; + + #[tokio::test] + async fn test() -> anyhow::Result<()> { + let stub = ConsistentHash::with_hasher( + vec![ + // For easier reading of the assertions made in this test, each Mock's response + // value is equal to a hash value that should map to its index: 3 % 3 = 0, 1 % + // 3 = 1, etc. + Mock::new([('a', 3), ('b', 3), ('c', 3)]), + Mock::new([('a', 1), ('b', 1), ('c', 1)]), + Mock::new([('a', 2), ('b', 2), ('c', 2)]), + ], + FakeHasherBuilder::new([('a', 1), ('b', 2), ('c', 3)]), + )?; + + for _ in 0..2 { + let resp = stub.call(context::current(), "", 'a').await?; + assert_eq!(resp, 1); + + let resp = stub.call(context::current(), "", 'b').await?; + assert_eq!(resp, 2); + + let resp = stub.call(context::current(), "", 'c').await?; + assert_eq!(resp, 3); + } + + Ok(()) + } + + struct HashRecorder(Vec); + impl Hasher for HashRecorder { + fn write(&mut self, bytes: &[u8]) { + self.0 = Vec::from(bytes); + } + fn finish(&self) -> u64 { + 0 + } + } + + struct FakeHasherBuilder { + recorded_hashes: Rc, u64>>, + } + + struct FakeHasher { + recorded_hashes: Rc, u64>>, + output: u64, + } + + impl BuildHasher for FakeHasherBuilder { + type Hasher = FakeHasher; + + fn build_hasher(&self) -> Self::Hasher { + FakeHasher { + recorded_hashes: self.recorded_hashes.clone(), + output: 0, + } + } + } + + impl FakeHasherBuilder { + fn new(fake_hashes: [(T, u64); N]) -> Self { + let mut recorded_hashes = HashMap::new(); + for (to_hash, fake_hash) in fake_hashes { + let mut recorder = HashRecorder(vec![]); + to_hash.hash(&mut recorder); + recorded_hashes.insert(recorder.0, fake_hash); + } + Self { + recorded_hashes: Rc::new(recorded_hashes), + } + } + } + + impl Hasher for FakeHasher { + fn write(&mut self, bytes: &[u8]) { + if let Some(hash) = self.recorded_hashes.get(bytes) { + self.output = *hash; + } + } + fn finish(&self) -> u64 { + self.output + } + } + } +} diff --git a/tarpc/src/client/stub/mock.rs b/tarpc/src/client/stub/mock.rs new file mode 100644 index 0000000..99a5442 --- /dev/null +++ b/tarpc/src/client/stub/mock.rs @@ -0,0 +1,54 @@ +use crate::{ + client::{stub::Stub, RpcError}, + context, ServerError, +}; +use futures::future; +use std::{collections::HashMap, hash::Hash, io}; + +/// A mock stub that returns user-specified responses. +pub struct Mock { + responses: HashMap, +} + +impl Mock +where + Req: Eq + Hash, +{ + /// Returns a new mock, mocking the specified (request, response) pairs. + pub fn new(responses: [(Req, Resp); N]) -> Self { + Self { + responses: HashMap::from(responses), + } + } +} + +impl Stub for Mock +where + Req: Eq + Hash, + Resp: Clone, +{ + type Req = Req; + type Resp = Resp; + type RespFut<'a> = future::Ready> + where Self: 'a; + + fn call<'a>( + &'a self, + _: context::Context, + _: &'static str, + request: Self::Req, + ) -> Self::RespFut<'a> { + future::ready( + self.responses + .get(&request) + .cloned() + .map(Ok) + .unwrap_or_else(|| { + Err(RpcError::Server(ServerError { + kind: io::ErrorKind::NotFound, + detail: "mock (request, response) entry not found".into(), + })) + }), + ) + } +} diff --git a/tarpc/src/client/stub/retry.rs b/tarpc/src/client/stub/retry.rs new file mode 100644 index 0000000..46ad096 --- /dev/null +++ b/tarpc/src/client/stub/retry.rs @@ -0,0 +1,75 @@ +//! Provides a stub that retries requests based on response contents.. + +use crate::{ + client::{stub, RpcError}, + context, +}; +use futures::prelude::*; +use std::sync::Arc; + +impl stub::Stub for Retry +where + Stub: stub::Stub>, + F: Fn(&Result, u32) -> bool, +{ + type Req = Req; + type Resp = Stub::Resp; + type RespFut<'a> = RespFut<'a, Stub, Self::Req, F> + where Self: 'a, + Self::Req: 'a; + + fn call<'a>( + &'a self, + ctx: context::Context, + request_name: &'static str, + request: Self::Req, + ) -> Self::RespFut<'a> { + Self::call(self, ctx, request_name, request) + } +} + +type RespFut<'a, Stub: stub::Stub + 'a, Req: 'a, F: 'a> = + impl Future> + 'a; + +/// A Stub that retries requests based on response contents. +/// Note: to use this stub with Serde serialization, the "rc" feature of Serde needs to be enabled. +#[derive(Clone, Debug)] +pub struct Retry { + should_retry: F, + stub: Stub, +} + +impl Retry +where + Stub: stub::Stub>, + F: Fn(&Result, u32) -> bool, +{ + /// Creates a new Retry stub that delegates calls to the underlying `stub`. + pub fn new(stub: Stub, should_retry: F) -> Self { + Self { stub, should_retry } + } + + async fn call<'a, 'b>( + &'a self, + ctx: context::Context, + request_name: &'static str, + request: Req, + ) -> Result + where + Req: 'b, + { + let request = Arc::new(request); + for i in 1.. { + let result = self + .stub + .call(ctx, request_name, Arc::clone(&request)) + .await; + if (self.should_retry)(&result, i) { + tracing::trace!("Retrying on attempt {i}"); + continue; + } + return result; + } + unreachable!("Wow, that was a lot of attempts!"); + } +} diff --git a/tarpc/src/lib.rs b/tarpc/src/lib.rs index 418cedd..b47d13b 100644 --- a/tarpc/src/lib.rs +++ b/tarpc/src/lib.rs @@ -200,6 +200,15 @@ //! //! Use `cargo doc` as you normally would to see the documentation created for all //! items expanded by a `service!` invocation. + +// For async_fn_in_trait +#![allow(incomplete_features)] +#![feature( + iter_intersperse, + type_alias_impl_trait, + async_fn_in_trait, + return_position_impl_trait_in_trait +)] #![deny(missing_docs)] #![allow(clippy::type_complexity)] #![cfg_attr(docsrs, feature(doc_cfg))] @@ -407,6 +416,13 @@ where Close(#[source] E), } +impl ServerError { + /// Returns a new server error with `kind` and `detail`. + pub fn new(kind: io::ErrorKind, detail: String) -> ServerError { + Self { kind, detail } + } +} + impl Request { /// Returns the deadline for this request. pub fn deadline(&self) -> &SystemTime { diff --git a/tarpc/src/server.rs b/tarpc/src/server.rs index 9f04b27..b44724d 100644 --- a/tarpc/src/server.rs +++ b/tarpc/src/server.rs @@ -9,7 +9,7 @@ use crate::{ cancellations::{cancellations, CanceledRequests, RequestCancellation}, context::{self, SpanExt}, - trace, ChannelError, ClientMessage, Request, Response, Transport, + trace, ChannelError, ClientMessage, Request, Response, ServerError, Transport, }; use ::tokio::sync::mpsc; use futures::{ @@ -21,7 +21,7 @@ use futures::{ }; use in_flight_requests::{AlreadyExistsError, InFlightRequests}; use pin_project::pin_project; -use std::{convert::TryFrom, error::Error, fmt, marker::PhantomData, pin::Pin, sync::Arc}; +use std::{convert::TryFrom, error::Error, fmt, marker::PhantomData, mem, pin::Pin, sync::Arc}; use tracing::{info_span, instrument::Instrument, Span}; mod in_flight_requests; @@ -565,6 +565,10 @@ where }| { // The response guard becomes active once in an InFlightRequest. response_guard.cancel = true; + { + let _entered = span.enter(); + tracing::info!("BeginRequest"); + } InFlightRequest { request, abort_registration, @@ -686,6 +690,29 @@ impl InFlightRequest { &self.request } + /// Respond without executing a service function. Useful for early aborts (e.g. for throttling). + pub async fn respond(self, response: Result) { + let Self { + response_tx, + response_guard, + request: Request { id: request_id, .. }, + span, + .. + } = self; + let _entered = span.enter(); + tracing::info!("CompleteRequest"); + let response = Response { + request_id, + message: response, + }; + let _ = response_tx.send(response).await; + tracing::info!("BufferResponse"); + // Request processing has completed, meaning either the channel canceled the request or + // a request was sent back to the channel. Either way, the channel will clean up the + // request data, so the request does not need to be canceled. + mem::forget(response_guard); + } + /// Returns a [future](Future) that executes the request using the given [service /// function](Serve). The service function's output is automatically sent back to the [Channel] /// that yielded this request. The request will be executed in the scope of this request's @@ -720,7 +747,6 @@ impl InFlightRequest { span.record("otel.name", method.unwrap_or("")); let _ = Abortable::new( async move { - tracing::info!("BeginRequest"); let response = serve.serve(context, message).await; tracing::info!("CompleteRequest"); let response = Response { diff --git a/tarpc/tests/compile_fail/tarpc_server_missing_async.stderr b/tarpc/tests/compile_fail/tarpc_server_missing_async.stderr index 28106e6..d96cda8 100644 --- a/tarpc/tests/compile_fail/tarpc_server_missing_async.stderr +++ b/tarpc/tests/compile_fail/tarpc_server_missing_async.stderr @@ -1,11 +1,15 @@ error: not all trait items implemented, missing: `HelloFut` - --> $DIR/tarpc_server_missing_async.rs:9:1 - | -9 | impl World for HelloServer { - | ^^^^ + --> tests/compile_fail/tarpc_server_missing_async.rs:9:1 + | +9 | / impl World for HelloServer { +10 | | fn hello(name: String) -> String { +11 | | format!("Hello, {name}!", name) +12 | | } +13 | | } + | |_^ error: hint: `#[tarpc::server]` only rewrites async fns, and `fn hello` is not async - --> $DIR/tarpc_server_missing_async.rs:10:5 + --> tests/compile_fail/tarpc_server_missing_async.rs:10:5 | 10 | fn hello(name: String) -> String { - | ^^ + | ^^^^^^^^