Rewrite traits to use async-fn-in-trait.

- Stub
- BeforeRequest
- AfterRequest

Also removed the last remaining usage of an unstable feature,
iter_intersperse.
This commit is contained in:
Tim Kuehn
2023-11-06 12:43:48 -08:00
committed by Tim
parent 84932df9b4
commit 6cf18a1caf
22 changed files with 97 additions and 251 deletions

View File

@@ -4,9 +4,6 @@
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
#![allow(incomplete_features)]
#![feature(async_fn_in_trait)]
use std::env;
use tracing_subscriber::{fmt::format::FmtSpan, prelude::*};

View File

@@ -4,9 +4,6 @@
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
#![allow(incomplete_features)]
#![feature(async_fn_in_trait)]
use clap::Parser;
use futures::{future, prelude::*};
use rand::{

View File

@@ -1,6 +1,3 @@
#![allow(incomplete_features)]
#![feature(async_fn_in_trait)]
// these need to be out here rather than inside the function so that the
// assert_type_eq macro can pick them up.
#[tarpc::service]

View File

@@ -1,6 +1,3 @@
#![allow(incomplete_features)]
#![feature(async_fn_in_trait)]
use tarpc::context;
#[test]

View File

@@ -4,9 +4,6 @@
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
#![allow(incomplete_features)]
#![feature(async_fn_in_trait)]
use flate2::{read::DeflateDecoder, write::DeflateEncoder, Compression};
use futures::{prelude::*, Sink, SinkExt, Stream, StreamExt, TryStreamExt};
use serde::{Deserialize, Serialize};

View File

@@ -4,9 +4,6 @@
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
#![allow(incomplete_features)]
#![feature(async_fn_in_trait)]
use futures::prelude::*;
use tarpc::context::Context;
use tarpc::serde_transport as transport;

View File

@@ -4,9 +4,6 @@
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
#![allow(incomplete_features)]
#![feature(async_fn_in_trait)]
/// - The PubSub server sets up TCP listeners on 2 ports, the "subscriber" port and the "publisher"
/// port. Because both publishers and subscribers initiate their connections to the PubSub
/// server, the server requires no prior knowledge of either publishers or subscribers.

View File

@@ -4,9 +4,6 @@
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
#![allow(incomplete_features)]
#![feature(async_fn_in_trait)]
use futures::prelude::*;
use tarpc::{
client, context,

View File

@@ -4,9 +4,6 @@
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
#![allow(incomplete_features)]
#![feature(async_fn_in_trait)]
use futures::prelude::*;
use rustls_pemfile::certs;
use std::io::{BufReader, Cursor};

View File

@@ -4,9 +4,6 @@
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
#![allow(incomplete_features)]
#![feature(async_fn_in_trait, type_alias_impl_trait)]
use crate::{
add::{Add as AddService, AddStub},
double::Double as DoubleService,
@@ -69,7 +66,6 @@ struct DoubleServer<Stub> {
impl<Stub> DoubleService for DoubleServer<Stub>
where
Stub: AddStub + Clone + Send + Sync + 'static,
for<'a> Stub::RespFut<'a>: Send,
{
async fn double(self, _: context::Context, x: i32) -> Result<i32, String> {
self.add_client

View File

@@ -4,7 +4,6 @@ use crate::{
client::{Channel, RpcError},
context,
};
use futures::prelude::*;
pub mod load_balance;
pub mod retry;
@@ -14,6 +13,7 @@ mod mock;
/// A connection to a remote service.
/// Calls the service with requests of type `Req` and receives responses of type `Resp`.
#[allow(async_fn_in_trait)]
pub trait Stub {
/// The service request type.
type Req;
@@ -21,36 +21,25 @@ pub trait Stub {
/// The service response type.
type Resp;
/// The type of the future returned by `Stub::call`.
type RespFut<'a>: Future<Output = Result<Self::Resp, RpcError>>
where
Self: 'a,
Self::Req: 'a,
Self::Resp: 'a;
/// Calls a remote service.
fn call<'a>(
&'a self,
async fn call(
&self,
ctx: context::Context,
request_name: &'static str,
request: Self::Req,
) -> Self::RespFut<'a>;
) -> Result<Self::Resp, RpcError>;
}
impl<Req, Resp> Stub for Channel<Req, Resp> {
type Req = Req;
type Resp = Resp;
type RespFut<'a> = RespFut<'a, Req, Resp>
where Self: 'a;
fn call<'a>(
&'a self,
async fn call(
&self,
ctx: context::Context,
request_name: &'static str,
request: Req,
) -> Self::RespFut<'a> {
Self::call(self, ctx, request_name, request)
) -> Result<Self::Resp, RpcError> {
Self::call(self, ctx, request_name, request).await
}
}
type RespFut<'a, Req: 'a, Resp: 'a> = impl Future<Output = Result<Resp, RpcError>> + 'a;

View File

@@ -10,7 +10,6 @@ mod round_robin {
context,
};
use cycle::AtomicCycle;
use futures::prelude::*;
impl<Stub> stub::Stub for RoundRobin<Stub>
where
@@ -18,22 +17,18 @@ mod round_robin {
{
type Req = Stub::Req;
type Resp = Stub::Resp;
type RespFut<'a> = RespFut<'a, Stub>
where Self: 'a;
fn call<'a>(
&'a self,
async fn call(
&self,
ctx: context::Context,
request_name: &'static str,
request: Self::Req,
) -> Self::RespFut<'a> {
Self::call(self, ctx, request_name, request)
) -> Result<Stub::Resp, RpcError> {
let next = self.stubs.next();
next.call(ctx, request_name, request).await
}
}
type RespFut<'a, Stub: stub::Stub + 'a> =
impl Future<Output = Result<Stub::Resp, RpcError>> + 'a;
/// A Stub that load-balances across backing stubs by round robin.
#[derive(Clone, Debug)]
pub struct RoundRobin<Stub> {
@@ -50,16 +45,6 @@ mod round_robin {
stubs: AtomicCycle::new(stubs),
}
}
async fn call(
&self,
ctx: context::Context,
request_name: &'static str,
request: Stub::Req,
) -> Result<Stub::Resp, RpcError> {
let next = self.stubs.next();
next.call(ctx, request_name, request).await
}
}
mod cycle {
@@ -118,36 +103,36 @@ mod consistent_hash {
client::{stub, RpcError},
context,
};
use futures::prelude::*;
use std::{
collections::hash_map::RandomState,
hash::{BuildHasher, Hash, Hasher},
num::TryFromIntError,
};
impl<Stub> stub::Stub for ConsistentHash<Stub>
impl<Stub, S> stub::Stub for ConsistentHash<Stub, S>
where
Stub: stub::Stub,
Stub::Req: Hash,
S: BuildHasher,
{
type Req = Stub::Req;
type Resp = Stub::Resp;
type RespFut<'a> = RespFut<'a, Stub>
where Self: 'a;
fn call<'a>(
&'a self,
async fn call(
&self,
ctx: context::Context,
request_name: &'static str,
request: Self::Req,
) -> Self::RespFut<'a> {
Self::call(self, ctx, request_name, request)
) -> Result<Stub::Resp, RpcError> {
let index = usize::try_from(self.hash_request(&request) % self.stubs_len).expect(
"invariant broken: stubs_len is not larger than a usize, \
so the hash modulo stubs_len should always fit in a usize",
);
let next = &self.stubs[index];
next.call(ctx, request_name, request).await
}
}
type RespFut<'a, Stub: stub::Stub + 'a> =
impl Future<Output = Result<Stub::Resp, RpcError>> + 'a;
/// A Stub that load-balances across backing stubs by round robin.
#[derive(Clone, Debug)]
pub struct ConsistentHash<Stub, S = RandomState> {
@@ -188,20 +173,6 @@ mod consistent_hash {
})
}
async fn call(
&self,
ctx: context::Context,
request_name: &'static str,
request: Stub::Req,
) -> Result<Stub::Resp, RpcError> {
let index = usize::try_from(self.hash_request(&request) % self.stubs_len).expect(
"invariant broken: stubs_len is not larger than a usize, \
so the hash modulo stubs_len should always fit in a usize",
);
let next = &self.stubs[index];
next.call(ctx, request_name, request).await
}
fn hash_request(&self, req: &Stub::Req) -> u64 {
let mut hasher = self.hasher.build_hasher();
req.hash(&mut hasher);
@@ -212,7 +183,10 @@ mod consistent_hash {
#[cfg(test)]
mod tests {
use super::ConsistentHash;
use crate::{client::stub::mock::Mock, context};
use crate::{
client::stub::{mock::Mock, Stub},
context,
};
use std::{
collections::HashMap,
hash::{BuildHasher, Hash, Hasher},
@@ -221,7 +195,7 @@ mod consistent_hash {
#[tokio::test]
async fn test() -> anyhow::Result<()> {
let stub = ConsistentHash::with_hasher(
let stub = ConsistentHash::<_, FakeHasherBuilder>::with_hasher(
vec![
// For easier reading of the assertions made in this test, each Mock's response
// value is equal to a hash value that should map to its index: 3 % 3 = 0, 1 %

View File

@@ -2,7 +2,6 @@ use crate::{
client::{stub::Stub, RpcError},
context, ServerError,
};
use futures::future;
use std::{collections::HashMap, hash::Hash, io};
/// A mock stub that returns user-specified responses.
@@ -29,26 +28,22 @@ where
{
type Req = Req;
type Resp = Resp;
type RespFut<'a> = future::Ready<Result<Resp, RpcError>>
where Self: 'a;
fn call<'a>(
&'a self,
async fn call(
&self,
_: context::Context,
_: &'static str,
request: Self::Req,
) -> Self::RespFut<'a> {
future::ready(
self.responses
.get(&request)
.cloned()
.map(Ok)
.unwrap_or_else(|| {
Err(RpcError::Server(ServerError {
kind: io::ErrorKind::NotFound,
detail: "mock (request, response) entry not found".into(),
}))
}),
)
) -> Result<Resp, RpcError> {
self.responses
.get(&request)
.cloned()
.map(Ok)
.unwrap_or_else(|| {
Err(RpcError::Server(ServerError {
kind: io::ErrorKind::NotFound,
detail: "mock (request, response) entry not found".into(),
}))
})
}
}

View File

@@ -4,7 +4,6 @@ use crate::{
client::{stub, RpcError},
context,
};
use futures::prelude::*;
use std::sync::Arc;
impl<Stub, Req, F> stub::Stub for Retry<F, Stub>
@@ -14,23 +13,29 @@ where
{
type Req = Req;
type Resp = Stub::Resp;
type RespFut<'a> = RespFut<'a, Stub, Self::Req, F>
where Self: 'a,
Self::Req: 'a;
fn call<'a>(
&'a self,
async fn call(
&self,
ctx: context::Context,
request_name: &'static str,
request: Self::Req,
) -> Self::RespFut<'a> {
Self::call(self, ctx, request_name, request)
) -> Result<Stub::Resp, RpcError> {
let request = Arc::new(request);
for i in 1.. {
let result = self
.stub
.call(ctx, request_name, Arc::clone(&request))
.await;
if (self.should_retry)(&result, i) {
tracing::trace!("Retrying on attempt {i}");
continue;
}
return result;
}
unreachable!("Wow, that was a lot of attempts!");
}
}
type RespFut<'a, Stub: stub::Stub + 'a, Req: 'a, F: 'a> =
impl Future<Output = Result<Stub::Resp, RpcError>> + 'a;
/// A Stub that retries requests based on response contents.
/// Note: to use this stub with Serde serialization, the "rc" feature of Serde needs to be enabled.
#[derive(Clone, Debug)]
@@ -48,28 +53,4 @@ where
pub fn new(stub: Stub, should_retry: F) -> Self {
Self { stub, should_retry }
}
async fn call<'a, 'b>(
&'a self,
ctx: context::Context,
request_name: &'static str,
request: Req,
) -> Result<Stub::Resp, RpcError>
where
Req: 'b,
{
let request = Arc::new(request);
for i in 1.. {
let result = self
.stub
.call(ctx, request_name, Arc::clone(&request))
.await;
if (self.should_retry)(&result, i) {
tracing::trace!("Retrying on attempt {i}");
continue;
}
return result;
}
unreachable!("Wow, that was a lot of attempts!");
}
}

View File

@@ -80,8 +80,6 @@
//! First, let's set up the dependencies and service definition.
//!
//! ```rust
//! #![allow(incomplete_features)]
//! #![feature(async_fn_in_trait)]
//! # extern crate futures;
//!
//! use futures::{
@@ -106,8 +104,6 @@
//! implement it for our Server struct.
//!
//! ```rust
//! # #![allow(incomplete_features)]
//! # #![feature(async_fn_in_trait)]
//! # extern crate futures;
//! # use futures::{
//! # future::{self, Ready},
@@ -143,8 +139,6 @@
//! available behind the `tcp` feature.
//!
//! ```rust
//! # #![allow(incomplete_features)]
//! # #![feature(async_fn_in_trait)]
//! # extern crate futures;
//! # use futures::{
//! # future::{self, Ready},
@@ -206,14 +200,6 @@
//! Use `cargo doc` as you normally would to see the documentation created for all
//! items expanded by a `service!` invocation.
// For async_fn_in_trait
#![allow(incomplete_features)]
#![feature(
iter_intersperse,
type_alias_impl_trait,
async_fn_in_trait,
return_position_impl_trait_in_trait
)]
#![deny(missing_docs)]
#![allow(clippy::type_complexity)]
#![cfg_attr(docsrs, feature(doc_cfg))]
@@ -239,7 +225,6 @@ pub use tarpc_plugins::derive_serde;
/// Rpc methods are specified, mirroring trait syntax:
///
/// ```
/// #![feature(async_fn_in_trait)]
/// #[tarpc::service]
/// trait Service {
/// /// Say hello

View File

@@ -67,6 +67,7 @@ impl Config {
}
/// Equivalent to a `FnOnce(Req) -> impl Future<Output = Resp>`.
#[allow(async_fn_in_trait)]
pub trait Serve {
/// Type of request.
type Req;
@@ -186,24 +187,19 @@ pub trait Serve {
/// struct PrintLatency(Instant);
///
/// impl<Req> BeforeRequest<Req> for PrintLatency {
/// type Fut<'a> = future::Ready<Result<(), ServerError>> where Self: 'a, Req: 'a;
///
/// fn before<'a>(&'a mut self, _: &'a mut context::Context, _: &'a Req) -> Self::Fut<'a> {
/// async fn before(&mut self, _: &mut context::Context, _: &Req) -> Result<(), ServerError> {
/// self.0 = Instant::now();
/// future::ready(Ok(()))
/// Ok(())
/// }
/// }
///
/// impl<Resp> AfterRequest<Resp> for PrintLatency {
/// type Fut<'a> = future::Ready<()> where Self:'a, Resp:'a;
///
/// fn after<'a>(
/// &'a mut self,
/// _: &'a mut context::Context,
/// _: &'a mut Result<Resp, ServerError>,
/// ) -> Self::Fut<'a> {
/// async fn after(
/// &mut self,
/// _: &mut context::Context,
/// _: &mut Result<Resp, ServerError>,
/// ) {
/// tracing::info!("Elapsed: {:?}", self.0.elapsed());
/// future::ready(())
/// }
/// }
///
@@ -1052,8 +1048,8 @@ impl<Req, Res> InFlightRequest<Req, Res> {
fn print_err(e: &(dyn Error + 'static)) -> String {
anyhow::Chain::new(e)
.map(|e| e.to_string())
.intersperse(": ".into())
.collect::<String>()
.collect::<Vec<_>>()
.join(": ")
}
impl<C> Stream for Requests<C>
@@ -1191,18 +1187,14 @@ mod tests {
#[tokio::test]
async fn serve_before_mutates_context() -> anyhow::Result<()> {
struct SetDeadline(SystemTime);
type SetDeadlineFut<'a, Req: 'a> = impl Future<Output = Result<(), ServerError>> + 'a;
impl<Req> BeforeRequest<Req> for SetDeadline {
type Fut<'a> = SetDeadlineFut<'a, Req> where Self: 'a, Req: 'a;
fn before<'a>(
&'a mut self,
ctx: &'a mut context::Context,
_: &'a Req,
) -> Self::Fut<'a> {
async move {
ctx.deadline = self.0;
Ok(())
}
async fn before(
&mut self,
ctx: &mut context::Context,
_: &Req,
) -> Result<(), ServerError> {
ctx.deadline = self.0;
Ok(())
}
}
@@ -1234,27 +1226,19 @@ mod tests {
}
}
}
type StartFut<'a, Req: 'a> = impl Future<Output = Result<(), ServerError>> + 'a;
type EndFut<'a, Resp: 'a> = impl Future<Output = ()> + 'a;
impl<Req> BeforeRequest<Req> for PrintLatency {
type Fut<'a> = StartFut<'a, Req> where Self: 'a, Req: 'a;
fn before<'a>(&'a mut self, _: &'a mut context::Context, _: &'a Req) -> Self::Fut<'a> {
async move {
self.start = Instant::now();
Ok(())
}
async fn before(
&mut self,
_: &mut context::Context,
_: &Req,
) -> Result<(), ServerError> {
self.start = Instant::now();
Ok(())
}
}
impl<Resp> AfterRequest<Resp> for PrintLatency {
type Fut<'a> = EndFut<'a, Resp> where Self: 'a, Resp: 'a;
fn after<'a>(
&'a mut self,
_: &'a mut context::Context,
_: &'a mut Result<Resp, ServerError>,
) -> Self::Fut<'a> {
async move {
tracing::info!("Elapsed: {:?}", self.start.elapsed());
}
async fn after(&mut self, _: &mut context::Context, _: &mut Result<Resp, ServerError>) {
tracing::info!("Elapsed: {:?}", self.start.elapsed());
}
}

View File

@@ -10,21 +10,12 @@ use crate::{context, server::Serve, ServerError};
use futures::prelude::*;
/// A hook that runs after request execution.
#[allow(async_fn_in_trait)]
pub trait AfterRequest<Resp> {
/// The type of future returned by the hook.
type Fut<'a>: Future<Output = ()>
where
Self: 'a,
Resp: 'a;
/// The function that is called after request execution.
///
/// The hook can modify the request context and the response.
fn after<'a>(
&'a mut self,
ctx: &'a mut context::Context,
resp: &'a mut Result<Resp, ServerError>,
) -> Self::Fut<'a>;
async fn after(&mut self, ctx: &mut context::Context, resp: &mut Result<Resp, ServerError>);
}
impl<F, Fut, Resp> AfterRequest<Resp> for F
@@ -32,14 +23,8 @@ where
F: FnMut(&mut context::Context, &mut Result<Resp, ServerError>) -> Fut,
Fut: Future<Output = ()>,
{
type Fut<'a> = Fut where Self: 'a, Resp: 'a;
fn after<'a>(
&'a mut self,
ctx: &'a mut context::Context,
resp: &'a mut Result<Resp, ServerError>,
) -> Self::Fut<'a> {
self(ctx, resp)
async fn after(&mut self, ctx: &mut context::Context, resp: &mut Result<Resp, ServerError>) {
self(ctx, resp).await
}
}

View File

@@ -10,13 +10,8 @@ use crate::{context, server::Serve, ServerError};
use futures::prelude::*;
/// A hook that runs before request execution.
#[allow(async_fn_in_trait)]
pub trait BeforeRequest<Req> {
/// The type of future returned by the hook.
type Fut<'a>: Future<Output = Result<(), ServerError>>
where
Self: 'a,
Req: 'a;
/// The function that is called before request execution.
///
/// If this function returns an error, the request will not be executed and the error will be
@@ -24,7 +19,7 @@ pub trait BeforeRequest<Req> {
///
/// This function can also modify the request context. This could be used, for example, to
/// enforce a maximum deadline on all requests.
fn before<'a>(&'a mut self, ctx: &'a mut context::Context, req: &'a Req) -> Self::Fut<'a>;
async fn before(&mut self, ctx: &mut context::Context, req: &Req) -> Result<(), ServerError>;
}
impl<F, Fut, Req> BeforeRequest<Req> for F
@@ -32,10 +27,8 @@ where
F: FnMut(&mut context::Context, &Req) -> Fut,
Fut: Future<Output = Result<(), ServerError>>,
{
type Fut<'a> = Fut where Self: 'a, Req: 'a;
fn before<'a>(&'a mut self, ctx: &'a mut context::Context, req: &'a Req) -> Self::Fut<'a> {
self(ctx, req)
async fn before(&mut self, ctx: &mut context::Context, req: &Req) -> Result<(), ServerError> {
self(ctx, req).await
}
}

View File

@@ -1,6 +1,3 @@
#![allow(incomplete_features)]
#![feature(async_fn_in_trait)]
use tarpc::client;
#[tarpc::service]

View File

@@ -1,15 +1,15 @@
error: unused `RequestDispatch` that must be used
--> tests/compile_fail/must_use_request_dispatch.rs:16:9
--> tests/compile_fail/must_use_request_dispatch.rs:13:9
|
16 | WorldClient::new(client::Config::default(), client_transport).dispatch;
13 | WorldClient::new(client::Config::default(), client_transport).dispatch;
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
note: the lint level is defined here
--> tests/compile_fail/must_use_request_dispatch.rs:14:12
--> tests/compile_fail/must_use_request_dispatch.rs:11:12
|
14 | #[deny(unused_must_use)]
11 | #[deny(unused_must_use)]
| ^^^^^^^^^^^^^^^
help: use `let _ = ...` to ignore the resulting value
|
16 | let _ = WorldClient::new(client::Config::default(), client_transport).dispatch;
13 | let _ = WorldClient::new(client::Config::default(), client_transport).dispatch;
| +++++++

View File

@@ -1,6 +1,3 @@
#![allow(incomplete_features)]
#![feature(async_fn_in_trait)]
use futures::prelude::*;
use tarpc::serde_transport;
use tarpc::{

View File

@@ -1,6 +1,3 @@
#![allow(incomplete_features)]
#![feature(async_fn_in_trait)]
use assert_matches::assert_matches;
use futures::{
future::{join_all, ready},