mirror of
https://github.com/OMGeeky/tarpc.git
synced 2026-02-23 15:49:54 +01:00
Compare commits
43 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d27f341bde | ||
|
|
2264ebecfc | ||
|
|
3207affb4a | ||
|
|
0602afd50c | ||
|
|
4343e12217 | ||
|
|
7fda862fb8 | ||
|
|
aa7b875b1a | ||
|
|
54d6e0e3b6 | ||
|
|
bea3b442aa | ||
|
|
954a2502e7 | ||
|
|
e3f34917c5 | ||
|
|
f65dd05949 | ||
|
|
240c436b34 | ||
|
|
c9803688cc | ||
|
|
4987094483 | ||
|
|
ff55080193 | ||
|
|
258193c932 | ||
|
|
67823ef5de | ||
|
|
a671457243 | ||
|
|
cf654549da | ||
|
|
6a01e32a2d | ||
|
|
e6597fab03 | ||
|
|
ebd245a93d | ||
|
|
3ebc3b5845 | ||
|
|
0e5973109d | ||
|
|
5f02d7383a | ||
|
|
2bae148529 | ||
|
|
42a2e03aab | ||
|
|
b566d0c646 | ||
|
|
b359f16767 | ||
|
|
f8681ab134 | ||
|
|
7e521768ab | ||
|
|
e9b1e7d101 | ||
|
|
f0322fb892 | ||
|
|
617daebb88 | ||
|
|
a11d4fff58 | ||
|
|
bf42a04d83 | ||
|
|
06528d6953 | ||
|
|
9f00395746 | ||
|
|
e0674cd57f | ||
|
|
7e49bd9ee7 | ||
|
|
8a1baa9c4e | ||
|
|
31c713d188 |
40
README.md
40
README.md
@@ -1,9 +1,21 @@
|
||||
[](https://github.com/google/tarpc/actions?query=workflow%3A%22Continuous+integration%22)
|
||||
[](https://crates.io/crates/tarpc)
|
||||
[](https://discordapp.com/channels/647529123996237854)
|
||||
[![Crates.io][crates-badge]][crates-url]
|
||||
[![MIT licensed][mit-badge]][mit-url]
|
||||
[![Build status][gh-actions-badge]][gh-actions-url]
|
||||
[![Discord chat][discord-badge]][discord-url]
|
||||
|
||||
[crates-badge]: https://img.shields.io/crates/v/tarpc.svg
|
||||
[crates-url]: https://crates.io/crates/tarpc
|
||||
[mit-badge]: https://img.shields.io/badge/license-MIT-blue.svg
|
||||
[mit-url]: LICENSE
|
||||
[gh-actions-badge]: https://github.com/google/tarpc/workflows/Continuous%20integration/badge.svg
|
||||
[gh-actions-url]: https://github.com/google/tarpc/actions?query=workflow%3A%22Continuous+integration%22
|
||||
[discord-badge]: https://img.shields.io/discord/647529123996237854.svg?logo=discord&style=flat-square
|
||||
[discord-url]: https://discord.gg/gXwpdSt
|
||||
|
||||
# tarpc
|
||||
|
||||
<!-- cargo-sync-readme start -->
|
||||
|
||||
*Disclaimer*: This is not an official Google product.
|
||||
|
||||
tarpc is an RPC framework for rust with a focus on ease of use. Defining a
|
||||
@@ -12,7 +24,7 @@ writing a server is taken care of for you.
|
||||
|
||||
[Documentation](https://docs.rs/crate/tarpc/)
|
||||
|
||||
### What is an RPC framework?
|
||||
## What is an RPC framework?
|
||||
"RPC" stands for "Remote Procedure Call," a function call where the work of
|
||||
producing the return value is being done somewhere else. When an rpc function is
|
||||
invoked, behind the scenes the function contacts some other process somewhere
|
||||
@@ -30,7 +42,7 @@ process, and no context switching between different languages.
|
||||
Some other features of tarpc:
|
||||
- Pluggable transport: any type impling `Stream<Item = Request> + Sink<Response>` can be
|
||||
used as a transport to connect the client and server.
|
||||
- `Send` optional: if the transport doesn't require it, neither does tarpc!
|
||||
- `Send + 'static` optional: if the transport doesn't require it, neither does tarpc!
|
||||
- Cascading cancellation: dropping a request will send a cancellation message to the server.
|
||||
The server will cease any unfinished work on the request, subsequently cancelling any of its
|
||||
own requests, repeating for the entire chain of transitive dependencies.
|
||||
@@ -43,18 +55,18 @@ Some other features of tarpc:
|
||||
responses `Serialize + Deserialize`. It's entirely optional, though: in-memory transports can
|
||||
be used, as well, so the price of serialization doesn't have to be paid when it's not needed.
|
||||
|
||||
### Usage
|
||||
## Usage
|
||||
Add to your `Cargo.toml` dependencies:
|
||||
|
||||
```toml
|
||||
tarpc = { version = "0.18.0", features = ["full"] }
|
||||
tarpc = "0.22.0"
|
||||
```
|
||||
|
||||
The `tarpc::service` attribute expands to a collection of items that form an rpc service.
|
||||
These generated types make it easy and ergonomic to write servers with less boilerplate.
|
||||
Simply implement the generated service trait, and you're off to the races!
|
||||
|
||||
### Example
|
||||
## Example
|
||||
|
||||
For this example, in addition to tarpc, also add two other dependencies to
|
||||
your `Cargo.toml`:
|
||||
@@ -71,6 +83,7 @@ For a more real-world example, see [example-service](example-service).
|
||||
First, let's set up the dependencies and service definition.
|
||||
|
||||
```rust
|
||||
|
||||
use futures::{
|
||||
future::{self, Ready},
|
||||
prelude::*,
|
||||
@@ -112,10 +125,9 @@ impl World for HelloServer {
|
||||
```
|
||||
|
||||
Lastly let's write our `main` that will start the server. While this example uses an
|
||||
[in-process
|
||||
channel](https://docs.rs/tarpc/0.18.0/tarpc/transport/channel/struct.UnboundedChannel.html),
|
||||
tarpc also ships bincode and JSON
|
||||
tokio-net based TCP transports that are generic over all serializable types.
|
||||
[in-process channel](rpc::transport::channel), tarpc also ships a generic [`serde_transport`]
|
||||
behind the `serde-transport` feature, with additional [TCP](serde_transport::tcp) functionality
|
||||
available behind the `tcp` feature.
|
||||
|
||||
```rust
|
||||
#[tokio::main]
|
||||
@@ -145,9 +157,11 @@ async fn main() -> io::Result<()> {
|
||||
}
|
||||
```
|
||||
|
||||
### Service Documentation
|
||||
## Service Documentation
|
||||
|
||||
Use `cargo doc` as you normally would to see the documentation created for all
|
||||
items expanded by a `service!` invocation.
|
||||
|
||||
<!-- cargo-sync-readme end -->
|
||||
|
||||
License: MIT
|
||||
|
||||
100
RELEASES.md
100
RELEASES.md
@@ -1,3 +1,101 @@
|
||||
## 0.22.0 (2020-08-02)
|
||||
|
||||
This release adds some flexibility and consistency to `serde_transport`, with one new feature and
|
||||
one small breaking change.
|
||||
|
||||
### New Features
|
||||
|
||||
`serde_transport::tcp` now exposes framing configuration on `connect()` and `listen()`. This is
|
||||
useful if, for instance, you want to send requests or responses that are larger than the maximum
|
||||
payload allowed by default:
|
||||
|
||||
```rust
|
||||
let mut transport = tarpc::serde_transport::tcp::connect(server_addr, Json::default);
|
||||
transport.config_mut().max_frame_length(4294967296);
|
||||
let mut client = MyClient::new(client::Config::default(), transport.await?).spawn()?;
|
||||
```
|
||||
|
||||
### Breaking Changes
|
||||
|
||||
The codec argument to `serde_transport::tcp::connect` changed from a Codec to impl Fn() -> Codec,
|
||||
to be consistent with `serde_transport::tcp::listen`. While only one Codec is needed, more than one
|
||||
person has been tripped up by the inconsistency between `connect` and `listen`. Unfortunately, the
|
||||
compiler errors are not much help in this case, so it was decided to simply do the more intuitive
|
||||
thing so that the compiler doesn't need to step in in the first place.
|
||||
|
||||
|
||||
## 0.21.1 (2020-08-02)
|
||||
|
||||
### New Features
|
||||
|
||||
#### #[tarpc::server] diagnostics
|
||||
|
||||
When a service impl uses #[tarpc::server], only `async fn`s are re-written. This can lead to
|
||||
confusing compiler errors about missing associated types:
|
||||
|
||||
```
|
||||
error: not all trait items implemented, missing: `HelloFut`
|
||||
--> $DIR/tarpc_server_missing_async.rs:9:1
|
||||
|
|
||||
9 | impl World for HelloServer {
|
||||
| ^^^^
|
||||
```
|
||||
|
||||
The proc macro now provides better diagnostics for this case:
|
||||
|
||||
```
|
||||
error: not all trait items implemented, missing: `HelloFut`
|
||||
--> $DIR/tarpc_server_missing_async.rs:9:1
|
||||
|
|
||||
9 | impl World for HelloServer {
|
||||
| ^^^^
|
||||
|
||||
error: hint: `#[tarpc::server]` only rewrites async fns, and `fn hello` is not async
|
||||
--> $DIR/tarpc_server_missing_async.rs:10:5
|
||||
|
|
||||
10 | fn hello(name: String) -> String {
|
||||
| ^^
|
||||
```
|
||||
|
||||
### Bug Fixes
|
||||
|
||||
#### Fixed client hanging when server shuts down
|
||||
|
||||
Previously, clients would ignore when the read half of the transport was closed, continuing to
|
||||
write requests. This didn't make much sense, because without the ability to receive responses,
|
||||
clients have no way to know if requests were actually processed by the server. It basically just
|
||||
led to clients that would hang for a few seconds before shutting down. This has now been
|
||||
corrected: clients will immediately shut down when the read-half of the transport is closed.
|
||||
|
||||
#### More docs.rs documentation
|
||||
|
||||
Previously, docs.rs only documented items enabled by default, notably leaving out documentation
|
||||
for tokio and serde features. This has now been corrected: docs.rs should have documentation
|
||||
for all optional features.
|
||||
|
||||
## 0.21.0 (2020-06-26)
|
||||
|
||||
### New Features
|
||||
|
||||
A new proc macro, `#[tarpc::server]` was added! This enables service impls to elide the boilerplate
|
||||
of specifying associated types for each RPC. With the ubiquity of async-await, most code won't have
|
||||
nameable futures and will just be boxing the return type anyway. This macro does that for you.
|
||||
|
||||
### Breaking Changes
|
||||
|
||||
- Enums had _non_exhaustive fields replaced with the #[non_exhaustive] attribute.
|
||||
|
||||
### Bug Fixes
|
||||
|
||||
- https://github.com/google/tarpc/issues/304
|
||||
|
||||
A race condition in code that limits number of connections per client caused occasional panics.
|
||||
|
||||
- https://github.com/google/tarpc/pull/295
|
||||
|
||||
Made request timeouts account for time spent in the outbound buffer. Previously, a large outbound
|
||||
queue would lead to requests not timing out correctly.
|
||||
|
||||
## 0.20.0 (2019-12-11)
|
||||
|
||||
### Breaking Changes
|
||||
@@ -10,7 +108,7 @@
|
||||
|
||||
## 0.13.0 (2018-10-16)
|
||||
|
||||
### Breaking Changes
|
||||
### Breaking Changes
|
||||
|
||||
Version 0.13 marks a significant departure from previous versions of tarpc. The
|
||||
API has changed significantly. The tokio-proto crate has been torn out and
|
||||
|
||||
@@ -16,9 +16,10 @@ description = "An example server built on tarpc."
|
||||
clap = "2.0"
|
||||
futures = "0.3"
|
||||
serde = { version = "1.0" }
|
||||
tarpc = { version = "0.20", path = "../tarpc", features = ["full"] }
|
||||
tarpc = { version = "0.22", path = "../tarpc", features = ["full"] }
|
||||
tokio = { version = "0.2", features = ["full"] }
|
||||
tokio-serde = { version = "0.6", features = ["json"] }
|
||||
tokio-util = { version = "0.3", features = ["codec"] }
|
||||
env_logger = "0.6"
|
||||
|
||||
[lib]
|
||||
|
||||
@@ -11,6 +11,8 @@ use tokio_serde::formats::Json;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> io::Result<()> {
|
||||
env_logger::init();
|
||||
|
||||
let flags = App::new("Hello Client")
|
||||
.version("0.1")
|
||||
.author("Tim <tikue@google.com>")
|
||||
@@ -41,11 +43,13 @@ async fn main() -> io::Result<()> {
|
||||
|
||||
let name = flags.value_of("name").unwrap().into();
|
||||
|
||||
let transport = tarpc::serde_transport::tcp::connect(server_addr, Json::default()).await?;
|
||||
let mut transport = tarpc::serde_transport::tcp::connect(server_addr, Json::default);
|
||||
transport.config_mut().max_frame_length(4294967296);
|
||||
|
||||
// WorldClient is generated by the service attribute. It has a constructor `new` that takes a
|
||||
// config and any Transport as input.
|
||||
let mut client = service::WorldClient::new(client::Config::default(), transport).spawn()?;
|
||||
let mut client =
|
||||
service::WorldClient::new(client::Config::default(), transport.await?).spawn()?;
|
||||
|
||||
// The client has an RPC method for each RPC defined in the annotated trait. It takes the same
|
||||
// args as defined, with the addition of a Context, which is always the first arg. The Context
|
||||
|
||||
@@ -5,10 +5,7 @@
|
||||
// https://opensource.org/licenses/MIT.
|
||||
|
||||
use clap::{App, Arg};
|
||||
use futures::{
|
||||
future::{self, Ready},
|
||||
prelude::*,
|
||||
};
|
||||
use futures::{future, prelude::*};
|
||||
use service::World;
|
||||
use std::{
|
||||
io,
|
||||
@@ -25,17 +22,10 @@ use tokio_serde::formats::Json;
|
||||
#[derive(Clone)]
|
||||
struct HelloServer(SocketAddr);
|
||||
|
||||
#[tarpc::server]
|
||||
impl World for HelloServer {
|
||||
// Each defined rpc generates two items in the trait, a fn that serves the RPC, and
|
||||
// an associated type representing the future output by the fn.
|
||||
|
||||
type HelloFut = Ready<String>;
|
||||
|
||||
fn hello(self, _: context::Context, name: String) -> Self::HelloFut {
|
||||
future::ready(format!(
|
||||
"Hello, {}! You are connected from {:?}.",
|
||||
name, self.0
|
||||
))
|
||||
async fn hello(self, _: context::Context, name: String) -> String {
|
||||
format!("Hello, {}! You are connected from {:?}.", name, self.0)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -67,8 +57,9 @@ async fn main() -> io::Result<()> {
|
||||
|
||||
// JSON transport is provided by the json_transport tarpc module. It makes it easy
|
||||
// to start up a serde-powered json serialization strategy over TCP.
|
||||
tarpc::serde_transport::tcp::listen(&server_addr, Json::default)
|
||||
.await?
|
||||
let mut listener = tarpc::serde_transport::tcp::listen(&server_addr, Json::default).await?;
|
||||
listener.config_mut().max_frame_length(4294967296);
|
||||
listener
|
||||
// Ignore accept errors.
|
||||
.filter_map(|r| future::ready(r.ok()))
|
||||
.map(server::BaseChannel::with_defaults)
|
||||
|
||||
@@ -93,7 +93,7 @@ diff=""
|
||||
for file in $(git diff --name-only --cached);
|
||||
do
|
||||
if [ ${file: -3} == ".rs" ]; then
|
||||
diff="$diff$(cargo fmt -- --skip-children --write-mode=diff $file)"
|
||||
diff="$diff$(cargo fmt -- --unstable-features --skip-children --check $file)"
|
||||
fi
|
||||
done
|
||||
if grep --quiet "^[-+]" <<< "$diff"; then
|
||||
|
||||
@@ -89,12 +89,12 @@ if [ "$?" == 0 ]; then
|
||||
exit 1
|
||||
fi
|
||||
|
||||
try_run "Building ... " cargo build --color=always
|
||||
try_run "Testing ... " cargo test --color=always
|
||||
try_run "Testing with all features enabled ... " cargo test --all-features --color=always
|
||||
for EXAMPLE in $(cargo run --example 2>&1 | grep ' ' | awk '{print $1}')
|
||||
try_run "Building ... " cargo +stable build --color=always
|
||||
try_run "Testing ... " cargo +stable test --color=always
|
||||
try_run "Testing with all features enabled ... " cargo +stable test --all-features --color=always
|
||||
for EXAMPLE in $(cargo +stable run --example 2>&1 | grep ' ' | awk '{print $1}')
|
||||
do
|
||||
try_run "Running example \"$EXAMPLE\" ... " cargo run --example $EXAMPLE
|
||||
try_run "Running example \"$EXAMPLE\" ... " cargo +stable run --example $EXAMPLE
|
||||
done
|
||||
|
||||
fi
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "tarpc-plugins"
|
||||
version = "0.7.0"
|
||||
version = "0.8.0"
|
||||
authors = ["Adam Wright <adam.austin.wright@gmail.com>", "Tim Kuehn <timothy.j.kuehn@gmail.com>"]
|
||||
edition = "2018"
|
||||
license = "MIT"
|
||||
@@ -30,3 +30,4 @@ proc-macro = true
|
||||
futures = "0.3"
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
tarpc = { path = "../tarpc" }
|
||||
assert-type-eq = "0.1.0"
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
144
plugins/tests/server.rs
Normal file
144
plugins/tests/server.rs
Normal file
@@ -0,0 +1,144 @@
|
||||
use assert_type_eq::assert_type_eq;
|
||||
use futures::Future;
|
||||
use std::pin::Pin;
|
||||
use tarpc::context;
|
||||
|
||||
// these need to be out here rather than inside the function so that the
|
||||
// assert_type_eq macro can pick them up.
|
||||
#[tarpc::service]
|
||||
trait Foo {
|
||||
async fn two_part(s: String, i: i32) -> (String, i32);
|
||||
async fn bar(s: String) -> String;
|
||||
async fn baz();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn type_generation_works() {
|
||||
#[tarpc::server]
|
||||
impl Foo for () {
|
||||
async fn two_part(self, _: context::Context, s: String, i: i32) -> (String, i32) {
|
||||
(s, i)
|
||||
}
|
||||
|
||||
async fn bar(self, _: context::Context, s: String) -> String {
|
||||
s
|
||||
}
|
||||
|
||||
async fn baz(self, _: context::Context) {}
|
||||
}
|
||||
|
||||
// the assert_type_eq macro can only be used once per block.
|
||||
{
|
||||
assert_type_eq!(
|
||||
<() as Foo>::TwoPartFut,
|
||||
Pin<Box<dyn Future<Output = (String, i32)> + Send>>
|
||||
);
|
||||
}
|
||||
{
|
||||
assert_type_eq!(
|
||||
<() as Foo>::BarFut,
|
||||
Pin<Box<dyn Future<Output = String> + Send>>
|
||||
);
|
||||
}
|
||||
{
|
||||
assert_type_eq!(
|
||||
<() as Foo>::BazFut,
|
||||
Pin<Box<dyn Future<Output = ()> + Send>>
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(non_camel_case_types)]
|
||||
#[test]
|
||||
fn raw_idents_work() {
|
||||
type r#yield = String;
|
||||
|
||||
#[tarpc::service]
|
||||
trait r#trait {
|
||||
async fn r#await(r#struct: r#yield, r#enum: i32) -> (r#yield, i32);
|
||||
async fn r#fn(r#impl: r#yield) -> r#yield;
|
||||
async fn r#async();
|
||||
}
|
||||
|
||||
#[tarpc::server]
|
||||
impl r#trait for () {
|
||||
async fn r#await(
|
||||
self,
|
||||
_: context::Context,
|
||||
r#struct: r#yield,
|
||||
r#enum: i32,
|
||||
) -> (r#yield, i32) {
|
||||
(r#struct, r#enum)
|
||||
}
|
||||
|
||||
async fn r#fn(self, _: context::Context, r#impl: r#yield) -> r#yield {
|
||||
r#impl
|
||||
}
|
||||
|
||||
async fn r#async(self, _: context::Context) {}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn syntax() {
|
||||
#[tarpc::service]
|
||||
trait Syntax {
|
||||
#[deny(warnings)]
|
||||
#[allow(non_snake_case)]
|
||||
async fn TestCamelCaseDoesntConflict();
|
||||
async fn hello() -> String;
|
||||
#[doc = "attr"]
|
||||
async fn attr(s: String) -> String;
|
||||
async fn no_args_no_return();
|
||||
async fn no_args() -> ();
|
||||
async fn one_arg(one: String) -> i32;
|
||||
async fn two_args_no_return(one: String, two: u64);
|
||||
async fn two_args(one: String, two: u64) -> String;
|
||||
async fn no_args_ret_error() -> i32;
|
||||
async fn one_arg_ret_error(one: String) -> String;
|
||||
async fn no_arg_implicit_return_error();
|
||||
#[doc = "attr"]
|
||||
async fn one_arg_implicit_return_error(one: String);
|
||||
}
|
||||
|
||||
#[tarpc::server]
|
||||
impl Syntax for () {
|
||||
#[deny(warnings)]
|
||||
#[allow(non_snake_case)]
|
||||
async fn TestCamelCaseDoesntConflict(self, _: context::Context) {}
|
||||
|
||||
async fn hello(self, _: context::Context) -> String {
|
||||
String::new()
|
||||
}
|
||||
|
||||
async fn attr(self, _: context::Context, _s: String) -> String {
|
||||
String::new()
|
||||
}
|
||||
|
||||
async fn no_args_no_return(self, _: context::Context) {}
|
||||
|
||||
async fn no_args(self, _: context::Context) -> () {}
|
||||
|
||||
async fn one_arg(self, _: context::Context, _one: String) -> i32 {
|
||||
0
|
||||
}
|
||||
|
||||
async fn two_args_no_return(self, _: context::Context, _one: String, _two: u64) {}
|
||||
|
||||
async fn two_args(self, _: context::Context, _one: String, _two: u64) -> String {
|
||||
String::new()
|
||||
}
|
||||
|
||||
async fn no_args_ret_error(self, _: context::Context) -> i32 {
|
||||
0
|
||||
}
|
||||
|
||||
async fn one_arg_ret_error(self, _: context::Context, _one: String) -> String {
|
||||
String::new()
|
||||
}
|
||||
|
||||
async fn no_arg_implicit_return_error(self, _: context::Context) {}
|
||||
|
||||
async fn one_arg_implicit_return_error(self, _: context::Context, _one: String) {}
|
||||
}
|
||||
}
|
||||
@@ -29,6 +29,38 @@ fn att_service_trait() {
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(non_camel_case_types)]
|
||||
#[test]
|
||||
fn raw_idents() {
|
||||
use futures::future::{ready, Ready};
|
||||
|
||||
type r#yield = String;
|
||||
|
||||
#[tarpc::service]
|
||||
trait r#trait {
|
||||
async fn r#await(r#struct: r#yield, r#enum: i32) -> (r#yield, i32);
|
||||
async fn r#fn(r#impl: r#yield) -> r#yield;
|
||||
async fn r#async();
|
||||
}
|
||||
|
||||
impl r#trait for () {
|
||||
type AwaitFut = Ready<(r#yield, i32)>;
|
||||
fn r#await(self, _: context::Context, r#struct: r#yield, r#enum: i32) -> Self::AwaitFut {
|
||||
ready((r#struct, r#enum))
|
||||
}
|
||||
|
||||
type FnFut = Ready<r#yield>;
|
||||
fn r#fn(self, _: context::Context, r#impl: r#yield) -> Self::FnFut {
|
||||
ready(r#impl)
|
||||
}
|
||||
|
||||
type AsyncFut = Ready<()>;
|
||||
fn r#async(self, _: context::Context) -> Self::AsyncFut {
|
||||
ready(())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn syntax() {
|
||||
#[tarpc::service]
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "tarpc"
|
||||
version = "0.20.0"
|
||||
version = "0.22.0"
|
||||
authors = ["Adam Wright <adam.austin.wright@gmail.com>", "Tim Kuehn <timothy.j.kuehn@gmail.com>"]
|
||||
edition = "2018"
|
||||
license = "MIT"
|
||||
@@ -26,29 +26,38 @@ full = ["serde1", "tokio1", "serde-transport", "tcp"]
|
||||
travis-ci = { repository = "google/tarpc" }
|
||||
|
||||
[dependencies]
|
||||
anyhow = "1.0"
|
||||
fnv = "1.0"
|
||||
futures = "0.3"
|
||||
humantime = "1.0"
|
||||
log = "0.4"
|
||||
pin-project = "0.4"
|
||||
raii-counter = "0.2"
|
||||
pin-project = "0.4.17"
|
||||
rand = "0.7"
|
||||
tokio = { version = "0.2", features = ["time"] }
|
||||
serde = { optional = true, version = "1.0", features = ["derive"] }
|
||||
tokio-util = { optional = true, version = "0.2" }
|
||||
tarpc-plugins = { path = "../plugins", version = "0.7" }
|
||||
static_assertions = "1.1.0"
|
||||
tarpc-plugins = { path = "../plugins", version = "0.8" }
|
||||
tokio-util = { optional = true, version = "0.3" }
|
||||
tokio-serde = { optional = true, version = "0.6" }
|
||||
|
||||
[dev-dependencies]
|
||||
assert_matches = "1.0"
|
||||
bincode = "1.3"
|
||||
bytes = { version = "0.5", features = ["serde"] }
|
||||
env_logger = "0.6"
|
||||
flate2 = "1.0.16"
|
||||
futures = "0.3"
|
||||
humantime = "1.0"
|
||||
log = "0.4"
|
||||
pin-utils = "0.1.0-alpha"
|
||||
serde_bytes = "0.11"
|
||||
tokio = { version = "0.2", features = ["full"] }
|
||||
tokio-serde = { version = "0.6", features = ["json"] }
|
||||
tokio-serde = { version = "0.6", features = ["json", "bincode"] }
|
||||
trybuild = "1.0"
|
||||
|
||||
[package.metadata.docs.rs]
|
||||
all-features = true
|
||||
rustdoc-args = ["--cfg", "docsrs"]
|
||||
|
||||
[[example]]
|
||||
name = "server_calling_server"
|
||||
@@ -61,4 +70,3 @@ required-features = ["full"]
|
||||
[[example]]
|
||||
name = "pubsub"
|
||||
required-features = ["full"]
|
||||
|
||||
|
||||
130
tarpc/examples/compression.rs
Normal file
130
tarpc/examples/compression.rs
Normal file
@@ -0,0 +1,130 @@
|
||||
use flate2::{read::DeflateDecoder, write::DeflateEncoder, Compression};
|
||||
use futures::{Sink, SinkExt, Stream, StreamExt, TryStreamExt};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_bytes::ByteBuf;
|
||||
use std::{io, io::Read, io::Write};
|
||||
use tarpc::{
|
||||
client, context,
|
||||
serde_transport::tcp,
|
||||
server::{BaseChannel, Channel},
|
||||
};
|
||||
use tokio_serde::formats::Bincode;
|
||||
|
||||
/// Type of compression that should be enabled on the request. The transport is free to ignore this.
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Copy, Deserialize, Serialize)]
|
||||
pub enum CompressionAlgorithm {
|
||||
Deflate,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
pub enum CompressedMessage<T> {
|
||||
Uncompressed(T),
|
||||
Compressed {
|
||||
algorithm: CompressionAlgorithm,
|
||||
payload: ByteBuf,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Serialize)]
|
||||
enum CompressionType {
|
||||
Uncompressed,
|
||||
Compressed,
|
||||
}
|
||||
|
||||
async fn compress<T>(message: T) -> io::Result<CompressedMessage<T>>
|
||||
where
|
||||
T: Serialize,
|
||||
{
|
||||
let message = serialize(message)?;
|
||||
let mut encoder = DeflateEncoder::new(Vec::new(), Compression::default());
|
||||
encoder.write_all(&message).unwrap();
|
||||
let compressed = encoder.finish()?;
|
||||
Ok(CompressedMessage::Compressed {
|
||||
algorithm: CompressionAlgorithm::Deflate,
|
||||
payload: ByteBuf::from(compressed),
|
||||
})
|
||||
}
|
||||
|
||||
async fn decompress<T>(message: CompressedMessage<T>) -> io::Result<T>
|
||||
where
|
||||
for<'a> T: Deserialize<'a>,
|
||||
{
|
||||
match message {
|
||||
CompressedMessage::Compressed { algorithm, payload } => {
|
||||
if algorithm != CompressionAlgorithm::Deflate {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
format!("Compression algorithm {:?} not supported", algorithm),
|
||||
));
|
||||
}
|
||||
let mut deflater = DeflateDecoder::new(payload.as_slice());
|
||||
let mut payload = ByteBuf::new();
|
||||
deflater.read_to_end(&mut payload)?;
|
||||
let message = deserialize(payload)?;
|
||||
Ok(message)
|
||||
}
|
||||
CompressedMessage::Uncompressed(message) => Ok(message),
|
||||
}
|
||||
}
|
||||
|
||||
fn serialize<T: Serialize>(t: T) -> io::Result<ByteBuf> {
|
||||
bincode::serialize(&t)
|
||||
.map(ByteBuf::from)
|
||||
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))
|
||||
}
|
||||
|
||||
fn deserialize<D>(message: ByteBuf) -> io::Result<D>
|
||||
where
|
||||
for<'a> D: Deserialize<'a>,
|
||||
{
|
||||
bincode::deserialize(message.as_ref()).map_err(|e| io::Error::new(io::ErrorKind::Other, e))
|
||||
}
|
||||
|
||||
fn add_compression<In, Out>(
|
||||
transport: impl Stream<Item = io::Result<CompressedMessage<In>>>
|
||||
+ Sink<CompressedMessage<Out>, Error = io::Error>,
|
||||
) -> impl Stream<Item = io::Result<In>> + Sink<Out, Error = io::Error>
|
||||
where
|
||||
Out: Serialize,
|
||||
for<'a> In: Deserialize<'a>,
|
||||
{
|
||||
transport.with(compress).and_then(decompress)
|
||||
}
|
||||
|
||||
#[tarpc::service]
|
||||
pub trait World {
|
||||
async fn hello(name: String) -> String;
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct HelloServer;
|
||||
|
||||
#[tarpc::server]
|
||||
impl World for HelloServer {
|
||||
async fn hello(self, _: context::Context, name: String) -> String {
|
||||
format!("Hey, {}!", name)
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
let mut incoming = tcp::listen("localhost:0", Bincode::default).await?;
|
||||
let addr = incoming.local_addr();
|
||||
tokio::spawn(async move {
|
||||
let transport = incoming.next().await.unwrap().unwrap();
|
||||
BaseChannel::with_defaults(add_compression(transport))
|
||||
.respond_with(HelloServer.serve())
|
||||
.execute()
|
||||
.await;
|
||||
});
|
||||
|
||||
let transport = tcp::connect(addr, Bincode::default).await?;
|
||||
let mut client =
|
||||
WorldClient::new(client::Config::default(), add_compression(transport)).spawn()?;
|
||||
|
||||
println!(
|
||||
"{}",
|
||||
client.hello(context::current(), "friend".into()).await?
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
@@ -4,192 +4,341 @@
|
||||
// license that can be found in the LICENSE file or at
|
||||
// https://opensource.org/licenses/MIT.
|
||||
|
||||
/// - The PubSub server sets up TCP listeners on 2 ports, the "subscriber" port and the "publisher"
|
||||
/// port. Because both publishers and subscribers initiate their connections to the PubSub
|
||||
/// server, the server requires no prior knowledge of either publishers or subscribers.
|
||||
///
|
||||
/// - Subscribers connect to the server on the server's "subscriber" port. Once a connection is
|
||||
/// established, the server acts as the client of the Subscriber service, initially requesting
|
||||
/// the topics the subscriber is interested in, and subsequently sending topical messages to the
|
||||
/// subscriber.
|
||||
///
|
||||
/// - Publishers connect to the server on the "publisher" port and, once connected, they send
|
||||
/// topical messages via Publisher service to the server. The server then broadcasts each
|
||||
/// messages to all clients subscribed to the topic of that message.
|
||||
///
|
||||
/// Subscriber Publisher PubSub Server
|
||||
/// T1 | | |
|
||||
/// T2 |-----Connect------------------------------------------------------>|
|
||||
/// T3 | | |
|
||||
/// T2 |<-------------------------------------------------------Topics-----|
|
||||
/// T2 |-----(OK) Topics-------------------------------------------------->|
|
||||
/// T3 | | |
|
||||
/// T4 | |-----Connect-------------------->|
|
||||
/// T5 | | |
|
||||
/// T6 | |-----Publish-------------------->|
|
||||
/// T7 | | |
|
||||
/// T8 |<------------------------------------------------------Receive-----|
|
||||
/// T9 |-----(OK) Receive------------------------------------------------->|
|
||||
/// T10 | | |
|
||||
/// T11 | |<--------------(OK) Publish------|
|
||||
use anyhow::anyhow;
|
||||
use futures::{
|
||||
future::{self, Ready},
|
||||
channel::oneshot,
|
||||
future::{self, AbortHandle},
|
||||
prelude::*,
|
||||
Future,
|
||||
};
|
||||
use log::info;
|
||||
use publisher::Publisher as _;
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
io,
|
||||
net::SocketAddr,
|
||||
pin::Pin,
|
||||
sync::{Arc, Mutex},
|
||||
time::Duration,
|
||||
sync::{Arc, Mutex, RwLock},
|
||||
};
|
||||
use subscriber::Subscriber as _;
|
||||
use tarpc::{
|
||||
client, context,
|
||||
server::{self, Handler},
|
||||
serde_transport::tcp,
|
||||
server::{self, Channel},
|
||||
};
|
||||
use tokio::net::ToSocketAddrs;
|
||||
use tokio_serde::formats::Json;
|
||||
|
||||
pub mod subscriber {
|
||||
#[tarpc::service]
|
||||
pub trait Subscriber {
|
||||
async fn receive(message: String);
|
||||
async fn topics() -> Vec<String>;
|
||||
async fn receive(topic: String, message: String);
|
||||
}
|
||||
}
|
||||
|
||||
pub mod publisher {
|
||||
use std::net::SocketAddr;
|
||||
|
||||
#[tarpc::service]
|
||||
pub trait Publisher {
|
||||
async fn broadcast(message: String);
|
||||
async fn subscribe(id: u32, address: SocketAddr) -> Result<(), String>;
|
||||
async fn unsubscribe(id: u32);
|
||||
async fn publish(topic: String, message: String);
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct Subscriber {
|
||||
id: u32,
|
||||
local_addr: SocketAddr,
|
||||
topics: Vec<String>,
|
||||
}
|
||||
|
||||
#[tarpc::server]
|
||||
impl subscriber::Subscriber for Subscriber {
|
||||
type ReceiveFut = Ready<()>;
|
||||
async fn topics(self, _: context::Context) -> Vec<String> {
|
||||
self.topics.clone()
|
||||
}
|
||||
|
||||
fn receive(self, _: context::Context, message: String) -> Self::ReceiveFut {
|
||||
eprintln!("{} received message: {}", self.id, message);
|
||||
future::ready(())
|
||||
async fn receive(self, _: context::Context, topic: String, message: String) {
|
||||
info!(
|
||||
"[{}] received message on topic '{}': {}",
|
||||
self.local_addr, topic, message
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
struct SubscriberHandle(AbortHandle);
|
||||
|
||||
impl Drop for SubscriberHandle {
|
||||
fn drop(&mut self) {
|
||||
self.0.abort();
|
||||
}
|
||||
}
|
||||
|
||||
impl Subscriber {
|
||||
async fn listen(id: u32, config: server::Config) -> io::Result<SocketAddr> {
|
||||
let incoming = tarpc::serde_transport::tcp::listen("localhost:0", Json::default)
|
||||
.await?
|
||||
.filter_map(|r| future::ready(r.ok()));
|
||||
let addr = incoming.get_ref().local_addr();
|
||||
tokio::spawn(
|
||||
server::new(config)
|
||||
.incoming(incoming)
|
||||
.take(1)
|
||||
.respond_with(Subscriber { id }.serve()),
|
||||
);
|
||||
Ok(addr)
|
||||
async fn connect(
|
||||
publisher_addr: impl ToSocketAddrs,
|
||||
topics: Vec<String>,
|
||||
) -> anyhow::Result<SubscriberHandle> {
|
||||
let publisher = tcp::connect(publisher_addr, Json::default).await?;
|
||||
let local_addr = publisher.local_addr()?;
|
||||
let mut handler = server::BaseChannel::with_defaults(publisher)
|
||||
.respond_with(Subscriber { local_addr, topics }.serve());
|
||||
// The first request is for the topics being subscriibed to.
|
||||
match handler.next().await {
|
||||
Some(init_topics) => init_topics?.await,
|
||||
None => {
|
||||
return Err(anyhow!(
|
||||
"[{}] Server never initialized the subscriber.",
|
||||
local_addr
|
||||
))
|
||||
}
|
||||
};
|
||||
let (handler, abort_handle) = future::abortable(handler.execute());
|
||||
tokio::spawn(async move {
|
||||
match handler.await {
|
||||
Ok(()) | Err(future::Aborted) => info!("[{}] subscriber shutdown.", local_addr),
|
||||
}
|
||||
});
|
||||
Ok(SubscriberHandle(abort_handle))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Subscription {
|
||||
subscriber: subscriber::SubscriberClient,
|
||||
topics: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct Publisher {
|
||||
clients: Arc<Mutex<HashMap<u32, subscriber::SubscriberClient>>>,
|
||||
clients: Arc<Mutex<HashMap<SocketAddr, Subscription>>>,
|
||||
subscriptions: Arc<RwLock<HashMap<String, HashMap<SocketAddr, subscriber::SubscriberClient>>>>,
|
||||
}
|
||||
|
||||
struct PublisherAddrs {
|
||||
publisher: SocketAddr,
|
||||
subscriptions: SocketAddr,
|
||||
}
|
||||
|
||||
impl Publisher {
|
||||
fn new() -> Publisher {
|
||||
Publisher {
|
||||
clients: Arc::new(Mutex::new(HashMap::new())),
|
||||
async fn start(self) -> io::Result<PublisherAddrs> {
|
||||
let mut connecting_publishers = tcp::listen("localhost:0", Json::default).await?;
|
||||
|
||||
let publisher_addrs = PublisherAddrs {
|
||||
publisher: connecting_publishers.local_addr(),
|
||||
subscriptions: self.clone().start_subscription_manager().await?,
|
||||
};
|
||||
|
||||
info!("[{}] listening for publishers.", publisher_addrs.publisher);
|
||||
tokio::spawn(async move {
|
||||
// Because this is just an example, we know there will only be one publisher. In more
|
||||
// realistic code, this would be a loop to continually accept new publisher
|
||||
// connections.
|
||||
let publisher = connecting_publishers.next().await.unwrap().unwrap();
|
||||
info!("[{}] publisher connected.", publisher.peer_addr().unwrap());
|
||||
|
||||
server::BaseChannel::with_defaults(publisher)
|
||||
.respond_with(self.serve())
|
||||
.execute()
|
||||
.await
|
||||
});
|
||||
|
||||
Ok(publisher_addrs)
|
||||
}
|
||||
|
||||
async fn start_subscription_manager(mut self) -> io::Result<SocketAddr> {
|
||||
let mut connecting_subscribers = tcp::listen("localhost:0", Json::default)
|
||||
.await?
|
||||
.filter_map(|r| future::ready(r.ok()));
|
||||
let new_subscriber_addr = connecting_subscribers.get_ref().local_addr();
|
||||
info!("[{}] listening for subscribers.", new_subscriber_addr);
|
||||
|
||||
tokio::spawn(async move {
|
||||
while let Some(conn) = connecting_subscribers.next().await {
|
||||
let subscriber_addr = conn.peer_addr().unwrap();
|
||||
|
||||
let tarpc::client::NewClient {
|
||||
client: subscriber,
|
||||
dispatch,
|
||||
} = subscriber::SubscriberClient::new(client::Config::default(), conn);
|
||||
let (ready_tx, ready) = oneshot::channel();
|
||||
self.clone()
|
||||
.start_subscriber_gc(subscriber_addr, dispatch, ready);
|
||||
|
||||
// Populate the topics
|
||||
self.initialize_subscription(subscriber_addr, subscriber)
|
||||
.await;
|
||||
|
||||
// Signal that initialization is done.
|
||||
ready_tx.send(()).unwrap();
|
||||
}
|
||||
});
|
||||
|
||||
Ok(new_subscriber_addr)
|
||||
}
|
||||
|
||||
async fn initialize_subscription(
|
||||
&mut self,
|
||||
subscriber_addr: SocketAddr,
|
||||
mut subscriber: subscriber::SubscriberClient,
|
||||
) {
|
||||
// Populate the topics
|
||||
if let Ok(topics) = subscriber.topics(context::current()).await {
|
||||
self.clients.lock().unwrap().insert(
|
||||
subscriber_addr,
|
||||
Subscription {
|
||||
subscriber: subscriber.clone(),
|
||||
topics: topics.clone(),
|
||||
},
|
||||
);
|
||||
|
||||
info!("[{}] subscribed to topics: {:?}", subscriber_addr, topics);
|
||||
let mut subscriptions = self.subscriptions.write().unwrap();
|
||||
for topic in topics {
|
||||
subscriptions
|
||||
.entry(topic)
|
||||
.or_insert_with(HashMap::new)
|
||||
.insert(subscriber_addr, subscriber.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn start_subscriber_gc(
|
||||
self,
|
||||
subscriber_addr: SocketAddr,
|
||||
client_dispatch: impl Future<Output = anyhow::Result<()>> + Send + 'static,
|
||||
subscriber_ready: oneshot::Receiver<()>,
|
||||
) {
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = client_dispatch.await {
|
||||
info!(
|
||||
"[{}] subscriber connection broken: {:?}",
|
||||
subscriber_addr, e
|
||||
)
|
||||
}
|
||||
// Don't clean up the subscriber until initialization is done.
|
||||
let _ = subscriber_ready.await;
|
||||
if let Some(subscription) = self.clients.lock().unwrap().remove(&subscriber_addr) {
|
||||
info!(
|
||||
"[{} unsubscribing from topics: {:?}",
|
||||
subscriber_addr, subscription.topics
|
||||
);
|
||||
let mut subscriptions = self.subscriptions.write().unwrap();
|
||||
for topic in subscription.topics {
|
||||
let subscribers = subscriptions.get_mut(&topic).unwrap();
|
||||
subscribers.remove(&subscriber_addr);
|
||||
if subscribers.is_empty() {
|
||||
subscriptions.remove(&topic);
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
#[tarpc::server]
|
||||
impl publisher::Publisher for Publisher {
|
||||
type BroadcastFut = Pin<Box<dyn Future<Output = ()> + Send>>;
|
||||
|
||||
fn broadcast(self, _: context::Context, message: String) -> Self::BroadcastFut {
|
||||
async fn broadcast(
|
||||
clients: Arc<Mutex<HashMap<u32, subscriber::SubscriberClient>>>,
|
||||
message: String,
|
||||
) {
|
||||
let mut clients = clients.lock().unwrap().clone();
|
||||
for client in clients.values_mut() {
|
||||
// Ignore failing subscribers. In a real pubsub,
|
||||
// you'd want to continually retry until subscribers
|
||||
// ack.
|
||||
let _ = client.receive(context::current(), message.clone()).await;
|
||||
async fn publish(self, _: context::Context, topic: String, message: String) {
|
||||
info!("received message to publish.");
|
||||
let mut subscribers = match self.subscriptions.read().unwrap().get(&topic) {
|
||||
None => return,
|
||||
Some(subscriptions) => subscriptions.clone(),
|
||||
};
|
||||
let mut publications = Vec::new();
|
||||
for client in subscribers.values_mut() {
|
||||
publications.push(client.receive(context::current(), topic.clone(), message.clone()));
|
||||
}
|
||||
// Ignore failing subscribers. In a real pubsub, you'd want to continually retry until
|
||||
// subscribers ack. Of course, a lot would be different in a real pubsub :)
|
||||
for response in future::join_all(publications).await {
|
||||
if let Err(e) = response {
|
||||
info!("failed to broadcast to subscriber: {}", e);
|
||||
}
|
||||
}
|
||||
|
||||
broadcast(self.clients.clone(), message).boxed()
|
||||
}
|
||||
|
||||
type SubscribeFut = Pin<Box<dyn Future<Output = Result<(), String>> + Send>>;
|
||||
|
||||
fn subscribe(self, _: context::Context, id: u32, addr: SocketAddr) -> Self::SubscribeFut {
|
||||
async fn subscribe(
|
||||
clients: Arc<Mutex<HashMap<u32, subscriber::SubscriberClient>>>,
|
||||
id: u32,
|
||||
addr: SocketAddr,
|
||||
) -> io::Result<()> {
|
||||
let conn = tarpc::serde_transport::tcp::connect(addr, Json::default()).await?;
|
||||
let subscriber =
|
||||
subscriber::SubscriberClient::new(client::Config::default(), conn).spawn()?;
|
||||
eprintln!("Subscribing {}.", id);
|
||||
clients.lock().unwrap().insert(id, subscriber);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
subscribe(Arc::clone(&self.clients), id, addr)
|
||||
.map_err(|e| e.to_string())
|
||||
.boxed()
|
||||
}
|
||||
|
||||
type UnsubscribeFut = Pin<Box<dyn Future<Output = ()> + Send>>;
|
||||
|
||||
fn unsubscribe(self, _: context::Context, id: u32) -> Self::UnsubscribeFut {
|
||||
eprintln!("Unsubscribing {}", id);
|
||||
let mut clients = self.clients.lock().unwrap();
|
||||
if clients.remove(&id).is_none() {
|
||||
eprintln!(
|
||||
"Client {} not found. Existings clients: {:?}",
|
||||
id, &*clients
|
||||
);
|
||||
}
|
||||
future::ready(()).boxed()
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> io::Result<()> {
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
env_logger::init();
|
||||
|
||||
let transport = tarpc::serde_transport::tcp::listen("localhost:0", Json::default)
|
||||
.await?
|
||||
.filter_map(|r| future::ready(r.ok()));
|
||||
let publisher_addr = transport.get_ref().local_addr();
|
||||
tokio::spawn(
|
||||
transport
|
||||
.take(1)
|
||||
.map(server::BaseChannel::with_defaults)
|
||||
.respond_with(Publisher::new().serve()),
|
||||
);
|
||||
|
||||
let subscriber1 = Subscriber::listen(0, server::Config::default()).await?;
|
||||
let subscriber2 = Subscriber::listen(1, server::Config::default()).await?;
|
||||
|
||||
let publisher_conn = tarpc::serde_transport::tcp::connect(publisher_addr, Json::default());
|
||||
let publisher_conn = publisher_conn.await?;
|
||||
let mut publisher =
|
||||
publisher::PublisherClient::new(client::Config::default(), publisher_conn).spawn()?;
|
||||
|
||||
if let Err(e) = publisher
|
||||
.subscribe(context::current(), 0, subscriber1)
|
||||
.await?
|
||||
{
|
||||
eprintln!("Couldn't subscribe subscriber 0: {}", e);
|
||||
}
|
||||
if let Err(e) = publisher
|
||||
.subscribe(context::current(), 1, subscriber2)
|
||||
.await?
|
||||
{
|
||||
eprintln!("Couldn't subscribe subscriber 1: {}", e);
|
||||
let clients = Arc::new(Mutex::new(HashMap::new()));
|
||||
let addrs = Publisher {
|
||||
clients,
|
||||
subscriptions: Arc::new(RwLock::new(HashMap::new())),
|
||||
}
|
||||
.start()
|
||||
.await?;
|
||||
|
||||
println!("Broadcasting...");
|
||||
publisher
|
||||
.broadcast(context::current(), "hello to all".to_string())
|
||||
.await?;
|
||||
publisher.unsubscribe(context::current(), 1).await?;
|
||||
publisher
|
||||
.broadcast(context::current(), "hi again".to_string())
|
||||
.await?;
|
||||
drop(publisher);
|
||||
let _subscriber0 = Subscriber::connect(
|
||||
addrs.subscriptions,
|
||||
vec!["calculus".into(), "cool shorts".into()],
|
||||
)
|
||||
.await?;
|
||||
|
||||
tokio::time::delay_for(Duration::from_millis(100)).await;
|
||||
println!("Done.");
|
||||
let _subscriber1 = Subscriber::connect(
|
||||
addrs.subscriptions,
|
||||
vec!["cool shorts".into(), "history".into()],
|
||||
)
|
||||
.await?;
|
||||
|
||||
let mut publisher = publisher::PublisherClient::new(
|
||||
client::Config::default(),
|
||||
tcp::connect(addrs.publisher, Json::default).await?,
|
||||
)
|
||||
.spawn()?;
|
||||
|
||||
publisher
|
||||
.publish(context::current(), "calculus".into(), "sqrt(2)".into())
|
||||
.await?;
|
||||
|
||||
publisher
|
||||
.publish(
|
||||
context::current(),
|
||||
"cool shorts".into(),
|
||||
"hello to all".into(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
publisher
|
||||
.publish(context::current(), "history".into(), "napoleon".to_string())
|
||||
.await?;
|
||||
|
||||
drop(_subscriber0);
|
||||
|
||||
publisher
|
||||
.publish(
|
||||
context::current(),
|
||||
"cool shorts".into(),
|
||||
"hello to who?".into(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
info!("done.");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -62,7 +62,7 @@ async fn main() -> io::Result<()> {
|
||||
};
|
||||
tokio::spawn(server);
|
||||
|
||||
let transport = tarpc::serde_transport::tcp::connect(addr, Json::default()).await?;
|
||||
let transport = tarpc::serde_transport::tcp::connect(addr, Json::default).await?;
|
||||
|
||||
// WorldClient is generated by the tarpc::service attribute. It has a constructor `new` that
|
||||
// takes a config and any Transport as input.
|
||||
|
||||
@@ -5,11 +5,8 @@
|
||||
// https://opensource.org/licenses/MIT.
|
||||
|
||||
use crate::{add::Add as AddService, double::Double as DoubleService};
|
||||
use futures::{
|
||||
future::{self, Ready},
|
||||
prelude::*,
|
||||
};
|
||||
use std::{io, pin::Pin};
|
||||
use futures::{future, prelude::*};
|
||||
use std::io;
|
||||
use tarpc::{
|
||||
client, context,
|
||||
server::{Handler, Server},
|
||||
@@ -35,11 +32,10 @@ pub mod double {
|
||||
#[derive(Clone)]
|
||||
struct AddServer;
|
||||
|
||||
#[tarpc::server]
|
||||
impl AddService for AddServer {
|
||||
type AddFut = Ready<i32>;
|
||||
|
||||
fn add(self, _: context::Context, x: i32, y: i32) -> Self::AddFut {
|
||||
future::ready(x + y)
|
||||
async fn add(self, _: context::Context, x: i32, y: i32) -> i32 {
|
||||
x + y
|
||||
}
|
||||
}
|
||||
|
||||
@@ -48,18 +44,13 @@ struct DoubleServer {
|
||||
add_client: add::AddClient,
|
||||
}
|
||||
|
||||
#[tarpc::server]
|
||||
impl DoubleService for DoubleServer {
|
||||
type DoubleFut = Pin<Box<dyn Future<Output = Result<i32, String>> + Send>>;
|
||||
|
||||
fn double(self, _: context::Context, x: i32) -> Self::DoubleFut {
|
||||
async fn double(mut client: add::AddClient, x: i32) -> Result<i32, String> {
|
||||
client
|
||||
.add(context::current(), x, x)
|
||||
.await
|
||||
.map_err(|e| e.to_string())
|
||||
}
|
||||
|
||||
double(self.add_client.clone(), x).boxed()
|
||||
async fn double(mut self, _: context::Context, x: i32) -> Result<i32, String> {
|
||||
self.add_client
|
||||
.add(context::current(), x, x)
|
||||
.await
|
||||
.map_err(|e| e.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -77,7 +68,7 @@ async fn main() -> io::Result<()> {
|
||||
.respond_with(AddServer.serve());
|
||||
tokio::spawn(add_server);
|
||||
|
||||
let to_add_server = tarpc::serde_transport::tcp::connect(addr, Json::default()).await?;
|
||||
let to_add_server = tarpc::serde_transport::tcp::connect(addr, Json::default).await?;
|
||||
let add_client = add::AddClient::new(client::Config::default(), to_add_server).spawn()?;
|
||||
|
||||
let double_listener = tarpc::serde_transport::tcp::listen("localhost:0", Json::default)
|
||||
@@ -90,7 +81,7 @@ async fn main() -> io::Result<()> {
|
||||
.respond_with(DoubleServer { add_client }.serve());
|
||||
tokio::spawn(double_server);
|
||||
|
||||
let to_double_server = tarpc::serde_transport::tcp::connect(addr, Json::default()).await?;
|
||||
let to_double_server = tarpc::serde_transport::tcp::connect(addr, Json::default).await?;
|
||||
let mut double_client =
|
||||
double::DoubleClient::new(client::Config::default(), to_double_server).spawn()?;
|
||||
|
||||
|
||||
@@ -4,9 +4,6 @@
|
||||
// license that can be found in the LICENSE file or at
|
||||
// https://opensource.org/licenses/MIT.
|
||||
|
||||
//! [](https://crates.io/crates/tarpc)
|
||||
//! [](https://gitter.im/tarpc/Lobby?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge)
|
||||
//!
|
||||
//! *Disclaimer*: This is not an official Google product.
|
||||
//!
|
||||
//! tarpc is an RPC framework for rust with a focus on ease of use. Defining a
|
||||
@@ -50,7 +47,7 @@
|
||||
//! Add to your `Cargo.toml` dependencies:
|
||||
//!
|
||||
//! ```toml
|
||||
//! tarpc = "0.20.0"
|
||||
//! tarpc = "0.22.0"
|
||||
//! ```
|
||||
//!
|
||||
//! The `tarpc::service` attribute expands to a collection of items that form an rpc service.
|
||||
@@ -201,11 +198,13 @@
|
||||
//! items expanded by a `service!` invocation.
|
||||
#![deny(missing_docs)]
|
||||
#![allow(clippy::type_complexity)]
|
||||
#![cfg_attr(docsrs, feature(doc_cfg))]
|
||||
|
||||
pub mod rpc;
|
||||
pub use rpc::*;
|
||||
|
||||
#[cfg(feature = "serde-transport")]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "serde-transport")))]
|
||||
pub mod serde_transport;
|
||||
|
||||
pub mod trace;
|
||||
@@ -215,7 +214,6 @@ pub mod trace;
|
||||
/// Rpc methods are specified, mirroring trait syntax:
|
||||
///
|
||||
/// ```
|
||||
/// # fn main() {}
|
||||
/// #[tarpc::service]
|
||||
/// trait Service {
|
||||
/// /// Say hello
|
||||
@@ -234,3 +232,59 @@ pub mod trace;
|
||||
/// * `Client` -- a client stub with a fn for each RPC.
|
||||
/// * `fn new_stub` -- creates a new Client stub.
|
||||
pub use tarpc_plugins::service;
|
||||
|
||||
/// A utility macro that can be used for RPC server implementations.
|
||||
///
|
||||
/// Syntactic sugar to make using async functions in the server implementation
|
||||
/// easier. It does this by rewriting code like this, which would normally not
|
||||
/// compile because async functions are disallowed in trait implementations:
|
||||
///
|
||||
/// ```rust
|
||||
/// # use tarpc::context;
|
||||
/// # use std::net::SocketAddr;
|
||||
/// #[tarpc::service]
|
||||
/// trait World {
|
||||
/// async fn hello(name: String) -> String;
|
||||
/// }
|
||||
///
|
||||
/// #[derive(Clone)]
|
||||
/// struct HelloServer(SocketAddr);
|
||||
///
|
||||
/// #[tarpc::server]
|
||||
/// impl World for HelloServer {
|
||||
/// async fn hello(self, _: context::Context, name: String) -> String {
|
||||
/// format!("Hello, {}! You are connected from {:?}.", name, self.0)
|
||||
/// }
|
||||
/// }
|
||||
/// ```
|
||||
///
|
||||
/// Into code like this, which matches the service trait definition:
|
||||
///
|
||||
/// ```rust
|
||||
/// # use tarpc::context;
|
||||
/// # use std::pin::Pin;
|
||||
/// # use futures::Future;
|
||||
/// # use std::net::SocketAddr;
|
||||
/// #[derive(Clone)]
|
||||
/// struct HelloServer(SocketAddr);
|
||||
///
|
||||
/// #[tarpc::service]
|
||||
/// trait World {
|
||||
/// async fn hello(name: String) -> String;
|
||||
/// }
|
||||
///
|
||||
/// impl World for HelloServer {
|
||||
/// type HelloFut = Pin<Box<dyn Future<Output = String> + Send>>;
|
||||
///
|
||||
/// fn hello(self, _: context::Context, name: String) -> Pin<Box<dyn Future<Output = String>
|
||||
/// + Send>> {
|
||||
/// Box::pin(async move {
|
||||
/// format!("Hello, {}! You are connected from {:?}.", name, self.0)
|
||||
/// })
|
||||
/// }
|
||||
/// }
|
||||
/// ```
|
||||
///
|
||||
/// Note that this won't touch functions unless they have been annotated with
|
||||
/// `async`, meaning that this should not break existing code.
|
||||
pub use tarpc_plugins::server;
|
||||
|
||||
@@ -32,12 +32,14 @@ pub(crate) mod util;
|
||||
|
||||
pub use crate::{client::Client, server::Server, trace, transport::sealed::Transport};
|
||||
|
||||
use anyhow::Context as _;
|
||||
use futures::task::*;
|
||||
use std::{io, time::SystemTime};
|
||||
use std::{fmt::Display, io, time::SystemTime};
|
||||
|
||||
/// A message from a client to a server.
|
||||
#[derive(Debug)]
|
||||
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
|
||||
#[non_exhaustive]
|
||||
pub enum ClientMessage<T> {
|
||||
/// A request initiated by a user. The server responds to a request by invoking a
|
||||
/// service-provided request handler. The handler completes with a [`response`](Response), which
|
||||
@@ -58,12 +60,11 @@ pub enum ClientMessage<T> {
|
||||
/// The ID of the request to cancel.
|
||||
request_id: u64,
|
||||
},
|
||||
#[doc(hidden)]
|
||||
_NonExhaustive,
|
||||
}
|
||||
|
||||
/// A request from a client to a server.
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
#[non_exhaustive]
|
||||
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
|
||||
pub struct Request<T> {
|
||||
/// Trace context, deadline, and other cross-cutting concerns.
|
||||
@@ -72,26 +73,22 @@ pub struct Request<T> {
|
||||
pub id: u64,
|
||||
/// The request body.
|
||||
pub message: T,
|
||||
#[doc(hidden)]
|
||||
#[cfg_attr(feature = "serde1", serde(skip_serializing, default))]
|
||||
_non_exhaustive: (),
|
||||
}
|
||||
|
||||
/// A response from a server to a client.
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
|
||||
#[non_exhaustive]
|
||||
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
|
||||
pub struct Response<T> {
|
||||
/// The ID of the request being responded to.
|
||||
pub request_id: u64,
|
||||
/// The response body, or an error if the request failed.
|
||||
pub message: Result<T, ServerError>,
|
||||
#[doc(hidden)]
|
||||
#[cfg_attr(feature = "serde1", serde(skip_serializing, default))]
|
||||
_non_exhaustive: (),
|
||||
}
|
||||
|
||||
/// An error response from a server to a client.
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
|
||||
#[non_exhaustive]
|
||||
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
|
||||
pub struct ServerError {
|
||||
#[cfg_attr(
|
||||
@@ -106,9 +103,6 @@ pub struct ServerError {
|
||||
pub kind: io::ErrorKind,
|
||||
/// A message describing more detail about the error that occurred.
|
||||
pub detail: Option<String>,
|
||||
#[doc(hidden)]
|
||||
#[cfg_attr(feature = "serde1", serde(skip_serializing, default))]
|
||||
_non_exhaustive: (),
|
||||
}
|
||||
|
||||
impl From<ServerError> for io::Error {
|
||||
@@ -125,3 +119,30 @@ impl<T> Request<T> {
|
||||
}
|
||||
|
||||
pub(crate) type PollIo<T> = Poll<Option<io::Result<T>>>;
|
||||
pub(crate) trait PollContext<T> {
|
||||
fn context<C>(self, context: C) -> Poll<Option<anyhow::Result<T>>>
|
||||
where
|
||||
C: Display + Send + Sync + 'static;
|
||||
|
||||
fn with_context<C, F>(self, f: F) -> Poll<Option<anyhow::Result<T>>>
|
||||
where
|
||||
C: Display + Send + Sync + 'static,
|
||||
F: FnOnce() -> C;
|
||||
}
|
||||
|
||||
impl<T> PollContext<T> for PollIo<T> {
|
||||
fn context<C>(self, context: C) -> Poll<Option<anyhow::Result<T>>>
|
||||
where
|
||||
C: Display + Send + Sync + 'static,
|
||||
{
|
||||
self.map(|o| o.map(|r| r.context(context)))
|
||||
}
|
||||
|
||||
fn with_context<C, F>(self, f: F) -> Poll<Option<anyhow::Result<T>>>
|
||||
where
|
||||
C: Display + Send + Sync + 'static,
|
||||
F: FnOnce() -> C,
|
||||
{
|
||||
self.map(|o| o.map(|r| r.with_context(f)))
|
||||
}
|
||||
}
|
||||
@@ -104,6 +104,7 @@ where
|
||||
|
||||
/// Settings that control the behavior of the client.
|
||||
#[derive(Clone, Debug)]
|
||||
#[non_exhaustive]
|
||||
pub struct Config {
|
||||
/// The number of requests that can be in flight at once.
|
||||
/// `max_in_flight_requests` controls the size of the map used by the client
|
||||
@@ -113,8 +114,6 @@ pub struct Config {
|
||||
/// `pending_requests_buffer` controls the size of the channel clients use
|
||||
/// to communicate with the request dispatch task.
|
||||
pub pending_request_buffer: usize,
|
||||
#[doc(hidden)]
|
||||
_non_exhaustive: (),
|
||||
}
|
||||
|
||||
impl Default for Config {
|
||||
@@ -122,7 +121,6 @@ impl Default for Config {
|
||||
Config {
|
||||
max_in_flight_requests: 1_000,
|
||||
pending_request_buffer: 100,
|
||||
_non_exhaustive: (),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -137,12 +135,14 @@ pub struct NewClient<C, D> {
|
||||
pub dispatch: D,
|
||||
}
|
||||
|
||||
impl<C, D> NewClient<C, D>
|
||||
impl<C, D, E> NewClient<C, D>
|
||||
where
|
||||
D: Future<Output = io::Result<()>> + Send + 'static,
|
||||
D: Future<Output = Result<(), E>> + Send + 'static,
|
||||
E: std::fmt::Display,
|
||||
{
|
||||
/// Helper method to spawn the dispatch on the default executor.
|
||||
#[cfg(feature = "tokio1")]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))]
|
||||
pub fn spawn(self) -> io::Result<C> {
|
||||
use log::error;
|
||||
|
||||
@@ -8,7 +8,7 @@ use crate::{
|
||||
context,
|
||||
trace::SpanId,
|
||||
util::{Compact, TimeUntil},
|
||||
ClientMessage, PollIo, Request, Response, Transport,
|
||||
ClientMessage, PollContext, PollIo, Request, Response, Transport,
|
||||
};
|
||||
use fnv::FnvHashMap;
|
||||
use futures::{
|
||||
@@ -78,14 +78,21 @@ impl<'a, Req, Resp> Future for Send<'a, Req, Resp> {
|
||||
#[must_use = "futures do nothing unless polled"]
|
||||
pub struct Call<'a, Req, Resp> {
|
||||
#[pin]
|
||||
fut: AndThenIdent<Send<'a, Req, Resp>, DispatchResponse<Resp>>,
|
||||
fut: tokio::time::Timeout<AndThenIdent<Send<'a, Req, Resp>, DispatchResponse<Resp>>>,
|
||||
}
|
||||
|
||||
impl<'a, Req, Resp> Future for Call<'a, Req, Resp> {
|
||||
type Output = io::Result<Resp>;
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
self.as_mut().project().fut.poll(cx)
|
||||
let resp = ready!(self.as_mut().project().fut.poll(cx));
|
||||
Poll::Ready(match resp {
|
||||
Ok(resp) => resp,
|
||||
Err(tokio::time::Elapsed { .. }) => Err(io::Error::new(
|
||||
io::ErrorKind::TimedOut,
|
||||
"Client dropped expired request.".to_string(),
|
||||
)),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -97,13 +104,6 @@ impl<Req, Resp> Channel<Req, Resp> {
|
||||
ctx.trace_context.parent_id = Some(ctx.trace_context.span_id);
|
||||
ctx.trace_context.span_id = SpanId::random(&mut rand::thread_rng());
|
||||
|
||||
let timeout = ctx.deadline.time_until();
|
||||
trace!(
|
||||
"[{}] Queuing request with timeout {:?}.",
|
||||
ctx.trace_id(),
|
||||
timeout,
|
||||
);
|
||||
|
||||
let (response_completion, response) = oneshot::channel();
|
||||
let cancellation = self.cancellation.clone();
|
||||
let request_id = self.next_request_id.fetch_add(1, Ordering::Relaxed);
|
||||
@@ -116,7 +116,7 @@ impl<Req, Resp> Channel<Req, Resp> {
|
||||
response_completion,
|
||||
})),
|
||||
DispatchResponse {
|
||||
response: tokio::time::timeout(timeout, response),
|
||||
response,
|
||||
complete: false,
|
||||
request_id,
|
||||
cancellation,
|
||||
@@ -128,9 +128,16 @@ impl<Req, Resp> Channel<Req, Resp> {
|
||||
|
||||
/// Sends a request to the dispatch task to forward to the server, returning a [`Future`] that
|
||||
/// resolves to the response.
|
||||
pub fn call(&mut self, context: context::Context, request: Req) -> Call<Req, Resp> {
|
||||
pub fn call(&mut self, ctx: context::Context, request: Req) -> Call<Req, Resp> {
|
||||
let timeout = ctx.deadline.time_until();
|
||||
trace!(
|
||||
"[{}] Queuing request with timeout {:?}.",
|
||||
ctx.trace_id(),
|
||||
timeout,
|
||||
);
|
||||
|
||||
Call {
|
||||
fut: AndThenIdent::new(self.send(context, request)),
|
||||
fut: tokio::time::timeout(timeout, AndThenIdent::new(self.send(ctx, request))),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -140,7 +147,7 @@ impl<Req, Resp> Channel<Req, Resp> {
|
||||
#[pin_project(PinnedDrop)]
|
||||
#[derive(Debug)]
|
||||
struct DispatchResponse<Resp> {
|
||||
response: tokio::time::Timeout<oneshot::Receiver<Response<Resp>>>,
|
||||
response: oneshot::Receiver<Response<Resp>>,
|
||||
ctx: context::Context,
|
||||
complete: bool,
|
||||
cancellation: RequestCancellation,
|
||||
@@ -152,24 +159,15 @@ impl<Resp> Future for DispatchResponse<Resp> {
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<Resp>> {
|
||||
let resp = ready!(self.response.poll_unpin(cx));
|
||||
|
||||
self.complete = true;
|
||||
Poll::Ready(match resp {
|
||||
Ok(resp) => {
|
||||
self.complete = true;
|
||||
match resp {
|
||||
Ok(resp) => Ok(resp.message?),
|
||||
Err(oneshot::Canceled) => {
|
||||
// The oneshot is Canceled when the dispatch task ends. In that case,
|
||||
// there's nothing listening on the other side, so there's no point in
|
||||
// propagating cancellation.
|
||||
Err(io::Error::from(io::ErrorKind::ConnectionReset))
|
||||
}
|
||||
}
|
||||
Ok(resp) => Ok(resp.message?),
|
||||
Err(oneshot::Canceled) => {
|
||||
// The oneshot is Canceled when the dispatch task ends. In that case,
|
||||
// there's nothing listening on the other side, so there's no point in
|
||||
// propagating cancellation.
|
||||
Err(io::Error::from(io::ErrorKind::ConnectionReset))
|
||||
}
|
||||
Err(tokio::time::Elapsed { .. }) => Err(io::Error::new(
|
||||
io::ErrorKind::TimedOut,
|
||||
"Client dropped expired request.".to_string(),
|
||||
)),
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -189,7 +187,7 @@ impl<Resp> PinnedDrop for DispatchResponse<Resp> {
|
||||
// closing the receiver before sending the cancel message, it is guaranteed that if the
|
||||
// dispatch task misses an early-arriving cancellation message, then it will see the
|
||||
// receiver as closed.
|
||||
self.response.get_mut().close();
|
||||
self.response.close();
|
||||
let request_id = self.request_id;
|
||||
self.cancellation.cancel(request_id);
|
||||
}
|
||||
@@ -385,9 +383,7 @@ where
|
||||
context: context::Context {
|
||||
deadline: dispatch_request.ctx.deadline,
|
||||
trace_context: dispatch_request.ctx.trace_context,
|
||||
_non_exhaustive: (),
|
||||
},
|
||||
_non_exhaustive: (),
|
||||
});
|
||||
self.as_mut().project().transport.start_send(request)?;
|
||||
self.as_mut().project().in_flight_requests.insert(
|
||||
@@ -444,11 +440,22 @@ impl<Req, Resp, C> Future for RequestDispatch<Req, Resp, C>
|
||||
where
|
||||
C: Transport<ClientMessage<Req>, Response<Resp>>,
|
||||
{
|
||||
type Output = io::Result<()>;
|
||||
type Output = anyhow::Result<()>;
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<anyhow::Result<()>> {
|
||||
loop {
|
||||
match (self.as_mut().pump_read(cx)?, self.as_mut().pump_write(cx)?) {
|
||||
match (
|
||||
self.as_mut()
|
||||
.pump_read(cx)
|
||||
.context("failed to read from transport")?,
|
||||
self.as_mut()
|
||||
.pump_write(cx)
|
||||
.context("failed to write to transport")?,
|
||||
) {
|
||||
(Poll::Ready(None), _) => {
|
||||
info!("Shutdown: read half closed, so shutting down.");
|
||||
return Poll::Ready(Ok(()));
|
||||
}
|
||||
(read, Poll::Ready(None)) => {
|
||||
if self.as_mut().project().in_flight_requests.is_empty() {
|
||||
info!("Shutdown: write half closed, and no requests in flight.");
|
||||
@@ -632,11 +639,12 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
#[pin_project(project = TryChainProj)]
|
||||
#[must_use = "futures do nothing unless polled"]
|
||||
#[derive(Debug)]
|
||||
enum TryChain<Fut1, Fut2> {
|
||||
First(Fut1),
|
||||
Second(Fut2),
|
||||
First(#[pin] Fut1),
|
||||
Second(#[pin] Fut2),
|
||||
Empty,
|
||||
}
|
||||
|
||||
@@ -658,7 +666,7 @@ where
|
||||
}
|
||||
|
||||
fn poll<F>(
|
||||
self: Pin<&mut Self>,
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
f: F,
|
||||
) -> Poll<Result<Fut2::Ok, Fut2::Error>>
|
||||
@@ -667,31 +675,28 @@ where
|
||||
{
|
||||
let mut f = Some(f);
|
||||
|
||||
// Safe to call `get_unchecked_mut` because we won't move the futures.
|
||||
let this = unsafe { Pin::get_unchecked_mut(self) };
|
||||
|
||||
loop {
|
||||
let output = match this {
|
||||
TryChain::First(fut1) => {
|
||||
let output = match self.as_mut().project() {
|
||||
TryChainProj::First(fut1) => {
|
||||
// Poll the first future
|
||||
match unsafe { Pin::new_unchecked(fut1) }.try_poll(cx) {
|
||||
match fut1.try_poll(cx) {
|
||||
Poll::Pending => return Poll::Pending,
|
||||
Poll::Ready(output) => output,
|
||||
}
|
||||
}
|
||||
TryChain::Second(fut2) => {
|
||||
TryChainProj::Second(fut2) => {
|
||||
// Poll the second future
|
||||
return unsafe { Pin::new_unchecked(fut2) }.try_poll(cx);
|
||||
return fut2.try_poll(cx);
|
||||
}
|
||||
TryChain::Empty => {
|
||||
TryChainProj::Empty => {
|
||||
panic!("future must not be polled after it returned `Poll::Ready`");
|
||||
}
|
||||
};
|
||||
|
||||
*this = TryChain::Empty; // Drop fut1
|
||||
self.set(TryChain::Empty); // Drop fut1
|
||||
let f = f.take().unwrap();
|
||||
match f(output) {
|
||||
TryChainAction::Future(fut2) => *this = TryChain::Second(fut2),
|
||||
TryChainAction::Future(fut2) => self.set(TryChain::Second(fut2)),
|
||||
TryChainAction::Output(output) => return Poll::Ready(output),
|
||||
}
|
||||
}
|
||||
@@ -716,24 +721,21 @@ mod tests {
|
||||
prelude::*,
|
||||
task::*,
|
||||
};
|
||||
use std::time::Duration;
|
||||
use std::{pin::Pin, sync::atomic::AtomicU64, sync::Arc};
|
||||
|
||||
#[tokio::test(threaded_scheduler)]
|
||||
async fn dispatch_response_cancels_on_timeout() {
|
||||
let (_response_completion, response) = oneshot::channel();
|
||||
async fn dispatch_response_cancels_on_drop() {
|
||||
let (cancellation, mut canceled_requests) = cancellations();
|
||||
let resp = DispatchResponse::<u64> {
|
||||
// Timeout in the past should cause resp to error out when polled.
|
||||
response: tokio::time::timeout(Duration::from_secs(0), response),
|
||||
let (_, response) = oneshot::channel();
|
||||
drop(DispatchResponse::<u32> {
|
||||
response,
|
||||
cancellation,
|
||||
complete: false,
|
||||
request_id: 3,
|
||||
cancellation,
|
||||
ctx: context::current(),
|
||||
};
|
||||
let _ = futures::poll!(resp);
|
||||
});
|
||||
// resp's drop() is run, which should send a cancel message.
|
||||
assert!(canceled_requests.0.try_next().unwrap() == Some(3));
|
||||
assert_eq!(canceled_requests.0.try_next().unwrap(), Some(3));
|
||||
}
|
||||
|
||||
#[tokio::test(threaded_scheduler)]
|
||||
@@ -768,7 +770,6 @@ mod tests {
|
||||
Response {
|
||||
request_id: 0,
|
||||
message: Ok("hello".into()),
|
||||
_non_exhaustive: (),
|
||||
},
|
||||
)
|
||||
.await;
|
||||
@@ -822,7 +823,7 @@ mod tests {
|
||||
// i.e. still in `drop fn` -- will cause the request to not be added to the in-flight request
|
||||
// map.
|
||||
let mut resp = send_request(&mut channel, "hi").await;
|
||||
resp.response.get_mut().close();
|
||||
resp.response.close();
|
||||
|
||||
assert!(dispatch.poll_next_request(cx).is_pending());
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
//! client to server and is used by the server to enforce response deadlines.
|
||||
|
||||
use crate::trace::{self, TraceId};
|
||||
use static_assertions::assert_impl_all;
|
||||
use std::time::{Duration, SystemTime};
|
||||
|
||||
/// A request context that carries request-scoped information like deadlines and trace information.
|
||||
@@ -16,6 +17,7 @@ use std::time::{Duration, SystemTime};
|
||||
/// The context should not be stored directly in a server implementation, because the context will
|
||||
/// be different for each request in scope.
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
#[non_exhaustive]
|
||||
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
|
||||
pub struct Context {
|
||||
/// When the client expects the request to be complete by. The server should cancel the request
|
||||
@@ -35,11 +37,10 @@ pub struct Context {
|
||||
/// include the same `trace_id` as that included on the original request. This way,
|
||||
/// users can trace related actions across a distributed system.
|
||||
pub trace_context: trace::Context,
|
||||
#[doc(hidden)]
|
||||
#[cfg_attr(feature = "serde1", serde(skip_serializing, default))]
|
||||
pub(crate) _non_exhaustive: (),
|
||||
}
|
||||
|
||||
assert_impl_all!(Context: Send, Sync);
|
||||
|
||||
#[cfg(feature = "serde1")]
|
||||
fn ten_seconds_from_now() -> SystemTime {
|
||||
SystemTime::now() + Duration::from_secs(10)
|
||||
@@ -51,7 +52,6 @@ pub fn current() -> Context {
|
||||
Context {
|
||||
deadline: SystemTime::now() + Duration::from_secs(10),
|
||||
trace_context: trace::Context::new_root(),
|
||||
_non_exhaustive: (),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -144,8 +144,9 @@ where
|
||||
ThrottlerStream::new(self, n)
|
||||
}
|
||||
|
||||
/// Responds to all requests with `server`.
|
||||
/// Responds to all requests with [`server::serve`](Serve).
|
||||
#[cfg(feature = "tokio1")]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))]
|
||||
fn respond_with<S>(self, server: S) -> Running<Self, S>
|
||||
where
|
||||
S: Serve<C::Req, Resp = C::Resp>,
|
||||
@@ -197,11 +198,16 @@ where
|
||||
Self::new(Config::default(), transport)
|
||||
}
|
||||
|
||||
/// Returns the inner transport.
|
||||
/// Returns the inner transport over which messages are sent and received.
|
||||
pub fn get_ref(&self) -> &T {
|
||||
self.transport.get_ref()
|
||||
}
|
||||
|
||||
/// Returns the inner transport over which messages are sent and received.
|
||||
pub fn get_pin_ref(self: Pin<&mut Self>) -> Pin<&mut T> {
|
||||
self.project().transport.get_pin_mut()
|
||||
}
|
||||
|
||||
fn cancel_request(mut self: Pin<&mut Self>, trace_context: &trace::Context, request_id: u64) {
|
||||
// It's possible the request was already completed, so it's fine
|
||||
// if this is None.
|
||||
@@ -304,7 +310,6 @@ where
|
||||
} => {
|
||||
self.as_mut().cancel_request(&trace_context, request_id);
|
||||
}
|
||||
ClientMessage::_NonExhaustive => unreachable!(),
|
||||
},
|
||||
None => return Poll::Ready(None),
|
||||
}
|
||||
@@ -401,6 +406,11 @@ where
|
||||
C: Channel,
|
||||
S: Serve<C::Req, Resp = C::Resp>,
|
||||
{
|
||||
/// Returns the inner channel over which messages are sent and received.
|
||||
pub fn get_pin_channel(self: Pin<&mut Self>) -> Pin<&mut C> {
|
||||
self.project().channel
|
||||
}
|
||||
|
||||
fn pump_read(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
@@ -569,11 +579,9 @@ where
|
||||
"Response did not complete before deadline of {}s.",
|
||||
format_rfc3339(self.deadline)
|
||||
)),
|
||||
_non_exhaustive: (),
|
||||
})
|
||||
}
|
||||
},
|
||||
_non_exhaustive: (),
|
||||
});
|
||||
*self.as_mut().project().state = RespState::PollReady;
|
||||
}
|
||||
@@ -647,27 +655,26 @@ where
|
||||
S: Serve<C::Req, Resp = C::Resp> + Send + 'static,
|
||||
S::Fut: Send + 'static,
|
||||
{
|
||||
/// Runs the client handler until completion by spawning each
|
||||
/// Runs the client handler until completion by [spawning](tokio::spawn) each
|
||||
/// request handler onto the default executor.
|
||||
#[cfg(feature = "tokio1")]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))]
|
||||
pub fn execute(self) -> impl Future<Output = ()> {
|
||||
use log::info;
|
||||
|
||||
self.try_for_each(|request_handler| {
|
||||
async {
|
||||
tokio::spawn(request_handler);
|
||||
Ok(())
|
||||
}
|
||||
self.try_for_each(|request_handler| async {
|
||||
tokio::spawn(request_handler);
|
||||
Ok(())
|
||||
})
|
||||
.unwrap_or_else(|e| info!("ClientHandler errored out: {}", e))
|
||||
.map_ok(|()| log::info!("ClientHandler finished."))
|
||||
.unwrap_or_else(|e| log::info!("ClientHandler errored out: {}", e))
|
||||
}
|
||||
}
|
||||
|
||||
/// A future that drives the server by spawning channels and request handlers on the default
|
||||
/// A future that drives the server by [spawning](tokio::spawn) channels and request handlers on the default
|
||||
/// executor.
|
||||
#[pin_project]
|
||||
#[derive(Debug)]
|
||||
#[cfg(feature = "tokio1")]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))]
|
||||
pub struct Running<St, Se> {
|
||||
#[pin]
|
||||
incoming: St,
|
||||
@@ -687,8 +694,6 @@ where
|
||||
type Output = ();
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
|
||||
use log::info;
|
||||
|
||||
while let Some(channel) = ready!(self.as_mut().project().incoming.poll_next(cx)) {
|
||||
tokio::spawn(
|
||||
channel
|
||||
@@ -696,7 +701,7 @@ where
|
||||
.execute(),
|
||||
);
|
||||
}
|
||||
info!("Server shutting down.");
|
||||
log::info!("Server shutting down.");
|
||||
Poll::Ready(())
|
||||
}
|
||||
}
|
||||
@@ -12,7 +12,6 @@ use fnv::FnvHashMap;
|
||||
use futures::{channel::mpsc, future::AbortRegistration, prelude::*, ready, stream::Fuse, task::*};
|
||||
use log::{debug, info, trace};
|
||||
use pin_project::pin_project;
|
||||
use raii_counter::{Counter, WeakCounter};
|
||||
use std::sync::{Arc, Weak};
|
||||
use std::{
|
||||
collections::hash_map::Entry, convert::TryInto, fmt, hash::Hash, marker::Unpin, pin::Pin,
|
||||
@@ -32,7 +31,7 @@ where
|
||||
dropped_keys: mpsc::UnboundedReceiver<K>,
|
||||
#[pin]
|
||||
dropped_keys_tx: mpsc::UnboundedSender<K>,
|
||||
key_counts: FnvHashMap<K, TrackerPrototype<K>>,
|
||||
key_counts: FnvHashMap<K, Weak<Tracker<K>>>,
|
||||
keymaker: F,
|
||||
}
|
||||
|
||||
@@ -42,37 +41,22 @@ where
|
||||
pub struct TrackedChannel<C, K> {
|
||||
#[pin]
|
||||
inner: C,
|
||||
tracker: Tracker<K>,
|
||||
tracker: Arc<Tracker<K>>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
#[derive(Debug)]
|
||||
struct Tracker<K> {
|
||||
key: Option<Arc<K>>,
|
||||
counter: Counter,
|
||||
key: Option<K>,
|
||||
dropped_keys: mpsc::UnboundedSender<K>,
|
||||
}
|
||||
|
||||
impl<K> Drop for Tracker<K> {
|
||||
fn drop(&mut self) {
|
||||
if self.counter.count() <= 1 {
|
||||
// Don't care if the listener is dropped.
|
||||
match Arc::try_unwrap(self.key.take().unwrap()) {
|
||||
Ok(key) => {
|
||||
let _ = self.dropped_keys.unbounded_send(key);
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
// Don't care if the listener is dropped.
|
||||
let _ = self.dropped_keys.unbounded_send(self.key.take().unwrap());
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct TrackerPrototype<K> {
|
||||
key: Weak<K>,
|
||||
counter: WeakCounter,
|
||||
dropped_keys: mpsc::UnboundedSender<K>,
|
||||
}
|
||||
|
||||
impl<C, K> Stream for TrackedChannel<C, K>
|
||||
where
|
||||
C: Stream,
|
||||
@@ -181,7 +165,7 @@ where
|
||||
trace!(
|
||||
"[{}] Opening channel ({}/{}) channels for key.",
|
||||
key,
|
||||
tracker.counter.count(),
|
||||
Arc::strong_count(&tracker),
|
||||
self.as_mut().project().channels_per_key
|
||||
);
|
||||
|
||||
@@ -191,28 +175,22 @@ where
|
||||
})
|
||||
}
|
||||
|
||||
fn increment_channels_for_key(mut self: Pin<&mut Self>, key: K) -> Result<Tracker<K>, K> {
|
||||
fn increment_channels_for_key(mut self: Pin<&mut Self>, key: K) -> Result<Arc<Tracker<K>>, K> {
|
||||
let channels_per_key = self.channels_per_key;
|
||||
let dropped_keys = self.dropped_keys_tx.clone();
|
||||
let key_counts = &mut self.as_mut().project().key_counts;
|
||||
match key_counts.entry(key.clone()) {
|
||||
Entry::Vacant(vacant) => {
|
||||
let key = Arc::new(key);
|
||||
let counter = WeakCounter::new();
|
||||
|
||||
vacant.insert(TrackerPrototype {
|
||||
key: Arc::downgrade(&key),
|
||||
counter: counter.clone(),
|
||||
dropped_keys: dropped_keys.clone(),
|
||||
});
|
||||
Ok(Tracker {
|
||||
let tracker = Arc::new(Tracker {
|
||||
key: Some(key),
|
||||
counter: counter.upgrade(),
|
||||
dropped_keys,
|
||||
})
|
||||
});
|
||||
|
||||
vacant.insert(Arc::downgrade(&tracker));
|
||||
Ok(tracker)
|
||||
}
|
||||
Entry::Occupied(o) => {
|
||||
let count = o.get().counter.count();
|
||||
Entry::Occupied(mut o) => {
|
||||
let count = o.get().strong_count();
|
||||
if count >= channels_per_key.try_into().unwrap() {
|
||||
info!(
|
||||
"[{}] Opened max channels from key ({}/{}).",
|
||||
@@ -220,16 +198,15 @@ where
|
||||
);
|
||||
Err(key)
|
||||
} else {
|
||||
let TrackerPrototype {
|
||||
key,
|
||||
counter,
|
||||
dropped_keys,
|
||||
} = o.get().clone();
|
||||
Ok(Tracker {
|
||||
counter: counter.upgrade(),
|
||||
key: Some(key.upgrade().unwrap()),
|
||||
dropped_keys,
|
||||
})
|
||||
Ok(o.get().upgrade().unwrap_or_else(|| {
|
||||
let tracker = Arc::new(Tracker {
|
||||
key: Some(key),
|
||||
dropped_keys,
|
||||
});
|
||||
|
||||
*o.get_mut() = Arc::downgrade(&tracker);
|
||||
tracker
|
||||
}))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -302,12 +279,10 @@ fn ctx() -> Context<'static> {
|
||||
#[test]
|
||||
fn tracker_drop() {
|
||||
use assert_matches::assert_matches;
|
||||
use raii_counter::Counter;
|
||||
|
||||
let (tx, mut rx) = mpsc::unbounded();
|
||||
Tracker {
|
||||
key: Some(Arc::new(1)),
|
||||
counter: Counter::new(),
|
||||
key: Some(1),
|
||||
dropped_keys: tx,
|
||||
};
|
||||
assert_matches!(rx.try_next(), Ok(Some(1)));
|
||||
@@ -317,17 +292,15 @@ fn tracker_drop() {
|
||||
fn tracked_channel_stream() {
|
||||
use assert_matches::assert_matches;
|
||||
use pin_utils::pin_mut;
|
||||
use raii_counter::Counter;
|
||||
|
||||
let (chan_tx, chan) = mpsc::unbounded();
|
||||
let (dropped_keys, _) = mpsc::unbounded();
|
||||
let channel = TrackedChannel {
|
||||
inner: chan,
|
||||
tracker: Tracker {
|
||||
key: Some(Arc::new(1)),
|
||||
counter: Counter::new(),
|
||||
tracker: Arc::new(Tracker {
|
||||
key: Some(1),
|
||||
dropped_keys,
|
||||
},
|
||||
}),
|
||||
};
|
||||
|
||||
chan_tx.unbounded_send("test").unwrap();
|
||||
@@ -339,17 +312,15 @@ fn tracked_channel_stream() {
|
||||
fn tracked_channel_sink() {
|
||||
use assert_matches::assert_matches;
|
||||
use pin_utils::pin_mut;
|
||||
use raii_counter::Counter;
|
||||
|
||||
let (chan, mut chan_rx) = mpsc::unbounded();
|
||||
let (dropped_keys, _) = mpsc::unbounded();
|
||||
let channel = TrackedChannel {
|
||||
inner: chan,
|
||||
tracker: Tracker {
|
||||
key: Some(Arc::new(1)),
|
||||
counter: Counter::new(),
|
||||
tracker: Arc::new(Tracker {
|
||||
key: Some(1),
|
||||
dropped_keys,
|
||||
},
|
||||
}),
|
||||
};
|
||||
|
||||
pin_mut!(channel);
|
||||
@@ -371,12 +342,12 @@ fn channel_filter_increment_channels_for_key() {
|
||||
let filter = ChannelFilter::new(listener, 2, |chan: &TestChannel| chan.key);
|
||||
pin_mut!(filter);
|
||||
let tracker1 = filter.as_mut().increment_channels_for_key("key").unwrap();
|
||||
assert_eq!(tracker1.counter.count(), 1);
|
||||
assert_eq!(Arc::strong_count(&tracker1), 1);
|
||||
let tracker2 = filter.as_mut().increment_channels_for_key("key").unwrap();
|
||||
assert_eq!(tracker1.counter.count(), 2);
|
||||
assert_eq!(Arc::strong_count(&tracker1), 2);
|
||||
assert_matches!(filter.increment_channels_for_key("key"), Err("key"));
|
||||
drop(tracker2);
|
||||
assert_eq!(tracker1.counter.count(), 1);
|
||||
assert_eq!(Arc::strong_count(&tracker1), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -395,20 +366,20 @@ fn channel_filter_handle_new_channel() {
|
||||
.as_mut()
|
||||
.handle_new_channel(TestChannel { key: "key" })
|
||||
.unwrap();
|
||||
assert_eq!(channel1.tracker.counter.count(), 1);
|
||||
assert_eq!(Arc::strong_count(&channel1.tracker), 1);
|
||||
|
||||
let channel2 = filter
|
||||
.as_mut()
|
||||
.handle_new_channel(TestChannel { key: "key" })
|
||||
.unwrap();
|
||||
assert_eq!(channel1.tracker.counter.count(), 2);
|
||||
assert_eq!(Arc::strong_count(&channel1.tracker), 2);
|
||||
|
||||
assert_matches!(
|
||||
filter.handle_new_channel(TestChannel { key: "key" }),
|
||||
Err("key")
|
||||
);
|
||||
drop(channel2);
|
||||
assert_eq!(channel1.tracker.counter.count(), 1);
|
||||
assert_eq!(Arc::strong_count(&channel1.tracker), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -429,14 +400,14 @@ fn channel_filter_poll_listener() {
|
||||
.unwrap();
|
||||
let channel1 =
|
||||
assert_matches!(filter.as_mut().poll_listener(&mut ctx()), Poll::Ready(Some(Ok(c))) => c);
|
||||
assert_eq!(channel1.tracker.counter.count(), 1);
|
||||
assert_eq!(Arc::strong_count(&channel1.tracker), 1);
|
||||
|
||||
new_channels
|
||||
.unbounded_send(TestChannel { key: "key" })
|
||||
.unwrap();
|
||||
let _channel2 =
|
||||
assert_matches!(filter.as_mut().poll_listener(&mut ctx()), Poll::Ready(Some(Ok(c))) => c);
|
||||
assert_eq!(channel1.tracker.counter.count(), 2);
|
||||
assert_eq!(Arc::strong_count(&channel1.tracker), 2);
|
||||
|
||||
new_channels
|
||||
.unbounded_send(TestChannel { key: "key" })
|
||||
@@ -444,7 +415,7 @@ fn channel_filter_poll_listener() {
|
||||
let key =
|
||||
assert_matches!(filter.as_mut().poll_listener(&mut ctx()), Poll::Ready(Some(Err(k))) => k);
|
||||
assert_eq!(key, "key");
|
||||
assert_eq!(channel1.tracker.counter.count(), 2);
|
||||
assert_eq!(Arc::strong_count(&channel1.tracker), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@@ -1,3 +1,9 @@
|
||||
// Copyright 2020 Google LLC
|
||||
//
|
||||
// Use of this source code is governed by an MIT-style
|
||||
// license that can be found in the LICENSE file or at
|
||||
// https://opensource.org/licenses/MIT.
|
||||
|
||||
use crate::server::{Channel, Config};
|
||||
use crate::{context, Request, Response};
|
||||
use fnv::FnvHashSet;
|
||||
@@ -87,11 +93,9 @@ impl<Req, Resp> FakeChannel<io::Result<Request<Req>>, Response<Resp>> {
|
||||
context: context::Context {
|
||||
deadline: SystemTime::UNIX_EPOCH,
|
||||
trace_context: Default::default(),
|
||||
_non_exhaustive: (),
|
||||
},
|
||||
id,
|
||||
message,
|
||||
_non_exhaustive: (),
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,3 +1,9 @@
|
||||
// Copyright 2020 Google LLC
|
||||
//
|
||||
// Use of this source code is governed by an MIT-style
|
||||
// license that can be found in the LICENSE file or at
|
||||
// https://opensource.org/licenses/MIT.
|
||||
|
||||
use super::{Channel, Config};
|
||||
use crate::{Response, ServerError};
|
||||
use futures::{future::AbortRegistration, prelude::*, ready, task::*};
|
||||
@@ -61,9 +67,7 @@ where
|
||||
message: Err(ServerError {
|
||||
kind: io::ErrorKind::WouldBlock,
|
||||
detail: Some("Server throttled the request.".into()),
|
||||
_non_exhaustive: (),
|
||||
}),
|
||||
_non_exhaustive: (),
|
||||
})?;
|
||||
}
|
||||
None => return Poll::Ready(None),
|
||||
@@ -311,7 +315,6 @@ fn throttler_start_send() {
|
||||
.start_send(Response {
|
||||
request_id: 0,
|
||||
message: Ok(1),
|
||||
_non_exhaustive: (),
|
||||
})
|
||||
.unwrap();
|
||||
assert!(throttler.inner.in_flight_requests.is_empty());
|
||||
@@ -320,7 +323,6 @@ fn throttler_start_send() {
|
||||
Some(&Response {
|
||||
request_id: 0,
|
||||
message: Ok(1),
|
||||
_non_exhaustive: ()
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
@@ -4,9 +4,9 @@
|
||||
// license that can be found in the LICENSE file or at
|
||||
// https://opensource.org/licenses/MIT.
|
||||
|
||||
//! Provides a [`Transport`] trait as well as implementations.
|
||||
//! Provides a [`Transport`](sealed::Transport) trait as well as implementations.
|
||||
//!
|
||||
//! The rpc crate is transport- and protocol-agnostic. Any transport that impls [`Transport`]
|
||||
//! The rpc crate is transport- and protocol-agnostic. Any transport that impls [`Transport`](sealed::Transport)
|
||||
//! can be plugged in, using whatever protocol it wants.
|
||||
|
||||
use futures::prelude::*;
|
||||
@@ -11,6 +11,7 @@ use std::{
|
||||
};
|
||||
|
||||
#[cfg(feature = "serde")]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "serde")))]
|
||||
pub mod serde;
|
||||
|
||||
/// Extension trait for [SystemTimes](SystemTime) in the future, i.e. deadlines.
|
||||
@@ -15,9 +15,10 @@ pub fn serialize_epoch_secs<S>(system_time: &SystemTime, serializer: S) -> Resul
|
||||
where
|
||||
S: Serializer,
|
||||
{
|
||||
const ZERO_SECS: Duration = Duration::from_secs(0);
|
||||
system_time
|
||||
.duration_since(SystemTime::UNIX_EPOCH)
|
||||
.unwrap_or(Duration::from_secs(0))
|
||||
.unwrap_or(ZERO_SECS)
|
||||
.as_secs() // Only care about second precision
|
||||
.serialize(serializer)
|
||||
}
|
||||
|
||||
@@ -14,15 +14,25 @@ use serde::{Deserialize, Serialize};
|
||||
use std::{error::Error, io, pin::Pin};
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio_serde::{Framed as SerdeFramed, *};
|
||||
use tokio_util::codec::{length_delimited::LengthDelimitedCodec, Framed};
|
||||
use tokio_util::codec::{
|
||||
length_delimited::{self, LengthDelimitedCodec},
|
||||
Framed,
|
||||
};
|
||||
|
||||
/// A transport that serializes to, and deserializes from, a [`TcpStream`].
|
||||
/// A transport that serializes to, and deserializes from, a byte stream.
|
||||
#[pin_project]
|
||||
pub struct Transport<S, Item, SinkItem, Codec> {
|
||||
#[pin]
|
||||
inner: SerdeFramed<Framed<S, LengthDelimitedCodec>, Item, SinkItem, Codec>,
|
||||
}
|
||||
|
||||
impl<S, Item, SinkItem, Codec> Transport<S, Item, SinkItem, Codec> {
|
||||
/// Returns the inner transport over which messages are sent and received.
|
||||
pub fn get_ref(&self) -> &S {
|
||||
self.inner.get_ref().get_ref()
|
||||
}
|
||||
}
|
||||
|
||||
impl<S, Item, SinkItem, Codec, CodecError> Stream for Transport<S, Item, SinkItem, Codec>
|
||||
where
|
||||
S: AsyncWrite + AsyncRead,
|
||||
@@ -83,6 +93,22 @@ fn convert<E: Into<Box<dyn Error + Send + Sync>>>(
|
||||
poll.map(|ready| ready.map_err(|e| io::Error::new(io::ErrorKind::Other, e)))
|
||||
}
|
||||
|
||||
/// Constructs a new transport from a framed transport and a serialization codec.
|
||||
pub fn new<S, Item, SinkItem, Codec>(
|
||||
framed_io: Framed<S, LengthDelimitedCodec>,
|
||||
codec: Codec,
|
||||
) -> Transport<S, Item, SinkItem, Codec>
|
||||
where
|
||||
S: AsyncWrite + AsyncRead,
|
||||
Item: for<'de> Deserialize<'de>,
|
||||
SinkItem: Serialize,
|
||||
Codec: Serializer<SinkItem> + Deserializer<Item>,
|
||||
{
|
||||
Transport {
|
||||
inner: SerdeFramed::new(framed_io, codec),
|
||||
}
|
||||
}
|
||||
|
||||
impl<S, Item, SinkItem, Codec> From<(S, Codec)> for Transport<S, Item, SinkItem, Codec>
|
||||
where
|
||||
S: AsyncWrite + AsyncRead,
|
||||
@@ -90,10 +116,8 @@ where
|
||||
SinkItem: Serialize,
|
||||
Codec: Serializer<SinkItem> + Deserializer<Item>,
|
||||
{
|
||||
fn from((inner, codec): (S, Codec)) -> Self {
|
||||
Transport {
|
||||
inner: SerdeFramed::new(Framed::new(inner, LengthDelimitedCodec::new()), codec),
|
||||
}
|
||||
fn from((io, codec): (S, Codec)) -> Self {
|
||||
new(Framed::new(io, LengthDelimitedCodec::new()), codec)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -127,34 +151,65 @@ pub mod tcp {
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a new JSON transport that reads from and writes to `io`.
|
||||
pub fn new<Item, SinkItem, Codec>(
|
||||
io: TcpStream,
|
||||
codec: Codec,
|
||||
) -> Transport<TcpStream, Item, SinkItem, Codec>
|
||||
/// A connection Future that also exposes the length-delimited framing config.
|
||||
#[pin_project]
|
||||
pub struct Connect<T, Item, SinkItem, CodecFn> {
|
||||
#[pin]
|
||||
inner: T,
|
||||
codec_fn: CodecFn,
|
||||
config: length_delimited::Builder,
|
||||
ghost: PhantomData<(fn(SinkItem), fn() -> Item)>,
|
||||
}
|
||||
|
||||
impl<T, Item, SinkItem, Codec, CodecFn> Future for Connect<T, Item, SinkItem, CodecFn>
|
||||
where
|
||||
T: Future<Output = io::Result<TcpStream>>,
|
||||
Item: for<'de> Deserialize<'de>,
|
||||
SinkItem: Serialize,
|
||||
Codec: Serializer<SinkItem> + Deserializer<Item>,
|
||||
CodecFn: Fn() -> Codec,
|
||||
{
|
||||
Transport::from((io, codec))
|
||||
type Output = io::Result<Transport<TcpStream, Item, SinkItem, Codec>>;
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
|
||||
let io = ready!(self.as_mut().project().inner.poll(cx))?;
|
||||
Poll::Ready(Ok(new(self.config.new_framed(io), (self.codec_fn)())))
|
||||
}
|
||||
}
|
||||
|
||||
/// Connects to `addr`, wrapping the connection in a JSON transport.
|
||||
pub async fn connect<A, Item, SinkItem, Codec>(
|
||||
impl<T, Item, SinkItem, CodecFn> Connect<T, Item, SinkItem, CodecFn> {
|
||||
/// Returns an immutable reference to the length-delimited codec's config.
|
||||
pub fn config(&self) -> &length_delimited::Builder {
|
||||
&self.config
|
||||
}
|
||||
|
||||
/// Returns a mutable reference to the length-delimited codec's config.
|
||||
pub fn config_mut(&mut self) -> &mut length_delimited::Builder {
|
||||
&mut self.config
|
||||
}
|
||||
}
|
||||
|
||||
/// Connects to `addr`, wrapping the connection in a TCP transport.
|
||||
pub fn connect<A, Item, SinkItem, Codec, CodecFn>(
|
||||
addr: A,
|
||||
codec: Codec,
|
||||
) -> io::Result<Transport<TcpStream, Item, SinkItem, Codec>>
|
||||
codec_fn: CodecFn,
|
||||
) -> Connect<impl Future<Output = io::Result<TcpStream>>, Item, SinkItem, CodecFn>
|
||||
where
|
||||
A: ToSocketAddrs,
|
||||
Item: for<'de> Deserialize<'de>,
|
||||
SinkItem: Serialize,
|
||||
Codec: Serializer<SinkItem> + Deserializer<Item>,
|
||||
CodecFn: Fn() -> Codec,
|
||||
{
|
||||
Ok(new(TcpStream::connect(addr).await?, codec))
|
||||
Connect {
|
||||
inner: TcpStream::connect(addr),
|
||||
codec_fn,
|
||||
config: LengthDelimitedCodec::builder(),
|
||||
ghost: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
/// Listens on `addr`, wrapping accepted connections in JSON transports.
|
||||
/// Listens on `addr`, wrapping accepted connections in TCP transports.
|
||||
pub async fn listen<A, Item, SinkItem, Codec, CodecFn>(
|
||||
addr: A,
|
||||
codec_fn: CodecFn,
|
||||
@@ -171,18 +226,20 @@ pub mod tcp {
|
||||
listener,
|
||||
codec_fn,
|
||||
local_addr,
|
||||
config: LengthDelimitedCodec::builder(),
|
||||
ghost: PhantomData,
|
||||
})
|
||||
}
|
||||
|
||||
/// A [`TcpListener`] that wraps connections in JSON transports.
|
||||
/// A [`TcpListener`] that wraps connections in [transports](Transport).
|
||||
#[pin_project]
|
||||
#[derive(Debug)]
|
||||
pub struct Incoming<Item, SinkItem, Codec, CodecFn> {
|
||||
listener: TcpListener,
|
||||
local_addr: SocketAddr,
|
||||
codec_fn: CodecFn,
|
||||
ghost: PhantomData<(Item, SinkItem, Codec)>,
|
||||
config: length_delimited::Builder,
|
||||
ghost: PhantomData<(fn() -> Item, fn(SinkItem), Codec)>,
|
||||
}
|
||||
|
||||
impl<Item, SinkItem, Codec, CodecFn> Incoming<Item, SinkItem, Codec, CodecFn> {
|
||||
@@ -190,6 +247,16 @@ pub mod tcp {
|
||||
pub fn local_addr(&self) -> SocketAddr {
|
||||
self.local_addr
|
||||
}
|
||||
|
||||
/// Returns an immutable reference to the length-delimited codec's config.
|
||||
pub fn config(&self) -> &length_delimited::Builder {
|
||||
&self.config
|
||||
}
|
||||
|
||||
/// Returns a mutable reference to the length-delimited codec's config.
|
||||
pub fn config_mut(&mut self) -> &mut length_delimited::Builder {
|
||||
&mut self.config
|
||||
}
|
||||
}
|
||||
|
||||
impl<Item, SinkItem, Codec, CodecFn> Stream for Incoming<Item, SinkItem, Codec, CodecFn>
|
||||
@@ -204,7 +271,7 @@ pub mod tcp {
|
||||
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
let next =
|
||||
ready!(Pin::new(&mut self.as_mut().project().listener.incoming()).poll_next(cx)?);
|
||||
Poll::Ready(next.map(|conn| Ok(new(conn, (self.codec_fn)()))))
|
||||
Poll::Ready(next.map(|conn| Ok(new(self.config.new_framed(conn), (self.codec_fn)()))))
|
||||
}
|
||||
}
|
||||
}
|
||||
5
tarpc/tests/compile_fail.rs
Normal file
5
tarpc/tests/compile_fail.rs
Normal file
@@ -0,0 +1,5 @@
|
||||
#[test]
|
||||
fn ui() {
|
||||
let t = trybuild::TestCases::new();
|
||||
t.compile_fail("tests/compile_fail/*.rs");
|
||||
}
|
||||
15
tarpc/tests/compile_fail/tarpc_server_missing_async.rs
Normal file
15
tarpc/tests/compile_fail/tarpc_server_missing_async.rs
Normal file
@@ -0,0 +1,15 @@
|
||||
#[tarpc::service]
|
||||
trait World {
|
||||
async fn hello(name: String) -> String;
|
||||
}
|
||||
|
||||
struct HelloServer;
|
||||
|
||||
#[tarpc::server]
|
||||
impl World for HelloServer {
|
||||
fn hello(name: String) -> String {
|
||||
format!("Hello, {}!", name)
|
||||
}
|
||||
}
|
||||
|
||||
fn main() {}
|
||||
19
tarpc/tests/compile_fail/tarpc_server_missing_async.stderr
Normal file
19
tarpc/tests/compile_fail/tarpc_server_missing_async.stderr
Normal file
@@ -0,0 +1,19 @@
|
||||
error: not all trait items implemented, missing: `HelloFut`
|
||||
--> $DIR/tarpc_server_missing_async.rs:9:1
|
||||
|
|
||||
9 | impl World for HelloServer {
|
||||
| ^^^^
|
||||
|
||||
error: hint: `#[tarpc::server]` only rewrites async fns, and `fn hello` is not async
|
||||
--> $DIR/tarpc_server_missing_async.rs:10:5
|
||||
|
|
||||
10 | fn hello(name: String) -> String {
|
||||
| ^^
|
||||
|
||||
error[E0433]: failed to resolve: use of undeclared type or module `serde`
|
||||
--> $DIR/tarpc_server_missing_async.rs:1:1
|
||||
|
|
||||
1 | #[tarpc::service]
|
||||
| ^^^^^^^^^^^^^^^^^ use of undeclared type or module `serde`
|
||||
|
|
||||
= note: this error originates in an attribute macro (in Nightly builds, run with -Z macro-backtrace for more info)
|
||||
6
tarpc/tests/compile_fail/tarpc_service_arg_pat.rs
Normal file
6
tarpc/tests/compile_fail/tarpc_service_arg_pat.rs
Normal file
@@ -0,0 +1,6 @@
|
||||
#[tarpc::service]
|
||||
trait World {
|
||||
async fn pat((a, b): (u8, u32));
|
||||
}
|
||||
|
||||
fn main() {}
|
||||
5
tarpc/tests/compile_fail/tarpc_service_arg_pat.stderr
Normal file
5
tarpc/tests/compile_fail/tarpc_service_arg_pat.stderr
Normal file
@@ -0,0 +1,5 @@
|
||||
error: patterns aren't allowed in RPC args
|
||||
--> $DIR/tarpc_service_arg_pat.rs:3:18
|
||||
|
|
||||
3 | async fn pat((a, b): (u8, u32));
|
||||
| ^^^^^^
|
||||
6
tarpc/tests/compile_fail/tarpc_service_fn_new.rs
Normal file
6
tarpc/tests/compile_fail/tarpc_service_fn_new.rs
Normal file
@@ -0,0 +1,6 @@
|
||||
#[tarpc::service]
|
||||
trait World {
|
||||
async fn new();
|
||||
}
|
||||
|
||||
fn main() {}
|
||||
5
tarpc/tests/compile_fail/tarpc_service_fn_new.stderr
Normal file
5
tarpc/tests/compile_fail/tarpc_service_fn_new.stderr
Normal file
@@ -0,0 +1,5 @@
|
||||
error: method name conflicts with generated fn `WorldClient::new`
|
||||
--> $DIR/tarpc_service_fn_new.rs:3:14
|
||||
|
|
||||
3 | async fn new();
|
||||
| ^^^
|
||||
6
tarpc/tests/compile_fail/tarpc_service_fn_serve.rs
Normal file
6
tarpc/tests/compile_fail/tarpc_service_fn_serve.rs
Normal file
@@ -0,0 +1,6 @@
|
||||
#[tarpc::service]
|
||||
trait World {
|
||||
async fn serve();
|
||||
}
|
||||
|
||||
fn main() {}
|
||||
5
tarpc/tests/compile_fail/tarpc_service_fn_serve.stderr
Normal file
5
tarpc/tests/compile_fail/tarpc_service_fn_serve.stderr
Normal file
@@ -0,0 +1,5 @@
|
||||
error: method name conflicts with generated fn `World::serve`
|
||||
--> $DIR/tarpc_service_fn_serve.rs:3:14
|
||||
|
|
||||
3 | async fn serve();
|
||||
| ^^^^^
|
||||
@@ -1,6 +1,6 @@
|
||||
use assert_matches::assert_matches;
|
||||
use futures::{
|
||||
future::{ready, Ready},
|
||||
future::{join_all, ready, Ready},
|
||||
prelude::*,
|
||||
};
|
||||
use std::io;
|
||||
@@ -10,6 +10,7 @@ use tarpc::{
|
||||
server::{self, BaseChannel, Channel, Handler},
|
||||
transport::channel,
|
||||
};
|
||||
use tokio::join;
|
||||
use tokio_serde::formats::Json;
|
||||
|
||||
#[tarpc_plugins::service]
|
||||
@@ -70,7 +71,7 @@ async fn serde() -> io::Result<()> {
|
||||
.respond_with(Server.serve()),
|
||||
);
|
||||
|
||||
let transport = serde_transport::tcp::connect(addr, Json::default()).await?;
|
||||
let transport = serde_transport::tcp::connect(addr, Json::default).await?;
|
||||
let mut client = ServiceClient::new(client::Config::default(), transport).spawn()?;
|
||||
|
||||
assert_matches!(client.add(context::current(), 1, 2).await, Ok(3));
|
||||
@@ -110,3 +111,59 @@ async fn concurrent() -> io::Result<()> {
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(threaded_scheduler)]
|
||||
async fn concurrent_join() -> io::Result<()> {
|
||||
let _ = env_logger::try_init();
|
||||
|
||||
let (tx, rx) = channel::unbounded();
|
||||
tokio::spawn(
|
||||
tarpc::Server::default()
|
||||
.incoming(stream::once(ready(rx)))
|
||||
.respond_with(Server.serve()),
|
||||
);
|
||||
|
||||
let client = ServiceClient::new(client::Config::default(), tx).spawn()?;
|
||||
|
||||
let mut c = client.clone();
|
||||
let req1 = c.add(context::current(), 1, 2);
|
||||
|
||||
let mut c = client.clone();
|
||||
let req2 = c.add(context::current(), 3, 4);
|
||||
|
||||
let mut c = client.clone();
|
||||
let req3 = c.hey(context::current(), "Tim".to_string());
|
||||
|
||||
let (resp1, resp2, resp3) = join!(req1, req2, req3);
|
||||
assert_matches!(resp1, Ok(3));
|
||||
assert_matches!(resp2, Ok(7));
|
||||
assert_matches!(resp3, Ok(ref s) if s == "Hey, Tim.");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(threaded_scheduler)]
|
||||
async fn concurrent_join_all() -> io::Result<()> {
|
||||
let _ = env_logger::try_init();
|
||||
|
||||
let (tx, rx) = channel::unbounded();
|
||||
tokio::spawn(
|
||||
tarpc::Server::default()
|
||||
.incoming(stream::once(ready(rx)))
|
||||
.respond_with(Server.serve()),
|
||||
);
|
||||
|
||||
let client = ServiceClient::new(client::Config::default(), tx).spawn()?;
|
||||
|
||||
let mut c1 = client.clone();
|
||||
let mut c2 = client.clone();
|
||||
|
||||
let req1 = c1.add(context::current(), 1, 2);
|
||||
let req2 = c2.add(context::current(), 3, 4);
|
||||
|
||||
let responses = join_all(vec![req1, req2]).await;
|
||||
assert_matches!(responses[0], Ok(3));
|
||||
assert_matches!(responses[1], Ok(7));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user