mirror of
https://github.com/OMGeeky/tarpc.git
synced 2026-01-07 03:56:48 +01:00
Add back the Client trait, renamed Stub.
Also adds a Client stub trait alias for each generated service. Now that generic associated types are stable, it's almost possible to define a trait for Channel that works with async fns on stable. `impl trait in type aliases` is still necessary (and unstable), but we're getting closer. As a proof of concept, three more implementations of Stub are implemented; 1. A load balancer that round-robins requests between different stubs. 2. A load balancer that selects a stub based on a request hash, so that the same requests go to the same stubs. 3. A stub that retries requests based on a configurable policy. The "serde/rc" feature is added to the "full" feature because the Retry stub wraps the request in an Arc, so that the request is reusable for multiple calls. Server implementors commonly need to operate generically across all services or request types. For example, a server throttler may want to return errors telling clients to back off, which is not specific to any one service.
This commit is contained in:
@@ -276,6 +276,7 @@ pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream {
|
||||
ServiceGenerator {
|
||||
response_fut_name,
|
||||
service_ident: ident,
|
||||
client_stub_ident: &format_ident!("{}Stub", ident),
|
||||
server_ident: &format_ident!("Serve{}", ident),
|
||||
response_fut_ident: &Ident::new(response_fut_name, ident.span()),
|
||||
client_ident: &format_ident!("{}Client", ident),
|
||||
@@ -432,6 +433,7 @@ fn verify_types_were_provided(
|
||||
// the client stub.
|
||||
struct ServiceGenerator<'a> {
|
||||
service_ident: &'a Ident,
|
||||
client_stub_ident: &'a Ident,
|
||||
server_ident: &'a Ident,
|
||||
response_fut_ident: &'a Ident,
|
||||
response_fut_name: &'a str,
|
||||
@@ -461,6 +463,9 @@ impl<'a> ServiceGenerator<'a> {
|
||||
future_types,
|
||||
return_types,
|
||||
service_ident,
|
||||
client_stub_ident,
|
||||
request_ident,
|
||||
response_ident,
|
||||
server_ident,
|
||||
..
|
||||
} = self;
|
||||
@@ -490,6 +495,7 @@ impl<'a> ServiceGenerator<'a> {
|
||||
},
|
||||
);
|
||||
|
||||
let stub_doc = format!("The stub trait for service [`{service_ident}`].");
|
||||
quote! {
|
||||
#( #attrs )*
|
||||
#vis trait #service_ident: Sized {
|
||||
@@ -501,6 +507,15 @@ impl<'a> ServiceGenerator<'a> {
|
||||
#server_ident { service: self }
|
||||
}
|
||||
}
|
||||
|
||||
#[doc = #stub_doc]
|
||||
#vis trait #client_stub_ident: tarpc::client::stub::Stub<Req = #request_ident, Resp = #response_ident> {
|
||||
}
|
||||
|
||||
impl<S> #client_stub_ident for S
|
||||
where S: tarpc::client::stub::Stub<Req = #request_ident, Resp = #response_ident>
|
||||
{
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -689,7 +704,9 @@ impl<'a> ServiceGenerator<'a> {
|
||||
#[derive(Clone, Debug)]
|
||||
/// The client stub that makes RPC calls to the server. All request methods return
|
||||
/// [Futures](std::future::Future).
|
||||
#vis struct #client_ident(tarpc::client::Channel<#request_ident, #response_ident>);
|
||||
#vis struct #client_ident<
|
||||
Stub = tarpc::client::Channel<#request_ident, #response_ident>
|
||||
>(Stub);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -719,6 +736,17 @@ impl<'a> ServiceGenerator<'a> {
|
||||
dispatch: new_client.dispatch,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<Stub> From<Stub> for #client_ident<Stub>
|
||||
where Stub: tarpc::client::stub::Stub<
|
||||
Req = #request_ident,
|
||||
Resp = #response_ident>
|
||||
{
|
||||
/// Returns a new client stub that sends requests over the given transport.
|
||||
fn from(stub: Stub) -> Self {
|
||||
#client_ident(stub)
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
@@ -741,7 +769,11 @@ impl<'a> ServiceGenerator<'a> {
|
||||
} = self;
|
||||
|
||||
quote! {
|
||||
impl #client_ident {
|
||||
impl<Stub> #client_ident<Stub>
|
||||
where Stub: tarpc::client::stub::Stub<
|
||||
Req = #request_ident,
|
||||
Resp = #response_ident>
|
||||
{
|
||||
#(
|
||||
#[allow(unused)]
|
||||
#( #method_attrs )*
|
||||
|
||||
@@ -19,7 +19,7 @@ description = "An RPC framework for Rust with a focus on ease of use."
|
||||
[features]
|
||||
default = []
|
||||
|
||||
serde1 = ["tarpc-plugins/serde1", "serde", "serde/derive"]
|
||||
serde1 = ["tarpc-plugins/serde1", "serde", "serde/derive", "serde/rc"]
|
||||
tokio1 = ["tokio/rt"]
|
||||
serde-transport = ["serde1", "tokio1", "tokio-serde", "tokio-util/codec"]
|
||||
serde-transport-json = ["tokio-serde/json"]
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
//! Provides a client that connects to a server and sends multiplexed requests.
|
||||
|
||||
mod in_flight_requests;
|
||||
pub mod stub;
|
||||
|
||||
use crate::{
|
||||
cancellations::{cancellations, CanceledRequests, RequestCancellation},
|
||||
|
||||
56
tarpc/src/client/stub.rs
Normal file
56
tarpc/src/client/stub.rs
Normal file
@@ -0,0 +1,56 @@
|
||||
//! Provides a Stub trait, implemented by types that can call remote services.
|
||||
|
||||
use crate::{
|
||||
client::{Channel, RpcError},
|
||||
context,
|
||||
};
|
||||
use futures::prelude::*;
|
||||
|
||||
pub mod load_balance;
|
||||
pub mod retry;
|
||||
|
||||
#[cfg(test)]
|
||||
mod mock;
|
||||
|
||||
/// A connection to a remote service.
|
||||
/// Calls the service with requests of type `Req` and receives responses of type `Resp`.
|
||||
pub trait Stub {
|
||||
/// The service request type.
|
||||
type Req;
|
||||
|
||||
/// The service response type.
|
||||
type Resp;
|
||||
|
||||
/// The type of the future returned by `Stub::call`.
|
||||
type RespFut<'a>: Future<Output = Result<Self::Resp, RpcError>>
|
||||
where
|
||||
Self: 'a,
|
||||
Self::Req: 'a,
|
||||
Self::Resp: 'a;
|
||||
|
||||
/// Calls a remote service.
|
||||
fn call<'a>(
|
||||
&'a self,
|
||||
ctx: context::Context,
|
||||
request_name: &'static str,
|
||||
request: Self::Req,
|
||||
) -> Self::RespFut<'a>;
|
||||
}
|
||||
|
||||
impl<Req, Resp> Stub for Channel<Req, Resp> {
|
||||
type Req = Req;
|
||||
type Resp = Resp;
|
||||
type RespFut<'a> = RespFut<'a, Req, Resp>
|
||||
where Self: 'a;
|
||||
|
||||
fn call<'a>(
|
||||
&'a self,
|
||||
ctx: context::Context,
|
||||
request_name: &'static str,
|
||||
request: Req,
|
||||
) -> Self::RespFut<'a> {
|
||||
Self::call(self, ctx, request_name, request)
|
||||
}
|
||||
}
|
||||
|
||||
type RespFut<'a, Req: 'a, Resp: 'a> = impl Future<Output = Result<Resp, RpcError>> + 'a;
|
||||
305
tarpc/src/client/stub/load_balance.rs
Normal file
305
tarpc/src/client/stub/load_balance.rs
Normal file
@@ -0,0 +1,305 @@
|
||||
//! Provides load-balancing [Stubs](crate::client::stub::Stub).
|
||||
|
||||
pub use consistent_hash::ConsistentHash;
|
||||
pub use round_robin::RoundRobin;
|
||||
|
||||
/// Provides a stub that load-balances with a simple round-robin strategy.
|
||||
mod round_robin {
|
||||
use crate::{
|
||||
client::{stub, RpcError},
|
||||
context,
|
||||
};
|
||||
use cycle::AtomicCycle;
|
||||
use futures::prelude::*;
|
||||
|
||||
impl<Stub> stub::Stub for RoundRobin<Stub>
|
||||
where
|
||||
Stub: stub::Stub,
|
||||
{
|
||||
type Req = Stub::Req;
|
||||
type Resp = Stub::Resp;
|
||||
type RespFut<'a> = RespFut<'a, Stub>
|
||||
where Self: 'a;
|
||||
|
||||
fn call<'a>(
|
||||
&'a self,
|
||||
ctx: context::Context,
|
||||
request_name: &'static str,
|
||||
request: Self::Req,
|
||||
) -> Self::RespFut<'a> {
|
||||
Self::call(self, ctx, request_name, request)
|
||||
}
|
||||
}
|
||||
|
||||
type RespFut<'a, Stub: stub::Stub + 'a> =
|
||||
impl Future<Output = Result<Stub::Resp, RpcError>> + 'a;
|
||||
|
||||
/// A Stub that load-balances across backing stubs by round robin.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct RoundRobin<Stub> {
|
||||
stubs: AtomicCycle<Stub>,
|
||||
}
|
||||
|
||||
impl<Stub> RoundRobin<Stub>
|
||||
where
|
||||
Stub: stub::Stub,
|
||||
{
|
||||
/// Returns a new RoundRobin stub.
|
||||
pub fn new(stubs: Vec<Stub>) -> Self {
|
||||
Self {
|
||||
stubs: AtomicCycle::new(stubs),
|
||||
}
|
||||
}
|
||||
|
||||
async fn call(
|
||||
&self,
|
||||
ctx: context::Context,
|
||||
request_name: &'static str,
|
||||
request: Stub::Req,
|
||||
) -> Result<Stub::Resp, RpcError> {
|
||||
let next = self.stubs.next();
|
||||
next.call(ctx, request_name, request).await
|
||||
}
|
||||
}
|
||||
|
||||
mod cycle {
|
||||
use std::sync::{
|
||||
atomic::{AtomicUsize, Ordering},
|
||||
Arc,
|
||||
};
|
||||
|
||||
/// Cycles endlessly and atomically over a collection of elements of type T.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct AtomicCycle<T>(Arc<State<T>>);
|
||||
|
||||
#[derive(Debug)]
|
||||
struct State<T> {
|
||||
elements: Vec<T>,
|
||||
next: AtomicUsize,
|
||||
}
|
||||
|
||||
impl<T> AtomicCycle<T> {
|
||||
pub fn new(elements: Vec<T>) -> Self {
|
||||
Self(Arc::new(State {
|
||||
elements,
|
||||
next: Default::default(),
|
||||
}))
|
||||
}
|
||||
|
||||
pub fn next(&self) -> &T {
|
||||
self.0.next()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> State<T> {
|
||||
pub fn next(&self) -> &T {
|
||||
let next = self.next.fetch_add(1, Ordering::Relaxed);
|
||||
&self.elements[next % self.elements.len()]
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cycle() {
|
||||
let cycle = AtomicCycle::new(vec![1, 2, 3]);
|
||||
assert_eq!(cycle.next(), &1);
|
||||
assert_eq!(cycle.next(), &2);
|
||||
assert_eq!(cycle.next(), &3);
|
||||
assert_eq!(cycle.next(), &1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Provides a stub that load-balances with a consistent hashing strategy.
|
||||
///
|
||||
/// Each request is hashed, then mapped to a stub based on the hash. Equivalent requests will use
|
||||
/// the same stub.
|
||||
mod consistent_hash {
|
||||
use crate::{
|
||||
client::{stub, RpcError},
|
||||
context,
|
||||
};
|
||||
use futures::prelude::*;
|
||||
use std::{
|
||||
collections::hash_map::RandomState,
|
||||
hash::{BuildHasher, Hash, Hasher},
|
||||
num::TryFromIntError,
|
||||
};
|
||||
|
||||
impl<Stub> stub::Stub for ConsistentHash<Stub>
|
||||
where
|
||||
Stub: stub::Stub,
|
||||
Stub::Req: Hash,
|
||||
{
|
||||
type Req = Stub::Req;
|
||||
type Resp = Stub::Resp;
|
||||
type RespFut<'a> = RespFut<'a, Stub>
|
||||
where Self: 'a;
|
||||
|
||||
fn call<'a>(
|
||||
&'a self,
|
||||
ctx: context::Context,
|
||||
request_name: &'static str,
|
||||
request: Self::Req,
|
||||
) -> Self::RespFut<'a> {
|
||||
Self::call(self, ctx, request_name, request)
|
||||
}
|
||||
}
|
||||
|
||||
type RespFut<'a, Stub: stub::Stub + 'a> =
|
||||
impl Future<Output = Result<Stub::Resp, RpcError>> + 'a;
|
||||
|
||||
/// A Stub that load-balances across backing stubs by round robin.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct ConsistentHash<Stub, S = RandomState> {
|
||||
stubs: Vec<Stub>,
|
||||
stubs_len: u64,
|
||||
hasher: S,
|
||||
}
|
||||
|
||||
impl<Stub> ConsistentHash<Stub, RandomState>
|
||||
where
|
||||
Stub: stub::Stub,
|
||||
Stub::Req: Hash,
|
||||
{
|
||||
/// Returns a new RoundRobin stub.
|
||||
/// Returns an err if the length of `stubs` overflows a u64.
|
||||
pub fn new(stubs: Vec<Stub>) -> Result<Self, TryFromIntError> {
|
||||
Ok(Self {
|
||||
stubs_len: stubs.len().try_into()?,
|
||||
stubs,
|
||||
hasher: RandomState::new(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<Stub, S> ConsistentHash<Stub, S>
|
||||
where
|
||||
Stub: stub::Stub,
|
||||
Stub::Req: Hash,
|
||||
S: BuildHasher,
|
||||
{
|
||||
/// Returns a new RoundRobin stub.
|
||||
/// Returns an err if the length of `stubs` overflows a u64.
|
||||
pub fn with_hasher(stubs: Vec<Stub>, hasher: S) -> Result<Self, TryFromIntError> {
|
||||
Ok(Self {
|
||||
stubs_len: stubs.len().try_into()?,
|
||||
stubs,
|
||||
hasher,
|
||||
})
|
||||
}
|
||||
|
||||
async fn call(
|
||||
&self,
|
||||
ctx: context::Context,
|
||||
request_name: &'static str,
|
||||
request: Stub::Req,
|
||||
) -> Result<Stub::Resp, RpcError> {
|
||||
let index = usize::try_from(self.hash_request(&request) % self.stubs_len).expect(
|
||||
"invariant broken: stubs_len is not larger than a usize, \
|
||||
so the hash modulo stubs_len should always fit in a usize",
|
||||
);
|
||||
let next = &self.stubs[index];
|
||||
next.call(ctx, request_name, request).await
|
||||
}
|
||||
|
||||
fn hash_request(&self, req: &Stub::Req) -> u64 {
|
||||
let mut hasher = self.hasher.build_hasher();
|
||||
req.hash(&mut hasher);
|
||||
hasher.finish()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::ConsistentHash;
|
||||
use crate::{client::stub::mock::Mock, context};
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
hash::{BuildHasher, Hash, Hasher},
|
||||
rc::Rc,
|
||||
};
|
||||
|
||||
#[tokio::test]
|
||||
async fn test() -> anyhow::Result<()> {
|
||||
let stub = ConsistentHash::with_hasher(
|
||||
vec![
|
||||
// For easier reading of the assertions made in this test, each Mock's response
|
||||
// value is equal to a hash value that should map to its index: 3 % 3 = 0, 1 %
|
||||
// 3 = 1, etc.
|
||||
Mock::new([('a', 3), ('b', 3), ('c', 3)]),
|
||||
Mock::new([('a', 1), ('b', 1), ('c', 1)]),
|
||||
Mock::new([('a', 2), ('b', 2), ('c', 2)]),
|
||||
],
|
||||
FakeHasherBuilder::new([('a', 1), ('b', 2), ('c', 3)]),
|
||||
)?;
|
||||
|
||||
for _ in 0..2 {
|
||||
let resp = stub.call(context::current(), "", 'a').await?;
|
||||
assert_eq!(resp, 1);
|
||||
|
||||
let resp = stub.call(context::current(), "", 'b').await?;
|
||||
assert_eq!(resp, 2);
|
||||
|
||||
let resp = stub.call(context::current(), "", 'c').await?;
|
||||
assert_eq!(resp, 3);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
struct HashRecorder(Vec<u8>);
|
||||
impl Hasher for HashRecorder {
|
||||
fn write(&mut self, bytes: &[u8]) {
|
||||
self.0 = Vec::from(bytes);
|
||||
}
|
||||
fn finish(&self) -> u64 {
|
||||
0
|
||||
}
|
||||
}
|
||||
|
||||
struct FakeHasherBuilder {
|
||||
recorded_hashes: Rc<HashMap<Vec<u8>, u64>>,
|
||||
}
|
||||
|
||||
struct FakeHasher {
|
||||
recorded_hashes: Rc<HashMap<Vec<u8>, u64>>,
|
||||
output: u64,
|
||||
}
|
||||
|
||||
impl BuildHasher for FakeHasherBuilder {
|
||||
type Hasher = FakeHasher;
|
||||
|
||||
fn build_hasher(&self) -> Self::Hasher {
|
||||
FakeHasher {
|
||||
recorded_hashes: self.recorded_hashes.clone(),
|
||||
output: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl FakeHasherBuilder {
|
||||
fn new<T: Hash, const N: usize>(fake_hashes: [(T, u64); N]) -> Self {
|
||||
let mut recorded_hashes = HashMap::new();
|
||||
for (to_hash, fake_hash) in fake_hashes {
|
||||
let mut recorder = HashRecorder(vec![]);
|
||||
to_hash.hash(&mut recorder);
|
||||
recorded_hashes.insert(recorder.0, fake_hash);
|
||||
}
|
||||
Self {
|
||||
recorded_hashes: Rc::new(recorded_hashes),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Hasher for FakeHasher {
|
||||
fn write(&mut self, bytes: &[u8]) {
|
||||
if let Some(hash) = self.recorded_hashes.get(bytes) {
|
||||
self.output = *hash;
|
||||
}
|
||||
}
|
||||
fn finish(&self) -> u64 {
|
||||
self.output
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
54
tarpc/src/client/stub/mock.rs
Normal file
54
tarpc/src/client/stub/mock.rs
Normal file
@@ -0,0 +1,54 @@
|
||||
use crate::{
|
||||
client::{stub::Stub, RpcError},
|
||||
context, ServerError,
|
||||
};
|
||||
use futures::future;
|
||||
use std::{collections::HashMap, hash::Hash, io};
|
||||
|
||||
/// A mock stub that returns user-specified responses.
|
||||
pub struct Mock<Req, Resp> {
|
||||
responses: HashMap<Req, Resp>,
|
||||
}
|
||||
|
||||
impl<Req, Resp> Mock<Req, Resp>
|
||||
where
|
||||
Req: Eq + Hash,
|
||||
{
|
||||
/// Returns a new mock, mocking the specified (request, response) pairs.
|
||||
pub fn new<const N: usize>(responses: [(Req, Resp); N]) -> Self {
|
||||
Self {
|
||||
responses: HashMap::from(responses),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<Req, Resp> Stub for Mock<Req, Resp>
|
||||
where
|
||||
Req: Eq + Hash,
|
||||
Resp: Clone,
|
||||
{
|
||||
type Req = Req;
|
||||
type Resp = Resp;
|
||||
type RespFut<'a> = future::Ready<Result<Resp, RpcError>>
|
||||
where Self: 'a;
|
||||
|
||||
fn call<'a>(
|
||||
&'a self,
|
||||
_: context::Context,
|
||||
_: &'static str,
|
||||
request: Self::Req,
|
||||
) -> Self::RespFut<'a> {
|
||||
future::ready(
|
||||
self.responses
|
||||
.get(&request)
|
||||
.cloned()
|
||||
.map(Ok)
|
||||
.unwrap_or_else(|| {
|
||||
Err(RpcError::Server(ServerError {
|
||||
kind: io::ErrorKind::NotFound,
|
||||
detail: "mock (request, response) entry not found".into(),
|
||||
}))
|
||||
}),
|
||||
)
|
||||
}
|
||||
}
|
||||
75
tarpc/src/client/stub/retry.rs
Normal file
75
tarpc/src/client/stub/retry.rs
Normal file
@@ -0,0 +1,75 @@
|
||||
//! Provides a stub that retries requests based on response contents..
|
||||
|
||||
use crate::{
|
||||
client::{stub, RpcError},
|
||||
context,
|
||||
};
|
||||
use futures::prelude::*;
|
||||
use std::sync::Arc;
|
||||
|
||||
impl<Stub, Req, F> stub::Stub for Retry<F, Stub>
|
||||
where
|
||||
Stub: stub::Stub<Req = Arc<Req>>,
|
||||
F: Fn(&Result<Stub::Resp, RpcError>, u32) -> bool,
|
||||
{
|
||||
type Req = Req;
|
||||
type Resp = Stub::Resp;
|
||||
type RespFut<'a> = RespFut<'a, Stub, Self::Req, F>
|
||||
where Self: 'a,
|
||||
Self::Req: 'a;
|
||||
|
||||
fn call<'a>(
|
||||
&'a self,
|
||||
ctx: context::Context,
|
||||
request_name: &'static str,
|
||||
request: Self::Req,
|
||||
) -> Self::RespFut<'a> {
|
||||
Self::call(self, ctx, request_name, request)
|
||||
}
|
||||
}
|
||||
|
||||
type RespFut<'a, Stub: stub::Stub + 'a, Req: 'a, F: 'a> =
|
||||
impl Future<Output = Result<Stub::Resp, RpcError>> + 'a;
|
||||
|
||||
/// A Stub that retries requests based on response contents.
|
||||
/// Note: to use this stub with Serde serialization, the "rc" feature of Serde needs to be enabled.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Retry<F, Stub> {
|
||||
should_retry: F,
|
||||
stub: Stub,
|
||||
}
|
||||
|
||||
impl<Stub, Req, F> Retry<F, Stub>
|
||||
where
|
||||
Stub: stub::Stub<Req = Arc<Req>>,
|
||||
F: Fn(&Result<Stub::Resp, RpcError>, u32) -> bool,
|
||||
{
|
||||
/// Creates a new Retry stub that delegates calls to the underlying `stub`.
|
||||
pub fn new(stub: Stub, should_retry: F) -> Self {
|
||||
Self { stub, should_retry }
|
||||
}
|
||||
|
||||
async fn call<'a, 'b>(
|
||||
&'a self,
|
||||
ctx: context::Context,
|
||||
request_name: &'static str,
|
||||
request: Req,
|
||||
) -> Result<Stub::Resp, RpcError>
|
||||
where
|
||||
Req: 'b,
|
||||
{
|
||||
let request = Arc::new(request);
|
||||
for i in 1.. {
|
||||
let result = self
|
||||
.stub
|
||||
.call(ctx, request_name, Arc::clone(&request))
|
||||
.await;
|
||||
if (self.should_retry)(&result, i) {
|
||||
tracing::trace!("Retrying on attempt {i}");
|
||||
continue;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
unreachable!("Wow, that was a lot of attempts!");
|
||||
}
|
||||
}
|
||||
@@ -200,6 +200,15 @@
|
||||
//!
|
||||
//! Use `cargo doc` as you normally would to see the documentation created for all
|
||||
//! items expanded by a `service!` invocation.
|
||||
|
||||
// For async_fn_in_trait
|
||||
#![allow(incomplete_features)]
|
||||
#![feature(
|
||||
iter_intersperse,
|
||||
type_alias_impl_trait,
|
||||
async_fn_in_trait,
|
||||
return_position_impl_trait_in_trait
|
||||
)]
|
||||
#![deny(missing_docs)]
|
||||
#![allow(clippy::type_complexity)]
|
||||
#![cfg_attr(docsrs, feature(doc_cfg))]
|
||||
@@ -407,6 +416,13 @@ where
|
||||
Close(#[source] E),
|
||||
}
|
||||
|
||||
impl ServerError {
|
||||
/// Returns a new server error with `kind` and `detail`.
|
||||
pub fn new(kind: io::ErrorKind, detail: String) -> ServerError {
|
||||
Self { kind, detail }
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Request<T> {
|
||||
/// Returns the deadline for this request.
|
||||
pub fn deadline(&self) -> &SystemTime {
|
||||
|
||||
@@ -9,7 +9,7 @@
|
||||
use crate::{
|
||||
cancellations::{cancellations, CanceledRequests, RequestCancellation},
|
||||
context::{self, SpanExt},
|
||||
trace, ChannelError, ClientMessage, Request, Response, Transport,
|
||||
trace, ChannelError, ClientMessage, Request, Response, ServerError, Transport,
|
||||
};
|
||||
use ::tokio::sync::mpsc;
|
||||
use futures::{
|
||||
@@ -21,7 +21,7 @@ use futures::{
|
||||
};
|
||||
use in_flight_requests::{AlreadyExistsError, InFlightRequests};
|
||||
use pin_project::pin_project;
|
||||
use std::{convert::TryFrom, error::Error, fmt, marker::PhantomData, pin::Pin, sync::Arc};
|
||||
use std::{convert::TryFrom, error::Error, fmt, marker::PhantomData, mem, pin::Pin, sync::Arc};
|
||||
use tracing::{info_span, instrument::Instrument, Span};
|
||||
|
||||
mod in_flight_requests;
|
||||
@@ -565,6 +565,10 @@ where
|
||||
}| {
|
||||
// The response guard becomes active once in an InFlightRequest.
|
||||
response_guard.cancel = true;
|
||||
{
|
||||
let _entered = span.enter();
|
||||
tracing::info!("BeginRequest");
|
||||
}
|
||||
InFlightRequest {
|
||||
request,
|
||||
abort_registration,
|
||||
@@ -686,6 +690,29 @@ impl<Req, Res> InFlightRequest<Req, Res> {
|
||||
&self.request
|
||||
}
|
||||
|
||||
/// Respond without executing a service function. Useful for early aborts (e.g. for throttling).
|
||||
pub async fn respond(self, response: Result<Res, ServerError>) {
|
||||
let Self {
|
||||
response_tx,
|
||||
response_guard,
|
||||
request: Request { id: request_id, .. },
|
||||
span,
|
||||
..
|
||||
} = self;
|
||||
let _entered = span.enter();
|
||||
tracing::info!("CompleteRequest");
|
||||
let response = Response {
|
||||
request_id,
|
||||
message: response,
|
||||
};
|
||||
let _ = response_tx.send(response).await;
|
||||
tracing::info!("BufferResponse");
|
||||
// Request processing has completed, meaning either the channel canceled the request or
|
||||
// a request was sent back to the channel. Either way, the channel will clean up the
|
||||
// request data, so the request does not need to be canceled.
|
||||
mem::forget(response_guard);
|
||||
}
|
||||
|
||||
/// Returns a [future](Future) that executes the request using the given [service
|
||||
/// function](Serve). The service function's output is automatically sent back to the [Channel]
|
||||
/// that yielded this request. The request will be executed in the scope of this request's
|
||||
@@ -720,7 +747,6 @@ impl<Req, Res> InFlightRequest<Req, Res> {
|
||||
span.record("otel.name", method.unwrap_or(""));
|
||||
let _ = Abortable::new(
|
||||
async move {
|
||||
tracing::info!("BeginRequest");
|
||||
let response = serve.serve(context, message).await;
|
||||
tracing::info!("CompleteRequest");
|
||||
let response = Response {
|
||||
|
||||
@@ -1,11 +1,15 @@
|
||||
error: not all trait items implemented, missing: `HelloFut`
|
||||
--> $DIR/tarpc_server_missing_async.rs:9:1
|
||||
|
|
||||
9 | impl World for HelloServer {
|
||||
| ^^^^
|
||||
--> tests/compile_fail/tarpc_server_missing_async.rs:9:1
|
||||
|
|
||||
9 | / impl World for HelloServer {
|
||||
10 | | fn hello(name: String) -> String {
|
||||
11 | | format!("Hello, {name}!", name)
|
||||
12 | | }
|
||||
13 | | }
|
||||
| |_^
|
||||
|
||||
error: hint: `#[tarpc::server]` only rewrites async fns, and `fn hello` is not async
|
||||
--> $DIR/tarpc_server_missing_async.rs:10:5
|
||||
--> tests/compile_fail/tarpc_server_missing_async.rs:10:5
|
||||
|
|
||||
10 | fn hello(name: String) -> String {
|
||||
| ^^
|
||||
| ^^^^^^^^
|
||||
|
||||
Reference in New Issue
Block a user