diff --git a/plugins/src/lib.rs b/plugins/src/lib.rs index 1b83c32..d30363e 100644 --- a/plugins/src/lib.rs +++ b/plugins/src/lib.rs @@ -276,6 +276,7 @@ pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream { ServiceGenerator { response_fut_name, service_ident: ident, + client_stub_ident: &format_ident!("{}Stub", ident), server_ident: &format_ident!("Serve{}", ident), response_fut_ident: &Ident::new(response_fut_name, ident.span()), client_ident: &format_ident!("{}Client", ident), @@ -432,6 +433,7 @@ fn verify_types_were_provided( // the client stub. struct ServiceGenerator<'a> { service_ident: &'a Ident, + client_stub_ident: &'a Ident, server_ident: &'a Ident, response_fut_ident: &'a Ident, response_fut_name: &'a str, @@ -461,6 +463,9 @@ impl<'a> ServiceGenerator<'a> { future_types, return_types, service_ident, + client_stub_ident, + request_ident, + response_ident, server_ident, .. } = self; @@ -490,6 +495,7 @@ impl<'a> ServiceGenerator<'a> { }, ); + let stub_doc = format!("The stub trait for service [`{service_ident}`]."); quote! { #( #attrs )* #vis trait #service_ident: Sized { @@ -501,6 +507,15 @@ impl<'a> ServiceGenerator<'a> { #server_ident { service: self } } } + + #[doc = #stub_doc] + #vis trait #client_stub_ident: tarpc::client::stub::Stub { + } + + impl #client_stub_ident for S + where S: tarpc::client::stub::Stub + { + } } } @@ -689,7 +704,9 @@ impl<'a> ServiceGenerator<'a> { #[derive(Clone, Debug)] /// The client stub that makes RPC calls to the server. All request methods return /// [Futures](std::future::Future). - #vis struct #client_ident(tarpc::client::Channel<#request_ident, #response_ident>); + #vis struct #client_ident< + Stub = tarpc::client::Channel<#request_ident, #response_ident> + >(Stub); } } @@ -719,6 +736,17 @@ impl<'a> ServiceGenerator<'a> { dispatch: new_client.dispatch, } } + } + + impl From for #client_ident + where Stub: tarpc::client::stub::Stub< + Req = #request_ident, + Resp = #response_ident> + { + /// Returns a new client stub that sends requests over the given transport. + fn from(stub: Stub) -> Self { + #client_ident(stub) + } } } @@ -741,7 +769,11 @@ impl<'a> ServiceGenerator<'a> { } = self; quote! { - impl #client_ident { + impl #client_ident + where Stub: tarpc::client::stub::Stub< + Req = #request_ident, + Resp = #response_ident> + { #( #[allow(unused)] #( #method_attrs )* diff --git a/tarpc/Cargo.toml b/tarpc/Cargo.toml index 97ac952..c6f8064 100644 --- a/tarpc/Cargo.toml +++ b/tarpc/Cargo.toml @@ -19,7 +19,7 @@ description = "An RPC framework for Rust with a focus on ease of use." [features] default = [] -serde1 = ["tarpc-plugins/serde1", "serde", "serde/derive"] +serde1 = ["tarpc-plugins/serde1", "serde", "serde/derive", "serde/rc"] tokio1 = ["tokio/rt"] serde-transport = ["serde1", "tokio1", "tokio-serde", "tokio-util/codec"] serde-transport-json = ["tokio-serde/json"] diff --git a/tarpc/src/client.rs b/tarpc/src/client.rs index 0180ba4..9b54334 100644 --- a/tarpc/src/client.rs +++ b/tarpc/src/client.rs @@ -7,6 +7,7 @@ //! Provides a client that connects to a server and sends multiplexed requests. mod in_flight_requests; +pub mod stub; use crate::{ cancellations::{cancellations, CanceledRequests, RequestCancellation}, diff --git a/tarpc/src/client/stub.rs b/tarpc/src/client/stub.rs new file mode 100644 index 0000000..a8b72a2 --- /dev/null +++ b/tarpc/src/client/stub.rs @@ -0,0 +1,56 @@ +//! Provides a Stub trait, implemented by types that can call remote services. + +use crate::{ + client::{Channel, RpcError}, + context, +}; +use futures::prelude::*; + +pub mod load_balance; +pub mod retry; + +#[cfg(test)] +mod mock; + +/// A connection to a remote service. +/// Calls the service with requests of type `Req` and receives responses of type `Resp`. +pub trait Stub { + /// The service request type. + type Req; + + /// The service response type. + type Resp; + + /// The type of the future returned by `Stub::call`. + type RespFut<'a>: Future> + where + Self: 'a, + Self::Req: 'a, + Self::Resp: 'a; + + /// Calls a remote service. + fn call<'a>( + &'a self, + ctx: context::Context, + request_name: &'static str, + request: Self::Req, + ) -> Self::RespFut<'a>; +} + +impl Stub for Channel { + type Req = Req; + type Resp = Resp; + type RespFut<'a> = RespFut<'a, Req, Resp> + where Self: 'a; + + fn call<'a>( + &'a self, + ctx: context::Context, + request_name: &'static str, + request: Req, + ) -> Self::RespFut<'a> { + Self::call(self, ctx, request_name, request) + } +} + +type RespFut<'a, Req: 'a, Resp: 'a> = impl Future> + 'a; diff --git a/tarpc/src/client/stub/load_balance.rs b/tarpc/src/client/stub/load_balance.rs new file mode 100644 index 0000000..c9005a4 --- /dev/null +++ b/tarpc/src/client/stub/load_balance.rs @@ -0,0 +1,305 @@ +//! Provides load-balancing [Stubs](crate::client::stub::Stub). + +pub use consistent_hash::ConsistentHash; +pub use round_robin::RoundRobin; + +/// Provides a stub that load-balances with a simple round-robin strategy. +mod round_robin { + use crate::{ + client::{stub, RpcError}, + context, + }; + use cycle::AtomicCycle; + use futures::prelude::*; + + impl stub::Stub for RoundRobin + where + Stub: stub::Stub, + { + type Req = Stub::Req; + type Resp = Stub::Resp; + type RespFut<'a> = RespFut<'a, Stub> + where Self: 'a; + + fn call<'a>( + &'a self, + ctx: context::Context, + request_name: &'static str, + request: Self::Req, + ) -> Self::RespFut<'a> { + Self::call(self, ctx, request_name, request) + } + } + + type RespFut<'a, Stub: stub::Stub + 'a> = + impl Future> + 'a; + + /// A Stub that load-balances across backing stubs by round robin. + #[derive(Clone, Debug)] + pub struct RoundRobin { + stubs: AtomicCycle, + } + + impl RoundRobin + where + Stub: stub::Stub, + { + /// Returns a new RoundRobin stub. + pub fn new(stubs: Vec) -> Self { + Self { + stubs: AtomicCycle::new(stubs), + } + } + + async fn call( + &self, + ctx: context::Context, + request_name: &'static str, + request: Stub::Req, + ) -> Result { + let next = self.stubs.next(); + next.call(ctx, request_name, request).await + } + } + + mod cycle { + use std::sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, + }; + + /// Cycles endlessly and atomically over a collection of elements of type T. + #[derive(Clone, Debug)] + pub struct AtomicCycle(Arc>); + + #[derive(Debug)] + struct State { + elements: Vec, + next: AtomicUsize, + } + + impl AtomicCycle { + pub fn new(elements: Vec) -> Self { + Self(Arc::new(State { + elements, + next: Default::default(), + })) + } + + pub fn next(&self) -> &T { + self.0.next() + } + } + + impl State { + pub fn next(&self) -> &T { + let next = self.next.fetch_add(1, Ordering::Relaxed); + &self.elements[next % self.elements.len()] + } + } + + #[test] + fn test_cycle() { + let cycle = AtomicCycle::new(vec![1, 2, 3]); + assert_eq!(cycle.next(), &1); + assert_eq!(cycle.next(), &2); + assert_eq!(cycle.next(), &3); + assert_eq!(cycle.next(), &1); + } + } +} + +/// Provides a stub that load-balances with a consistent hashing strategy. +/// +/// Each request is hashed, then mapped to a stub based on the hash. Equivalent requests will use +/// the same stub. +mod consistent_hash { + use crate::{ + client::{stub, RpcError}, + context, + }; + use futures::prelude::*; + use std::{ + collections::hash_map::RandomState, + hash::{BuildHasher, Hash, Hasher}, + num::TryFromIntError, + }; + + impl stub::Stub for ConsistentHash + where + Stub: stub::Stub, + Stub::Req: Hash, + { + type Req = Stub::Req; + type Resp = Stub::Resp; + type RespFut<'a> = RespFut<'a, Stub> + where Self: 'a; + + fn call<'a>( + &'a self, + ctx: context::Context, + request_name: &'static str, + request: Self::Req, + ) -> Self::RespFut<'a> { + Self::call(self, ctx, request_name, request) + } + } + + type RespFut<'a, Stub: stub::Stub + 'a> = + impl Future> + 'a; + + /// A Stub that load-balances across backing stubs by round robin. + #[derive(Clone, Debug)] + pub struct ConsistentHash { + stubs: Vec, + stubs_len: u64, + hasher: S, + } + + impl ConsistentHash + where + Stub: stub::Stub, + Stub::Req: Hash, + { + /// Returns a new RoundRobin stub. + /// Returns an err if the length of `stubs` overflows a u64. + pub fn new(stubs: Vec) -> Result { + Ok(Self { + stubs_len: stubs.len().try_into()?, + stubs, + hasher: RandomState::new(), + }) + } + } + + impl ConsistentHash + where + Stub: stub::Stub, + Stub::Req: Hash, + S: BuildHasher, + { + /// Returns a new RoundRobin stub. + /// Returns an err if the length of `stubs` overflows a u64. + pub fn with_hasher(stubs: Vec, hasher: S) -> Result { + Ok(Self { + stubs_len: stubs.len().try_into()?, + stubs, + hasher, + }) + } + + async fn call( + &self, + ctx: context::Context, + request_name: &'static str, + request: Stub::Req, + ) -> Result { + let index = usize::try_from(self.hash_request(&request) % self.stubs_len).expect( + "invariant broken: stubs_len is not larger than a usize, \ + so the hash modulo stubs_len should always fit in a usize", + ); + let next = &self.stubs[index]; + next.call(ctx, request_name, request).await + } + + fn hash_request(&self, req: &Stub::Req) -> u64 { + let mut hasher = self.hasher.build_hasher(); + req.hash(&mut hasher); + hasher.finish() + } + } + + #[cfg(test)] + mod tests { + use super::ConsistentHash; + use crate::{client::stub::mock::Mock, context}; + use std::{ + collections::HashMap, + hash::{BuildHasher, Hash, Hasher}, + rc::Rc, + }; + + #[tokio::test] + async fn test() -> anyhow::Result<()> { + let stub = ConsistentHash::with_hasher( + vec![ + // For easier reading of the assertions made in this test, each Mock's response + // value is equal to a hash value that should map to its index: 3 % 3 = 0, 1 % + // 3 = 1, etc. + Mock::new([('a', 3), ('b', 3), ('c', 3)]), + Mock::new([('a', 1), ('b', 1), ('c', 1)]), + Mock::new([('a', 2), ('b', 2), ('c', 2)]), + ], + FakeHasherBuilder::new([('a', 1), ('b', 2), ('c', 3)]), + )?; + + for _ in 0..2 { + let resp = stub.call(context::current(), "", 'a').await?; + assert_eq!(resp, 1); + + let resp = stub.call(context::current(), "", 'b').await?; + assert_eq!(resp, 2); + + let resp = stub.call(context::current(), "", 'c').await?; + assert_eq!(resp, 3); + } + + Ok(()) + } + + struct HashRecorder(Vec); + impl Hasher for HashRecorder { + fn write(&mut self, bytes: &[u8]) { + self.0 = Vec::from(bytes); + } + fn finish(&self) -> u64 { + 0 + } + } + + struct FakeHasherBuilder { + recorded_hashes: Rc, u64>>, + } + + struct FakeHasher { + recorded_hashes: Rc, u64>>, + output: u64, + } + + impl BuildHasher for FakeHasherBuilder { + type Hasher = FakeHasher; + + fn build_hasher(&self) -> Self::Hasher { + FakeHasher { + recorded_hashes: self.recorded_hashes.clone(), + output: 0, + } + } + } + + impl FakeHasherBuilder { + fn new(fake_hashes: [(T, u64); N]) -> Self { + let mut recorded_hashes = HashMap::new(); + for (to_hash, fake_hash) in fake_hashes { + let mut recorder = HashRecorder(vec![]); + to_hash.hash(&mut recorder); + recorded_hashes.insert(recorder.0, fake_hash); + } + Self { + recorded_hashes: Rc::new(recorded_hashes), + } + } + } + + impl Hasher for FakeHasher { + fn write(&mut self, bytes: &[u8]) { + if let Some(hash) = self.recorded_hashes.get(bytes) { + self.output = *hash; + } + } + fn finish(&self) -> u64 { + self.output + } + } + } +} diff --git a/tarpc/src/client/stub/mock.rs b/tarpc/src/client/stub/mock.rs new file mode 100644 index 0000000..99a5442 --- /dev/null +++ b/tarpc/src/client/stub/mock.rs @@ -0,0 +1,54 @@ +use crate::{ + client::{stub::Stub, RpcError}, + context, ServerError, +}; +use futures::future; +use std::{collections::HashMap, hash::Hash, io}; + +/// A mock stub that returns user-specified responses. +pub struct Mock { + responses: HashMap, +} + +impl Mock +where + Req: Eq + Hash, +{ + /// Returns a new mock, mocking the specified (request, response) pairs. + pub fn new(responses: [(Req, Resp); N]) -> Self { + Self { + responses: HashMap::from(responses), + } + } +} + +impl Stub for Mock +where + Req: Eq + Hash, + Resp: Clone, +{ + type Req = Req; + type Resp = Resp; + type RespFut<'a> = future::Ready> + where Self: 'a; + + fn call<'a>( + &'a self, + _: context::Context, + _: &'static str, + request: Self::Req, + ) -> Self::RespFut<'a> { + future::ready( + self.responses + .get(&request) + .cloned() + .map(Ok) + .unwrap_or_else(|| { + Err(RpcError::Server(ServerError { + kind: io::ErrorKind::NotFound, + detail: "mock (request, response) entry not found".into(), + })) + }), + ) + } +} diff --git a/tarpc/src/client/stub/retry.rs b/tarpc/src/client/stub/retry.rs new file mode 100644 index 0000000..46ad096 --- /dev/null +++ b/tarpc/src/client/stub/retry.rs @@ -0,0 +1,75 @@ +//! Provides a stub that retries requests based on response contents.. + +use crate::{ + client::{stub, RpcError}, + context, +}; +use futures::prelude::*; +use std::sync::Arc; + +impl stub::Stub for Retry +where + Stub: stub::Stub>, + F: Fn(&Result, u32) -> bool, +{ + type Req = Req; + type Resp = Stub::Resp; + type RespFut<'a> = RespFut<'a, Stub, Self::Req, F> + where Self: 'a, + Self::Req: 'a; + + fn call<'a>( + &'a self, + ctx: context::Context, + request_name: &'static str, + request: Self::Req, + ) -> Self::RespFut<'a> { + Self::call(self, ctx, request_name, request) + } +} + +type RespFut<'a, Stub: stub::Stub + 'a, Req: 'a, F: 'a> = + impl Future> + 'a; + +/// A Stub that retries requests based on response contents. +/// Note: to use this stub with Serde serialization, the "rc" feature of Serde needs to be enabled. +#[derive(Clone, Debug)] +pub struct Retry { + should_retry: F, + stub: Stub, +} + +impl Retry +where + Stub: stub::Stub>, + F: Fn(&Result, u32) -> bool, +{ + /// Creates a new Retry stub that delegates calls to the underlying `stub`. + pub fn new(stub: Stub, should_retry: F) -> Self { + Self { stub, should_retry } + } + + async fn call<'a, 'b>( + &'a self, + ctx: context::Context, + request_name: &'static str, + request: Req, + ) -> Result + where + Req: 'b, + { + let request = Arc::new(request); + for i in 1.. { + let result = self + .stub + .call(ctx, request_name, Arc::clone(&request)) + .await; + if (self.should_retry)(&result, i) { + tracing::trace!("Retrying on attempt {i}"); + continue; + } + return result; + } + unreachable!("Wow, that was a lot of attempts!"); + } +} diff --git a/tarpc/src/lib.rs b/tarpc/src/lib.rs index 418cedd..b47d13b 100644 --- a/tarpc/src/lib.rs +++ b/tarpc/src/lib.rs @@ -200,6 +200,15 @@ //! //! Use `cargo doc` as you normally would to see the documentation created for all //! items expanded by a `service!` invocation. + +// For async_fn_in_trait +#![allow(incomplete_features)] +#![feature( + iter_intersperse, + type_alias_impl_trait, + async_fn_in_trait, + return_position_impl_trait_in_trait +)] #![deny(missing_docs)] #![allow(clippy::type_complexity)] #![cfg_attr(docsrs, feature(doc_cfg))] @@ -407,6 +416,13 @@ where Close(#[source] E), } +impl ServerError { + /// Returns a new server error with `kind` and `detail`. + pub fn new(kind: io::ErrorKind, detail: String) -> ServerError { + Self { kind, detail } + } +} + impl Request { /// Returns the deadline for this request. pub fn deadline(&self) -> &SystemTime { diff --git a/tarpc/src/server.rs b/tarpc/src/server.rs index 9f04b27..b44724d 100644 --- a/tarpc/src/server.rs +++ b/tarpc/src/server.rs @@ -9,7 +9,7 @@ use crate::{ cancellations::{cancellations, CanceledRequests, RequestCancellation}, context::{self, SpanExt}, - trace, ChannelError, ClientMessage, Request, Response, Transport, + trace, ChannelError, ClientMessage, Request, Response, ServerError, Transport, }; use ::tokio::sync::mpsc; use futures::{ @@ -21,7 +21,7 @@ use futures::{ }; use in_flight_requests::{AlreadyExistsError, InFlightRequests}; use pin_project::pin_project; -use std::{convert::TryFrom, error::Error, fmt, marker::PhantomData, pin::Pin, sync::Arc}; +use std::{convert::TryFrom, error::Error, fmt, marker::PhantomData, mem, pin::Pin, sync::Arc}; use tracing::{info_span, instrument::Instrument, Span}; mod in_flight_requests; @@ -565,6 +565,10 @@ where }| { // The response guard becomes active once in an InFlightRequest. response_guard.cancel = true; + { + let _entered = span.enter(); + tracing::info!("BeginRequest"); + } InFlightRequest { request, abort_registration, @@ -686,6 +690,29 @@ impl InFlightRequest { &self.request } + /// Respond without executing a service function. Useful for early aborts (e.g. for throttling). + pub async fn respond(self, response: Result) { + let Self { + response_tx, + response_guard, + request: Request { id: request_id, .. }, + span, + .. + } = self; + let _entered = span.enter(); + tracing::info!("CompleteRequest"); + let response = Response { + request_id, + message: response, + }; + let _ = response_tx.send(response).await; + tracing::info!("BufferResponse"); + // Request processing has completed, meaning either the channel canceled the request or + // a request was sent back to the channel. Either way, the channel will clean up the + // request data, so the request does not need to be canceled. + mem::forget(response_guard); + } + /// Returns a [future](Future) that executes the request using the given [service /// function](Serve). The service function's output is automatically sent back to the [Channel] /// that yielded this request. The request will be executed in the scope of this request's @@ -720,7 +747,6 @@ impl InFlightRequest { span.record("otel.name", method.unwrap_or("")); let _ = Abortable::new( async move { - tracing::info!("BeginRequest"); let response = serve.serve(context, message).await; tracing::info!("CompleteRequest"); let response = Response { diff --git a/tarpc/tests/compile_fail/tarpc_server_missing_async.stderr b/tarpc/tests/compile_fail/tarpc_server_missing_async.stderr index 28106e6..d96cda8 100644 --- a/tarpc/tests/compile_fail/tarpc_server_missing_async.stderr +++ b/tarpc/tests/compile_fail/tarpc_server_missing_async.stderr @@ -1,11 +1,15 @@ error: not all trait items implemented, missing: `HelloFut` - --> $DIR/tarpc_server_missing_async.rs:9:1 - | -9 | impl World for HelloServer { - | ^^^^ + --> tests/compile_fail/tarpc_server_missing_async.rs:9:1 + | +9 | / impl World for HelloServer { +10 | | fn hello(name: String) -> String { +11 | | format!("Hello, {name}!", name) +12 | | } +13 | | } + | |_^ error: hint: `#[tarpc::server]` only rewrites async fns, and `fn hello` is not async - --> $DIR/tarpc_server_missing_async.rs:10:5 + --> tests/compile_fail/tarpc_server_missing_async.rs:10:5 | 10 | fn hello(name: String) -> String { - | ^^ + | ^^^^^^^^