140 Commits

Author SHA1 Message Date
Tim Kuehn
7e521768ab Prepare for v0.21.0 release. 2020-06-26 20:05:02 -07:00
Tim Kuehn
e9b1e7d101 Use #[non_exhaustive] in lieu of _NonExhaustive enum variant. 2020-06-26 19:47:20 -07:00
Taiki Endo
f0322fb892 Remove uses of pin_project::project attribute
pin-project will deprecate the project attribute due to some unfixable
limitations.

Refs: https://github.com/taiki-e/pin-project/issues/225
2020-06-05 20:34:44 -07:00
Patrick Elsen
617daebb88 Add tarpc::server proc-macro as syntactic sugar for async methods. (#302)
The tarpc::server proc-macro can be used to annotate implementations of
services to turn async functions into the proper declarations needed
for tarpc to be able to call them.

This uses the assert_type_eq crate to check that the transformations
applied by the tarpc::server proc macro are correct and lead to code
that compiles.
2020-05-16 10:25:25 -07:00
Tim Kuehn
a11d4fff58 Remove raii_counter 2020-04-22 02:13:02 -07:00
Tim
bf42a04d83 Move the request timeout so that it surrounds the entire call, not just the response future. (#295)
* Move the request timeout so that it surrounds the entire call, not just the response future.

This will enable the timeout earlier, so that a backlog in the outbound request buffer can not cause requests to stall indefinitely.

* Run cargo fmt
2020-02-25 14:42:40 -08:00
Tim Kuehn
06528d6953 Fix clippy lint. 2019-12-19 12:28:26 -08:00
Tim Kuehn
9f00395746 Replace _non_exhaustive fields with #[non_exhaustive] attribute.
The attribute landed on stable rust (1.40.0) today.

Fixes https://github.com/google/tarpc/issues/275
2019-12-19 12:14:34 -08:00
Tim Kuehn
e0674cd57f Make pre-push run on rust stable. 2019-12-19 12:06:06 -08:00
Tim Kuehn
7e49bd9ee7 Clean up badges a bit. 2019-12-16 13:21:00 -08:00
Tim Kuehn
8a1baa9c4e Remove usage of unsafe in rpc::client::channel.
pin_project is actually able to handle the complexities of enum Futures.
2019-12-16 11:10:57 -08:00
Oleg Nosov
31c713d188 Allow raw identifiers + fixed naming + place all code generation methods in impl (#291)
Allows defining services using raw identifiers like:

```rust
pub mod service {
    #[tarpc::service]
    pub trait r#trait {
        async fn r#fn(x: i32) -> Result<u8, String>;
    }
}
```

Also:

- Refactored names (ident -> type)
- All code generation methods placed in impl
2019-12-12 10:13:57 -08:00
Tim Kuehn
d905bc1591 Prepare for tarpc release v0.20.0 2019-12-11 20:47:56 -08:00
Tim Kuehn
7f946c7f83 Make tokio a hard dependency.
Fixes #289
2019-12-11 20:08:36 -08:00
Tim Kuehn
36cfdb6c6f Fix tokio dependency for example-service 2019-12-11 20:01:06 -08:00
Tim Kuehn
dbabe9774f Clean up proc macro code to make clippy happy.
I made a separate TokenStream-returning fn for each item in the previously-huge quote block.
The service fn now primarily performs the duty of creating idents and joining all the TokenStreams.
2019-12-11 17:20:03 -08:00
Tim Kuehn
deb041b8d3 Replace travis-ci badge with github CI workflow badge 2019-12-11 12:54:56 -08:00
Oleg Nosov
85d49477f5 Updated and simplified macros (#290)
* syn updated to latest version
* quote updated to latest version
* proc-macro-2 updated to latest version
* Performance improvements
* Don't create unnecessary TokenStreams for output types
2019-12-11 12:28:24 -08:00
Tim Kuehn
45af6ccdeb Workaround for pubsub example hanging.
The publisher client isn't being dropped when the async fn returns. It
could potentially be something strange in the ThreadPool executor.
2019-12-07 22:01:41 -08:00
Tim Kuehn
917c0c5e2d Use tokio::time::delay_for in lieu of thread::sleep. 2019-12-07 21:28:45 -08:00
Artem Vorotnikov
bbbd43e282 Unify serde transports.
This PR obsoletes the JSON and Bincode transports and instead introduces a unified transport that
is generic over any tokio-serde serialization format as well as AsyncRead + AsyncWrite medium.
This comes with a slight hit for usability (having to manually specify the underlying transport
and codec), but it can be alleviated by making custom freestanding connect and listen fns.
2019-12-07 20:58:08 -08:00
Artem Vorotnikov
f945392b5a Use tokio/stream feature for json-transport 2019-12-07 09:54:33 -08:00
Artem Vorotnikov
f4060779e4 Add GitHub workflow 2019-12-05 20:13:14 -08:00
Artem Vorotnikov
7cc8d9640b Fix clippy warnings 2019-12-05 17:39:53 -08:00
Artem Vorotnikov
7f871f03ef Improve Travis configuration (#282)
* Improve Travis configuration

* Replace 0.0.0.0 with localhost in tests
2019-11-28 14:06:35 -08:00
Artem Vorotnikov
709b966150 Update to Tokio 0.2 and futures 0.3 (#277) 2019-11-27 19:53:44 -08:00
Artem Vorotnikov
5e19b79aa4 Unite most of tarpc into a single crate 2019-11-26 13:08:18 -08:00
Tim Kuehn
6eb806907a Replace Gitter badge with Discord badge. 2019-11-22 14:28:24 -08:00
Tim Kuehn
8250ca31ff Remove --no-default-features from pre-push hook.
It seemingly doesn't work at the root of a virtual workspace. Not sure if this is new behavior or just a new explicit error message.
2019-11-15 17:19:08 -08:00
Tim Kuehn
7cd776143b Fix typo 2019-11-15 17:12:00 -08:00
Artem Vorotnikov
5f6c3d7d98 Port to pin-project 2019-10-09 14:12:24 -07:00
Artem Vorotnikov
915fe3ed4e Use the JSON transport in examples 2019-10-08 19:18:49 -07:00
Artem Vorotnikov
d8c7b9feb2 JSON transport: use Tokio resolver for connect() 2019-10-08 18:03:25 -07:00
Artem Vorotnikov
5ab3866d96 Add Unpin note 2019-10-08 17:15:17 -07:00
Artem Vorotnikov
184ea42033 Upgrade json-transport to Tokio 0.2 2019-10-08 17:15:17 -07:00
Artem Vorotnikov
014c209b8e Do not serialize _non_exhaustive field 2019-10-03 13:09:26 -07:00
Artem Vorotnikov
e91005855c Remove remaining feature flags 2019-10-02 13:07:37 -07:00
Artem Vorotnikov
46bcc0f559 tokio 0.2.0-alpha.4 2019-08-30 09:29:18 -07:00
Artem Vorotnikov
61322ebf41 Clippy fixes 2019-08-29 11:34:38 -07:00
Artem Vorotnikov
db0c9c4182 Cut type_alias_impl_trait feature flag 2019-08-29 11:34:38 -07:00
Artem Vorotnikov
9ee3011687 Update to Tokio 0.3.0-alpha.3 2019-08-29 11:34:38 -07:00
Artem Vorotnikov
5aa4a2cef6 tokio 0.2.0-alpha.2 2019-08-19 23:13:06 -07:00
Artem Vorotnikov
f38a172523 Format code with rustfmt 2019-08-19 13:20:21 -07:00
Tim Kuehn
66dbca80b2 Add missing feature, "compat", back to json-transport dependency on futures-preview. 2019-08-14 09:16:44 -07:00
Tim
61377dd4ff Fix comment in example service
It referred to bincode instead of json.
2019-08-14 08:32:49 -07:00
Tim
cd03f3ff8c Don't mention 'static optional in readme
This isn't supported by the service attribute.
2019-08-13 08:49:11 -07:00
Tim Kuehn
9479963773 Don't enable serde1 by default. I forgot it gives bad compile errors to people who don't have serde in their Cargo.toml. 2019-08-09 01:21:31 -07:00
Tim Kuehn
f974533bf7 Use real crate names rather than internal aliases. It's less confusing for people reading examples. 2019-08-09 01:16:06 -07:00
Tim Kuehn
d560ac6197 Update to the latest rustc nightly. 2019-08-09 01:08:20 -07:00
Tim Kuehn
1cdff15412 Fix needless verbosity in readme 2019-08-09 00:50:06 -07:00
Tim Kuehn
f8ba7d9f4e Make tokio1 serde1 default features 2019-08-08 22:06:09 -07:00
Tim Kuehn
41c1aafaf7 Update tokio to v0.2.0-alpha.1
As part of this, I made an optional tokio feature which, when enabled,
adds utility functions that spawn on the default tokio executor. This
allows for the removal of the runtime crate.

On the one hand, this makes the spawning utils slightly less generic. On
the other hand:

- The fns are just helpers and are easily rewritten by the user.
- Tokio is the clear dominant futures executor, so most people will just
  use these versions.
2019-08-08 21:53:36 -07:00
Tim Kuehn
75d1e877be Update README to talk about deadlines a bit more precisely. 2019-08-08 20:45:37 -07:00
Tim Kuehn
88e1cf558b Generate README.md from cargo readme 2019-08-08 20:31:04 -07:00
Tim Kuehn
50879d2acb Don't bake in Send + 'static.
Send + 'static was baked in to make it possible to spawn futures onto
the default executor. We can accomplish the same thing by offering
helper fns that do the spawning while not requiring it for the rest of
the functionality.

Fixes https://github.com/google/tarpc/issues/212
2019-08-07 13:39:48 -07:00
Tim
13cb14a119 Merge pull request #248 from tikue/service-idents
With this change, the service definitions don't need to be isolated in their own modules.

Given:

```rust
#[tarpc::service]
trait World { ... }
```

Before this would generate the following items
------
- `trait World`
- `fn serve`
- `struct Client`
- `fn new_stub`

`// Implementation details below`
- `enum Request`
- `enum Response`
- `enum ResponseFut`

And now these items
------
- `trait World {    ...    fn serve }`
- `struct WorldClient ... impl WorldClient {    ...    async fn new }`

`// Implementation details below`
- `enum WorldRequest`
- `enum WorldResponse`
- `enum WorldResponseFut`
- `struct ServeWorld` (new manual closure impl because you can't use impl Trait in trait fns)
```
2019-08-05 12:23:35 -07:00
Tim Kuehn
22ef6b7800 Choose a slightly less obvious name for Serve impl.
To hopefully avoid most collisions.
2019-07-30 21:46:16 -07:00
Tim Kuehn
e48e6dfe67 Add nice error message for ident collisions 2019-07-30 21:31:22 -07:00
Tim Kuehn
1b58914d59 Move generated functions under their corresponding items.
- fn serve -> Service::serve
- fn new_stub -> Client::new

This allows the generated function names to remain consistent across
service definitions while preventing collisions.
2019-07-30 20:45:58 -07:00
Tim Kuehn
2f24842b2d Add service name to generated items.
With this change, the service definitions don't need to be isolated in their own modules.
2019-07-30 00:52:30 -07:00
Tim Kuehn
5c485fe608 Add some tests for snake to camel case conversion. 2019-07-30 00:52:30 -07:00
Tim Kuehn
b0319e7db9 Remove macros.rs 2019-07-30 00:51:29 -07:00
Tim Kuehn
a4d9581888 Remove service_registry example 2019-07-29 23:17:08 -07:00
Tim Kuehn
fb5022b1c0 cargo fmt 2019-07-29 22:08:53 -07:00
Tim Kuehn
abb0b5b3ac Rewrite to use proc_macro_attribute 2019-07-29 22:04:04 -07:00
Artem Vorotnikov
49f2641e3c Port to runtime crate 2019-07-29 08:36:06 -07:00
Tim
650c60fe44 Merge pull request #246 from google/rustfmt
Reformat all code using rustfmt
2019-07-22 17:53:48 -07:00
Artem Vorotnikov
1d0bbcb36c Reformat all code using rustfmt 2019-07-23 03:44:16 +03:00
Tim Kuehn
c456ad7fa5 Fix typo 2019-07-22 14:15:27 -07:00
Tim Kuehn
537446a5c9 Remove use of unstable feature 'arbitrary_self_types'.
Turns out, this actually wasn't needed, with some minor refactoring.
2019-07-19 00:48:59 -07:00
Tim Kuehn
94b5b2c431 Add tests for rpc/server/filter.rs 2019-07-16 21:48:11 -07:00
Tim Kuehn
9863433fea Remove unstable feature 'async_closure' 2019-07-16 11:17:18 -07:00
Tim Kuehn
9a27465a25 Remove use of unstable feature 'try_trait' 2019-07-16 11:08:53 -07:00
Tim Kuehn
263cfe1435 Remove unused unstable feature 'integer_atomics' 2019-07-16 10:27:59 -07:00
Tim
6ae5302a70 Merge pull request #240 from tikue/filter-refactor 2019-07-15 23:04:20 -07:00
Tim Kuehn
c67b7283e7 Move bench outside crate. 2019-07-15 22:43:58 -07:00
Tim Kuehn
7b6e98da7b Replace transport integration tests with unit tests.
I want 'cargo test' to run faster.
2019-07-15 22:40:58 -07:00
Tim Kuehn
15b65fa20f Replace usage of Once and unsafe code with once_cell crate. 2019-07-15 20:05:10 -07:00
Tim Kuehn
372900173a Merge origin/master => tikue/filter-refactor 2019-07-15 19:04:56 -07:00
Tim Kuehn
1089415451 Make server methods more composable.
-- Connection Limits

The problem with having ConnectionFilter default-enabled is elaborated on in https://github.com/google/tarpc/issues/217. The gist of it is not all servers want a policy based on `SocketAddr`. This PR allows customizing the behavior of ConnectionFilter, at the cost of not having it enabled by default. However, enabling it is as simple as one line:

incoming.max_channels_per_key(10, ip_addr)

The second argument is a key function that takes the user-chosen transport and returns some hashable, equatable, cloneable key. In the above example, it returns an `IpAddr`.

This also allows the `Transport` trait to have the addr fns removed, which means it has become simply an alias for `Stream + Sink`.

-- Per-Channel Request Throttling

With respect to Channel's throttling behavior, the same argument applies. There isn't a one size fits all solution to throttling requests, and the policy applied by tarpc is just one of potentially many solutions. As such, `Channel` is now a trait that offers a few combinators, one of which is throttling:

channel.max_concurrent_requests(10).respond_with(serve(Server))

This functionality is also available on the existing `Handler` trait, which applies it to all incoming channels and can be used in tandem with connection limits:

incoming
    .max_channels_per_key(10, ip_addr)
    .max_concurrent_requests_per_channel(10).respond_with(serve(Server))

-- Global Request Throttling

I've entirely removed the overall request limit enforced across all channels. This functionality is easily gotten back via [`StreamExt::buffer_unordered`](https://rust-lang-nursery.github.io/futures-api-docs/0.3.0-alpha.1/futures/stream/trait.StreamExt.html#method.buffer_unordered), with the difference being that the previous behavior allowed you to spawn channels onto different threads, whereas `buffer_unordered ` means the `Channels` are handled on a single thread (the per-request handlers are still spawned). Considering the existing options, I don't believe that the benefit provided by this functionality held its own.
2019-07-15 19:01:46 -07:00
Tim Kuehn
8dbeeff0eb Fix some lint warnings 2019-07-15 18:21:11 -07:00
iovxw
85312d430c Update to futures-preview 0.3.0-alpha.17 (#238)
* Update to futures-preview 0.3.0-alpha.17

* Update feature gate

async_closure was moved out from async_await
2019-07-15 18:20:19 -07:00
Adam Wright
9843af9e00 Reflow some text in the readme (#239) 2019-07-15 17:53:56 -07:00
Tim Kuehn
a6bd423ef0 Remove use of external crate 'libtest'. 2019-07-15 17:52:27 -07:00
Kevin Ji
146496d08c README: Use the SVG Travis badge (#236) 2019-06-08 10:31:08 -07:00
Tim Kuehn
b562051c38 Bump tarpc-lib to 0.6.1 to fix request cancellation issue. 2019-05-22 01:33:00 -07:00
Tim Kuehn
fe164ca368 Fix bug where expired request wasn't propagating cancellation.
DispatchResponse was incorrectly marking itself as complete even when
expiring without receiving a response. This can cause a chain of
deleterious effects:

- Request cancellation won't propagate when request timers expire.
- Which causes client dispatch to have an inconsistent in-flight request
  map containing stale IDs.
- Which can cause clients to hang rather than exiting.
2019-05-22 01:29:01 -07:00
Artem Vorotnikov
950ad5187c Add JSON transport (#219) 2019-05-20 18:45:41 -07:00
Tim Kuehn
e6ab69c314 Keep commented-out code in each block so that rustdoc is happy. 2019-05-15 16:31:11 -07:00
Tim Kuehn
373dcbed57 Clarify dependencies required for README example
Fixes https://github.com/google/tarpc/issues/232
2019-05-15 15:40:25 -07:00
Tim Kuehn
ce9c057b1b Remove await!() macro from readme 2019-05-13 10:16:25 -07:00
Tim Kuehn
6745cee72c Bump tarpc to v0.18.0 2019-05-11 13:00:35 -07:00
Artem Vorotnikov
31abea18b3 Update to futures-preview 0.3.0-alpha.16 (#230) 2019-05-11 15:18:52 -04:00
Tim Kuehn
593ac135ce Remove stable features from doc examples 2019-04-30 13:18:39 -07:00
Tim Kuehn
05a924d27f Bump tarpc version to 0.17.0 2019-04-30 13:01:45 -07:00
Artem Vorotnikov
af9d71ed0d Bump futures to 0.3.0-alpha.15 (#226) 2019-04-28 20:13:06 -07:00
Tim Kuehn
9b90f6ae51 Bump to v0.16.0 2019-04-16 10:46:53 -07:00
Tim
bbfc8ac352 Merge pull request #216 from vorot93/futures-master
* Use upstream sink compat shims
* Port to new Sink trait introduced in e101c891f04aba34ee29c6a8cd8321563c7e0161
* rustfmt
* Port to std::task::Context
* Add Google license header to bincode-transport/src/compat.rs
* Remove compat for it is no longer needed
* future::join as freestanding function
* Simplify dependencies
* Depend on futures-preview 0.3.0-alpha.14
* Fix infinite recursion
2019-04-16 08:43:10 -07:00
Tim
ad86a967ba Fix infinite recursion 2019-04-16 18:27:42 +03:00
Artem Vorotnikov
58a0eced19 Depend on futures-preview 0.3.0-alpha.14 2019-04-15 21:16:20 +03:00
Artem Vorotnikov
46fffd13e7 Simplify dependencies 2019-04-15 21:14:25 +03:00
Artem Vorotnikov
6c8d4be462 future::join as freestanding function 2019-04-15 20:30:04 +03:00
Artem Vorotnikov
e3a517bf0d Remove compat and transmute for they are no longer needed 2019-04-15 20:24:09 +03:00
Artem Vorotnikov
f4e22bdc2e Port to std::task::Context 2019-04-15 20:22:15 +03:00
Artem Vorotnikov
46f56fbdc0 Add Google license header to bincode-transport/src/compat.rs 2019-04-15 20:22:15 +03:00
Artem Vorotnikov
8665655592 Fix test client breakage by 9100ea46f997f24d4bc8c1764d0fe3ff8226ad2a 2019-04-15 20:22:15 +03:00
Artem Vorotnikov
4569d26d81 rustfmt 2019-04-15 20:22:15 +03:00
Artem Vorotnikov
b8b92ddb5f Workaround for stack overflow caused by 2a95710db0e2d85094938776ebb4f270bc389c41 2019-04-15 20:16:48 +03:00
Artem Vorotnikov
8dd3390876 Port to new Sink trait introduced in e101c891f04aba34ee29c6a8cd8321563c7e0161 2019-04-15 20:16:48 +03:00
Artem Vorotnikov
06c420b60c Use upstream sink compat shims 2019-04-15 20:16:48 +03:00
Artem Vorotnikov
a7fb4d22cc Switch to master branch of futures-preview 2019-04-15 20:16:48 +03:00
Tim
b1cd5f34e5 Don't panic in pump_write when a client is dropped and there are more calls to poll. (#221)
This can happen in cases where a response is being read and the client isn't around.

Fixes #220
2019-04-15 09:42:53 -07:00
Artem Vorotnikov
088e5f8f2c Remove deprecated feature from bincode dependency (#218) 2019-04-04 10:34:11 -07:00
Tim Kuehn
4e0be5b626 Publish tarpc v0.15.0 2019-03-26 21:13:41 -07:00
Artem Vorotnikov
5516034bbc Use libtest crate (#213) 2019-03-24 22:29:01 -07:00
Artem Vorotnikov
06544faa5a Update to futures 0.3.0-alpha.13 (#211) 2019-02-26 09:32:41 -08:00
Tim Kuehn
39737b720a Cargo fmt 2019-01-17 10:37:16 -08:00
Tim Kuehn
0f36985440 Update for latest changes to futures.
Fixes #209.
2019-01-17 10:37:03 -08:00
Tyler Bindon
959bb691cd Update regex to match diffs output by cargo fmt. (#208)
It appears the header of the diffs output by cargo fmt have changed. It now says "Diff in /blah/blah/blah.rs at line 99:" Matching on lines starting with + or - should be more future-proof against changes to the surroundings.
2018-12-09 01:59:35 -08:00
Tim
2a3162c5fa Cargo feature 'rename-dependency' is stabilized 2018-11-21 11:03:41 -08:00
Tim Kuehn
0cc976b729 cargo fmt 2018-11-06 17:01:27 -08:00
Tim Kuehn
4d2d3f24c6 Address Clippy lints 2018-11-06 17:00:15 -08:00
Tim Kuehn
2c7c64841f Add symlink tarpc/README.md -> README.md 2018-10-29 16:11:01 -07:00
Tim Kuehn
4ea142d0f3 Remove coverage badge.
It hasn't been updated in over 2 years.
2018-10-29 11:40:09 -07:00
Tim Kuehn
00751d2518 external_doc doesn't work with crates.io yet :( 2018-10-29 11:05:09 -07:00
Tim Kuehn
4394a52b65 Add doc tests to .travis.yml 2018-10-29 10:55:12 -07:00
Tim Kuehn
70938501d7 Use eternal_doc for tarpc package. This will ensure our README is always up-to-date. 2018-10-29 10:53:34 -07:00
Tim Kuehn
d5f5cf4300 Bump versions. 2018-10-29 10:43:41 -07:00
Tim Kuehn
e2c4164d8c Remove unused feature enablements from tarpc 2018-10-25 11:44:38 -07:00
Tim Kuehn
78124ef7a8 Cargo fmt 2018-10-25 11:44:18 -07:00
Tim Kuehn
096d354b7e Remove unused features 2018-10-25 11:41:08 -07:00
Tim
7ad0e4b070 Service registry (#204)
# Changes

## Client is now a trait
And `Channel<Req, Resp>` implements `Client<Req, Resp>`. Previously, `Client<Req, Resp>` was a thin wrapper around `Channel<Req, Resp>`.

This was changed to allow for mapping the request and response types. For example, you can take a `channel: Channel<Req, Resp>` and do:

```rust
channel
    .with_request(|req: Req2| -> Req { ... })
    .map_response(|resp: Resp| -> Resp2 { ... })
```

...which returns a type that implements `Client<Req2, Resp2>`.

### Why would you want to map request and response types?

The main benefit of this is that it enables creating different client types backed by the same channel. For example, you could run multiple clients multiplexing requests over a single `TcpStream`. I have a demo in `tarpc/examples/service_registry.rs` showing how you might do this with a bincode transport. I am considering factoring out the service registry portion of that to an actual library, because it's doing pretty cool stuff. For this PR, though, it'll just be part of the example.

## Client::new is now client::new

This is pretty minor, but necessary because async fns can't currently exist on traits. I changed `Server::new` to match this as well.

## Macro-generated Clients are generic over the backing Client.

This is a natural consequence of the above change. However, it is transparent to the user by keeping `Channel<Req, Resp>` as the default type for the `<C: Client>` type parameter. `new_stub` returns `Client<Channel<Req, Resp>>`, and other clients can be created via the `From` trait.

## example-service/ now has two binaries, one for client and one for server.

This serves as a "realistic" example of how one might set up a service. The other examples all run the client and server in the same binary, which isn't realistic in distributed systems use cases.

## `service!` trait fns take self by value.

Services are already cloned per request, so this just passes on that flexibility to the trait implementers.

# Open Questions

In the service registry example, multiple services are running on a single port, and thus multiple clients are sending requests over a single `TcpStream`. This has implications for throttling: [`max_in_flight_requests_per_connection`](https://github.com/google/tarpc/blob/master/rpc/src/server/mod.rs#L57-L60) will set a maximum for the sum of requests for all clients sharing a single connection. I think this is reasonable behavior, but users may expect this setting to act like `max_in_flight_requests_per_client`.

Fixes #103 #153 #205
2018-10-25 11:22:55 -07:00
Tim
64755d5329 Update futures 2018-10-19 11:19:25 -07:00
Tim Kuehn
3071422132 Helper fn to create transports 2018-10-18 00:24:26 -07:00
Tim Kuehn
8847330dbe impl From<S> for bincode::Transport<S> 2018-10-18 00:24:08 -07:00
Tim Kuehn
6d396520f4 Don't allow empty service invocations 2018-10-18 00:23:34 -07:00
Tim Kuehn
79a2f7fe2f Replace tokio-serde-bincode with async-bincode 2018-10-17 20:24:31 -07:00
Tim Kuehn
af66841f68 Remove keyword 2018-10-17 11:59:09 -07:00
Tim
1ab4cfdff9 Make Request and Resonse enums' docs public, because they show up in the serve fn. 2018-10-16 23:02:52 -07:00
Tim
f7e03eeeb7 Fix up readme 2018-10-16 22:28:57 -07:00
60 changed files with 5048 additions and 4113 deletions

66
.github/workflows/main.yml vendored Normal file
View File

@@ -0,0 +1,66 @@
on: [push, pull_request]
name: Continuous integration
jobs:
check:
name: Check
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v1
- uses: actions-rs/toolchain@v1
with:
profile: minimal
toolchain: stable
override: true
- uses: actions-rs/cargo@v1
with:
command: check
args: --all-features
test:
name: Test Suite
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v1
- uses: actions-rs/toolchain@v1
with:
profile: minimal
toolchain: stable
override: true
- uses: actions-rs/cargo@v1
with:
command: test
args: --all-features
fmt:
name: Rustfmt
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v1
- uses: actions-rs/toolchain@v1
with:
profile: minimal
toolchain: stable
override: true
- run: rustup component add rustfmt
- uses: actions-rs/cargo@v1
with:
command: fmt
args: --all -- --check
clippy:
name: Clippy
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v1
- uses: actions-rs/toolchain@v1
with:
profile: minimal
toolchain: stable
override: true
- run: rustup component add clippy
- uses: actions-rs/cargo@v1
with:
command: clippy
args: --all-features -- -D warnings

View File

@@ -1,12 +0,0 @@
language: rust
rust:
- nightly
sudo: false
cache: cargo
os:
- osx
- linux
script:
- cargo test --all --all-features

View File

@@ -2,9 +2,6 @@
members = [
"example-service",
"rpc",
"trace",
"bincode-transport",
"tarpc",
"plugins",
]

164
README.md
View File

@@ -1,9 +1,18 @@
## tarpc: Tim & Adam's RPC lib
[![Travis-CI Status](https://travis-ci.org/google/tarpc.png?branch=master)](https://travis-ci.org/google/tarpc)
[![Coverage Status](https://coveralls.io/repos/github/google/tarpc/badge.svg?branch=master)](https://coveralls.io/github/google/tarpc?branch=master)
[![Software License](https://img.shields.io/badge/license-MIT-brightgreen.svg)](LICENSE)
[![Latest Version](https://img.shields.io/crates/v/tarpc.svg)](https://crates.io/crates/tarpc)
[![Join the chat at https://gitter.im/tarpc/Lobby](https://badges.gitter.im/tarpc/Lobby.svg)](https://gitter.im/tarpc/Lobby?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge)
[![Crates.io][crates-badge]][crates-url]
[![MIT licensed][mit-badge]][mit-url]
[![Build status][gh-actions-badge]][gh-actions-url]
[![Discord chat][discord-badge]][discord-url]
[crates-badge]: https://img.shields.io/crates/v/tarpc.svg
[crates-url]: https://crates.io/crates/tarpc
[mit-badge]: https://img.shields.io/badge/license-MIT-blue.svg
[mit-url]: LICENSE
[gh-actions-badge]: https://github.com/google/tarpc/workflows/Continuous%20integration/badge.svg
[gh-actions-url]: https://github.com/google/tarpc/actions?query=workflow%3A%22Continuous+integration%22
[discord-badge]: https://img.shields.io/discord/647529123996237854.svg?logo=discord&style=flat-square
[discord-url]: https://discord.gg/gXwpdSt
# tarpc
*Disclaimer*: This is not an official Google product.
@@ -13,7 +22,7 @@ writing a server is taken care of for you.
[Documentation](https://docs.rs/crate/tarpc/)
## What is an RPC framework?
### What is an RPC framework?
"RPC" stands for "Remote Procedure Call," a function call where the work of
producing the return value is being done somewhere else. When an rpc function is
invoked, behind the scenes the function contacts some other process somewhere
@@ -26,126 +35,129 @@ architectures. Two well-known ones are [gRPC](http://www.grpc.io) and
tarpc differentiates itself from other RPC frameworks by defining the schema in code,
rather than in a separate language such as .proto. This means there's no separate compilation
process, and no cognitive context switching between different languages. Additionally, it
works with the community-backed library serde: any serde-serializable type can be used as
arguments to tarpc fns.
process, and no context switching between different languages.
## Usage
**NB**: *this example is for master. Are you looking for other
[versions](https://docs.rs/tarpc)?*
Some other features of tarpc:
- Pluggable transport: any type impling `Stream<Item = Request> + Sink<Response>` can be
used as a transport to connect the client and server.
- `Send` optional: if the transport doesn't require it, neither does tarpc!
- Cascading cancellation: dropping a request will send a cancellation message to the server.
The server will cease any unfinished work on the request, subsequently cancelling any of its
own requests, repeating for the entire chain of transitive dependencies.
- Configurable deadlines and deadline propagation: request deadlines default to 10s if
unspecified. The server will automatically cease work when the deadline has passed. Any
requests sent by the server that use the request context will propagate the request deadline.
For example, if a server is handling a request with a 10s deadline, does 2s of work, then
sends a request to another server, that server will see an 8s deadline.
- Serde serialization: enabling the `serde1` Cargo feature will make service requests and
responses `Serialize + Deserialize`. It's entirely optional, though: in-memory transports can
be used, as well, so the price of serialization doesn't have to be paid when it's not needed.
### Usage
Add to your `Cargo.toml` dependencies:
```toml
tarpc = "0.12.0"
tarpc-plugins = "0.4.0"
tarpc = { version = "0.21.0", features = ["full"] }
```
The `tarpc::service` attribute expands to a collection of items that form an rpc service.
These generated types make it easy and ergonomic to write servers with less boilerplate.
Simply implement the generated service trait, and you're off to the races!
The `service!` macro expands to a collection of items that form an
rpc service. In the above example, the macro is called within the
`hello_service` module. This module will contain a `Client` stub and `Service` trait. There is
These generated types make it easy and ergonomic to write servers without dealing with serialization
directly. Simply implement one of the generated traits, and you're off to the
races!
### Example
## Example:
For this example, in addition to tarpc, also add two other dependencies to
your `Cargo.toml`:
Here's a small service.
```toml
futures = "0.3"
tokio = "0.2"
```
In the following example, we use an in-process channel for communication between
client and server. In real code, you will likely communicate over the network.
For a more real-world example, see [example-service](example-service).
First, let's set up the dependencies and service definition.
```rust
#![feature(plugin, futures_api, pin, arbitrary_self_types, await_macro, async_await)]
#![plugin(tarpc_plugins)]
use futures::{
compat::TokioDefaultSpawner,
future::{self, Ready},
prelude::*,
spawn,
};
use tarpc::rpc::{
use tarpc::{
client, context,
server::{self, Handler, Server},
server::{self, Handler},
};
use std::io;
// This is the service definition. It looks a lot like a trait definition.
// It defines one RPC, hello, which takes one arg, name, and returns a String.
tarpc::service! {
rpc hello(name: String) -> String;
#[tarpc::service]
trait World {
/// Returns a greeting for name.
async fn hello(name: String) -> String;
}
```
// This is the type that implements the generated Service trait. It is the business logic
This service definition generates a trait called `World`. Next we need to
implement it for our Server struct.
```rust
// This is the type that implements the generated World trait. It is the business logic
// and is used to start the server.
#[derive(Clone)]
struct HelloServer;
impl Service for HelloServer {
impl World for HelloServer {
// Each defined rpc generates two items in the trait, a fn that serves the RPC, and
// an associated type representing the future output by the fn.
type HelloFut = Ready<String>;
fn hello(&self, _: context::Context, name: String) -> Self::HelloFut {
fn hello(self, _: context::Context, name: String) -> Self::HelloFut {
future::ready(format!("Hello, {}!", name))
}
}
```
async fn run() -> io::Result<()> {
// bincode_transport is provided by the associated crate bincode-transport. It makes it easy
// to start up a serde-powered bincode serialization strategy over TCP.
let transport = bincode_transport::listen(&"0.0.0.0:0".parse().unwrap())?;
let addr = transport.local_addr();
Lastly let's write our `main` that will start the server. While this example uses an
[in-process
channel](https://docs.rs/tarpc/0.18.0/tarpc/transport/channel/struct.UnboundedChannel.html),
tarpc also ships bincode and JSON
tokio-net based TCP transports that are generic over all serializable types.
// The server is configured with the defaults.
let server = Server::new(server::Config::default())
// Server can listen on any type that implements the Transport trait.
.incoming(transport)
// Close the stream after the client connects
.take(1)
// serve is generated by the service! macro. It takes as input any type implementing
// the generated Service trait.
.respond_with(serve(HelloServer));
```rust
#[tokio::main]
async fn main() -> io::Result<()> {
let (client_transport, server_transport) = tarpc::transport::channel::unbounded();
spawn!(server).unwrap();
let server = server::new(server::Config::default())
// incoming() takes a stream of transports such as would be returned by
// TcpListener::incoming (but a stream instead of an iterator).
.incoming(stream::once(future::ready(server_transport)))
.respond_with(HelloServer.serve());
let transport = await!(bincode_transport::connect(&addr))?;
tokio::spawn(server);
// new_stub is generated by the service! macro. Like Server, it takes a config and any
// Transport as input, and returns a Client, also generated by the macro.
// by the service mcro.
let mut client = await!(new_stub(client::Config::default(), transport));
// WorldClient is generated by the macro. It has a constructor `new` that takes a config and
// any Transport as input
let mut client = WorldClient::new(client::Config::default(), client_transport).spawn()?;
// The client has an RPC method for each RPC defined in service!. It takes the same args
// as defined, with the addition of a Context, which is always the first arg. The Context
// The client has an RPC method for each RPC defined in the annotated trait. It takes the same
// args as defined, with the addition of a Context, which is always the first arg. The Context
// specifies a deadline and trace information which can be helpful in debugging requests.
let hello = await!(client.hello(context::current(), "Stim".to_string()))?;
let hello = client.hello(context::current(), "Stim".to_string()).await?;
println!("{}", hello);
Ok(())
}
fn main() {
tokio::run(run()
.map_err(|e| eprintln!("Oh no: {}", e))
.boxed()
.compat(TokioDefaultSpawner),
);
}
```
## Service Documentation
### Service Documentation
Use `cargo doc` as you normally would to see the documentation created for all
items expanded by a `service!` invocation.
## Contributing
To contribute to tarpc, please see [CONTRIBUTING](CONTRIBUTING.md).
## License
tarpc is distributed under the terms of the MIT license.
See [LICENSE](LICENSE) for details.
License: MIT

View File

@@ -1,3 +1,36 @@
## 0.21.0 (2020-06-26)
### New Features
A new proc macro, `#[tarpc::server]` was added! This enables service impls to elide the boilerplate
of specifying associated types for each RPC. With the ubiquity of async-await, most code won't have
nameable futures and will just be boxing the return type anyway. This macro does that for you.
### Breaking Changes
- Enums had _non_exhaustive fields replaced with the #[non_exhaustive] attribute.
### Bug Fixes
- https://github.com/google/tarpc/issues/304
A race condition in code that limits number of connections per client caused occasional panics.
- https://github.com/google/tarpc/pull/295
Made request timeouts account for time spent in the outbound buffer. Previously, a large outbound
queue would lead to requests not timing out correctly.
## 0.20.0 (2019-12-11)
### Breaking Changes
1. tarpc has updated its tokio dependency to the latest 0.2 version.
2. The tarpc crates have been unified into just `tarpc`, with new Cargo features to enable
functionality.
- The bincode-transport and json-transport crates are deprecated and superseded by
the `serde_transport` module, which unifies much of the logic present in both crates.
## 0.13.0 (2018-10-16)
### Breaking Changes

View File

@@ -1,43 +0,0 @@
cargo-features = ["rename-dependency"]
[package]
name = "tarpc-bincode-transport"
version = "0.1.0"
authors = ["Tim Kuehn <tikue@google.com>"]
edition = '2018'
license = "MIT"
documentation = "https://docs.rs/tarpc-bincode-transport"
homepage = "https://github.com/google/tarpc"
repository = "https://github.com/google/tarpc"
keywords = ["rpc", "network", "bincode", "serde", "tarpc"]
categories = ["asynchronous", "network-programming"]
readme = "../README.md"
description = "A bincode-based transport for tarpc services."
[dependencies]
bincode = { version = "1.0", features = ["i128"] }
bytes = "0.4"
futures_legacy = { version = "0.1", package = "futures" }
pin-utils = "0.1.0-alpha.2"
rpc = { package = "tarpc-lib", version = "0.1", path = "../rpc", features = ["serde1"] }
serde = "1.0"
tokio = "0.1"
tokio-io = "0.1"
tokio-serde-bincode = "0.1"
tokio-tcp = "0.1"
tokio-serde = "0.2"
[target.'cfg(not(test))'.dependencies]
futures-preview = { version = "0.3.0-alpha.8", features = ["compat"] }
[dev-dependencies]
futures-preview = { version = "0.3.0-alpha.8", features = ["compat", "tokio-compat"] }
env_logger = "0.5"
humantime = "1.0"
log = "0.4"
rand = "0.5"
tokio = "0.1"
tokio-executor = "0.1"
tokio-reactor = "0.1"
tokio-serde = "0.2"
tokio-timer = "0.2"

View File

@@ -1 +0,0 @@
edition = "Edition2018"

View File

@@ -1,286 +0,0 @@
// Copyright 2018 Google LLC
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
//! A TCP [`Transport`] that serializes as bincode.
#![feature(
futures_api,
pin,
arbitrary_self_types,
underscore_imports,
await_macro,
async_await,
)]
#![deny(missing_docs, missing_debug_implementations)]
mod vendored;
use bytes::{Bytes, BytesMut};
use crate::vendored::tokio_serde_bincode::{IoErrorWrapper, ReadBincode, WriteBincode};
use futures::{
Poll,
compat::{Compat01As03, Future01CompatExt, Stream01CompatExt},
prelude::*,
ready, task,
};
use futures_legacy::{
executor::{
self as executor01, Notify as Notify01, NotifyHandle as NotifyHandle01,
UnsafeNotify as UnsafeNotify01,
},
sink::SinkMapErr as SinkMapErr01,
sink::With as With01,
stream::MapErr as MapErr01,
Async as Async01, AsyncSink as AsyncSink01, Sink as Sink01, Stream as Stream01,
};
use pin_utils::unsafe_pinned;
use serde::{Deserialize, Serialize};
use std::{fmt, io, marker::PhantomData, net::SocketAddr, pin::Pin, task::LocalWaker};
use tokio::codec::{Framed, LengthDelimitedCodec, length_delimited};
use tokio_tcp::{self, TcpListener, TcpStream};
/// Returns a new bincode transport that reads from and writes to `io`.
pub fn new<Item, SinkItem>(io: TcpStream) -> Transport<Item, SinkItem>
where
Item: for<'de> Deserialize<'de>,
SinkItem: Serialize,
{
let peer_addr = io.peer_addr();
let local_addr = io.local_addr();
let inner = length_delimited::Builder::new()
.max_frame_length(8_000_000)
.new_framed(io)
.map_err(IoErrorWrapper as _)
.sink_map_err(IoErrorWrapper as _)
.with(freeze as _);
let inner = WriteBincode::new(inner);
let inner = ReadBincode::new(inner);
Transport {
inner,
staged_item: None,
peer_addr,
local_addr,
}
}
fn freeze(bytes: BytesMut) -> Result<Bytes, IoErrorWrapper> {
Ok(bytes.freeze())
}
/// Connects to `addr`, wrapping the connection in a bincode transport.
pub async fn connect<Item, SinkItem>(addr: &SocketAddr) -> io::Result<Transport<Item, SinkItem>>
where
Item: for<'de> Deserialize<'de>,
SinkItem: Serialize,
{
let stream = await!(TcpStream::connect(addr).compat())?;
Ok(new(stream))
}
/// Listens on `addr`, wrapping accepted connections in bincode transports.
pub fn listen<Item, SinkItem>(addr: &SocketAddr) -> io::Result<Incoming<Item, SinkItem>>
where
Item: for<'de> Deserialize<'de>,
SinkItem: Serialize,
{
let listener = TcpListener::bind(addr)?;
let local_addr = listener.local_addr()?;
let incoming = listener.incoming().compat();
Ok(Incoming {
incoming,
local_addr,
ghost: PhantomData,
})
}
/// A [`TcpListener`] that wraps connections in bincode transports.
#[derive(Debug)]
pub struct Incoming<Item, SinkItem> {
incoming: Compat01As03<tokio_tcp::Incoming>,
local_addr: SocketAddr,
ghost: PhantomData<(Item, SinkItem)>,
}
impl<Item, SinkItem> Incoming<Item, SinkItem> {
unsafe_pinned!(incoming: Compat01As03<tokio_tcp::Incoming>);
/// Returns the address being listened on.
pub fn local_addr(&self) -> SocketAddr {
self.local_addr
}
}
impl<Item, SinkItem> Stream for Incoming<Item, SinkItem>
where
Item: for<'a> Deserialize<'a>,
SinkItem: Serialize,
{
type Item = io::Result<Transport<Item, SinkItem>>;
fn poll_next(mut self: Pin<&mut Self>, waker: &LocalWaker) -> Poll<Option<Self::Item>> {
let next = ready!(self.incoming().poll_next(waker)?);
Poll::Ready(next.map(|conn| Ok(new(conn))))
}
}
/// A transport that serializes to, and deserializes from, a [`TcpStream`].
pub struct Transport<Item, SinkItem> {
inner: ReadBincode<
WriteBincode<
With01<
SinkMapErr01<
MapErr01<
Framed<tokio_tcp::TcpStream, LengthDelimitedCodec>,
fn(std::io::Error) -> IoErrorWrapper,
>,
fn(std::io::Error) -> IoErrorWrapper,
>,
BytesMut,
fn(BytesMut) -> Result<Bytes, IoErrorWrapper>,
Result<Bytes, IoErrorWrapper>
>,
SinkItem,
>,
Item,
>,
staged_item: Option<SinkItem>,
peer_addr: io::Result<SocketAddr>,
local_addr: io::Result<SocketAddr>,
}
impl<Item, SinkItem> fmt::Debug for Transport<Item, SinkItem> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "Transport")
}
}
impl<Item, SinkItem> Stream for Transport<Item, SinkItem>
where
Item: for<'a> Deserialize<'a>,
{
type Item = io::Result<Item>;
fn poll_next(self: Pin<&mut Self>, waker: &LocalWaker) -> Poll<Option<io::Result<Item>>> {
unsafe {
let inner = &mut Pin::get_mut_unchecked(self).inner;
let mut compat = inner.compat();
let compat = Pin::new_unchecked(&mut compat);
match ready!(compat.poll_next(waker)) {
None => Poll::Ready(None),
Some(Ok(next)) => Poll::Ready(Some(Ok(next))),
Some(Err(e)) => Poll::Ready(Some(Err(e.0))),
}
}
}
}
impl<Item, SinkItem> Sink for Transport<Item, SinkItem>
where
SinkItem: Serialize,
{
type SinkItem = SinkItem;
type SinkError = io::Error;
fn start_send(self: Pin<&mut Self>, item: SinkItem) -> io::Result<()> {
let me = unsafe { Pin::get_mut_unchecked(self) };
assert!(me.staged_item.is_none());
me.staged_item = Some(item);
Ok(())
}
fn poll_ready(self: Pin<&mut Self>, waker: &LocalWaker) -> Poll<io::Result<()>> {
let notify = &WakerToHandle(waker);
executor01::with_notify(notify, 0, move || {
let me = unsafe { Pin::get_mut_unchecked(self) };
match me.staged_item.take() {
Some(staged_item) => match me.inner.start_send(staged_item)? {
AsyncSink01::Ready => Poll::Ready(Ok(())),
AsyncSink01::NotReady(item) => {
me.staged_item = Some(item);
Poll::Pending
}
},
None => Poll::Ready(Ok(())),
}
})
}
fn poll_flush(self: Pin<&mut Self>, waker: &LocalWaker) -> Poll<io::Result<()>> {
let notify = &WakerToHandle(waker);
executor01::with_notify(notify, 0, move || {
let me = unsafe { Pin::get_mut_unchecked(self) };
match me.inner.poll_complete()? {
Async01::Ready(()) => Poll::Ready(Ok(())),
Async01::NotReady => Poll::Pending,
}
})
}
fn poll_close(self: Pin<&mut Self>, waker: &LocalWaker) -> Poll<io::Result<()>> {
let notify = &WakerToHandle(waker);
executor01::with_notify(notify, 0, move || {
let me = unsafe { Pin::get_mut_unchecked(self) };
match me.inner.get_mut().close()? {
Async01::Ready(()) => Poll::Ready(Ok(())),
Async01::NotReady => Poll::Pending,
}
})
}
}
impl<Item, SinkItem> rpc::Transport for Transport<Item, SinkItem>
where
Item: for<'de> Deserialize<'de>,
SinkItem: Serialize,
{
type Item = Item;
type SinkItem = SinkItem;
fn peer_addr(&self) -> io::Result<SocketAddr> {
// TODO: should just access from the inner transport.
// https://github.com/alexcrichton/tokio-serde-bincode/issues/4
Ok(*self.peer_addr.as_ref().unwrap())
}
fn local_addr(&self) -> io::Result<SocketAddr> {
Ok(*self.local_addr.as_ref().unwrap())
}
}
#[derive(Clone, Debug)]
struct WakerToHandle<'a>(&'a LocalWaker);
#[derive(Debug)]
struct NotifyWaker(task::Waker);
impl Notify01 for NotifyWaker {
fn notify(&self, _: usize) {
self.0.wake();
}
}
unsafe impl UnsafeNotify01 for NotifyWaker {
unsafe fn clone_raw(&self) -> NotifyHandle01 {
let ptr = Box::new(NotifyWaker(self.0.clone()));
NotifyHandle01::new(Box::into_raw(ptr))
}
unsafe fn drop_raw(&self) {
let ptr: *const dyn UnsafeNotify01 = self;
drop(Box::from_raw(ptr as *mut dyn UnsafeNotify01));
}
}
impl<'a> From<WakerToHandle<'a>> for NotifyHandle01 {
fn from(handle: WakerToHandle<'a>) -> NotifyHandle01 {
unsafe { NotifyWaker(handle.0.clone().into_waker()).clone_raw() }
}
}

View File

@@ -1,7 +0,0 @@
// Copyright 2018 Google LLC
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
pub(crate) mod tokio_serde_bincode;

View File

@@ -1,224 +0,0 @@
//! `Stream` and `Sink` adaptors for serializing and deserializing values using
//! Bincode.
//!
//! This crate provides adaptors for going from a stream or sink of buffers
//! ([`Bytes`]) to a stream or sink of values by performing Bincode encoding or
//! decoding. It is expected that each yielded buffer contains a single
//! serialized Bincode value. The specific strategy by which this is done is left
//! up to the user. One option is to use using [`length_delimited`] from
//! [tokio-io].
//!
//! [`Bytes`]: https://docs.rs/bytes/0.4/bytes/struct.Bytes.html
//! [`length_delimited`]: http://alexcrichton.com/tokio-io/tokio_io/codec/length_delimited/index.html
//! [tokio-io]: http://github.com/alexcrichton/tokio-io
//! [examples]: https://github.com/carllerche/tokio-serde-json/tree/master/examples
#![allow(missing_debug_implementations)]
use bincode::Error;
use bytes::{Bytes, BytesMut};
use futures_legacy::{Poll, Sink, StartSend, Stream};
use serde::{Deserialize, Serialize};
use std::io;
use tokio_serde::{Deserializer, FramedRead, FramedWrite, Serializer};
use std::marker::PhantomData;
/// Adapts a stream of Bincode encoded buffers to a stream of values by
/// deserializing them.
///
/// `ReadBincode` implements `Stream` by polling the inner buffer stream and
/// deserializing the buffer as Bincode. It expects that each yielded buffer
/// represents a single Bincode value and does not contain any extra trailing
/// bytes.
pub(crate) struct ReadBincode<T, U> {
inner: FramedRead<T, U, Bincode<U>>,
}
/// Adapts a buffer sink to a value sink by serializing the values as Bincode.
///
/// `WriteBincode` implements `Sink` by serializing the submitted values to a
/// buffer. The buffer is then sent to the inner stream, which is responsible
/// for handling framing on the wire.
pub(crate) struct WriteBincode<T: Sink, U> {
inner: FramedWrite<T, U, Bincode<U>>,
}
struct Bincode<T> {
ghost: PhantomData<T>,
}
impl<T, U> ReadBincode<T, U>
where
T: Stream<Error = IoErrorWrapper>,
U: for<'de> Deserialize<'de>,
Bytes: From<T::Item>,
{
/// Creates a new `ReadBincode` with the given buffer stream.
pub fn new(inner: T) -> ReadBincode<T, U> {
let json = Bincode { ghost: PhantomData };
ReadBincode {
inner: FramedRead::new(inner, json),
}
}
}
impl<T, U> ReadBincode<T, U> {
/// Returns a mutable reference to the underlying stream wrapped by
/// `ReadBincode`.
///
/// Note that care should be taken to not tamper with the underlying stream
/// of data coming in as it may corrupt the stream of frames otherwise
/// being worked with.
pub fn get_mut(&mut self) -> &mut T {
self.inner.get_mut()
}
}
impl<T, U> Stream for ReadBincode<T, U>
where
T: Stream<Error = IoErrorWrapper>,
U: for<'de> Deserialize<'de>,
Bytes: From<T::Item>,
{
type Item = U;
type Error = <T as Stream>::Error;
fn poll(&mut self) -> Poll<Option<U>, Self::Error> {
self.inner.poll()
}
}
impl<T, U> Sink for ReadBincode<T, U>
where
T: Sink,
{
type SinkItem = T::SinkItem;
type SinkError = T::SinkError;
fn start_send(&mut self, item: T::SinkItem) -> StartSend<T::SinkItem, T::SinkError> {
self.get_mut().start_send(item)
}
fn poll_complete(&mut self) -> Poll<(), T::SinkError> {
self.get_mut().poll_complete()
}
fn close(&mut self) -> Poll<(), T::SinkError> {
self.get_mut().close()
}
}
pub(crate) struct IoErrorWrapper(pub io::Error);
impl From<Box<bincode::ErrorKind>> for IoErrorWrapper {
fn from(e: Box<bincode::ErrorKind>) -> Self {
IoErrorWrapper(match *e {
bincode::ErrorKind::Io(e) => e,
bincode::ErrorKind::InvalidUtf8Encoding(e) => {
io::Error::new(io::ErrorKind::InvalidInput, e)
}
bincode::ErrorKind::InvalidBoolEncoding(e) => {
io::Error::new(io::ErrorKind::InvalidInput, e.to_string())
}
bincode::ErrorKind::InvalidTagEncoding(e) => {
io::Error::new(io::ErrorKind::InvalidInput, e.to_string())
}
bincode::ErrorKind::InvalidCharEncoding => {
io::Error::new(io::ErrorKind::InvalidInput, "Invalid char encoding")
}
bincode::ErrorKind::DeserializeAnyNotSupported => {
io::Error::new(io::ErrorKind::InvalidInput, "Deserialize Any not supported")
}
bincode::ErrorKind::SizeLimit => {
io::Error::new(io::ErrorKind::InvalidInput, "Size limit exceeded")
}
bincode::ErrorKind::SequenceMustHaveLength => {
io::Error::new(io::ErrorKind::InvalidInput, "Sequence must have length")
}
bincode::ErrorKind::Custom(s) => io::Error::new(io::ErrorKind::Other, s),
})
}
}
impl From<IoErrorWrapper> for io::Error {
fn from(wrapper: IoErrorWrapper) -> io::Error {
wrapper.0
}
}
impl<T, U> WriteBincode<T, U>
where
T: Sink<SinkItem = BytesMut, SinkError = IoErrorWrapper>,
U: Serialize,
{
/// Creates a new `WriteBincode` with the given buffer sink.
pub fn new(inner: T) -> WriteBincode<T, U> {
let json = Bincode { ghost: PhantomData };
WriteBincode {
inner: FramedWrite::new(inner, json),
}
}
}
impl<T: Sink, U> WriteBincode<T, U> {
/// Returns a mutable reference to the underlying sink wrapped by
/// `WriteBincode`.
///
/// Note that care should be taken to not tamper with the underlying sink as
/// it may corrupt the sequence of frames otherwise being worked with.
pub fn get_mut(&mut self) -> &mut T {
self.inner.get_mut()
}
}
impl<T, U> Sink for WriteBincode<T, U>
where
T: Sink<SinkItem = BytesMut, SinkError = IoErrorWrapper>,
U: Serialize,
{
type SinkItem = U;
type SinkError = <T as Sink>::SinkError;
fn start_send(&mut self, item: U) -> StartSend<U, Self::SinkError> {
self.inner.start_send(item)
}
fn poll_complete(&mut self) -> Poll<(), Self::SinkError> {
self.inner.poll_complete()
}
fn close(&mut self) -> Poll<(), Self::SinkError> {
self.inner.poll_complete()
}
}
impl<T, U> Stream for WriteBincode<T, U>
where
T: Stream + Sink,
{
type Item = T::Item;
type Error = T::Error;
fn poll(&mut self) -> Poll<Option<T::Item>, T::Error> {
self.get_mut().poll()
}
}
impl<T> Deserializer<T> for Bincode<T>
where
T: for<'de> Deserialize<'de>,
{
type Error = Error;
fn deserialize(&mut self, src: &Bytes) -> Result<T, Error> {
bincode::deserialize(src)
}
}
impl<T: Serialize> Serializer<T> for Bincode<T> {
type Error = Error;
fn serialize(&mut self, item: &T) -> Result<BytesMut, Self::Error> {
bincode::serialize(item).map(Into::into)
}
}

View File

@@ -1,116 +0,0 @@
// Copyright 2018 Google LLC
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
//! Tests client/server control flow.
#![feature(
test,
integer_atomics,
futures_api,
generators,
await_macro,
async_await
)]
extern crate test;
use self::test::stats::Stats;
use futures::{compat::TokioDefaultSpawner, prelude::*};
use rpc::{
client::{self, Client},
context,
server::{self, Handler, Server},
};
use std::{
io,
time::{Duration, Instant},
};
async fn bench() -> io::Result<()> {
let listener = tarpc_bincode_transport::listen(&"0.0.0.0:0".parse().unwrap())?;
let addr = listener.local_addr();
tokio_executor::spawn(
Server::<u32, u32>::new(server::Config::default())
.incoming(listener)
.take(1)
.respond_with(|_ctx, request| futures::future::ready(Ok(request)))
.unit_error()
.boxed()
.compat()
);
let conn = await!(tarpc_bincode_transport::connect(&addr))?;
let client = &mut await!(Client::<u32, u32>::new(client::Config::default(), conn))?;
let total = 10_000usize;
let mut successful = 0u32;
let mut unsuccessful = 0u32;
let mut durations = vec![];
for _ in 1..=total {
let now = Instant::now();
let response = await!(client.call(context::current(), 0u32));
let elapsed = now.elapsed();
match response {
Ok(_) => successful += 1,
Err(_) => unsuccessful += 1,
};
durations.push(elapsed);
}
let durations_nanos = durations
.iter()
.map(|duration| duration.as_secs() as f64 * 1E9 + duration.subsec_nanos() as f64)
.collect::<Vec<_>>();
let (lower, median, upper) = durations_nanos.quartiles();
println!("Of {} runs:", durations_nanos.len());
println!("\tSuccessful: {}", successful);
println!("\tUnsuccessful: {}", unsuccessful);
println!(
"\tMean: {:?}",
Duration::from_nanos(durations_nanos.mean() as u64)
);
println!("\tMedian: {:?}", Duration::from_nanos(median as u64));
println!(
"\tStd Dev: {:?}",
Duration::from_nanos(durations_nanos.std_dev() as u64)
);
println!(
"\tMin: {:?}",
Duration::from_nanos(durations_nanos.min() as u64)
);
println!(
"\tMax: {:?}",
Duration::from_nanos(durations_nanos.max() as u64)
);
println!(
"\tQuartiles: ({:?}, {:?}, {:?})",
Duration::from_nanos(lower as u64),
Duration::from_nanos(median as u64),
Duration::from_nanos(upper as u64)
);
Ok(())
}
#[test]
fn bench_small_packet() -> io::Result<()> {
env_logger::init();
rpc::init(TokioDefaultSpawner);
tokio::run(
bench()
.map_err(|e| panic!(e.to_string()))
.boxed()
.compat(),
);
println!("done");
Ok(())
}

View File

@@ -1,152 +0,0 @@
// Copyright 2018 Google LLC
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
//! Tests client/server control flow.
#![feature(generators, await_macro, async_await, futures_api,)]
use futures::{
compat::{Future01CompatExt, TokioDefaultSpawner},
prelude::*,
stream,
};
use log::{info, trace};
use rand::distributions::{Distribution, Normal};
use rpc::{
client::{self, Client},
context,
server::{self, Server},
};
use std::{
io,
time::{Duration, Instant, SystemTime},
};
use tokio::timer::Delay;
pub trait AsDuration {
/// Delay of 0 if self is in the past
fn as_duration(&self) -> Duration;
}
impl AsDuration for SystemTime {
fn as_duration(&self) -> Duration {
self.duration_since(SystemTime::now()).unwrap_or_default()
}
}
async fn run() -> io::Result<()> {
let listener = tarpc_bincode_transport::listen(&"0.0.0.0:0".parse().unwrap())?;
let addr = listener.local_addr();
let server = Server::<String, String>::new(server::Config::default())
.incoming(listener)
.take(1)
.for_each(async move |channel| {
let channel = if let Ok(channel) = channel {
channel
} else {
return;
};
let client_addr = *channel.client_addr();
let handler = channel.respond_with(move |ctx, request| {
// Sleep for a time sampled from a normal distribution with:
// - mean: 1/2 the deadline.
// - std dev: 1/2 the deadline.
let deadline: Duration = ctx.deadline.as_duration();
let deadline_millis = deadline.as_secs() * 1000 + deadline.subsec_millis() as u64;
let distribution =
Normal::new(deadline_millis as f64 / 2., deadline_millis as f64 / 2.);
let delay_millis = distribution.sample(&mut rand::thread_rng()).max(0.);
let delay = Duration::from_millis(delay_millis as u64);
trace!(
"[{}/{}] Responding to request in {:?}.",
ctx.trace_id(),
client_addr,
delay,
);
let wait = Delay::new(Instant::now() + delay).compat();
async move {
await!(wait).unwrap();
Ok(request)
}
});
tokio_executor::spawn(handler.unit_error().boxed().compat());
});
tokio_executor::spawn(server.unit_error().boxed().compat());
let conn = await!(tarpc_bincode_transport::connect(&addr))?;
let client = await!(Client::<String, String>::new(
client::Config::default(),
conn
))?;
// Proxy service
let listener = tarpc_bincode_transport::listen(&"0.0.0.0:0".parse().unwrap())?;
let addr = listener.local_addr();
let proxy_server = Server::<String, String>::new(server::Config::default())
.incoming(listener)
.take(1)
.for_each(move |channel| {
let client = client.clone();
async move {
let channel = if let Ok(channel) = channel {
channel
} else {
return;
};
let client_addr = *channel.client_addr();
let handler = channel.respond_with(move |ctx, request| {
trace!("[{}/{}] Proxying request.", ctx.trace_id(), client_addr);
let mut client = client.clone();
async move { await!(client.call(ctx, request)) }
});
tokio_executor::spawn(handler.unit_error().boxed().compat());
}
});
tokio_executor::spawn(proxy_server.unit_error().boxed().compat());
let mut config = client::Config::default();
config.max_in_flight_requests = 10;
config.pending_request_buffer = 10;
let client = await!(Client::<String, String>::new(
config,
await!(tarpc_bincode_transport::connect(&addr))?
))?;
// Make 3 speculative requests, returning only the quickest.
let mut clients: Vec<_> = (1..=3u32).map(|_| client.clone()).collect();
let mut requests = vec![];
for client in &mut clients {
let mut ctx = context::current();
ctx.deadline = SystemTime::now() + Duration::from_millis(200);
let trace_id = *ctx.trace_id();
let response = client.call(ctx, "ping".into());
requests.push(response.map(move |r| (trace_id, r)));
}
let (fastest_response, _) = await!(stream::futures_unordered(requests).into_future());
let (trace_id, resp) = fastest_response.unwrap();
info!("[{}] fastest_response = {:?}", trace_id, resp);
Ok::<_, io::Error>(())
}
#[test]
fn cancel_slower() -> io::Result<()> {
env_logger::init();
rpc::init(TokioDefaultSpawner);
tokio::run(
run()
.boxed()
.map_err(|e| panic!(e))
.compat(),
);
Ok(())
}

View File

@@ -1,120 +0,0 @@
// Copyright 2018 Google LLC
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
//! Tests client/server control flow.
#![feature(generators, await_macro, async_await, futures_api,)]
use futures::{
compat::{Future01CompatExt, TokioDefaultSpawner},
prelude::*,
};
use log::{error, info, trace};
use rand::distributions::{Distribution, Normal};
use rpc::{
client::{self, Client},
context,
server::{self, Server},
};
use std::{
io,
time::{Duration, Instant, SystemTime},
};
use tokio::timer::Delay;
pub trait AsDuration {
/// Delay of 0 if self is in the past
fn as_duration(&self) -> Duration;
}
impl AsDuration for SystemTime {
fn as_duration(&self) -> Duration {
self.duration_since(SystemTime::now()).unwrap_or_default()
}
}
async fn run() -> io::Result<()> {
let listener = tarpc_bincode_transport::listen(&"0.0.0.0:0".parse().unwrap())?;
let addr = listener.local_addr();
let server = Server::<String, String>::new(server::Config::default())
.incoming(listener)
.take(1)
.for_each(async move |channel| {
let channel = if let Ok(channel) = channel {
channel
} else {
return;
};
let client_addr = *channel.client_addr();
let handler = channel.respond_with(move |ctx, request| {
// Sleep for a time sampled from a normal distribution with:
// - mean: 1/2 the deadline.
// - std dev: 1/2 the deadline.
let deadline: Duration = ctx.deadline.as_duration();
let deadline_millis = deadline.as_secs() * 1000 + deadline.subsec_millis() as u64;
let distribution =
Normal::new(deadline_millis as f64 / 2., deadline_millis as f64 / 2.);
let delay_millis = distribution.sample(&mut rand::thread_rng()).max(0.);
let delay = Duration::from_millis(delay_millis as u64);
trace!(
"[{}/{}] Responding to request in {:?}.",
ctx.trace_id(),
client_addr,
delay,
);
let sleep = Delay::new(Instant::now() + delay).compat();
async {
await!(sleep).unwrap();
Ok(request)
}
});
tokio_executor::spawn(handler.unit_error().boxed().compat());
});
tokio_executor::spawn(server.unit_error().boxed().compat());
let mut config = client::Config::default();
config.max_in_flight_requests = 10;
config.pending_request_buffer = 10;
let conn = await!(tarpc_bincode_transport::connect(&addr))?;
let client = await!(Client::<String, String>::new(config, conn))?;
let clients = (1..=100u32).map(|_| client.clone()).collect::<Vec<_>>();
for mut client in clients {
let ctx = context::current();
tokio_executor::spawn(
async move {
let trace_id = *ctx.trace_id();
let response = client.call(ctx, "ping".into());
match await!(response) {
Ok(response) => info!("[{}] response: {}", trace_id, response),
Err(e) => error!("[{}] request error: {:?}: {}", trace_id, e.kind(), e),
}
}.unit_error().boxed().compat()
);
}
Ok(())
}
#[test]
fn ping_pong() -> io::Result<()> {
env_logger::init();
rpc::init(TokioDefaultSpawner);
tokio::run(
run()
.map_ok(|_| println!("done"))
.map_err(|e| panic!(e.to_string()))
.boxed()
.compat(),
);
Ok(())
}

View File

@@ -1,26 +1,25 @@
cargo-features = ["rename-dependency"]
[package]
name = "tarpc-example-service"
version = "0.1.0"
version = "0.6.0"
authors = ["Tim Kuehn <tikue@google.com>"]
edition = "2018"
license = "MIT"
documentation = "https://docs.rs/tarpc-example-service"
homepage = "https://github.com/google/tarpc"
repository = "https://github.com/google/tarpc"
keywords = ["rpc", "network", "server", "api", "microservices", "example"]
keywords = ["rpc", "network", "server", "microservices", "example"]
categories = ["asynchronous", "network-programming"]
readme = "../README.md"
description = "An example server built on tarpc."
[dependencies]
bincode-transport = { package = "tarpc-bincode-transport", version = "0.1", path = "../bincode-transport" }
futures-preview = { version = "0.3.0-alpha.8", features = ["compat", "tokio-compat"] }
clap = "2.0"
futures = "0.3"
serde = { version = "1.0" }
tarpc = { version = "0.13", path = "../tarpc", features = ["serde1"] }
tokio = "0.1"
tokio-executor = "0.1"
tarpc = { version = "0.21", path = "../tarpc", features = ["full"] }
tokio = { version = "0.2", features = ["full"] }
tokio-serde = { version = "0.6", features = ["json"] }
env_logger = "0.6"
[lib]
name = "service"
@@ -28,4 +27,8 @@ path = "src/lib.rs"
[[bin]]
name = "server"
path = "src/main.rs"
path = "src/server.rs"
[[bin]]
name = "client"
path = "src/client.rs"

View File

@@ -0,0 +1,58 @@
// Copyright 2018 Google LLC
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
use clap::{App, Arg};
use std::{io, net::SocketAddr};
use tarpc::{client, context};
use tokio_serde::formats::Json;
#[tokio::main]
async fn main() -> io::Result<()> {
let flags = App::new("Hello Client")
.version("0.1")
.author("Tim <tikue@google.com>")
.about("Say hello!")
.arg(
Arg::with_name("server_addr")
.long("server_addr")
.value_name("ADDRESS")
.help("Sets the server address to connect to.")
.required(true)
.takes_value(true),
)
.arg(
Arg::with_name("name")
.short("n")
.long("name")
.value_name("STRING")
.help("Sets the name to say hello to.")
.required(true)
.takes_value(true),
)
.get_matches();
let server_addr = flags.value_of("server_addr").unwrap();
let server_addr = server_addr
.parse::<SocketAddr>()
.unwrap_or_else(|e| panic!(r#"--server_addr value "{}" invalid: {}"#, server_addr, e));
let name = flags.value_of("name").unwrap().into();
let transport = tarpc::serde_transport::tcp::connect(server_addr, Json::default()).await?;
// WorldClient is generated by the service attribute. It has a constructor `new` that takes a
// config and any Transport as input.
let mut client = service::WorldClient::new(client::Config::default(), transport).spawn()?;
// The client has an RPC method for each RPC defined in the annotated trait. It takes the same
// args as defined, with the addition of a Context, which is always the first arg. The Context
// specifies a deadline and trace information which can be helpful in debugging requests.
let hello = client.hello(context::current(), name).await?;
println!("{}", hello);
Ok(())
}

View File

@@ -4,18 +4,10 @@
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
#![feature(
futures_api,
pin,
arbitrary_self_types,
await_macro,
async_await,
proc_macro_hygiene,
)]
// This is the service definition. It looks a lot like a trait definition.
// It defines one RPC, hello, which takes one arg, name, and returns a String.
tarpc::service! {
/// This is the service definition. It looks a lot like a trait definition.
/// It defines one RPC, hello, which takes one arg, name, and returns a String.
#[tarpc::service]
pub trait World {
/// Returns a greeting for name.
rpc hello(name: String) -> String;
async fn hello(name: String) -> String;
}

View File

@@ -1,85 +0,0 @@
// Copyright 2018 Google LLC
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
#![feature(
futures_api,
pin,
arbitrary_self_types,
await_macro,
async_await,
)]
use futures::{
compat::TokioDefaultSpawner,
future::{self, Ready},
prelude::*,
};
use tarpc::{
client, context,
server::{self, Handler, Server},
};
use std::io;
// This is the type that implements the generated Service trait. It is the business logic
// and is used to start the server.
#[derive(Clone)]
struct HelloServer;
impl service::Service for HelloServer {
// Each defined rpc generates two items in the trait, a fn that serves the RPC, and
// an associated type representing the future output by the fn.
type HelloFut = Ready<String>;
fn hello(&self, _: context::Context, name: String) -> Self::HelloFut {
future::ready(format!("Hello, {}!", name))
}
}
async fn run() -> io::Result<()> {
// bincode_transport is provided by the associated crate bincode-transport. It makes it easy
// to start up a serde-powered bincode serialization strategy over TCP.
let transport = bincode_transport::listen(&"0.0.0.0:0".parse().unwrap())?;
let addr = transport.local_addr();
// The server is configured with the defaults.
let server = Server::new(server::Config::default())
// Server can listen on any type that implements the Transport trait.
.incoming(transport)
// Close the stream after the client connects
.take(1)
// serve is generated by the service! macro. It takes as input any type implementing
// the generated Service trait.
.respond_with(service::serve(HelloServer));
tokio_executor::spawn(server.unit_error().boxed().compat());
let transport = await!(bincode_transport::connect(&addr))?;
// new_stub is generated by the service! macro. Like Server, it takes a config and any
// Transport as input, and returns a Client, also generated by the macro.
// by the service mcro.
let mut client = await!(service::new_stub(client::Config::default(), transport))?;
// The client has an RPC method for each RPC defined in service!. It takes the same args
// as defined, with the addition of a Context, which is always the first arg. The Context
// specifies a deadline and trace information which can be helpful in debugging requests.
let hello = await!(client.hello(context::current(), "Stim".to_string()))?;
println!("{}", hello);
Ok(())
}
fn main() {
tarpc::init(TokioDefaultSpawner);
tokio::run(run()
.map_err(|e| eprintln!("Oh no: {}", e))
.boxed()
.compat()
);
}

View File

@@ -0,0 +1,89 @@
// Copyright 2018 Google LLC
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
use clap::{App, Arg};
use futures::{
future::{self, Ready},
prelude::*,
};
use service::World;
use std::{
io,
net::{IpAddr, SocketAddr},
};
use tarpc::{
context,
server::{self, Channel, Handler},
};
use tokio_serde::formats::Json;
// This is the type that implements the generated World trait. It is the business logic
// and is used to start the server.
#[derive(Clone)]
struct HelloServer(SocketAddr);
impl World for HelloServer {
// Each defined rpc generates two items in the trait, a fn that serves the RPC, and
// an associated type representing the future output by the fn.
type HelloFut = Ready<String>;
fn hello(self, _: context::Context, name: String) -> Self::HelloFut {
future::ready(format!(
"Hello, {}! You are connected from {:?}.",
name, self.0
))
}
}
#[tokio::main]
async fn main() -> io::Result<()> {
env_logger::init();
let flags = App::new("Hello Server")
.version("0.1")
.author("Tim <tikue@google.com>")
.about("Say hello!")
.arg(
Arg::with_name("port")
.short("p")
.long("port")
.value_name("NUMBER")
.help("Sets the port number to listen on")
.required(true)
.takes_value(true),
)
.get_matches();
let port = flags.value_of("port").unwrap();
let port = port
.parse()
.unwrap_or_else(|e| panic!(r#"--port value "{}" invalid: {}"#, port, e));
let server_addr = (IpAddr::from([0, 0, 0, 0]), port);
// JSON transport is provided by the json_transport tarpc module. It makes it easy
// to start up a serde-powered json serialization strategy over TCP.
tarpc::serde_transport::tcp::listen(&server_addr, Json::default)
.await?
// Ignore accept errors.
.filter_map(|r| future::ready(r.ok()))
.map(server::BaseChannel::with_defaults)
// Limit channels to 1 per IP.
.max_channels_per_key(1, |t| t.as_ref().peer_addr().unwrap().ip())
// serve is generated by the service attribute. It takes as input any type implementing
// the generated World trait.
.map(|channel| {
let server = HelloServer(channel.as_ref().as_ref().peer_addr().unwrap());
channel.respond_with(server.serve()).execute()
})
// Max 10 channels.
.buffer_unordered(10)
.for_each(|_| async {})
.await;
Ok(())
}

View File

@@ -96,7 +96,7 @@ do
diff="$diff$(cargo fmt -- --skip-children --write-mode=diff $file)"
fi
done
if grep --quiet "^Diff at line" <<< "$diff"; then
if grep --quiet "^[-+]" <<< "$diff"; then
FMTRESULT=1
fi

View File

@@ -89,9 +89,13 @@ if [ "$?" == 0 ]; then
exit 1
fi
try_run "Building ... " cargo build --color=always
try_run "Testing ... " cargo test --color=always
try_run "Doc Test ... " cargo clean && cargo build --tests && rustdoc --test README.md --edition 2018 -L target/debug/deps -Z unstable-options
try_run "Building ... " cargo +stable build --color=always
try_run "Testing ... " cargo +stable test --color=always
try_run "Testing with all features enabled ... " cargo +stable test --all-features --color=always
for EXAMPLE in $(cargo +stable run --example 2>&1 | grep ' ' | awk '{print $1}')
do
try_run "Running example \"$EXAMPLE\" ... " cargo +stable run --example $EXAMPLE
done
fi

View File

@@ -1,7 +1,8 @@
[package]
name = "tarpc-plugins"
version = "0.5.0"
version = "0.8.0"
authors = ["Adam Wright <adam.austin.wright@gmail.com>", "Tim Kuehn <timothy.j.kuehn@gmail.com>"]
edition = "2018"
license = "MIT"
documentation = "https://docs.rs/tarpc-plugins"
homepage = "https://github.com/google/tarpc"
@@ -11,14 +12,22 @@ categories = ["asynchronous", "network-programming"]
readme = "../README.md"
description = "Proc macros for tarpc."
[features]
serde1 = []
[badges]
travis-ci = { repository = "google/tarpc" }
[dependencies]
itertools = "0.7"
syn = { version = "0.15", features = ["full", "extra-traits"] }
quote = "0.6"
proc-macro2 = "0.4"
syn = { version = "1.0.11", features = ["full"] }
quote = "1.0.2"
proc-macro2 = "1.0.6"
[lib]
proc-macro = true
[dev-dependencies]
futures = "0.3"
serde = { version = "1.0", features = ["derive"] }
tarpc = { path = "../tarpc" }
assert-type-eq = "0.1.0"

View File

@@ -1 +1 @@
edition = "Edition2018"
edition = "2018"

View File

@@ -4,88 +4,729 @@
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
#![recursion_limit = "512"]
extern crate proc_macro;
extern crate proc_macro2;
extern crate syn;
extern crate itertools;
extern crate quote;
extern crate syn;
use proc_macro::TokenStream;
use proc_macro2::TokenStream as TokenStream2;
use quote::{format_ident, quote, ToTokens};
use syn::{
braced,
ext::IdentExt,
parenthesized,
parse::{Parse, ParseStream},
parse_macro_input, parse_quote, parse_str,
punctuated::Punctuated,
token::Comma,
Attribute, FnArg, Ident, ImplItem, ImplItemMethod, ImplItemType, ItemImpl, Lit, LitBool,
MetaNameValue, Pat, PatType, ReturnType, Token, Type, Visibility,
};
use itertools::Itertools;
use quote::ToTokens;
use syn::{Ident, TraitItemType, TypePath, parse};
use proc_macro2::Span;
use std::str::FromStr;
struct Service {
attrs: Vec<Attribute>,
vis: Visibility,
ident: Ident,
rpcs: Vec<RpcMethod>,
}
#[proc_macro]
pub fn snake_to_camel(input: TokenStream) -> TokenStream {
let i = input.clone();
let mut assoc_type = parse::<TraitItemType>(input).unwrap_or_else(|_| panic!("Could not parse trait item from:\n{}", i));
struct RpcMethod {
attrs: Vec<Attribute>,
ident: Ident,
args: Vec<PatType>,
output: ReturnType,
}
let old_ident = convert(&mut assoc_type.ident);
impl Parse for Service {
fn parse(input: ParseStream) -> syn::Result<Self> {
let attrs = input.call(Attribute::parse_outer)?;
let vis = input.parse()?;
input.parse::<Token![trait]>()?;
let ident: Ident = input.parse()?;
let content;
braced!(content in input);
let mut rpcs = Vec::<RpcMethod>::new();
while !content.is_empty() {
rpcs.push(content.parse()?);
}
for rpc in &rpcs {
if rpc.ident == "new" {
return Err(input.error(format!(
"method name conflicts with generated fn `{}Client::new`",
ident.unraw()
)));
}
if rpc.ident == "serve" {
return Err(input.error(format!(
"method name conflicts with generated fn `{}::serve`",
ident
)));
}
}
for mut attr in &mut assoc_type.attrs {
if let Some(pair) = attr.path.segments.first() {
if pair.value().ident == "doc" {
attr.tts = proc_macro2::TokenStream::from_str(&attr.tts.to_string().replace("{}", &old_ident)).unwrap();
Ok(Self {
attrs,
vis,
ident,
rpcs,
})
}
}
impl Parse for RpcMethod {
fn parse(input: ParseStream) -> syn::Result<Self> {
let attrs = input.call(Attribute::parse_outer)?;
input.parse::<Token![async]>()?;
input.parse::<Token![fn]>()?;
let ident = input.parse()?;
let content;
parenthesized!(content in input);
let args: Punctuated<FnArg, Comma> = content.parse_terminated(FnArg::parse)?;
let args = args
.into_iter()
.map(|arg| match arg {
FnArg::Typed(captured) => match *captured.pat {
Pat::Ident(_) => Ok(captured),
_ => Err(input.error("patterns aren't allowed in RPC args")),
},
FnArg::Receiver(_) => Err(input.error("method args cannot start with self")),
})
.collect::<Result<_, _>>()?;
let output = input.parse()?;
input.parse::<Token![;]>()?;
Ok(Self {
attrs,
ident,
args,
output,
})
}
}
// If `derive_serde` meta item is not present, defaults to cfg!(feature = "serde1").
// `derive_serde` can only be true when serde1 is enabled.
struct DeriveSerde(bool);
impl Parse for DeriveSerde {
fn parse(input: ParseStream) -> syn::Result<Self> {
if input.is_empty() {
return Ok(Self(cfg!(feature = "serde1")));
}
match input.parse::<MetaNameValue>()? {
MetaNameValue {
ref path, ref lit, ..
} if path.segments.len() == 1
&& path.segments.first().unwrap().ident == "derive_serde" =>
{
match lit {
Lit::Bool(LitBool { value: true, .. }) if cfg!(feature = "serde1") => {
Ok(Self(true))
}
Lit::Bool(LitBool { value: true, .. }) => {
Err(input
.error("To enable serde, first enable the `serde1` feature of tarpc"))
}
Lit::Bool(LitBool { value: false, .. }) => Ok(Self(false)),
_ => Err(input.error("`derive_serde` expects a value of type `bool`")),
}
}
_ => {
Err(input
.error("tarpc::service only supports one meta item, `derive_serde = {bool}`"))
}
}
}
}
/// Generates:
/// - service trait
/// - serve fn
/// - client stub struct
/// - new_stub client factory fn
/// - Request and Response enums
/// - ResponseFut Future
#[proc_macro_attribute]
pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream {
let derive_serde = parse_macro_input!(attr as DeriveSerde);
let unit_type: &Type = &parse_quote!(());
let Service {
ref attrs,
ref vis,
ref ident,
ref rpcs,
} = parse_macro_input!(input as Service);
let camel_case_fn_names: &Vec<_> = &rpcs
.iter()
.map(|rpc| snake_to_camel(&rpc.ident.unraw().to_string()))
.collect();
let args: &[&[PatType]] = &rpcs.iter().map(|rpc| &*rpc.args).collect::<Vec<_>>();
let response_fut_name = &format!("{}ResponseFut", ident.unraw());
let derive_serialize = if derive_serde.0 {
Some(quote!(#[derive(serde::Serialize, serde::Deserialize)]))
} else {
None
};
ServiceGenerator {
response_fut_name,
service_ident: ident,
server_ident: &format_ident!("Serve{}", ident),
response_fut_ident: &Ident::new(&response_fut_name, ident.span()),
client_ident: &format_ident!("{}Client", ident),
request_ident: &format_ident!("{}Request", ident),
response_ident: &format_ident!("{}Response", ident),
vis,
args,
method_attrs: &rpcs.iter().map(|rpc| &*rpc.attrs).collect::<Vec<_>>(),
method_idents: &rpcs.iter().map(|rpc| &rpc.ident).collect::<Vec<_>>(),
attrs,
rpcs,
return_types: &rpcs
.iter()
.map(|rpc| match rpc.output {
ReturnType::Type(_, ref ty) => ty,
ReturnType::Default => unit_type,
})
.collect::<Vec<_>>(),
arg_pats: &args
.iter()
.map(|args| args.iter().map(|arg| &*arg.pat).collect())
.collect::<Vec<_>>(),
camel_case_idents: &rpcs
.iter()
.zip(camel_case_fn_names.iter())
.map(|(rpc, name)| Ident::new(name, rpc.ident.span()))
.collect::<Vec<_>>(),
future_types: &camel_case_fn_names
.iter()
.map(|name| parse_str(&format!("{}Fut", name)).unwrap())
.collect::<Vec<_>>(),
derive_serialize: derive_serialize.as_ref(),
}
.into_token_stream()
.into()
}
/// Transforms an async function into a sync one, returning a type declaration
/// for the return type (a future).
fn transform_method(method: &mut ImplItemMethod) -> ImplItemType {
method.sig.asyncness = None;
// get either the return type or ().
let ret = match &method.sig.output {
ReturnType::Default => quote!(()),
ReturnType::Type(_, ret) => quote!(#ret),
};
// generate an identifier consisting of the method name to CamelCase with
// Fut appended to it.
let fut_name = snake_to_camel(&method.sig.ident.unraw().to_string()) + "Fut";
let fut_name_ident = Ident::new(&fut_name, method.sig.ident.span());
// generate the updated return signature.
method.sig.output = parse_quote! {
-> ::core::pin::Pin<Box<
dyn ::core::future::Future<Output = #ret> + ::core::marker::Send
>>
};
// transform the body of the method into Box::pin(async move { body }).
let block = method.block.clone();
method.block = parse_quote! [{
Box::pin(async move
#block
)
}];
// generate and return type declaration for return type.
let t: ImplItemType = parse_quote! {
type #fut_name_ident = ::core::pin::Pin<Box<dyn ::core::future::Future<Output = #ret> + ::core::marker::Send>>;
};
t
}
/// Syntactic sugar to make using async functions in the server implementation
/// easier. It does this by rewriting code like this, which would normally not
/// compile because async functions are disallowed in trait implementations:
///
/// ```rust
/// # extern crate tarpc;
/// # use tarpc::context;
/// # use std::net::SocketAddr;
/// #[tarpc_plugins::service]
/// trait World {
/// async fn hello(name: String) -> String;
/// }
///
/// #[derive(Clone)]
/// struct HelloServer(SocketAddr);
///
/// #[tarpc_plugins::server]
/// impl World for HelloServer {
/// async fn hello(self, _: context::Context, name: String) -> String {
/// format!("Hello, {}! You are connected from {:?}.", name, self.0)
/// }
/// }
/// ```
///
/// Into code like this, which matches the service trait definition:
///
/// ```rust
/// # extern crate tarpc;
/// # use tarpc::context;
/// # use std::pin::Pin;
/// # use futures::Future;
/// # use std::net::SocketAddr;
/// #[tarpc_plugins::service]
/// trait World {
/// async fn hello(name: String) -> String;
/// }
///
/// #[derive(Clone)]
/// struct HelloServer(SocketAddr);
///
/// impl World for HelloServer {
/// type HelloFut = Pin<Box<dyn Future<Output = String> + Send>>;
///
/// fn hello(self, _: context::Context, name: String) -> Pin<Box<dyn Future<Output = String>
/// + Send>> {
/// Box::pin(async move {
/// format!("Hello, {}! You are connected from {:?}.", name, self.0)
/// })
/// }
/// }
/// ```
///
/// Note that this won't touch functions unless they have been annotated with
/// `async`, meaning that this should not break existing code.
#[proc_macro_attribute]
pub fn server(_attr: TokenStream, input: TokenStream) -> TokenStream {
let mut item = syn::parse_macro_input!(input as ItemImpl);
// the generated type declarations
let mut types: Vec<ImplItemType> = Vec::new();
for inner in &mut item.items {
if let ImplItem::Method(method) = inner {
let sig = &method.sig;
// if this function is declared async, transform it into a regular function
if sig.asyncness.is_some() {
let typedecl = transform_method(method);
types.push(typedecl);
}
}
}
assoc_type.into_token_stream().into()
// add the type declarations into the impl block
for t in types.into_iter() {
item.items.push(syn::ImplItem::Type(t));
}
TokenStream::from(quote!(#item))
}
#[proc_macro]
pub fn ty_snake_to_camel(input: TokenStream) -> TokenStream {
let mut path = parse::<TypePath>(input).unwrap();
// Only capitalize the final segment
convert(&mut path.path
.segments
.last_mut()
.unwrap()
.into_value()
.ident);
path.into_token_stream().into()
// Things needed to generate the service items: trait, serve impl, request/response enums, and
// the client stub.
struct ServiceGenerator<'a> {
service_ident: &'a Ident,
server_ident: &'a Ident,
response_fut_ident: &'a Ident,
response_fut_name: &'a str,
client_ident: &'a Ident,
request_ident: &'a Ident,
response_ident: &'a Ident,
vis: &'a Visibility,
attrs: &'a [Attribute],
rpcs: &'a [RpcMethod],
camel_case_idents: &'a [Ident],
future_types: &'a [Type],
method_idents: &'a [&'a Ident],
method_attrs: &'a [&'a [Attribute]],
args: &'a [&'a [PatType]],
return_types: &'a [&'a Type],
arg_pats: &'a [Vec<&'a Pat>],
derive_serialize: Option<&'a TokenStream2>,
}
/// Converts an ident in-place to CamelCase and returns the previous ident.
fn convert(ident: &mut Ident) -> String {
let ident_str = ident.to_string();
let mut camel_ty = String::new();
impl<'a> ServiceGenerator<'a> {
fn trait_service(&self) -> TokenStream2 {
let &Self {
attrs,
rpcs,
vis,
future_types,
return_types,
service_ident,
server_ident,
..
} = self;
{
// Find the first non-underscore and add it capitalized.
let mut chars = ident_str.chars();
let types_and_fns = rpcs
.iter()
.zip(future_types.iter())
.zip(return_types.iter())
.map(
|(
(
RpcMethod {
attrs, ident, args, ..
},
future_type,
),
output,
)| {
let ty_doc = format!("The response future returned by {}.", ident);
quote! {
#[doc = #ty_doc]
type #future_type: std::future::Future<Output = #output>;
// Find the first non-underscore char, uppercase it, and append it.
// Guaranteed to succeed because all idents must have at least one non-underscore char.
camel_ty.extend(chars.find(|&c| c != '_').unwrap().to_uppercase());
#( #attrs )*
fn #ident(self, context: tarpc::context::Context, #( #args ),*) -> Self::#future_type;
}
},
);
// When we find an underscore, we remove it and capitalize the next char. To do this,
// we need to ensure the next char is not another underscore.
let mut chars = chars.coalesce(|c1, c2| {
if c1 == '_' && c2 == '_' {
Ok(c1)
} else {
Err((c1, c2))
quote! {
#( #attrs )*
#vis trait #service_ident: Clone {
#( #types_and_fns )*
/// Returns a serving function to use with tarpc::server::Server.
fn serve(self) -> #server_ident<Self> {
#server_ident { service: self }
}
}
});
}
}
while let Some(c) = chars.next() {
if c != '_' {
camel_ty.push(c);
} else if let Some(c) = chars.next() {
fn struct_server(&self) -> TokenStream2 {
let &Self {
vis, server_ident, ..
} = self;
quote! {
#[derive(Clone)]
#vis struct #server_ident<S> {
service: S,
}
}
}
fn impl_serve_for_server(&self) -> TokenStream2 {
let &Self {
request_ident,
server_ident,
service_ident,
response_ident,
response_fut_ident,
camel_case_idents,
arg_pats,
method_idents,
..
} = self;
quote! {
impl<S> tarpc::server::Serve<#request_ident> for #server_ident<S>
where S: #service_ident
{
type Resp = #response_ident;
type Fut = #response_fut_ident<S>;
fn serve(self, ctx: tarpc::context::Context, req: #request_ident) -> Self::Fut {
match req {
#(
#request_ident::#camel_case_idents{ #( #arg_pats ),* } => {
#response_fut_ident::#camel_case_idents(
#service_ident::#method_idents(
self.service, ctx, #( #arg_pats ),*
)
)
}
)*
}
}
}
}
}
fn enum_request(&self) -> TokenStream2 {
let &Self {
derive_serialize,
vis,
request_ident,
camel_case_idents,
args,
..
} = self;
quote! {
/// The request sent over the wire from the client to the server.
#[derive(Debug)]
#derive_serialize
#vis enum #request_ident {
#( #camel_case_idents{ #( #args ),* } ),*
}
}
}
fn enum_response(&self) -> TokenStream2 {
let &Self {
derive_serialize,
vis,
response_ident,
camel_case_idents,
return_types,
..
} = self;
quote! {
/// The response sent over the wire from the server to the client.
#[derive(Debug)]
#derive_serialize
#vis enum #response_ident {
#( #camel_case_idents(#return_types) ),*
}
}
}
fn enum_response_future(&self) -> TokenStream2 {
let &Self {
vis,
service_ident,
response_fut_ident,
camel_case_idents,
future_types,
..
} = self;
quote! {
/// A future resolving to a server response.
#vis enum #response_fut_ident<S: #service_ident> {
#( #camel_case_idents(<S as #service_ident>::#future_types) ),*
}
}
}
fn impl_debug_for_response_future(&self) -> TokenStream2 {
let &Self {
service_ident,
response_fut_ident,
response_fut_name,
..
} = self;
quote! {
impl<S: #service_ident> std::fmt::Debug for #response_fut_ident<S> {
fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
fmt.debug_struct(#response_fut_name).finish()
}
}
}
}
fn impl_future_for_response_future(&self) -> TokenStream2 {
let &Self {
service_ident,
response_fut_ident,
response_ident,
camel_case_idents,
..
} = self;
quote! {
impl<S: #service_ident> std::future::Future for #response_fut_ident<S> {
type Output = #response_ident;
fn poll(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>)
-> std::task::Poll<#response_ident>
{
unsafe {
match std::pin::Pin::get_unchecked_mut(self) {
#(
#response_fut_ident::#camel_case_idents(resp) =>
std::pin::Pin::new_unchecked(resp)
.poll(cx)
.map(#response_ident::#camel_case_idents),
)*
}
}
}
}
}
}
fn struct_client(&self) -> TokenStream2 {
let &Self {
vis,
client_ident,
request_ident,
response_ident,
..
} = self;
quote! {
#[allow(unused)]
#[derive(Clone, Debug)]
/// The client stub that makes RPC calls to the server. Exposes a Future interface.
#vis struct #client_ident<C = tarpc::client::Channel<#request_ident, #response_ident>>(C);
}
}
fn impl_from_for_client(&self) -> TokenStream2 {
let &Self {
client_ident,
request_ident,
response_ident,
..
} = self;
quote! {
impl<C> From<C> for #client_ident<C>
where for <'a> C: tarpc::Client<'a, #request_ident, Response = #response_ident>
{
fn from(client: C) -> Self {
#client_ident(client)
}
}
}
}
fn impl_client_new(&self) -> TokenStream2 {
let &Self {
client_ident,
vis,
request_ident,
response_ident,
..
} = self;
quote! {
impl #client_ident {
/// Returns a new client stub that sends requests over the given transport.
#vis fn new<T>(config: tarpc::client::Config, transport: T)
-> tarpc::client::NewClient<
Self,
tarpc::client::channel::RequestDispatch<#request_ident, #response_ident, T>
>
where
T: tarpc::Transport<tarpc::ClientMessage<#request_ident>, tarpc::Response<#response_ident>>
{
let new_client = tarpc::client::new(config, transport);
tarpc::client::NewClient {
client: #client_ident(new_client.client),
dispatch: new_client.dispatch,
}
}
}
}
}
fn impl_client_rpc_methods(&self) -> TokenStream2 {
let &Self {
client_ident,
request_ident,
response_ident,
method_attrs,
vis,
method_idents,
args,
return_types,
arg_pats,
camel_case_idents,
..
} = self;
quote! {
impl<C> #client_ident<C>
where for<'a> C: tarpc::Client<'a, #request_ident, Response = #response_ident>
{
#(
#[allow(unused)]
#( #method_attrs )*
#vis fn #method_idents(&mut self, ctx: tarpc::context::Context, #( #args ),*)
-> impl std::future::Future<Output = std::io::Result<#return_types>> + '_ {
let request = #request_ident::#camel_case_idents { #( #arg_pats ),* };
let resp = tarpc::Client::call(&mut self.0, ctx, request);
async move {
match resp.await? {
#response_ident::#camel_case_idents(msg) => std::result::Result::Ok(msg),
_ => unreachable!(),
}
}
}
)*
}
}
}
}
impl<'a> ToTokens for ServiceGenerator<'a> {
fn to_tokens(&self, output: &mut TokenStream2) {
output.extend(vec![
self.trait_service(),
self.struct_server(),
self.impl_serve_for_server(),
self.enum_request(),
self.enum_response(),
self.enum_response_future(),
self.impl_debug_for_response_future(),
self.impl_future_for_response_future(),
self.struct_client(),
self.impl_from_for_client(),
self.impl_client_new(),
self.impl_client_rpc_methods(),
])
}
}
fn snake_to_camel(ident_str: &str) -> String {
let mut camel_ty = String::with_capacity(ident_str.len());
let mut last_char_was_underscore = true;
for c in ident_str.chars() {
match c {
'_' => last_char_was_underscore = true,
c if last_char_was_underscore => {
camel_ty.extend(c.to_uppercase());
last_char_was_underscore = false;
}
c => camel_ty.extend(c.to_lowercase()),
}
}
// The Fut suffix is hardcoded right now; this macro isn't really meant to be general-purpose.
camel_ty.push_str("Fut");
*ident = Ident::new(&camel_ty, Span::call_site());
ident_str
camel_ty.shrink_to_fit();
camel_ty
}
#[test]
fn snake_to_camel_basic() {
assert_eq!(snake_to_camel("abc_def"), "AbcDef");
}
#[test]
fn snake_to_camel_underscore_suffix() {
assert_eq!(snake_to_camel("abc_def_"), "AbcDef");
}
#[test]
fn snake_to_camel_underscore_prefix() {
assert_eq!(snake_to_camel("_abc_def"), "AbcDef");
}
#[test]
fn snake_to_camel_underscore_consecutive() {
assert_eq!(snake_to_camel("abc__def"), "AbcDef");
}
#[test]
fn snake_to_camel_capital_in_middle() {
assert_eq!(snake_to_camel("aBc_dEf"), "AbcDef");
}

144
plugins/tests/server.rs Normal file
View File

@@ -0,0 +1,144 @@
use assert_type_eq::assert_type_eq;
use futures::Future;
use std::pin::Pin;
use tarpc::context;
// these need to be out here rather than inside the function so that the
// assert_type_eq macro can pick them up.
#[tarpc::service]
trait Foo {
async fn two_part(s: String, i: i32) -> (String, i32);
async fn bar(s: String) -> String;
async fn baz();
}
#[test]
fn type_generation_works() {
#[tarpc::server]
impl Foo for () {
async fn two_part(self, _: context::Context, s: String, i: i32) -> (String, i32) {
(s, i)
}
async fn bar(self, _: context::Context, s: String) -> String {
s
}
async fn baz(self, _: context::Context) {}
}
// the assert_type_eq macro can only be used once per block.
{
assert_type_eq!(
<() as Foo>::TwoPartFut,
Pin<Box<dyn Future<Output = (String, i32)> + Send>>
);
}
{
assert_type_eq!(
<() as Foo>::BarFut,
Pin<Box<dyn Future<Output = String> + Send>>
);
}
{
assert_type_eq!(
<() as Foo>::BazFut,
Pin<Box<dyn Future<Output = ()> + Send>>
);
}
}
#[allow(non_camel_case_types)]
#[test]
fn raw_idents_work() {
type r#yield = String;
#[tarpc::service]
trait r#trait {
async fn r#await(r#struct: r#yield, r#enum: i32) -> (r#yield, i32);
async fn r#fn(r#impl: r#yield) -> r#yield;
async fn r#async();
}
#[tarpc::server]
impl r#trait for () {
async fn r#await(
self,
_: context::Context,
r#struct: r#yield,
r#enum: i32,
) -> (r#yield, i32) {
(r#struct, r#enum)
}
async fn r#fn(self, _: context::Context, r#impl: r#yield) -> r#yield {
r#impl
}
async fn r#async(self, _: context::Context) {}
}
}
#[test]
fn syntax() {
#[tarpc::service]
trait Syntax {
#[deny(warnings)]
#[allow(non_snake_case)]
async fn TestCamelCaseDoesntConflict();
async fn hello() -> String;
#[doc = "attr"]
async fn attr(s: String) -> String;
async fn no_args_no_return();
async fn no_args() -> ();
async fn one_arg(one: String) -> i32;
async fn two_args_no_return(one: String, two: u64);
async fn two_args(one: String, two: u64) -> String;
async fn no_args_ret_error() -> i32;
async fn one_arg_ret_error(one: String) -> String;
async fn no_arg_implicit_return_error();
#[doc = "attr"]
async fn one_arg_implicit_return_error(one: String);
}
#[tarpc::server]
impl Syntax for () {
#[deny(warnings)]
#[allow(non_snake_case)]
async fn TestCamelCaseDoesntConflict(self, _: context::Context) {}
async fn hello(self, _: context::Context) -> String {
String::new()
}
async fn attr(self, _: context::Context, _s: String) -> String {
String::new()
}
async fn no_args_no_return(self, _: context::Context) {}
async fn no_args(self, _: context::Context) -> () {}
async fn one_arg(self, _: context::Context, _one: String) -> i32 {
0
}
async fn two_args_no_return(self, _: context::Context, _one: String, _two: u64) {}
async fn two_args(self, _: context::Context, _one: String, _two: u64) -> String {
String::new()
}
async fn no_args_ret_error(self, _: context::Context) -> i32 {
0
}
async fn one_arg_ret_error(self, _: context::Context, _one: String) -> String {
String::new()
}
async fn no_arg_implicit_return_error(self, _: context::Context) {}
async fn one_arg_implicit_return_error(self, _: context::Context, _one: String) {}
}
}

85
plugins/tests/service.rs Normal file
View File

@@ -0,0 +1,85 @@
use tarpc::context;
#[test]
fn att_service_trait() {
use futures::future::{ready, Ready};
#[tarpc::service]
trait Foo {
async fn two_part(s: String, i: i32) -> (String, i32);
async fn bar(s: String) -> String;
async fn baz();
}
impl Foo for () {
type TwoPartFut = Ready<(String, i32)>;
fn two_part(self, _: context::Context, s: String, i: i32) -> Self::TwoPartFut {
ready((s, i))
}
type BarFut = Ready<String>;
fn bar(self, _: context::Context, s: String) -> Self::BarFut {
ready(s)
}
type BazFut = Ready<()>;
fn baz(self, _: context::Context) -> Self::BazFut {
ready(())
}
}
}
#[allow(non_camel_case_types)]
#[test]
fn raw_idents() {
use futures::future::{ready, Ready};
type r#yield = String;
#[tarpc::service]
trait r#trait {
async fn r#await(r#struct: r#yield, r#enum: i32) -> (r#yield, i32);
async fn r#fn(r#impl: r#yield) -> r#yield;
async fn r#async();
}
impl r#trait for () {
type AwaitFut = Ready<(r#yield, i32)>;
fn r#await(self, _: context::Context, r#struct: r#yield, r#enum: i32) -> Self::AwaitFut {
ready((r#struct, r#enum))
}
type FnFut = Ready<r#yield>;
fn r#fn(self, _: context::Context, r#impl: r#yield) -> Self::FnFut {
ready(r#impl)
}
type AsyncFut = Ready<()>;
fn r#async(self, _: context::Context) -> Self::AsyncFut {
ready(())
}
}
}
#[test]
fn syntax() {
#[tarpc::service]
trait Syntax {
#[deny(warnings)]
#[allow(non_snake_case)]
async fn TestCamelCaseDoesntConflict();
async fn hello() -> String;
#[doc = "attr"]
async fn attr(s: String) -> String;
async fn no_args_no_return();
async fn no_args() -> ();
async fn one_arg(one: String) -> i32;
async fn two_args_no_return(one: String, two: u64);
async fn two_args(one: String, two: u64) -> String;
async fn no_args_ret_error() -> i32;
async fn one_arg_ret_error(one: String) -> String;
async fn no_arg_implicit_return_error();
#[doc = "attr"]
async fn one_arg_implicit_return_error(one: String);
}
}

View File

@@ -1,38 +0,0 @@
cargo-features = ["rename-dependency"]
[package]
name = "tarpc-lib"
version = "0.1.0"
authors = ["Tim Kuehn <tikue@google.com>"]
edition = '2018'
license = "MIT"
documentation = "https://docs.rs/tarpc-lib"
homepage = "https://github.com/google/tarpc"
repository = "https://github.com/google/tarpc"
keywords = ["rpc", "network", "server", "api", "microservices"]
categories = ["asynchronous", "network-programming"]
readme = "../README.md"
description = "An RPC framework for Rust with a focus on ease of use."
[features]
default = []
serde1 = ["trace/serde", "serde", "serde/derive"]
[dependencies]
fnv = "1.0"
humantime = "1.0"
log = "0.4"
pin-utils = "0.1.0-alpha.2"
rand = "0.5"
tokio-timer = "0.2"
trace = { package = "tarpc-trace", version = "0.1", path = "../trace" }
serde = { optional = true, version = "1.0" }
[target.'cfg(not(test))'.dependencies]
futures-preview = { version = "0.3.0-alpha.8", features = ["compat"] }
[dev-dependencies]
futures-preview = { version = "0.3.0-alpha.8", features = ["compat", "tokio-compat"] }
futures-test-preview = { version = "0.3.0-alpha.8" }
env_logger = "0.5"
tokio = "0.1"

View File

@@ -1 +0,0 @@
edition = "Edition2018"

View File

@@ -1,714 +0,0 @@
// Copyright 2018 Google LLC
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
use crate::{
context,
util::{deadline_compat, AsDuration, Compact},
ClientMessage, ClientMessageKind, Request, Response, Transport,
};
use fnv::FnvHashMap;
use futures::{
Poll,
channel::{mpsc, oneshot},
prelude::*,
ready,
stream::Fuse,
task::LocalWaker,
};
use humantime::format_rfc3339;
use log::{debug, error, info, trace};
use pin_utils::unsafe_pinned;
use std::{
io,
net::SocketAddr,
pin::Pin,
sync::{
atomic::{AtomicU64, Ordering},
Arc,
},
time::Instant,
};
use trace::SpanId;
use super::Config;
/// Handles communication from the client to request dispatch.
#[derive(Debug)]
pub(crate) struct Channel<Req, Resp> {
to_dispatch: mpsc::Sender<DispatchRequest<Req, Resp>>,
/// Channel to send a cancel message to the dispatcher.
cancellation: RequestCancellation,
/// The ID to use for the next request to stage.
next_request_id: Arc<AtomicU64>,
server_addr: SocketAddr,
}
impl<Req, Resp> Clone for Channel<Req, Resp> {
fn clone(&self) -> Self {
Self {
to_dispatch: self.to_dispatch.clone(),
cancellation: self.cancellation.clone(),
next_request_id: self.next_request_id.clone(),
server_addr: self.server_addr,
}
}
}
impl<Req, Resp> Channel<Req, Resp> {
/// Sends a request to the dispatch task to forward to the server, returning a [`Future`] that
/// resolves when the request is sent (not when the response is received).
pub(crate) async fn send(
&mut self,
mut ctx: context::Context,
request: Req,
) -> io::Result<DispatchResponse<Resp>> {
// Convert the context to the call context.
ctx.trace_context.parent_id = Some(ctx.trace_context.span_id);
ctx.trace_context.span_id = SpanId::random(&mut rand::thread_rng());
let timeout = ctx.deadline.as_duration();
let deadline = Instant::now() + timeout;
trace!(
"[{}/{}] Queuing request with deadline {} (timeout {:?}).",
ctx.trace_id(),
self.server_addr,
format_rfc3339(ctx.deadline),
timeout,
);
let (response_completion, response) = oneshot::channel();
let cancellation = self.cancellation.clone();
let request_id = self.next_request_id.fetch_add(1, Ordering::Relaxed);
await!(self.to_dispatch.send(DispatchRequest {
ctx,
request_id,
request,
response_completion,
})).map_err(|_| io::Error::from(io::ErrorKind::ConnectionReset))?;
Ok(DispatchResponse {
response: deadline_compat::Deadline::new(response, deadline),
complete: false,
request_id,
cancellation,
ctx,
server_addr: self.server_addr,
})
}
/// Sends a request to the dispatch task to forward to the server, returning a [`Future`] that
/// resolves to the response.
pub(crate) async fn call(
&mut self,
context: context::Context,
request: Req,
) -> io::Result<Resp> {
let response_future = await!(self.send(context, request))?;
await!(response_future)
}
}
/// A server response that is completed by request dispatch when the corresponding response
/// arrives off the wire.
#[derive(Debug)]
pub struct DispatchResponse<Resp> {
response: deadline_compat::Deadline<oneshot::Receiver<Response<Resp>>>,
ctx: context::Context,
complete: bool,
cancellation: RequestCancellation,
request_id: u64,
server_addr: SocketAddr,
}
impl<Resp> DispatchResponse<Resp> {
unsafe_pinned!(server_addr: SocketAddr);
unsafe_pinned!(ctx: context::Context);
}
impl<Resp> Future for DispatchResponse<Resp> {
type Output = io::Result<Resp>;
fn poll(mut self: Pin<&mut Self>, waker: &LocalWaker) -> Poll<io::Result<Resp>> {
let resp = ready!(self.response.poll_unpin(waker));
self.complete = true;
Poll::Ready(match resp {
Ok(resp) => Ok(resp.message?),
Err(e) => Err({
let trace_id = *self.ctx().trace_id();
let server_addr = *self.server_addr();
if e.is_elapsed() {
io::Error::new(
io::ErrorKind::TimedOut,
"Client dropped expired request.".to_string(),
)
} else if e.is_timer() {
let e = e.into_timer().unwrap();
if e.is_at_capacity() {
io::Error::new(
io::ErrorKind::Other,
"Cancelling request because an expiration could not be set \
due to the timer being at capacity."
.to_string(),
)
} else if e.is_shutdown() {
panic!("[{}/{}] Timer was shutdown", trace_id, server_addr)
} else {
panic!(
"[{}/{}] Unrecognized timer error: {}",
trace_id, server_addr, e
)
}
} else if e.is_inner() {
// The oneshot is Canceled when the dispatch task ends.
io::Error::from(io::ErrorKind::ConnectionReset)
} else {
panic!(
"[{}/{}] Unrecognized deadline error: {}",
trace_id, server_addr, e
)
}
}),
})
}
}
// Cancels the request when dropped, if not already complete.
impl<Resp> Drop for DispatchResponse<Resp> {
fn drop(&mut self) {
if !self.complete {
// The receiver needs to be closed to handle the edge case that the request has not
// yet been received by the dispatch task. It is possible for the cancel message to
// arrive before the request itself, in which case the request could get stuck in the
// dispatch map forever if the server never responds (e.g. if the server dies while
// responding). Even if the server does respond, it will have unnecessarily done work
// for a client no longer waiting for a response. To avoid this, the dispatch task
// checks if the receiver is closed before inserting the request in the map. By
// closing the receiver before sending the cancel message, it is guaranteed that if the
// dispatch task misses an early-arriving cancellation message, then it will see the
// receiver as closed.
self.response.get_mut().close();
self.cancellation.cancel(self.request_id);
}
}
}
/// Spawns a dispatch task on the default executor that manages the lifecycle of requests initiated
/// by the returned [`Channel`].
pub async fn spawn<Req, Resp, C>(
config: Config,
transport: C,
server_addr: SocketAddr,
) -> io::Result<Channel<Req, Resp>>
where
Req: Send,
Resp: Send,
C: Transport<Item = Response<Resp>, SinkItem = ClientMessage<Req>> + Send,
{
let (to_dispatch, pending_requests) = mpsc::channel(config.pending_request_buffer);
let (cancellation, canceled_requests) = cancellations();
crate::spawn(
RequestDispatch {
config,
server_addr,
canceled_requests,
transport: transport.fuse(),
in_flight_requests: FnvHashMap::default(),
pending_requests: pending_requests.fuse(),
}.unwrap_or_else(move |e| error!("[{}] Connection broken: {}", server_addr, e))
).map_err(|e| {
io::Error::new(
io::ErrorKind::Other,
format!(
"Could not spawn client dispatch task. Is shutdown: {}",
e.is_shutdown()
),
)
})?;
Ok(Channel {
to_dispatch,
cancellation,
server_addr,
next_request_id: Arc::new(AtomicU64::new(0)),
})
}
/// Handles the lifecycle of requests, writing requests to the wire, managing cancellations,
/// and dispatching responses to the appropriate channel.
struct RequestDispatch<Req, Resp, C> {
/// Writes requests to the wire and reads responses off the wire.
transport: Fuse<C>,
/// Requests waiting to be written to the wire.
pending_requests: Fuse<mpsc::Receiver<DispatchRequest<Req, Resp>>>,
/// Requests that were dropped.
canceled_requests: CanceledRequests,
/// Requests already written to the wire that haven't yet received responses.
in_flight_requests: FnvHashMap<u64, InFlightData<Resp>>,
/// Configures limits to prevent unlimited resource usage.
config: Config,
/// The address of the server connected to.
server_addr: SocketAddr,
}
impl<Req, Resp, C> RequestDispatch<Req, Resp, C>
where
Req: Send,
Resp: Send,
C: Transport<Item = Response<Resp>, SinkItem = ClientMessage<Req>>,
{
unsafe_pinned!(server_addr: SocketAddr);
unsafe_pinned!(in_flight_requests: FnvHashMap<u64, InFlightData<Resp>>);
unsafe_pinned!(canceled_requests: CanceledRequests);
unsafe_pinned!(pending_requests: Fuse<mpsc::Receiver<DispatchRequest<Req, Resp>>>);
unsafe_pinned!(transport: Fuse<C>);
fn pump_read(self: &mut Pin<&mut Self>, waker: &LocalWaker) -> Poll<Option<io::Result<()>>> {
Poll::Ready(match ready!(self.transport().poll_next(waker)?) {
Some(response) => {
self.complete(response);
Some(Ok(()))
}
None => {
trace!("[{}] read half closed", self.server_addr());
None
}
})
}
fn pump_write(self: &mut Pin<&mut Self>, waker: &LocalWaker) -> Poll<Option<io::Result<()>>> {
enum ReceiverStatus {
NotReady,
Closed,
}
let pending_requests_status = match self.poll_next_request(waker)? {
Poll::Ready(Some(dispatch_request)) => {
self.write_request(dispatch_request)?;
return Poll::Ready(Some(Ok(())));
}
Poll::Ready(None) => ReceiverStatus::Closed,
Poll::Pending => ReceiverStatus::NotReady,
};
let canceled_requests_status = match self.poll_next_cancellation(waker)? {
Poll::Ready(Some((context, request_id))) => {
self.write_cancel(context, request_id)?;
return Poll::Ready(Some(Ok(())));
}
Poll::Ready(None) => ReceiverStatus::Closed,
Poll::Pending => ReceiverStatus::NotReady,
};
match (pending_requests_status, canceled_requests_status) {
(ReceiverStatus::Closed, ReceiverStatus::Closed) => {
ready!(self.transport().poll_flush(waker)?);
Poll::Ready(None)
}
(ReceiverStatus::NotReady, _) | (_, ReceiverStatus::NotReady) => {
// No more messages to process, so flush any messages buffered in the transport.
ready!(self.transport().poll_flush(waker)?);
// Even if we fully-flush, we return Pending, because we have no more requests
// or cancellations right now.
Poll::Pending
}
}
}
/// Yields the next pending request, if one is ready to be sent.
fn poll_next_request(
self: &mut Pin<&mut Self>,
waker: &LocalWaker,
) -> Poll<Option<io::Result<DispatchRequest<Req, Resp>>>> {
if self.in_flight_requests().len() >= self.config.max_in_flight_requests {
info!(
"At in-flight request capacity ({}/{}).",
self.in_flight_requests().len(),
self.config.max_in_flight_requests
);
// No need to schedule a wakeup, because timers and responses are responsible
// for clearing out in-flight requests.
return Poll::Pending;
}
while let Poll::Pending = self.transport().poll_ready(waker)? {
// We can't yield a request-to-be-sent before the transport is capable of buffering it.
ready!(self.transport().poll_flush(waker)?);
}
loop {
match ready!(self.pending_requests().poll_next_unpin(waker)) {
Some(request) => {
if request.response_completion.is_canceled() {
trace!(
"[{}] Request canceled before being sent.",
request.ctx.trace_id()
);
continue;
}
return Poll::Ready(Some(Ok(request)));
}
None => {
trace!("[{}] pending_requests closed", self.server_addr());
return Poll::Ready(None);
}
}
}
}
/// Yields the next pending cancellation, and, if one is ready, cancels the associated request.
fn poll_next_cancellation(
self: &mut Pin<&mut Self>,
waker: &LocalWaker,
) -> Poll<Option<io::Result<(context::Context, u64)>>> {
while let Poll::Pending = self.transport().poll_ready(waker)? {
ready!(self.transport().poll_flush(waker)?);
}
loop {
match ready!(self.canceled_requests().poll_next_unpin(waker)) {
Some(request_id) => {
if let Some(in_flight_data) = self.in_flight_requests().remove(&request_id) {
self.in_flight_requests().compact(0.1);
debug!(
"[{}/{}] Removed request.",
in_flight_data.ctx.trace_id(),
self.server_addr()
);
return Poll::Ready(Some(Ok((in_flight_data.ctx, request_id))));
}
}
None => {
trace!("[{}] canceled_requests closed.", self.server_addr());
return Poll::Ready(None);
}
}
}
}
fn write_request(
self: &mut Pin<&mut Self>,
dispatch_request: DispatchRequest<Req, Resp>,
) -> io::Result<()> {
let request_id = dispatch_request.request_id;
let request = ClientMessage {
trace_context: dispatch_request.ctx.trace_context,
message: ClientMessageKind::Request(Request {
id: request_id,
message: dispatch_request.request,
deadline: dispatch_request.ctx.deadline,
}),
};
self.transport().start_send(request)?;
self.in_flight_requests().insert(
request_id,
InFlightData {
ctx: dispatch_request.ctx,
response_completion: dispatch_request.response_completion,
},
);
Ok(())
}
fn write_cancel(
self: &mut Pin<&mut Self>,
context: context::Context,
request_id: u64,
) -> io::Result<()> {
let trace_id = *context.trace_id();
let cancel = ClientMessage {
trace_context: context.trace_context,
message: ClientMessageKind::Cancel { request_id },
};
self.transport().start_send(cancel)?;
trace!("[{}/{}] Cancel message sent.", trace_id, self.server_addr());
return Ok(());
}
/// Sends a server response to the client task that initiated the associated request.
fn complete(self: &mut Pin<&mut Self>, response: Response<Resp>) -> bool {
if let Some(in_flight_data) = self.in_flight_requests().remove(&response.request_id) {
self.in_flight_requests().compact(0.1);
trace!(
"[{}/{}] Received response.",
in_flight_data.ctx.trace_id(),
self.server_addr()
);
let _ = in_flight_data.response_completion.send(response);
return true;
}
debug!(
"[{}] No in-flight request found for request_id = {}.",
self.server_addr(),
response.request_id
);
// If the response completion was absent, then the request was already canceled.
false
}
}
impl<Req, Resp, C> Future for RequestDispatch<Req, Resp, C>
where
Req: Send,
Resp: Send,
C: Transport<Item = Response<Resp>, SinkItem = ClientMessage<Req>>,
{
type Output = io::Result<()>;
fn poll(mut self: Pin<&mut Self>, waker: &LocalWaker) -> Poll<io::Result<()>> {
trace!("[{}] RequestDispatch::poll", self.server_addr());
loop {
match (self.pump_read(waker)?, self.pump_write(waker)?) {
(read, write @ Poll::Ready(None)) => {
if self.in_flight_requests().is_empty() {
info!(
"[{}] Shutdown: write half closed, and no requests in flight.",
self.server_addr()
);
return Poll::Ready(Ok(()));
}
match read {
Poll::Ready(Some(())) => continue,
_ => {
trace!(
"[{}] read: {:?}, write: {:?}, (not ready)",
self.server_addr(),
read,
write,
);
return Poll::Pending;
}
}
}
(read @ Poll::Ready(Some(())), write) | (read, write @ Poll::Ready(Some(()))) => {
trace!(
"[{}] read: {:?}, write: {:?}",
self.server_addr(),
read,
write,
)
}
(read, write) => {
trace!(
"[{}] read: {:?}, write: {:?} (not ready)",
self.server_addr(),
read,
write,
);
return Poll::Pending;
}
}
}
}
}
/// A server-bound request sent from a [`Channel`] to request dispatch, which will then manage
/// the lifecycle of the request.
#[derive(Debug)]
struct DispatchRequest<Req, Resp> {
ctx: context::Context,
request_id: u64,
request: Req,
response_completion: oneshot::Sender<Response<Resp>>,
}
struct InFlightData<Resp> {
ctx: context::Context,
response_completion: oneshot::Sender<Response<Resp>>,
}
/// Sends request cancellation signals.
#[derive(Debug, Clone)]
struct RequestCancellation(mpsc::UnboundedSender<u64>);
/// A stream of IDs of requests that have been canceled.
#[derive(Debug)]
struct CanceledRequests(mpsc::UnboundedReceiver<u64>);
/// Returns a channel to send request cancellation messages.
fn cancellations() -> (RequestCancellation, CanceledRequests) {
// Unbounded because messages are sent in the drop fn. This is fine, because it's still
// bounded by the number of in-flight requests. Additionally, each request has a clone
// of the sender, so the bounded channel would have the same behavior,
// since it guarantees a slot.
let (tx, rx) = mpsc::unbounded();
(RequestCancellation(tx), CanceledRequests(rx))
}
impl RequestCancellation {
/// Cancels the request with ID `request_id`.
fn cancel(&mut self, request_id: u64) {
let _ = self.0.unbounded_send(request_id);
}
}
impl Stream for CanceledRequests {
type Item = u64;
fn poll_next(mut self: Pin<&mut Self>, waker: &LocalWaker) -> Poll<Option<u64>> {
self.0.poll_next_unpin(waker)
}
}
#[cfg(test)]
mod tests {
use super::{CanceledRequests, Channel, RequestCancellation, RequestDispatch};
use crate::{
client::Config,
context,
transport::{self, channel::UnboundedChannel},
ClientMessage, Response,
};
use fnv::FnvHashMap;
use futures::{Poll, channel::mpsc, prelude::*};
use futures_test::task::{noop_local_waker_ref};
use std::{
net::{IpAddr, Ipv4Addr, SocketAddr},
pin::Pin,
sync::atomic::AtomicU64,
sync::Arc,
};
#[test]
fn stage_request() {
let (mut dispatch, mut channel, _server_channel) = set_up();
// Test that a request future dropped before it's processed by dispatch will cause the request
// to not be added to the in-flight request map.
let _resp = tokio::runtime::current_thread::block_on_all(
channel
.send(context::current(), "hi".to_string())
.boxed()
.compat(),
);
let mut dispatch = Pin::new(&mut dispatch);
let waker = &noop_local_waker_ref();
let req = dispatch.poll_next_request(waker).ready();
assert!(req.is_some());
let req = req.unwrap();
assert_eq!(req.request_id, 0);
assert_eq!(req.request, "hi".to_string());
}
#[test]
fn stage_request_response_future_dropped() {
let (mut dispatch, mut channel, _server_channel) = set_up();
// Test that a request future dropped before it's processed by dispatch will cause the request
// to not be added to the in-flight request map.
let resp = tokio::runtime::current_thread::block_on_all(
channel
.send(context::current(), "hi".into())
.boxed()
.compat(),
).unwrap();
drop(resp);
drop(channel);
let mut dispatch = Pin::new(&mut dispatch);
let waker = &noop_local_waker_ref();
dispatch.poll_next_cancellation(waker).unwrap();
assert!(dispatch.poll_next_request(waker).ready().is_none());
}
#[test]
fn stage_request_response_future_closed() {
let (mut dispatch, mut channel, _server_channel) = set_up();
// Test that a request future that's closed its receiver but not yet canceled its request --
// i.e. still in `drop fn` -- will cause the request to not be added to the in-flight request
// map.
let resp = tokio::runtime::current_thread::block_on_all(
channel
.send(context::current(), "hi".into())
.boxed()
.compat(),
).unwrap();
drop(resp);
drop(channel);
let mut dispatch = Pin::new(&mut dispatch);
let waker = &noop_local_waker_ref();
assert!(dispatch.poll_next_request(waker).ready().is_none());
}
fn set_up() -> (
RequestDispatch<String, String, UnboundedChannel<Response<String>, ClientMessage<String>>>,
Channel<String, String>,
UnboundedChannel<ClientMessage<String>, Response<String>>,
) {
let _ = env_logger::try_init();
let (to_dispatch, pending_requests) = mpsc::channel(1);
let (cancel_tx, canceled_requests) = mpsc::unbounded();
let (client_channel, server_channel) = transport::channel::unbounded();
let dispatch = RequestDispatch::<String, String, _> {
transport: client_channel.fuse(),
pending_requests: pending_requests.fuse(),
canceled_requests: CanceledRequests(canceled_requests),
in_flight_requests: FnvHashMap::default(),
config: Config::default(),
server_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
};
let cancellation = RequestCancellation(cancel_tx);
let channel = Channel {
to_dispatch,
cancellation,
next_request_id: Arc::new(AtomicU64::new(0)),
server_addr: SocketAddr::new(Ipv4Addr::UNSPECIFIED.into(), 0),
};
(dispatch, channel, server_channel)
}
trait PollTest {
type T;
fn unwrap(self) -> Poll<Self::T>;
fn ready(self) -> Self::T;
}
impl<T, E> PollTest for Poll<Option<Result<T, E>>>
where
E: ::std::fmt::Display + Send + 'static,
{
type T = Option<T>;
fn unwrap(self) -> Poll<Option<T>> {
match self {
Poll::Ready(Some(Ok(t))) => Poll::Ready(Some(t)),
Poll::Ready(None) => Poll::Ready(None),
Poll::Ready(Some(Err(e))) => panic!(e.to_string()),
Poll::Pending => Poll::Pending,
}
}
fn ready(self) -> Option<T> {
match self {
Poll::Ready(Some(Ok(t))) => Some(t),
Poll::Ready(None) => None,
Poll::Ready(Some(Err(e))) => panic!(e.to_string()),
Poll::Pending => panic!("Pending"),
}
}
}
}

View File

@@ -1,91 +0,0 @@
// Copyright 2018 Google LLC
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
//! Provides a client that connects to a server and sends multiplexed requests.
use crate::{context::Context, ClientMessage, Response, Transport};
use log::warn;
use std::{
io,
net::{Ipv4Addr, SocketAddr},
};
mod dispatch;
/// Sends multiplexed requests to, and receives responses from, a server.
#[derive(Debug)]
pub struct Client<Req, Resp> {
/// Channel to send requests to the dispatch task.
channel: dispatch::Channel<Req, Resp>,
}
impl<Req, Resp> Clone for Client<Req, Resp> {
fn clone(&self) -> Self {
Client {
channel: self.channel.clone(),
}
}
}
/// Settings that control the behavior of the client.
#[non_exhaustive]
#[derive(Clone, Debug)]
pub struct Config {
/// The number of requests that can be in flight at once.
/// `max_in_flight_requests` controls the size of the map used by the client
/// for storing pending requests.
pub max_in_flight_requests: usize,
/// The number of requests that can be buffered client-side before being sent.
/// `pending_requests_buffer` controls the size of the channel clients use
/// to communicate with the request dispatch task.
pub pending_request_buffer: usize,
}
impl Default for Config {
fn default() -> Self {
Config {
max_in_flight_requests: 1_000,
pending_request_buffer: 100,
}
}
}
impl<Req, Resp> Client<Req, Resp>
where
Req: Send,
Resp: Send,
{
/// Creates a new Client by wrapping a [`Transport`] and spawning a dispatch task
/// that manages the lifecycle of requests.
///
/// Must only be called from on an executor.
pub async fn new<T>(config: Config, transport: T) -> io::Result<Self>
where
T: Transport<Item = Response<Resp>, SinkItem = ClientMessage<Req>> + Send,
{
let server_addr = transport.peer_addr().unwrap_or_else(|e| {
warn!(
"Setting peer to unspecified because peer could not be determined: {}",
e
);
SocketAddr::new(Ipv4Addr::UNSPECIFIED.into(), 0)
});
Ok(Client {
channel: await!(dispatch::spawn(config, transport, server_addr))?,
})
}
/// Initiates a request, sending it to the dispatch task.
///
/// Returns a [`Future`] that resolves to this client and the future response
/// once the request is successfully enqueued.
///
/// [`Future`]: futures::Future
pub async fn call(&mut self, ctx: Context, request: Req) -> io::Result<Resp> {
await!(self.channel.call(ctx, request))
}
}

View File

@@ -1,257 +0,0 @@
// Copyright 2018 Google LLC
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
use crate::{
server::{Channel, Config},
util::Compact,
ClientMessage, Response, Transport,
};
use fnv::FnvHashMap;
use futures::{channel::mpsc, prelude::*, ready, stream::Fuse, task::{LocalWaker, Poll}};
use log::{debug, error, info, trace, warn};
use pin_utils::unsafe_pinned;
use std::{
collections::hash_map::Entry,
io,
marker::PhantomData,
net::{IpAddr, SocketAddr},
ops::Try,
option::NoneError,
pin::Pin,
};
/// Drops connections under configurable conditions:
///
/// 1. If the max number of connections is reached.
/// 2. If the max number of connections for a single IP is reached.
#[derive(Debug)]
pub struct ConnectionFilter<S, Req, Resp> {
listener: Fuse<S>,
closed_connections: mpsc::UnboundedSender<SocketAddr>,
closed_connections_rx: mpsc::UnboundedReceiver<SocketAddr>,
config: Config,
connections_per_ip: FnvHashMap<IpAddr, usize>,
open_connections: usize,
ghost: PhantomData<(Req, Resp)>,
}
enum NewConnection<Req, Resp, C> {
Filtered,
Accepted(Channel<Req, Resp, C>),
}
impl<Req, Resp, C> Try for NewConnection<Req, Resp, C> {
type Ok = Channel<Req, Resp, C>;
type Error = NoneError;
fn into_result(self) -> Result<Channel<Req, Resp, C>, NoneError> {
match self {
NewConnection::Filtered => Err(NoneError),
NewConnection::Accepted(channel) => Ok(channel),
}
}
fn from_error(_: NoneError) -> Self {
NewConnection::Filtered
}
fn from_ok(channel: Channel<Req, Resp, C>) -> Self {
NewConnection::Accepted(channel)
}
}
impl<S, Req, Resp> ConnectionFilter<S, Req, Resp> {
unsafe_pinned!(open_connections: usize);
unsafe_pinned!(config: Config);
unsafe_pinned!(connections_per_ip: FnvHashMap<IpAddr, usize>);
unsafe_pinned!(closed_connections_rx: mpsc::UnboundedReceiver<SocketAddr>);
unsafe_pinned!(listener: Fuse<S>);
/// Sheds new connections to stay under configured limits.
pub fn filter<C>(listener: S, config: Config) -> Self
where
S: Stream<Item = Result<C, io::Error>>,
C: Transport<Item = ClientMessage<Req>, SinkItem = Response<Resp>> + Send,
{
let (closed_connections, closed_connections_rx) = mpsc::unbounded();
ConnectionFilter {
listener: listener.fuse(),
closed_connections,
closed_connections_rx,
config,
connections_per_ip: FnvHashMap::default(),
open_connections: 0,
ghost: PhantomData,
}
}
fn handle_new_connection<C>(self: &mut Pin<&mut Self>, stream: C) -> NewConnection<Req, Resp, C>
where
C: Transport<Item = ClientMessage<Req>, SinkItem = Response<Resp>> + Send,
{
let peer = match stream.peer_addr() {
Ok(peer) => peer,
Err(e) => {
warn!("Could not get peer_addr of new connection: {}", e);
return NewConnection::Filtered;
}
};
let open_connections = *self.open_connections();
if open_connections >= self.config().max_connections {
warn!(
"[{}] Shedding connection because the maximum open connections \
limit is reached ({}/{}).",
peer,
open_connections,
self.config().max_connections
);
return NewConnection::Filtered;
}
let config = self.config.clone();
let open_connections_for_ip = self.increment_connections_for_ip(&peer)?;
*self.open_connections() += 1;
debug!(
"[{}] Opening channel ({}/{} connections for IP, {} total).",
peer,
open_connections_for_ip,
config.max_connections_per_ip,
self.open_connections(),
);
NewConnection::Accepted(Channel {
client_addr: peer,
closed_connections: self.closed_connections.clone(),
transport: stream.fuse(),
config,
ghost: PhantomData,
})
}
fn handle_closed_connection(self: &mut Pin<&mut Self>, addr: &SocketAddr) {
*self.open_connections() -= 1;
debug!(
"[{}] Closing channel. {} open connections remaining.",
addr, self.open_connections
);
self.decrement_connections_for_ip(&addr);
self.connections_per_ip().compact(0.1);
}
fn increment_connections_for_ip(self: &mut Pin<&mut Self>, peer: &SocketAddr) -> Option<usize> {
let max_connections_per_ip = self.config().max_connections_per_ip;
let mut occupied;
let mut connections_per_ip = self.connections_per_ip();
let occupied = match connections_per_ip.entry(peer.ip()) {
Entry::Vacant(vacant) => vacant.insert(0),
Entry::Occupied(o) => {
if *o.get() < max_connections_per_ip {
// Store the reference outside the block to extend the lifetime.
occupied = o;
occupied.get_mut()
} else {
info!(
"[{}] Opened max connections from IP ({}/{}).",
peer,
o.get(),
max_connections_per_ip
);
return None;
}
}
};
*occupied += 1;
Some(*occupied)
}
fn decrement_connections_for_ip(self: &mut Pin<&mut Self>, addr: &SocketAddr) {
let should_compact = match self.connections_per_ip().entry(addr.ip()) {
Entry::Vacant(_) => {
error!("[{}] Got vacant entry when closing connection.", addr);
return;
}
Entry::Occupied(mut occupied) => {
*occupied.get_mut() -= 1;
if *occupied.get() == 0 {
occupied.remove();
true
} else {
false
}
}
};
if should_compact {
self.connections_per_ip().compact(0.1);
}
}
fn poll_listener<C>(
self: &mut Pin<&mut Self>,
cx: &LocalWaker,
) -> Poll<Option<io::Result<NewConnection<Req, Resp, C>>>>
where
S: Stream<Item = Result<C, io::Error>>,
C: Transport<Item = ClientMessage<Req>, SinkItem = Response<Resp>> + Send,
{
match ready!(self.listener().poll_next_unpin(cx)?) {
Some(codec) => Poll::Ready(Some(Ok(self.handle_new_connection(codec)))),
None => Poll::Ready(None),
}
}
fn poll_closed_connections(
self: &mut Pin<&mut Self>,
cx: &LocalWaker,
) -> Poll<io::Result<()>> {
match ready!(self.closed_connections_rx().poll_next_unpin(cx)) {
Some(addr) => {
self.handle_closed_connection(&addr);
Poll::Ready(Ok(()))
}
None => unreachable!("Holding a copy of closed_connections and didn't close it."),
}
}
}
impl<S, Req, Resp, T> Stream for ConnectionFilter<S, Req, Resp>
where
S: Stream<Item = Result<T, io::Error>>,
T: Transport<Item = ClientMessage<Req>, SinkItem = Response<Resp>> + Send,
{
type Item = io::Result<Channel<Req, Resp, T>>;
fn poll_next(
mut self: Pin<&mut Self>,
cx: &LocalWaker,
) -> Poll<Option<io::Result<Channel<Req, Resp, T>>>> {
loop {
match (self.poll_listener(cx)?, self.poll_closed_connections(cx)?) {
(Poll::Ready(Some(NewConnection::Accepted(channel))), _) => {
return Poll::Ready(Some(Ok(channel)))
}
(Poll::Ready(Some(NewConnection::Filtered)), _) | (_, Poll::Ready(())) => {
trace!("Filtered a connection; {} open.", self.open_connections());
continue;
}
(Poll::Pending, Poll::Pending) => return Poll::Pending,
(Poll::Ready(None), Poll::Pending) => {
if *self.open_connections() > 0 {
trace!(
"Listener closed; {} open connections.",
self.open_connections()
);
return Poll::Pending;
}
trace!("Shutting down listener: all connections closed, and no more coming.");
return Poll::Ready(None);
}
}
}
}
}

View File

@@ -1,605 +0,0 @@
// Copyright 2018 Google LLC
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
//! Provides a server that concurrently handles many connections sending multiplexed requests.
use crate::{
context::Context, util::deadline_compat, util::AsDuration, util::Compact, ClientMessage,
ClientMessageKind, Request, Response, ServerError, Transport,
};
use fnv::FnvHashMap;
use futures::{
channel::mpsc,
future::{abortable, AbortHandle},
prelude::*,
ready,
stream::Fuse,
task::{LocalWaker, Poll},
try_ready,
};
use humantime::format_rfc3339;
use log::{debug, error, info, trace, warn};
use pin_utils::{unsafe_pinned, unsafe_unpinned};
use std::{
error::Error as StdError,
io,
marker::PhantomData,
net::SocketAddr,
pin::Pin,
time::{Instant, SystemTime},
};
use tokio_timer::timeout;
use trace::{self, TraceId};
mod filter;
/// Manages clients, serving multiplexed requests over each connection.
#[derive(Debug)]
pub struct Server<Req, Resp> {
config: Config,
ghost: PhantomData<(Req, Resp)>,
}
/// Settings that control the behavior of the server.
#[non_exhaustive]
#[derive(Clone, Debug)]
pub struct Config {
/// The maximum number of clients that can be connected to the server at once. When at the
/// limit, existing connections are honored and new connections are rejected.
pub max_connections: usize,
/// The maximum number of clients per IP address that can be connected to the server at once.
/// When an IP is at the limit, existing connections are honored and new connections on that IP
/// address are rejected.
pub max_connections_per_ip: usize,
/// The maximum number of requests that can be in flight for each client. When a client is at
/// the in-flight request limit, existing requests are fulfilled and new requests are rejected.
/// Rejected requests are sent a response error.
pub max_in_flight_requests_per_connection: usize,
/// The number of responses per client that can be buffered server-side before being sent.
/// `pending_response_buffer` controls the buffer size of the channel that a server's
/// response tasks use to send responses to the client handler task.
pub pending_response_buffer: usize,
}
impl Default for Config {
fn default() -> Self {
Config {
max_connections: 1_000_000,
max_connections_per_ip: 1_000,
max_in_flight_requests_per_connection: 1_000,
pending_response_buffer: 100,
}
}
}
impl<Req, Resp> Server<Req, Resp> {
/// Returns a new server with configuration specified `config`.
pub fn new(config: Config) -> Self {
Server {
config,
ghost: PhantomData,
}
}
/// Returns the config for this server.
pub fn config(&self) -> &Config {
&self.config
}
/// Returns a stream of the incoming connections to the server.
pub fn incoming<S, T>(
self,
listener: S,
) -> impl Stream<Item = io::Result<Channel<Req, Resp, T>>>
where
Req: Send,
Resp: Send,
S: Stream<Item = io::Result<T>>,
T: Transport<Item = ClientMessage<Req>, SinkItem = Response<Resp>> + Send,
{
self::filter::ConnectionFilter::filter(listener, self.config.clone())
}
}
/// The future driving the server.
#[derive(Debug)]
pub struct Running<S, F> {
incoming: S,
request_handler: F,
}
impl<S, F> Running<S, F> {
unsafe_pinned!(incoming: S);
unsafe_unpinned!(request_handler: F);
}
impl<S, T, Req, Resp, F, Fut> Future for Running<S, F>
where
S: Sized + Stream<Item = io::Result<Channel<Req, Resp, T>>>,
Req: Send + 'static,
Resp: Send + 'static,
T: Transport<Item = ClientMessage<Req>, SinkItem = Response<Resp>> + Send + 'static,
F: FnMut(Context, Req) -> Fut + Send + 'static + Clone,
Fut: Future<Output = io::Result<Resp>> + Send + 'static,
{
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &LocalWaker) -> Poll<()> {
while let Some(channel) = ready!(self.incoming().poll_next(cx)) {
match channel {
Ok(channel) => {
let peer = channel.client_addr;
if let Err(e) = crate::spawn(channel.respond_with(self.request_handler().clone()))
{
warn!("[{}] Failed to spawn connection handler: {:?}", peer, e);
}
}
Err(e) => {
warn!("Incoming connection error: {}", e);
}
}
}
info!("Server shutting down.");
return Poll::Ready(());
}
}
/// A utility trait enabling a stream to fluently chain a request handler.
pub trait Handler<T, Req, Resp>
where
Self: Sized + Stream<Item = io::Result<Channel<Req, Resp, T>>>,
Req: Send,
Resp: Send,
T: Transport<Item = ClientMessage<Req>, SinkItem = Response<Resp>> + Send,
{
/// Responds to all requests with `request_handler`.
fn respond_with<F, Fut>(self, request_handler: F) -> Running<Self, F>
where
F: FnMut(Context, Req) -> Fut + Send + 'static + Clone,
Fut: Future<Output = io::Result<Resp>> + Send + 'static,
{
Running {
incoming: self,
request_handler,
}
}
}
impl<T, Req, Resp, S> Handler<T, Req, Resp> for S
where
S: Sized + Stream<Item = io::Result<Channel<Req, Resp, T>>>,
Req: Send,
Resp: Send,
T: Transport<Item = ClientMessage<Req>, SinkItem = Response<Resp>> + Send,
{}
/// Responds to all requests with `request_handler`.
/// The server end of an open connection with a client.
#[derive(Debug)]
pub struct Channel<Req, Resp, T> {
/// Writes responses to the wire and reads requests off the wire.
transport: Fuse<T>,
/// Signals the connection is closed when `Channel` is dropped.
closed_connections: mpsc::UnboundedSender<SocketAddr>,
/// Channel limits to prevent unlimited resource usage.
config: Config,
/// The address of the server connected to.
client_addr: SocketAddr,
/// Types the request and response.
ghost: PhantomData<(Req, Resp)>,
}
impl<Req, Resp, T> Drop for Channel<Req, Resp, T> {
fn drop(&mut self) {
trace!("[{}] Closing channel.", self.client_addr);
// Even in a bounded channel, each connection would have a guaranteed slot, so using
// an unbounded sender is actually no different. And, the bound is on the maximum number
// of open connections.
if self
.closed_connections
.unbounded_send(self.client_addr)
.is_err()
{
warn!(
"[{}] Failed to send closed connection message.",
self.client_addr
);
}
}
}
impl<Req, Resp, T> Channel<Req, Resp, T> {
unsafe_pinned!(transport: Fuse<T>);
}
impl<Req, Resp, T> Channel<Req, Resp, T>
where
T: Transport<Item = ClientMessage<Req>, SinkItem = Response<Resp>> + Send,
Req: Send,
Resp: Send,
{
pub(crate) fn start_send(self: &mut Pin<&mut Self>, response: Response<Resp>) -> io::Result<()> {
self.transport().start_send(response)
}
pub(crate) fn poll_ready(
self: &mut Pin<&mut Self>,
cx: &LocalWaker,
) -> Poll<io::Result<()>> {
self.transport().poll_ready(cx)
}
pub(crate) fn poll_flush(
self: &mut Pin<&mut Self>,
cx: &LocalWaker,
) -> Poll<io::Result<()>> {
self.transport().poll_flush(cx)
}
pub(crate) fn poll_next(
self: &mut Pin<&mut Self>,
cx: &LocalWaker,
) -> Poll<Option<io::Result<ClientMessage<Req>>>> {
self.transport().poll_next(cx)
}
/// Returns the address of the client connected to the channel.
pub fn client_addr(&self) -> &SocketAddr {
&self.client_addr
}
/// Respond to requests coming over the channel with `f`. Returns a future that drives the
/// responses and resolves when the connection is closed.
pub fn respond_with<F, Fut>(self, f: F) -> impl Future<Output = ()>
where
F: FnMut(Context, Req) -> Fut + Send + 'static,
Fut: Future<Output = io::Result<Resp>> + Send + 'static,
Req: 'static,
Resp: 'static,
{
let (responses_tx, responses) = mpsc::channel(self.config.pending_response_buffer);
let responses = responses.fuse();
let peer = self.client_addr;
ClientHandler {
channel: self,
f,
pending_responses: responses,
responses_tx,
in_flight_requests: FnvHashMap::default(),
}.unwrap_or_else(move |e| {
info!("[{}] ClientHandler errored out: {}", peer, e);
})
}
}
#[derive(Debug)]
struct ClientHandler<Req, Resp, T, F> {
channel: Channel<Req, Resp, T>,
/// Responses waiting to be written to the wire.
pending_responses: Fuse<mpsc::Receiver<(Context, Response<Resp>)>>,
/// Handed out to request handlers to fan in responses.
responses_tx: mpsc::Sender<(Context, Response<Resp>)>,
/// Number of requests currently being responded to.
in_flight_requests: FnvHashMap<u64, AbortHandle>,
/// Request handler.
f: F,
}
impl<Req, Resp, T, F> ClientHandler<Req, Resp, T, F> {
unsafe_pinned!(channel: Channel<Req, Resp, T>);
unsafe_pinned!(in_flight_requests: FnvHashMap<u64, AbortHandle>);
unsafe_pinned!(pending_responses: Fuse<mpsc::Receiver<(Context, Response<Resp>)>>);
unsafe_pinned!(responses_tx: mpsc::Sender<(Context, Response<Resp>)>);
// For this to be safe, field f must be private, and code in this module must never
// construct PinMut<F>.
unsafe_unpinned!(f: F);
}
impl<Req, Resp, T, F, Fut> ClientHandler<Req, Resp, T, F>
where
Req: Send + 'static,
Resp: Send + 'static,
T: Transport<Item = ClientMessage<Req>, SinkItem = Response<Resp>> + Send,
F: FnMut(Context, Req) -> Fut + Send + 'static,
Fut: Future<Output = io::Result<Resp>> + Send + 'static,
{
/// If at max in-flight requests, check that there's room to immediately write a throttled
/// response.
fn poll_ready_if_throttling(
self: &mut Pin<&mut Self>,
cx: &LocalWaker,
) -> Poll<io::Result<()>> {
if self.in_flight_requests.len()
>= self.channel.config.max_in_flight_requests_per_connection
{
let peer = self.channel().client_addr;
while let Poll::Pending = self.channel().poll_ready(cx)? {
info!(
"[{}] In-flight requests at max ({}), and transport is not ready.",
peer,
self.in_flight_requests().len(),
);
try_ready!(self.channel().poll_flush(cx));
}
}
Poll::Ready(Ok(()))
}
fn pump_read(self: &mut Pin<&mut Self>, cx: &LocalWaker) -> Poll<Option<io::Result<()>>> {
ready!(self.poll_ready_if_throttling(cx)?);
Poll::Ready(match ready!(self.channel().poll_next(cx)?) {
Some(message) => {
match message.message {
ClientMessageKind::Request(request) => {
self.handle_request(message.trace_context, request)?;
}
ClientMessageKind::Cancel { request_id } => {
self.cancel_request(&message.trace_context, request_id);
}
}
Some(Ok(()))
}
None => {
trace!("[{}] Read half closed", self.channel.client_addr);
None
}
})
}
fn pump_write(
self: &mut Pin<&mut Self>,
cx: &LocalWaker,
read_half_closed: bool,
) -> Poll<Option<io::Result<()>>> {
match self.poll_next_response(cx)? {
Poll::Ready(Some((_, response))) => {
self.channel().start_send(response)?;
Poll::Ready(Some(Ok(())))
}
Poll::Ready(None) => {
// Shutdown can't be done before we finish pumping out remaining responses.
ready!(self.channel().poll_flush(cx)?);
Poll::Ready(None)
}
Poll::Pending => {
// No more requests to process, so flush any requests buffered in the transport.
ready!(self.channel().poll_flush(cx)?);
// Being here means there are no staged requests and all written responses are
// fully flushed. So, if the read half is closed and there are no in-flight
// requests, then we can close the write half.
if read_half_closed && self.in_flight_requests().is_empty() {
Poll::Ready(None)
} else {
Poll::Pending
}
}
}
}
fn poll_next_response(
self: &mut Pin<&mut Self>,
cx: &LocalWaker,
) -> Poll<Option<io::Result<(Context, Response<Resp>)>>> {
// Ensure there's room to write a response.
while let Poll::Pending = self.channel().poll_ready(cx)? {
ready!(self.channel().poll_flush(cx)?);
}
let peer = self.channel().client_addr;
match ready!(self.pending_responses().poll_next(cx)) {
Some((ctx, response)) => {
if let Some(_) = self.in_flight_requests().remove(&response.request_id) {
self.in_flight_requests().compact(0.1);
}
trace!(
"[{}/{}] Staging response. In-flight requests = {}.",
ctx.trace_id(),
peer,
self.in_flight_requests().len(),
);
return Poll::Ready(Some(Ok((ctx, response))));
}
None => {
// This branch likely won't happen, since the ClientHandler is holding a Sender.
trace!("[{}] No new responses.", peer);
Poll::Ready(None)
}
}
}
fn handle_request(
self: &mut Pin<&mut Self>,
trace_context: trace::Context,
request: Request<Req>,
) -> io::Result<()> {
let request_id = request.id;
let peer = self.channel().client_addr;
let ctx = Context {
deadline: request.deadline,
trace_context,
};
let request = request.message;
if self.in_flight_requests().len()
>= self.channel().config.max_in_flight_requests_per_connection
{
debug!(
"[{}/{}] Client has reached in-flight request limit ({}/{}).",
ctx.trace_id(),
peer,
self.in_flight_requests().len(),
self.channel().config.max_in_flight_requests_per_connection
);
self.channel().start_send(Response {
request_id,
message: Err(ServerError {
kind: io::ErrorKind::WouldBlock,
detail: Some("Server throttled the request.".into()),
}),
})?;
return Ok(());
}
let deadline = ctx.deadline;
let timeout = deadline.as_duration();
trace!(
"[{}/{}] Received request with deadline {} (timeout {:?}).",
ctx.trace_id(),
peer,
format_rfc3339(deadline),
timeout,
);
let mut response_tx = self.responses_tx().clone();
let trace_id = *ctx.trace_id();
let response = self.f()(ctx.clone(), request);
let response = deadline_compat::Deadline::new(response, Instant::now() + timeout).then(
async move |result| {
let response = Response {
request_id,
message: match result {
Ok(message) => Ok(message),
Err(e) => Err(make_server_error(e, trace_id, peer, deadline)),
},
};
trace!("[{}/{}] Sending response.", trace_id, peer);
await!(response_tx.send((ctx, response)).unwrap_or_else(|_| ()));
},
);
let (abortable_response, abort_handle) = abortable(response);
crate::spawn(abortable_response.map(|_| ()))
.map_err(|e| {
io::Error::new(
io::ErrorKind::Other,
format!(
"Could not spawn response task. Is shutdown: {}",
e.is_shutdown()
),
)
})?;
self.in_flight_requests().insert(request_id, abort_handle);
Ok(())
}
fn cancel_request(self: &mut Pin<&mut Self>, trace_context: &trace::Context, request_id: u64) {
// It's possible the request was already completed, so it's fine
// if this is None.
if let Some(cancel_handle) = self.in_flight_requests().remove(&request_id) {
self.in_flight_requests().compact(0.1);
cancel_handle.abort();
let remaining = self.in_flight_requests().len();
trace!(
"[{}/{}] Request canceled. In-flight requests = {}",
trace_context.trace_id,
self.channel.client_addr,
remaining,
);
} else {
trace!(
"[{}/{}] Received cancellation, but response handler \
is already complete.",
trace_context.trace_id,
self.channel.client_addr
);
}
}
}
impl<Req, Resp, T, F, Fut> Future for ClientHandler<Req, Resp, T, F>
where
Req: Send + 'static,
Resp: Send + 'static,
T: Transport<Item = ClientMessage<Req>, SinkItem = Response<Resp>> + Send,
F: FnMut(Context, Req) -> Fut + Send + 'static,
Fut: Future<Output = io::Result<Resp>> + Send + 'static,
{
type Output = io::Result<()>;
fn poll(mut self: Pin<&mut Self>, cx: &LocalWaker) -> Poll<io::Result<()>> {
trace!("[{}] ClientHandler::poll", self.channel.client_addr);
loop {
let read = self.pump_read(cx)?;
match (read, self.pump_write(cx, read == Poll::Ready(None))?) {
(Poll::Ready(None), Poll::Ready(None)) => {
info!("[{}] Client disconnected.", self.channel.client_addr);
return Poll::Ready(Ok(()));
}
(read @ Poll::Ready(Some(())), write) | (read, write @ Poll::Ready(Some(()))) => {
trace!(
"[{}] read: {:?}, write: {:?}.",
self.channel.client_addr,
read,
write
)
}
(read, write) => {
trace!(
"[{}] read: {:?}, write: {:?} (not ready).",
self.channel.client_addr,
read,
write,
);
return Poll::Pending;
}
}
}
}
}
fn make_server_error(
e: timeout::Error<io::Error>,
trace_id: TraceId,
peer: SocketAddr,
deadline: SystemTime,
) -> ServerError {
if e.is_elapsed() {
debug!(
"[{}/{}] Response did not complete before deadline of {}s.",
trace_id,
peer,
format_rfc3339(deadline)
);
// No point in responding, since the client will have dropped the request.
ServerError {
kind: io::ErrorKind::TimedOut,
detail: Some(format!(
"Response did not complete before deadline of {}s.",
format_rfc3339(deadline)
)),
}
} else if e.is_timer() {
error!(
"[{}/{}] Response failed because of an issue with a timer: {}",
trace_id, peer, e
);
ServerError {
kind: io::ErrorKind::Other,
detail: Some(format!("{}", e)),
}
} else if e.is_inner() {
let e = e.into_inner().unwrap();
ServerError {
kind: e.kind(),
detail: Some(e.description().into()),
}
} else {
error!("[{}/{}] Unexpected response failure: {}", trace_id, peer, e);
ServerError {
kind: io::ErrorKind::Other,
detail: Some(format!("Server unexpectedly failed to respond: {}", e)),
}
}
}

View File

@@ -1,157 +0,0 @@
// Copyright 2018 Google LLC
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
//! Transports backed by in-memory channels.
use crate::Transport;
use futures::{channel::mpsc, task::{LocalWaker}, Poll, Sink, Stream};
use pin_utils::unsafe_pinned;
use std::pin::Pin;
use std::{
io,
net::{IpAddr, Ipv4Addr, SocketAddr},
};
/// Returns two unbounded channel peers. Each [`Stream`] yields items sent through the other's
/// [`Sink`].
pub fn unbounded<SinkItem, Item>() -> (
UnboundedChannel<SinkItem, Item>,
UnboundedChannel<Item, SinkItem>,
) {
let (tx1, rx2) = mpsc::unbounded();
let (tx2, rx1) = mpsc::unbounded();
(
UnboundedChannel { tx: tx1, rx: rx1 },
UnboundedChannel { tx: tx2, rx: rx2 },
)
}
/// A bi-directional channel backed by an [`UnboundedSender`](mpsc::UnboundedSender)
/// and [`UnboundedReceiver`](mpsc::UnboundedReceiver).
#[derive(Debug)]
pub struct UnboundedChannel<Item, SinkItem> {
rx: mpsc::UnboundedReceiver<Item>,
tx: mpsc::UnboundedSender<SinkItem>,
}
impl<Item, SinkItem> UnboundedChannel<Item, SinkItem> {
unsafe_pinned!(rx: mpsc::UnboundedReceiver<Item>);
unsafe_pinned!(tx: mpsc::UnboundedSender<SinkItem>);
}
impl<Item, SinkItem> Stream for UnboundedChannel<Item, SinkItem> {
type Item = Result<Item, io::Error>;
fn poll_next(mut self: Pin<&mut Self>, cx: &LocalWaker) -> Poll<Option<io::Result<Item>>> {
self.rx().poll_next(cx).map(|option| option.map(Ok))
}
}
impl<Item, SinkItem> Sink for UnboundedChannel<Item, SinkItem> {
type SinkItem = SinkItem;
type SinkError = io::Error;
fn poll_ready(mut self: Pin<&mut Self>, cx: &LocalWaker) -> Poll<io::Result<()>> {
self.tx()
.poll_ready(cx)
.map_err(|_| io::Error::from(io::ErrorKind::NotConnected))
}
fn start_send(mut self: Pin<&mut Self>, item: SinkItem) -> io::Result<()> {
self.tx()
.start_send(item)
.map_err(|_| io::Error::from(io::ErrorKind::NotConnected))
}
fn poll_flush(
mut self: Pin<&mut Self>,
cx: &LocalWaker,
) -> Poll<Result<(), Self::SinkError>> {
self.tx()
.poll_flush(cx)
.map_err(|_| io::Error::from(io::ErrorKind::NotConnected))
}
fn poll_close(mut self: Pin<&mut Self>, cx: &LocalWaker) -> Poll<io::Result<()>> {
self.tx()
.poll_close(cx)
.map_err(|_| io::Error::from(io::ErrorKind::NotConnected))
}
}
impl<Item, SinkItem> Transport for UnboundedChannel<Item, SinkItem> {
type Item = Item;
type SinkItem = SinkItem;
fn peer_addr(&self) -> io::Result<SocketAddr> {
Ok(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0))
}
fn local_addr(&self) -> io::Result<SocketAddr> {
Ok(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0))
}
}
#[cfg(test)]
mod tests {
use crate::{client::{self, Client}, context, server::{self, Handler, Server}, transport};
use futures::{prelude::*, stream, compat::TokioDefaultSpawner};
use log::trace;
use std::io;
#[test]
fn integration() {
let _ = env_logger::try_init();
crate::init(TokioDefaultSpawner);
let (client_channel, server_channel) = transport::channel::unbounded();
let server = Server::<String, u64>::new(server::Config::default())
.incoming(stream::once(future::ready(Ok(server_channel))))
.respond_with(|_ctx, request| {
future::ready(request.parse::<u64>().map_err(|_| {
io::Error::new(
io::ErrorKind::InvalidInput,
format!("{:?} is not an int", request),
)
}))
});
let responses = async {
let mut client = await!(Client::new(client::Config::default(), client_channel))?;
let response1 = await!(client.call(context::current(), "123".into()));
let response2 = await!(client.call(context::current(), "abc".into()));
Ok::<_, io::Error>((response1, response2))
};
let (response1, response2) =
run_future(server.join(responses.unwrap_or_else(|e| panic!(e)))).1;
trace!("response1: {:?}, response2: {:?}", response1, response2);
assert!(response1.is_ok());
assert_eq!(response1.ok().unwrap(), 123);
assert!(response2.is_err());
assert_eq!(response2.err().unwrap().kind(), io::ErrorKind::InvalidInput);
}
fn run_future<F>(f: F) -> F::Output
where
F: Future + Send + 'static,
F::Output: Send + 'static,
{
let (tx, rx) = futures::channel::oneshot::channel();
tokio::run(
f.map(|result| tx.send(result).unwrap_or_else(|_| unreachable!()))
.boxed()
.unit_error()
.compat(),
);
futures::executor::block_on(rx).unwrap()
}
}

View File

@@ -1,32 +0,0 @@
// Copyright 2018 Google LLC
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
//! Provides a [`Transport`] trait as well as implementations.
//!
//! The rpc crate is transport- and protocol-agnostic. Any transport that impls [`Transport`]
//! can be plugged in, using whatever protocol it wants.
use futures::prelude::*;
use std::{io, net::SocketAddr};
pub mod channel;
/// A bidirectional stream ([`Sink`] + [`Stream`]) of messages.
pub trait Transport
where
Self: Stream<Item = io::Result<<Self as Transport>::Item>>,
Self: Sink<SinkItem = <Self as Transport>::SinkItem, SinkError = io::Error>,
{
/// The type read off the transport.
type Item;
/// The type written to the transport.
type SinkItem;
/// The address of the remote peer this transport is in communication with.
fn peer_addr(&self) -> io::Result<SocketAddr>;
/// The address of the local half of this transport.
fn local_addr(&self) -> io::Result<SocketAddr>;
}

View File

@@ -1,69 +0,0 @@
// Copyright 2018 Google LLC
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
use futures::{
compat::{Compat01As03, Future01CompatExt},
prelude::*,
ready, task::{Poll, LocalWaker},
};
use pin_utils::unsafe_pinned;
use std::pin::Pin;
use std::time::Instant;
use tokio_timer::{timeout, Delay};
#[must_use = "futures do nothing unless polled"]
#[derive(Debug)]
pub struct Deadline<T> {
future: T,
delay: Compat01As03<Delay>,
}
impl<T> Deadline<T> {
unsafe_pinned!(future: T);
unsafe_pinned!(delay: Compat01As03<Delay>);
/// Create a new `Deadline` that completes when `future` completes or when
/// `deadline` is reached.
pub fn new(future: T, deadline: Instant) -> Deadline<T> {
Deadline::new_with_delay(future, Delay::new(deadline))
}
pub(crate) fn new_with_delay(future: T, delay: Delay) -> Deadline<T> {
Deadline {
future,
delay: delay.compat(),
}
}
/// Gets a mutable reference to the underlying future in this deadline.
pub fn get_mut(&mut self) -> &mut T {
&mut self.future
}
}
impl<T> Future for Deadline<T>
where
T: TryFuture,
{
type Output = Result<T::Ok, timeout::Error<T::Error>>;
fn poll(mut self: Pin<&mut Self>, waker: &LocalWaker) -> Poll<Self::Output> {
// First, try polling the future
match self.future().try_poll(waker) {
Poll::Ready(Ok(v)) => return Poll::Ready(Ok(v)),
Poll::Pending => {}
Poll::Ready(Err(e)) => return Poll::Ready(Err(timeout::Error::inner(e))),
}
let delay = self.delay().poll_unpin(waker);
// Now check the timer
match ready!(delay) {
Ok(_) => Poll::Ready(Err(timeout::Error::elapsed())),
Err(e) => Poll::Ready(Err(timeout::Error::timer(e))),
}
}
}

View File

@@ -1,8 +1,6 @@
cargo-features = ["rename-dependency"]
[package]
name = "tarpc"
version = "0.13.0"
version = "0.21.0"
authors = ["Adam Wright <adam.austin.wright@gmail.com>", "Tim Kuehn <timothy.j.kuehn@gmail.com>"]
edition = "2018"
license = "MIT"
@@ -15,24 +13,50 @@ readme = "../README.md"
description = "An RPC framework for Rust with a focus on ease of use."
[features]
serde1 = ["rpc/serde1", "serde", "serde/derive"]
default = []
serde1 = ["tarpc-plugins/serde1", "serde", "serde/derive"]
tokio1 = []
serde-transport = ["tokio-serde", "tokio-util/codec"]
tcp = ["tokio/net", "tokio/stream"]
full = ["serde1", "tokio1", "serde-transport", "tcp"]
[badges]
travis-ci = { repository = "google/tarpc" }
[dependencies]
fnv = "1.0"
futures = "0.3"
humantime = "1.0"
log = "0.4"
serde = { optional = true, version = "1.0" }
tarpc-plugins = { path = "../plugins", version = "0.5.0" }
rpc = { package = "tarpc-lib", path = "../rpc", version = "0.1" }
[target.'cfg(not(test))'.dependencies]
futures-preview = "0.3.0-alpha.8"
pin-project = "0.4.17"
rand = "0.7"
tokio = { version = "0.2", features = ["time"] }
serde = { optional = true, version = "1.0", features = ["derive"] }
tokio-util = { optional = true, version = "0.2" }
tarpc-plugins = { path = "../plugins", version = "0.8" }
tokio-serde = { optional = true, version = "0.6" }
[dev-dependencies]
assert_matches = "1.0"
bytes = { version = "0.5", features = ["serde"] }
env_logger = "0.6"
futures = "0.3"
humantime = "1.0"
futures-preview = { version = "0.3.0-alpha.8", features = ["compat", "tokio-compat"] }
bincode-transport = { package = "tarpc-bincode-transport", version = "0.1", path = "../bincode-transport" }
env_logger = "0.5"
tokio = "0.1"
tokio-executor = "0.1"
log = "0.4"
pin-utils = "0.1.0-alpha"
tokio = { version = "0.2", features = ["full"] }
tokio-serde = { version = "0.6", features = ["json"] }
[[example]]
name = "server_calling_server"
required-features = ["full"]
[[example]]
name = "readme"
required-features = ["full"]
[[example]]
name = "pubsub"
required-features = ["full"]

1
tarpc/README.md Symbolic link
View File

@@ -0,0 +1 @@
../README.md

View File

@@ -4,46 +4,42 @@
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
#![feature(
arbitrary_self_types,
pin,
futures_api,
await_macro,
async_await,
existential_type,
proc_macro_hygiene,
)]
use futures::{
future::{self, Ready},
prelude::*,
Future,
};
use rpc::{
client, context,
server::{self, Handler, Server},
};
use publisher::Publisher as _;
use std::{
collections::HashMap,
io,
net::SocketAddr,
pin::Pin,
sync::{Arc, Mutex},
thread,
time::Duration,
};
use subscriber::Subscriber as _;
use tarpc::{
client, context,
server::{self, Handler},
};
use tokio_serde::formats::Json;
pub mod subscriber {
tarpc::service! {
rpc receive(message: String);
#[tarpc::service]
pub trait Subscriber {
async fn receive(message: String);
}
}
pub mod publisher {
use std::net::SocketAddr;
tarpc::service! {
rpc broadcast(message: String);
rpc subscribe(id: u32, address: SocketAddr) -> Result<(), String>;
rpc unsubscribe(id: u32);
#[tarpc::service]
pub trait Publisher {
async fn broadcast(message: String);
async fn subscribe(id: u32, address: SocketAddr) -> Result<(), String>;
async fn unsubscribe(id: u32);
}
}
@@ -52,27 +48,26 @@ struct Subscriber {
id: u32,
}
impl subscriber::Service for Subscriber {
impl subscriber::Subscriber for Subscriber {
type ReceiveFut = Ready<()>;
fn receive(&self, _: context::Context, message: String) -> Self::ReceiveFut {
println!("{} received message: {}", self.id, message);
fn receive(self, _: context::Context, message: String) -> Self::ReceiveFut {
eprintln!("{} received message: {}", self.id, message);
future::ready(())
}
}
impl Subscriber {
async fn listen(id: u32, config: server::Config) -> io::Result<SocketAddr> {
let incoming = bincode_transport::listen(&"0.0.0.0:0".parse().unwrap())?;
let addr = incoming.local_addr();
tokio_executor::spawn(
Server::new(config)
let incoming = tarpc::serde_transport::tcp::listen("localhost:0", Json::default)
.await?
.filter_map(|r| future::ready(r.ok()));
let addr = incoming.get_ref().local_addr();
tokio::spawn(
server::new(config)
.incoming(incoming)
.take(1)
.respond_with(subscriber::serve(Subscriber { id }))
.unit_error()
.boxed()
.compat()
.respond_with(Subscriber { id }.serve()),
);
Ok(addr)
}
@@ -80,7 +75,7 @@ impl Subscriber {
#[derive(Clone, Debug)]
struct Publisher {
clients: Arc<Mutex<HashMap<u32, subscriber::Client>>>,
clients: Arc<Mutex<HashMap<u32, subscriber::SubscriberClient>>>,
}
impl Publisher {
@@ -91,101 +86,110 @@ impl Publisher {
}
}
impl publisher::Service for Publisher {
existential type BroadcastFut: Future<Output = ()>;
impl publisher::Publisher for Publisher {
type BroadcastFut = Pin<Box<dyn Future<Output = ()> + Send>>;
fn broadcast(&self, _: context::Context, message: String) -> Self::BroadcastFut {
async fn broadcast(clients: Arc<Mutex<HashMap<u32, subscriber::Client>>>, message: String) {
fn broadcast(self, _: context::Context, message: String) -> Self::BroadcastFut {
async fn broadcast(
clients: Arc<Mutex<HashMap<u32, subscriber::SubscriberClient>>>,
message: String,
) {
let mut clients = clients.lock().unwrap().clone();
for client in clients.values_mut() {
// Ignore failing subscribers. In a real pubsub,
// you'd want to continually retry until subscribers
// ack.
let _ = await!(client.receive(context::current(), message.clone()));
let _ = client.receive(context::current(), message.clone()).await;
}
}
broadcast(self.clients.clone(), message)
broadcast(self.clients.clone(), message).boxed()
}
existential type SubscribeFut: Future<Output = Result<(), String>>;
type SubscribeFut = Pin<Box<dyn Future<Output = Result<(), String>> + Send>>;
fn subscribe(&self, _: context::Context, id: u32, addr: SocketAddr) -> Self::SubscribeFut {
fn subscribe(self, _: context::Context, id: u32, addr: SocketAddr) -> Self::SubscribeFut {
async fn subscribe(
clients: Arc<Mutex<HashMap<u32, subscriber::Client>>>,
clients: Arc<Mutex<HashMap<u32, subscriber::SubscriberClient>>>,
id: u32,
addr: SocketAddr,
) -> io::Result<()> {
let conn = await!(bincode_transport::connect(&addr))?;
let subscriber = await!(subscriber::new_stub(client::Config::default(), conn))?;
println!("Subscribing {}.", id);
let conn = tarpc::serde_transport::tcp::connect(addr, Json::default()).await?;
let subscriber =
subscriber::SubscriberClient::new(client::Config::default(), conn).spawn()?;
eprintln!("Subscribing {}.", id);
clients.lock().unwrap().insert(id, subscriber);
Ok(())
}
subscribe(Arc::clone(&self.clients), id, addr).map_err(|e| e.to_string())
subscribe(Arc::clone(&self.clients), id, addr)
.map_err(|e| e.to_string())
.boxed()
}
existential type UnsubscribeFut: Future<Output = ()>;
type UnsubscribeFut = Pin<Box<dyn Future<Output = ()> + Send>>;
fn unsubscribe(&self, _: context::Context, id: u32) -> Self::UnsubscribeFut {
println!("Unsubscribing {}", id);
fn unsubscribe(self, _: context::Context, id: u32) -> Self::UnsubscribeFut {
eprintln!("Unsubscribing {}", id);
let mut clients = self.clients.lock().unwrap();
if let None = clients.remove(&id) {
if clients.remove(&id).is_none() {
eprintln!(
"Client {} not found. Existings clients: {:?}",
id, &*clients
);
}
future::ready(())
future::ready(()).boxed()
}
}
async fn run() -> io::Result<()> {
#[tokio::main]
async fn main() -> io::Result<()> {
env_logger::init();
let transport = bincode_transport::listen(&"0.0.0.0:0".parse().unwrap())?;
let publisher_addr = transport.local_addr();
tokio_executor::spawn(
Server::new(server::Config::default())
.incoming(transport)
let transport = tarpc::serde_transport::tcp::listen("localhost:0", Json::default)
.await?
.filter_map(|r| future::ready(r.ok()));
let publisher_addr = transport.get_ref().local_addr();
tokio::spawn(
transport
.take(1)
.respond_with(publisher::serve(Publisher::new()))
.unit_error()
.boxed()
.compat()
.map(server::BaseChannel::with_defaults)
.respond_with(Publisher::new().serve()),
);
let subscriber1 = await!(Subscriber::listen(0, server::Config::default()))?;
let subscriber2 = await!(Subscriber::listen(1, server::Config::default()))?;
let subscriber1 = Subscriber::listen(0, server::Config::default()).await?;
let subscriber2 = Subscriber::listen(1, server::Config::default()).await?;
let publisher_conn = bincode_transport::connect(&publisher_addr);
let publisher_conn = await!(publisher_conn)?;
let mut publisher = await!(publisher::new_stub(
client::Config::default(),
publisher_conn
))?;
let publisher_conn = tarpc::serde_transport::tcp::connect(publisher_addr, Json::default());
let publisher_conn = publisher_conn.await?;
let mut publisher =
publisher::PublisherClient::new(client::Config::default(), publisher_conn).spawn()?;
if let Err(e) = await!(publisher.subscribe(context::current(), 0, subscriber1))? {
if let Err(e) = publisher
.subscribe(context::current(), 0, subscriber1)
.await?
{
eprintln!("Couldn't subscribe subscriber 0: {}", e);
}
if let Err(e) = await!(publisher.subscribe(context::current(), 1, subscriber2))? {
if let Err(e) = publisher
.subscribe(context::current(), 1, subscriber2)
.await?
{
eprintln!("Couldn't subscribe subscriber 1: {}", e);
}
println!("Broadcasting...");
await!(publisher.broadcast(context::current(), "hello to all".to_string()))?;
await!(publisher.unsubscribe(context::current(), 1))?;
await!(publisher.broadcast(context::current(), "hi again".to_string()))?;
publisher
.broadcast(context::current(), "hello to all".to_string())
.await?;
publisher.unsubscribe(context::current(), 1).await?;
publisher
.broadcast(context::current(), "hi again".to_string())
.await?;
drop(publisher);
tokio::time::delay_for(Duration::from_millis(100)).await;
println!("Done.");
Ok(())
}
fn main() {
tokio::run(
run()
.boxed()
.map_err(|e| panic!(e))
.boxed()
.compat(),
);
thread::sleep(Duration::from_millis(100));
}

View File

@@ -4,88 +4,76 @@
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
#![feature(
futures_api,
pin,
arbitrary_self_types,
await_macro,
async_await,
proc_macro_hygiene,
)]
use futures::{
future::{self, Ready},
prelude::*,
};
use rpc::{
client, context,
server::{self, Handler, Server},
};
use std::io;
use tarpc::{
client, context,
server::{BaseChannel, Channel},
};
use tokio_serde::formats::Json;
// This is the service definition. It looks a lot like a trait definition.
// It defines one RPC, hello, which takes one arg, name, and returns a String.
tarpc::service! {
rpc hello(name: String) -> String;
/// This is the service definition. It looks a lot like a trait definition.
/// It defines one RPC, hello, which takes one arg, name, and returns a String.
#[tarpc::service]
pub trait World {
async fn hello(name: String) -> String;
}
// This is the type that implements the generated Service trait. It is the business logic
// and is used to start the server.
/// This is the type that implements the generated World trait. It is the business logic
/// and is used to start the server.
#[derive(Clone)]
struct HelloServer;
impl Service for HelloServer {
impl World for HelloServer {
// Each defined rpc generates two items in the trait, a fn that serves the RPC, and
// an associated type representing the future output by the fn.
type HelloFut = Ready<String>;
fn hello(&self, _: context::Context, name: String) -> Self::HelloFut {
fn hello(self, _: context::Context, name: String) -> Self::HelloFut {
future::ready(format!("Hello, {}!", name))
}
}
async fn run() -> io::Result<()> {
// bincode_transport is provided by the associated crate bincode-transport. It makes it easy
// to start up a serde-powered bincode serialization strategy over TCP.
let transport = bincode_transport::listen(&"0.0.0.0:0".parse().unwrap())?;
#[tokio::main]
async fn main() -> io::Result<()> {
// tarpc_json_transport is provided by the associated crate json_transport. It makes it
// easy to start up a serde-powered JSON serialization strategy over TCP.
let mut transport = tarpc::serde_transport::tcp::listen("localhost:0", Json::default).await?;
let addr = transport.local_addr();
// The server is configured with the defaults.
let server = Server::new(server::Config::default())
// Server can listen on any type that implements the Transport trait.
.incoming(transport)
// Close the stream after the client connects
.take(1)
// serve is generated by the tarpc::service! macro. It takes as input any type implementing
// the generated Service trait.
.respond_with(serve(HelloServer));
let server = async move {
// For this example, we're just going to wait for one connection.
let client = transport.next().await.unwrap().unwrap();
tokio_executor::spawn(server.unit_error().boxed().compat());
// `Channel` is a trait representing a server-side connection. It is a trait to allow
// for some channels to be instrumented: for example, to track the number of open connections.
// BaseChannel is the most basic channel, simply wrapping a transport with no added
// functionality.
BaseChannel::with_defaults(client)
// serve_world is generated by the tarpc::service attribute. It takes as input any type
// implementing the generated World trait.
.respond_with(HelloServer.serve())
.execute()
.await;
};
tokio::spawn(server);
let transport = await!(bincode_transport::connect(&addr))?;
let transport = tarpc::serde_transport::tcp::connect(addr, Json::default()).await?;
// new_stub is generated by the tarpc::service! macro. Like Server, it takes a config and any
// Transport as input, and returns a Client, also generated by the macro.
// by the service mcro.
let mut client = await!(new_stub(client::Config::default(), transport))?;
// WorldClient is generated by the tarpc::service attribute. It has a constructor `new` that
// takes a config and any Transport as input.
let mut client = WorldClient::new(client::Config::default(), transport).spawn()?;
// The client has an RPC method for each RPC defined in tarpc::service!. It takes the same args
// as defined, with the addition of a Context, which is always the first arg. The Context
// The client has an RPC method for each RPC defined in the annotated trait. It takes the same
// args as defined, with the addition of a Context, which is always the first arg. The Context
// specifies a deadline and trace information which can be helpful in debugging requests.
let hello = await!(client.hello(context::current(), "Stim".to_string()))?;
let hello = client.hello(context::current(), "Stim".to_string()).await?;
println!("{}", hello);
eprintln!("{}", hello);
Ok(())
}
fn main() {
tokio::run(
run()
.map_err(|e| eprintln!("Oh no: {}", e))
.boxed()
.compat(),
);
}

View File

@@ -4,38 +4,31 @@
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
#![feature(
existential_type,
arbitrary_self_types,
pin,
futures_api,
await_macro,
async_await,
proc_macro_hygiene,
)]
use crate::{add::Service as AddService, double::Service as DoubleService};
use crate::{add::Add as AddService, double::Double as DoubleService};
use futures::{
future::{self, Ready},
prelude::*,
};
use rpc::{
use std::{io, pin::Pin};
use tarpc::{
client, context,
server::{self, Handler, Server},
server::{Handler, Server},
};
use std::io;
use tokio_serde::formats::Json;
pub mod add {
tarpc::service! {
#[tarpc::service]
pub trait Add {
/// Add two ints together.
rpc add(x: i32, y: i32) -> i32;
async fn add(x: i32, y: i32) -> i32;
}
}
pub mod double {
tarpc::service! {
#[tarpc::service]
pub trait Double {
/// 2 * x
rpc double(x: i32) -> Result<i32, String>;
async fn double(x: i32) -> Result<i32, String>;
}
}
@@ -45,67 +38,64 @@ struct AddServer;
impl AddService for AddServer {
type AddFut = Ready<i32>;
fn add(&self, _: context::Context, x: i32, y: i32) -> Self::AddFut {
fn add(self, _: context::Context, x: i32, y: i32) -> Self::AddFut {
future::ready(x + y)
}
}
#[derive(Clone)]
struct DoubleServer {
add_client: add::Client,
add_client: add::AddClient,
}
impl DoubleService for DoubleServer {
existential type DoubleFut: Future<Output = Result<i32, String>> + Send;
type DoubleFut = Pin<Box<dyn Future<Output = Result<i32, String>> + Send>>;
fn double(&self, _: context::Context, x: i32) -> Self::DoubleFut {
async fn double(mut client: add::Client, x: i32) -> Result<i32, String> {
let result = await!(client.add(context::current(), x, x));
result.map_err(|e| e.to_string())
fn double(self, _: context::Context, x: i32) -> Self::DoubleFut {
async fn double(mut client: add::AddClient, x: i32) -> Result<i32, String> {
client
.add(context::current(), x, x)
.await
.map_err(|e| e.to_string())
}
double(self.add_client.clone(), x)
double(self.add_client.clone(), x).boxed()
}
}
async fn run() -> io::Result<()> {
let add_listener = bincode_transport::listen(&"0.0.0.0:0".parse().unwrap())?;
let addr = add_listener.local_addr();
let add_server = Server::new(server::Config::default())
#[tokio::main]
async fn main() -> io::Result<()> {
env_logger::init();
let add_listener = tarpc::serde_transport::tcp::listen("localhost:0", Json::default)
.await?
.filter_map(|r| future::ready(r.ok()));
let addr = add_listener.get_ref().local_addr();
let add_server = Server::default()
.incoming(add_listener)
.take(1)
.respond_with(add::serve(AddServer));
tokio_executor::spawn(add_server.unit_error().boxed().compat());
.respond_with(AddServer.serve());
tokio::spawn(add_server);
let to_add_server = await!(bincode_transport::connect(&addr))?;
let add_client = await!(add::new_stub(client::Config::default(), to_add_server))?;
let to_add_server = tarpc::serde_transport::tcp::connect(addr, Json::default()).await?;
let add_client = add::AddClient::new(client::Config::default(), to_add_server).spawn()?;
let double_listener = bincode_transport::listen(&"0.0.0.0:0".parse().unwrap())?;
let addr = double_listener.local_addr();
let double_server = rpc::Server::new(server::Config::default())
let double_listener = tarpc::serde_transport::tcp::listen("localhost:0", Json::default)
.await?
.filter_map(|r| future::ready(r.ok()));
let addr = double_listener.get_ref().local_addr();
let double_server = tarpc::Server::default()
.incoming(double_listener)
.take(1)
.respond_with(double::serve(DoubleServer { add_client }));
tokio_executor::spawn(double_server.unit_error().boxed().compat());
.respond_with(DoubleServer { add_client }.serve());
tokio::spawn(double_server);
let to_double_server = await!(bincode_transport::connect(&addr))?;
let mut double_client = await!(double::new_stub(
client::Config::default(),
to_double_server
))?;
let to_double_server = tarpc::serde_transport::tcp::connect(addr, Json::default()).await?;
let mut double_client =
double::DoubleClient::new(client::Config::default(), to_double_server).spawn()?;
for i in 1..=5 {
println!("{:?}", await!(double_client.double(context::current(), i))?);
eprintln!("{:?}", double_client.double(context::current(), i).await?);
}
Ok(())
}
fn main() {
env_logger::init();
tokio::run(
run()
.map_err(|e| panic!(e))
.boxed()
.compat(),
);
}

View File

@@ -1 +1 @@
edition = "Edition2018"
edition = "2018"

View File

@@ -4,10 +4,17 @@
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
//! [![Latest Version](https://img.shields.io/crates/v/tarpc.svg)](https://crates.io/crates/tarpc)
//! [![Join the chat at https://gitter.im/tarpc/Lobby](https://badges.gitter.im/tarpc/Lobby.svg)](https://gitter.im/tarpc/Lobby?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge)
//!
//! *Disclaimer*: This is not an official Google product.
//!
//! tarpc is an RPC framework for rust with a focus on ease of use. Defining a
//! service can be done in just a few lines of code, and most of the boilerplate of
//! writing a server is taken care of for you.
//!
//! [Documentation](https://docs.rs/crate/tarpc/)
//!
//! ## What is an RPC framework?
//! "RPC" stands for "Remote Procedure Call," a function call where the work of
//! producing the return value is being done somewhere else. When an rpc function is
@@ -21,116 +28,264 @@
//!
//! tarpc differentiates itself from other RPC frameworks by defining the schema in code,
//! rather than in a separate language such as .proto. This means there's no separate compilation
//! process, and no cognitive context switching between different languages. Additionally, it
//! works with the community-backed library serde: any serde-serializable type can be used as
//! arguments to tarpc fns.
//! process, and no context switching between different languages.
//!
//! Some other features of tarpc:
//! - Pluggable transport: any type impling `Stream<Item = Request> + Sink<Response>` can be
//! used as a transport to connect the client and server.
//! - `Send + 'static` optional: if the transport doesn't require it, neither does tarpc!
//! - Cascading cancellation: dropping a request will send a cancellation message to the server.
//! The server will cease any unfinished work on the request, subsequently cancelling any of its
//! own requests, repeating for the entire chain of transitive dependencies.
//! - Configurable deadlines and deadline propagation: request deadlines default to 10s if
//! unspecified. The server will automatically cease work when the deadline has passed. Any
//! requests sent by the server that use the request context will propagate the request deadline.
//! For example, if a server is handling a request with a 10s deadline, does 2s of work, then
//! sends a request to another server, that server will see an 8s deadline.
//! - Serde serialization: enabling the `serde1` Cargo feature will make service requests and
//! responses `Serialize + Deserialize`. It's entirely optional, though: in-memory transports can
//! be used, as well, so the price of serialization doesn't have to be paid when it's not needed.
//!
//! ## Usage
//! Add to your `Cargo.toml` dependencies:
//!
//! ```toml
//! tarpc = "0.21.0"
//! ```
//!
//! The `tarpc::service` attribute expands to a collection of items that form an rpc service.
//! These generated types make it easy and ergonomic to write servers with less boilerplate.
//! Simply implement the generated service trait, and you're off to the races!
//!
//! ## Example
//!
//! Here's a small service.
//! For this example, in addition to tarpc, also add two other dependencies to
//! your `Cargo.toml`:
//!
//! ```toml
//! futures = "0.3"
//! tokio = "0.2"
//! ```
//!
//! In the following example, we use an in-process channel for communication between
//! client and server. In real code, you will likely communicate over the network.
//! For a more real-world example, see [example-service](example-service).
//!
//! First, let's set up the dependencies and service definition.
//!
//! ```rust
//! #![feature(futures_api, pin, arbitrary_self_types, await_macro, async_await, proc_macro_hygiene)]
//!
//! # extern crate futures;
//!
//! use futures::{
//! compat::TokioDefaultSpawner,
//! future::{self, Ready},
//! prelude::*,
//! };
//! use tarpc::{
//! client, context,
//! server::{self, Handler, Server},
//! server::{self, Handler},
//! };
//! use std::io;
//!
//! // This is the service definition. It looks a lot like a trait definition.
//! // It defines one RPC, hello, which takes one arg, name, and returns a String.
//! tarpc::service! {
//! #[tarpc::service]
//! trait World {
//! /// Returns a greeting for name.
//! rpc hello(name: String) -> String;
//! async fn hello(name: String) -> String;
//! }
//! ```
//!
//! // This is the type that implements the generated Service trait. It is the business logic
//! This service definition generates a trait called `World`. Next we need to
//! implement it for our Server struct.
//!
//! ```rust
//! # extern crate futures;
//! # use futures::{
//! # future::{self, Ready},
//! # prelude::*,
//! # };
//! # use tarpc::{
//! # client, context,
//! # server::{self, Handler},
//! # };
//! # use std::io;
//! # // This is the service definition. It looks a lot like a trait definition.
//! # // It defines one RPC, hello, which takes one arg, name, and returns a String.
//! # #[tarpc::service]
//! # trait World {
//! # /// Returns a greeting for name.
//! # async fn hello(name: String) -> String;
//! # }
//! // This is the type that implements the generated World trait. It is the business logic
//! // and is used to start the server.
//! #[derive(Clone)]
//! struct HelloServer;
//!
//! impl Service for HelloServer {
//! impl World for HelloServer {
//! // Each defined rpc generates two items in the trait, a fn that serves the RPC, and
//! // an associated type representing the future output by the fn.
//!
//! type HelloFut = Ready<String>;
//!
//! fn hello(&self, _: context::Context, name: String) -> Self::HelloFut {
//! fn hello(self, _: context::Context, name: String) -> Self::HelloFut {
//! future::ready(format!("Hello, {}!", name))
//! }
//! }
//! ```
//!
//! async fn run() -> io::Result<()> {
//! // bincode_transport is provided by the associated crate bincode-transport. It makes it easy
//! // to start up a serde-powered bincode serialization strategy over TCP.
//! let transport = bincode_transport::listen(&"0.0.0.0:0".parse().unwrap())?;
//! let addr = transport.local_addr();
//! Lastly let's write our `main` that will start the server. While this example uses an
//! [in-process channel](rpc::transport::channel), tarpc also ships a generic [`serde_transport`]
//! behind the `serde-transport` feature, with additional [TCP](serde_transport::tcp) functionality
//! available behind the `tcp` feature.
//!
//! // The server is configured with the defaults.
//! let server = Server::new(server::Config::default())
//! // Server can listen on any type that implements the Transport trait.
//! .incoming(transport)
//! // Close the stream after the client connects
//! .take(1)
//! // serve is generated by the service! macro. It takes as input any type implementing
//! // the generated Service trait.
//! .respond_with(serve(HelloServer));
//! ```rust
//! # extern crate futures;
//! # use futures::{
//! # future::{self, Ready},
//! # prelude::*,
//! # };
//! # use tarpc::{
//! # client, context,
//! # server::{self, Handler},
//! # };
//! # use std::io;
//! # // This is the service definition. It looks a lot like a trait definition.
//! # // It defines one RPC, hello, which takes one arg, name, and returns a String.
//! # #[tarpc::service]
//! # trait World {
//! # /// Returns a greeting for name.
//! # async fn hello(name: String) -> String;
//! # }
//! # // This is the type that implements the generated World trait. It is the business logic
//! # // and is used to start the server.
//! # #[derive(Clone)]
//! # struct HelloServer;
//! # impl World for HelloServer {
//! # // Each defined rpc generates two items in the trait, a fn that serves the RPC, and
//! # // an associated type representing the future output by the fn.
//! # type HelloFut = Ready<String>;
//! # fn hello(self, _: context::Context, name: String) -> Self::HelloFut {
//! # future::ready(format!("Hello, {}!", name))
//! # }
//! # }
//! #[tokio::main]
//! async fn main() -> io::Result<()> {
//! let (client_transport, server_transport) = tarpc::transport::channel::unbounded();
//!
//! tokio_executor::spawn(server.unit_error().boxed().compat());
//! let server = server::new(server::Config::default())
//! // incoming() takes a stream of transports such as would be returned by
//! // TcpListener::incoming (but a stream instead of an iterator).
//! .incoming(stream::once(future::ready(server_transport)))
//! .respond_with(HelloServer.serve());
//!
//! let transport = await!(bincode_transport::connect(&addr))?;
//! tokio::spawn(server);
//!
//! // new_stub is generated by the service! macro. Like Server, it takes a config and any
//! // Transport as input, and returns a Client, also generated by the macro.
//! // by the service mcro.
//! let mut client = await!(new_stub(client::Config::default(), transport))?;
//! // WorldClient is generated by the macro. It has a constructor `new` that takes a config and
//! // any Transport as input
//! let mut client = WorldClient::new(client::Config::default(), client_transport).spawn()?;
//!
//! // The client has an RPC method for each RPC defined in service!. It takes the same args
//! // as defined, with the addition of a Context, which is always the first arg. The Context
//! // The client has an RPC method for each RPC defined in the annotated trait. It takes the same
//! // args as defined, with the addition of a Context, which is always the first arg. The Context
//! // specifies a deadline and trace information which can be helpful in debugging requests.
//! let hello = await!(client.hello(context::current(), "Stim".to_string()))?;
//! let hello = client.hello(context::current(), "Stim".to_string()).await?;
//!
//! println!("{}", hello);
//!
//! Ok(())
//! }
//!
//! fn main() {
//! tarpc::init(TokioDefaultSpawner);
//! tokio::run(run()
//! .map_err(|e| eprintln!("Oh no: {}", e))
//! .boxed()
//! .compat(),
//! );
//! }
//! ```
//!
//! ## Service Documentation
//!
//! Use `cargo doc` as you normally would to see the documentation created for all
//! items expanded by a `service!` invocation.
#![deny(missing_docs)]
#![allow(clippy::type_complexity)]
#![deny(missing_docs, missing_debug_implementations)]
#![feature(
futures_api,
pin,
await_macro,
async_await,
decl_macro,
)]
#![cfg_attr(test, feature(proc_macro_hygiene, arbitrary_self_types))]
#[doc(hidden)]
pub use futures;
pub mod rpc;
pub use rpc::*;
#[cfg(feature = "serde")]
#[doc(hidden)]
pub use serde;
#[doc(hidden)]
pub use tarpc_plugins::*;
/// Provides the macro used for constructing rpc services and client stubs.
#[macro_use]
mod macros;
#[cfg(feature = "serde-transport")]
pub mod serde_transport;
pub mod trace;
/// The main macro that creates RPC services.
///
/// Rpc methods are specified, mirroring trait syntax:
///
/// ```
/// #[tarpc::service]
/// trait Service {
/// /// Say hello
/// async fn hello(name: String) -> String;
/// }
/// ```
///
/// Attributes can be attached to each rpc. These attributes
/// will then be attached to the generated service traits'
/// corresponding `fn`s, as well as to the client stubs' RPCs.
///
/// The following items are expanded in the enclosing module:
///
/// * `trait Service` -- defines the RPC service.
/// * `fn serve` -- turns a service impl into a request handler.
/// * `Client` -- a client stub with a fn for each RPC.
/// * `fn new_stub` -- creates a new Client stub.
pub use tarpc_plugins::service;
/// A utility macro that can be used for RPC server implementations.
///
/// Syntactic sugar to make using async functions in the server implementation
/// easier. It does this by rewriting code like this, which would normally not
/// compile because async functions are disallowed in trait implementations:
///
/// ```rust
/// # use tarpc::context;
/// # use std::net::SocketAddr;
/// #[tarpc::service]
/// trait World {
/// async fn hello(name: String) -> String;
/// }
///
/// #[derive(Clone)]
/// struct HelloServer(SocketAddr);
///
/// #[tarpc::server]
/// impl World for HelloServer {
/// async fn hello(self, _: context::Context, name: String) -> String {
/// format!("Hello, {}! You are connected from {:?}.", name, self.0)
/// }
/// }
/// ```
///
/// Into code like this, which matches the service trait definition:
///
/// ```rust
/// # use tarpc::context;
/// # use std::pin::Pin;
/// # use futures::Future;
/// # use std::net::SocketAddr;
/// #[derive(Clone)]
/// struct HelloServer(SocketAddr);
///
/// #[tarpc::service]
/// trait World {
/// async fn hello(name: String) -> String;
/// }
///
/// impl World for HelloServer {
/// type HelloFut = Pin<Box<dyn Future<Output = String> + Send>>;
///
/// fn hello(self, _: context::Context, name: String) -> Pin<Box<dyn Future<Output = String>
/// + Send>> {
/// Box::pin(async move {
/// format!("Hello, {}! You are connected from {:?}.", name, self.0)
/// })
/// }
/// }
/// ```
///
/// Note that this won't touch functions unless they have been annotated with
/// `async`, meaning that this should not break existing code.
pub use tarpc_plugins::server;

View File

@@ -1,364 +0,0 @@
// Copyright 2018 Google LLC
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
#[cfg(feature = "serde")]
#[doc(hidden)]
#[macro_export]
macro_rules! add_serde_if_enabled {
($(#[$attr:meta])* -- $i:item) => {
$(#[$attr])*
#[derive($crate::serde::Serialize, $crate::serde::Deserialize)]
$i
}
}
#[cfg(not(feature = "serde"))]
#[doc(hidden)]
#[macro_export]
macro_rules! add_serde_if_enabled {
($(#[$attr:meta])* -- $i:item) => {
$(#[$attr])*
$i
}
}
/// The main macro that creates RPC services.
///
/// Rpc methods are specified, mirroring trait syntax:
///
/// ```
/// # #![feature(await_macro, pin, arbitrary_self_types, async_await, futures_api, proc_macro_hygiene)]
/// # fn main() {}
/// # tarpc::service! {
/// /// Say hello
/// rpc hello(name: String) -> String;
/// # }
/// ```
///
/// Attributes can be attached to each rpc. These attributes
/// will then be attached to the generated service traits'
/// corresponding `fn`s, as well as to the client stubs' RPCs.
///
/// The following items are expanded in the enclosing module:
///
/// * `trait Service` -- defines the RPC service.
/// * `fn serve` -- turns a service impl into a request handler.
/// * `Client` -- a client stub with a fn for each RPC.
/// * `fn new_stub` -- creates a new Client stub.
///
#[macro_export]
macro_rules! service {
// Entry point
(
$(
$(#[$attr:meta])*
rpc $fn_name:ident( $( $arg:ident : $in_:ty ),* ) $(-> $out:ty)*;
)*
) => {
$crate::service! {{
$(
$(#[$attr])*
rpc $fn_name( $( $arg : $in_ ),* ) $(-> $out)*;
)*
}}
};
// Pattern for when the next rpc has an implicit unit return type.
(
{
$(#[$attr:meta])*
rpc $fn_name:ident( $( $arg:ident : $in_:ty ),* );
$( $unexpanded:tt )*
}
$( $expanded:tt )*
) => {
$crate::service! {
{ $( $unexpanded )* }
$( $expanded )*
$(#[$attr])*
rpc $fn_name( $( $arg : $in_ ),* ) -> ();
}
};
// Pattern for when the next rpc has an explicit return type.
(
{
$(#[$attr:meta])*
rpc $fn_name:ident( $( $arg:ident : $in_:ty ),* ) -> $out:ty;
$( $unexpanded:tt )*
}
$( $expanded:tt )*
) => {
$crate::service! {
{ $( $unexpanded )* }
$( $expanded )*
$(#[$attr])*
rpc $fn_name( $( $arg : $in_ ),* ) -> $out;
}
};
// Pattern for when all return types have been expanded
(
{ } // none left to expand
$(
$(#[$attr:meta])*
rpc $fn_name:ident ( $( $arg:ident : $in_:ty ),* ) -> $out:ty;
)*
) => {
$crate::add_serde_if_enabled! {
#[derive(Debug)]
#[doc(hidden)]
#[allow(non_camel_case_types, unused)]
--
pub enum Request__ {
$(
$fn_name{ $($arg: $in_,)* }
),*
}
}
$crate::add_serde_if_enabled! {
#[derive(Debug)]
#[doc(hidden)]
#[allow(non_camel_case_types, unused)]
--
pub enum Response__ {
$(
$fn_name($out)
),*
}
}
// TODO: proc_macro can't currently parse $crate, so this needs to be imported for the
// usage of snake_to_camel! to work.
use $crate::futures::Future as Future__;
/// Defines the RPC service. The additional trait bounds are required so that services can
/// multiplex requests across multiple tasks, potentially on multiple threads.
pub trait Service: Clone + Send + 'static {
$(
$crate::snake_to_camel! {
/// The type of future returned by `{}`.
type $fn_name: Future__<Output = $out> + Send;
}
$(#[$attr])*
fn $fn_name(&self, ctx: $crate::context::Context, $($arg:$in_),*) -> $crate::ty_snake_to_camel!(Self::$fn_name);
)*
}
// TODO: use an existential type instead of this when existential types work.
#[allow(non_camel_case_types)]
pub enum Response<S: Service> {
$(
$fn_name($crate::ty_snake_to_camel!(<S as Service>::$fn_name)),
)*
}
impl<S: Service> ::std::fmt::Debug for Response<S> {
fn fmt(&self, fmt: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
fmt.debug_struct("Response").finish()
}
}
impl<S: Service> ::std::future::Future for Response<S> {
type Output = ::std::io::Result<Response__>;
fn poll(self: ::std::pin::Pin<&mut Self>, waker: &::std::task::LocalWaker)
-> ::std::task::Poll<::std::io::Result<Response__>>
{
unsafe {
match ::std::pin::Pin::get_mut_unchecked(self) {
$(
Response::$fn_name(resp) =>
::std::pin::Pin::new_unchecked(resp)
.poll(waker)
.map(Response__::$fn_name)
.map(Ok),
)*
}
}
}
}
/// Returns a serving function to use with rpc::server::Server.
pub fn serve<S: Service>(service: S)
-> impl FnMut($crate::context::Context, Request__) -> Response<S> + Send + 'static + Clone {
move |ctx, req| {
match req {
$(
Request__::$fn_name{ $($arg,)* } => {
let resp = Service::$fn_name(&mut service.clone(), ctx, $($arg),*);
Response::$fn_name(resp)
}
)*
}
}
}
#[allow(unused)]
#[derive(Clone, Debug)]
/// The client stub that makes RPC calls to the server. Exposes a Future interface.
pub struct Client($crate::client::Client<Request__, Response__>);
/// Returns a new client stub that sends requests over the given transport.
pub async fn new_stub<T>(config: $crate::client::Config, transport: T)
-> ::std::io::Result<Client>
where
T: $crate::Transport<
Item = $crate::Response<Response__>,
SinkItem = $crate::ClientMessage<Request__>> + Send,
{
Ok(Client(await!($crate::client::Client::new(config, transport))?))
}
impl Client {
$(
#[allow(unused)]
$(#[$attr])*
pub fn $fn_name(&mut self, ctx: $crate::context::Context, $($arg: $in_),*)
-> impl ::std::future::Future<Output = ::std::io::Result<$out>> + '_ {
let request__ = Request__::$fn_name { $($arg,)* };
let resp = self.0.call(ctx, request__);
async move {
match await!(resp)? {
Response__::$fn_name(msg__) => ::std::result::Result::Ok(msg__),
_ => unreachable!(),
}
}
}
)*
}
}
}
// allow dead code; we're just testing that the macro expansion compiles
#[allow(dead_code)]
#[cfg(test)]
mod syntax_test {
service! {
#[deny(warnings)]
#[allow(non_snake_case)]
rpc TestCamelCaseDoesntConflict();
rpc hello() -> String;
#[doc="attr"]
rpc attr(s: String) -> String;
rpc no_args_no_return();
rpc no_args() -> ();
rpc one_arg(foo: String) -> i32;
rpc two_args_no_return(bar: String, baz: u64);
rpc two_args(bar: String, baz: u64) -> String;
rpc no_args_ret_error() -> i32;
rpc one_arg_ret_error(foo: String) -> String;
rpc no_arg_implicit_return_error();
#[doc="attr"]
rpc one_arg_implicit_return_error(foo: String);
}
}
#[cfg(test)]
mod functional_test {
use futures::{
compat::TokioDefaultSpawner,
future::{ready, Ready},
prelude::*,
};
use rpc::{
client, context,
server::{self, Handler},
transport::channel,
};
use std::io;
use tokio::runtime::current_thread;
service! {
rpc add(x: i32, y: i32) -> i32;
rpc hey(name: String) -> String;
}
#[derive(Clone)]
struct Server;
impl Service for Server {
type AddFut = Ready<i32>;
fn add(&self, _: context::Context, x: i32, y: i32) -> Self::AddFut {
ready(x + y)
}
type HeyFut = Ready<String>;
fn hey(&self, _: context::Context, name: String) -> Self::HeyFut {
ready(format!("Hey, {}.", name))
}
}
#[test]
fn sequential() {
let _ = env_logger::try_init();
rpc::init(TokioDefaultSpawner);
let test = async {
let (tx, rx) = channel::unbounded();
tokio_executor::spawn(
rpc::Server::new(server::Config::default())
.incoming(stream::once(ready(Ok(rx))))
.respond_with(serve(Server))
.unit_error()
.boxed()
.compat()
);
let mut client = await!(new_stub(client::Config::default(), tx))?;
assert_eq!(3, await!(client.add(context::current(), 1, 2))?);
assert_eq!(
"Hey, Tim.",
await!(client.hey(context::current(), "Tim".to_string()))?
);
Ok::<_, io::Error>(())
}
.map_err(|e| panic!(e.to_string()));
current_thread::block_on_all(test.boxed().compat()).unwrap();
}
#[test]
fn concurrent() {
let _ = env_logger::try_init();
rpc::init(TokioDefaultSpawner);
let test = async {
let (tx, rx) = channel::unbounded();
tokio_executor::spawn(
rpc::Server::new(server::Config::default())
.incoming(stream::once(ready(Ok(rx))))
.respond_with(serve(Server))
.unit_error()
.boxed()
.compat()
);
let client = await!(new_stub(client::Config::default(), tx))?;
let mut c = client.clone();
let req1 = c.add(context::current(), 1, 2);
let mut c = client.clone();
let req2 = c.add(context::current(), 3, 4);
let mut c = client.clone();
let req3 = c.hey(context::current(), "Tim".to_string());
assert_eq!(3, await!(req1)?);
assert_eq!(7, await!(req2)?);
assert_eq!("Hey, Tim.", await!(req3)?);
Ok::<_, io::Error>(())
}
.map_err(|e| panic!("test failed: {}", e));
current_thread::block_on_all(test.boxed().compat()).unwrap();
}
}

View File

@@ -0,0 +1,896 @@
// Copyright 2018 Google LLC
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
use crate::{
context,
trace::SpanId,
util::{Compact, TimeUntil},
ClientMessage, PollIo, Request, Response, Transport,
};
use fnv::FnvHashMap;
use futures::{
channel::{mpsc, oneshot},
prelude::*,
ready,
stream::Fuse,
task::*,
};
use log::{debug, info, trace};
use pin_project::{pin_project, pinned_drop};
use std::{
io,
pin::Pin,
sync::{
atomic::{AtomicU64, Ordering},
Arc,
},
};
use super::{Config, NewClient};
/// Handles communication from the client to request dispatch.
#[derive(Debug)]
pub struct Channel<Req, Resp> {
to_dispatch: mpsc::Sender<DispatchRequest<Req, Resp>>,
/// Channel to send a cancel message to the dispatcher.
cancellation: RequestCancellation,
/// The ID to use for the next request to stage.
next_request_id: Arc<AtomicU64>,
}
impl<Req, Resp> Clone for Channel<Req, Resp> {
fn clone(&self) -> Self {
Self {
to_dispatch: self.to_dispatch.clone(),
cancellation: self.cancellation.clone(),
next_request_id: self.next_request_id.clone(),
}
}
}
/// A future returned by [`Channel::send`] that resolves to a server response.
#[pin_project]
#[derive(Debug)]
#[must_use = "futures do nothing unless polled"]
struct Send<'a, Req, Resp> {
#[pin]
fut: MapOkDispatchResponse<SendMapErrConnectionReset<'a, Req, Resp>, Resp>,
}
type SendMapErrConnectionReset<'a, Req, Resp> = MapErrConnectionReset<
futures::sink::Send<'a, mpsc::Sender<DispatchRequest<Req, Resp>>, DispatchRequest<Req, Resp>>,
>;
impl<'a, Req, Resp> Future for Send<'a, Req, Resp> {
type Output = io::Result<DispatchResponse<Resp>>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.as_mut().project().fut.poll(cx)
}
}
/// A future returned by [`Channel::call`] that resolves to a server response.
#[pin_project]
#[derive(Debug)]
#[must_use = "futures do nothing unless polled"]
pub struct Call<'a, Req, Resp> {
#[pin]
fut: tokio::time::Timeout<AndThenIdent<Send<'a, Req, Resp>, DispatchResponse<Resp>>>,
}
impl<'a, Req, Resp> Future for Call<'a, Req, Resp> {
type Output = io::Result<Resp>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let resp = ready!(self.as_mut().project().fut.poll(cx));
Poll::Ready(match resp {
Ok(resp) => resp,
Err(tokio::time::Elapsed { .. }) => Err(io::Error::new(
io::ErrorKind::TimedOut,
"Client dropped expired request.".to_string(),
)),
})
}
}
impl<Req, Resp> Channel<Req, Resp> {
/// Sends a request to the dispatch task to forward to the server, returning a [`Future`] that
/// resolves when the request is sent (not when the response is received).
fn send(&mut self, mut ctx: context::Context, request: Req) -> Send<Req, Resp> {
// Convert the context to the call context.
ctx.trace_context.parent_id = Some(ctx.trace_context.span_id);
ctx.trace_context.span_id = SpanId::random(&mut rand::thread_rng());
let (response_completion, response) = oneshot::channel();
let cancellation = self.cancellation.clone();
let request_id = self.next_request_id.fetch_add(1, Ordering::Relaxed);
Send {
fut: MapOkDispatchResponse::new(
MapErrConnectionReset::new(self.to_dispatch.send(DispatchRequest {
ctx,
request_id,
request,
response_completion,
})),
DispatchResponse {
response,
complete: false,
request_id,
cancellation,
ctx,
},
),
}
}
/// Sends a request to the dispatch task to forward to the server, returning a [`Future`] that
/// resolves to the response.
pub fn call(&mut self, ctx: context::Context, request: Req) -> Call<Req, Resp> {
let timeout = ctx.deadline.time_until();
trace!(
"[{}] Queuing request with timeout {:?}.",
ctx.trace_id(),
timeout,
);
Call {
fut: tokio::time::timeout(timeout, AndThenIdent::new(self.send(ctx, request))),
}
}
}
/// A server response that is completed by request dispatch when the corresponding response
/// arrives off the wire.
#[pin_project(PinnedDrop)]
#[derive(Debug)]
struct DispatchResponse<Resp> {
response: oneshot::Receiver<Response<Resp>>,
ctx: context::Context,
complete: bool,
cancellation: RequestCancellation,
request_id: u64,
}
impl<Resp> Future for DispatchResponse<Resp> {
type Output = io::Result<Resp>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<Resp>> {
let resp = ready!(self.response.poll_unpin(cx));
self.complete = true;
Poll::Ready(match resp {
Ok(resp) => Ok(resp.message?),
Err(oneshot::Canceled) => {
// The oneshot is Canceled when the dispatch task ends. In that case,
// there's nothing listening on the other side, so there's no point in
// propagating cancellation.
Err(io::Error::from(io::ErrorKind::ConnectionReset))
}
})
}
}
// Cancels the request when dropped, if not already complete.
#[pinned_drop]
impl<Resp> PinnedDrop for DispatchResponse<Resp> {
fn drop(mut self: Pin<&mut Self>) {
if !self.complete {
// The receiver needs to be closed to handle the edge case that the request has not
// yet been received by the dispatch task. It is possible for the cancel message to
// arrive before the request itself, in which case the request could get stuck in the
// dispatch map forever if the server never responds (e.g. if the server dies while
// responding). Even if the server does respond, it will have unnecessarily done work
// for a client no longer waiting for a response. To avoid this, the dispatch task
// checks if the receiver is closed before inserting the request in the map. By
// closing the receiver before sending the cancel message, it is guaranteed that if the
// dispatch task misses an early-arriving cancellation message, then it will see the
// receiver as closed.
self.response.close();
let request_id = self.request_id;
self.cancellation.cancel(request_id);
}
}
}
/// Returns a channel and dispatcher that manages the lifecycle of requests initiated by the
/// channel.
pub fn new<Req, Resp, C>(
config: Config,
transport: C,
) -> NewClient<Channel<Req, Resp>, RequestDispatch<Req, Resp, C>>
where
C: Transport<ClientMessage<Req>, Response<Resp>>,
{
let (to_dispatch, pending_requests) = mpsc::channel(config.pending_request_buffer);
let (cancellation, canceled_requests) = cancellations();
let canceled_requests = canceled_requests.fuse();
NewClient {
client: Channel {
to_dispatch,
cancellation,
next_request_id: Arc::new(AtomicU64::new(0)),
},
dispatch: RequestDispatch {
config,
canceled_requests,
transport: transport.fuse(),
in_flight_requests: FnvHashMap::default(),
pending_requests: pending_requests.fuse(),
},
}
}
/// Handles the lifecycle of requests, writing requests to the wire, managing cancellations,
/// and dispatching responses to the appropriate channel.
#[pin_project]
#[derive(Debug)]
pub struct RequestDispatch<Req, Resp, C> {
/// Writes requests to the wire and reads responses off the wire.
#[pin]
transport: Fuse<C>,
/// Requests waiting to be written to the wire.
#[pin]
pending_requests: Fuse<mpsc::Receiver<DispatchRequest<Req, Resp>>>,
/// Requests that were dropped.
#[pin]
canceled_requests: Fuse<CanceledRequests>,
/// Requests already written to the wire that haven't yet received responses.
in_flight_requests: FnvHashMap<u64, InFlightData<Resp>>,
/// Configures limits to prevent unlimited resource usage.
config: Config,
}
impl<Req, Resp, C> RequestDispatch<Req, Resp, C>
where
C: Transport<ClientMessage<Req>, Response<Resp>>,
{
fn pump_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> PollIo<()> {
Poll::Ready(
match ready!(self.as_mut().project().transport.poll_next(cx)?) {
Some(response) => {
self.complete(response);
Some(Ok(()))
}
None => None,
},
)
}
fn pump_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> PollIo<()> {
enum ReceiverStatus {
NotReady,
Closed,
}
let pending_requests_status = match self.as_mut().poll_next_request(cx)? {
Poll::Ready(Some(dispatch_request)) => {
self.as_mut().write_request(dispatch_request)?;
return Poll::Ready(Some(Ok(())));
}
Poll::Ready(None) => ReceiverStatus::Closed,
Poll::Pending => ReceiverStatus::NotReady,
};
let canceled_requests_status = match self.as_mut().poll_next_cancellation(cx)? {
Poll::Ready(Some((context, request_id))) => {
self.as_mut().write_cancel(context, request_id)?;
return Poll::Ready(Some(Ok(())));
}
Poll::Ready(None) => ReceiverStatus::Closed,
Poll::Pending => ReceiverStatus::NotReady,
};
match (pending_requests_status, canceled_requests_status) {
(ReceiverStatus::Closed, ReceiverStatus::Closed) => {
ready!(self.as_mut().project().transport.poll_flush(cx)?);
Poll::Ready(None)
}
(ReceiverStatus::NotReady, _) | (_, ReceiverStatus::NotReady) => {
// No more messages to process, so flush any messages buffered in the transport.
ready!(self.as_mut().project().transport.poll_flush(cx)?);
// Even if we fully-flush, we return Pending, because we have no more requests
// or cancellations right now.
Poll::Pending
}
}
}
/// Yields the next pending request, if one is ready to be sent.
fn poll_next_request(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> PollIo<DispatchRequest<Req, Resp>> {
if self.as_mut().project().in_flight_requests.len() >= self.config.max_in_flight_requests {
info!(
"At in-flight request capacity ({}/{}).",
self.as_mut().project().in_flight_requests.len(),
self.config.max_in_flight_requests
);
// No need to schedule a wakeup, because timers and responses are responsible
// for clearing out in-flight requests.
return Poll::Pending;
}
while let Poll::Pending = self.as_mut().project().transport.poll_ready(cx)? {
// We can't yield a request-to-be-sent before the transport is capable of buffering it.
ready!(self.as_mut().project().transport.poll_flush(cx)?);
}
loop {
match ready!(self.as_mut().project().pending_requests.poll_next_unpin(cx)) {
Some(request) => {
if request.response_completion.is_canceled() {
trace!(
"[{}] Request canceled before being sent.",
request.ctx.trace_id()
);
continue;
}
return Poll::Ready(Some(Ok(request)));
}
None => return Poll::Ready(None),
}
}
}
/// Yields the next pending cancellation, and, if one is ready, cancels the associated request.
fn poll_next_cancellation(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> PollIo<(context::Context, u64)> {
while let Poll::Pending = self.as_mut().project().transport.poll_ready(cx)? {
ready!(self.as_mut().project().transport.poll_flush(cx)?);
}
loop {
let cancellation = self
.as_mut()
.project()
.canceled_requests
.poll_next_unpin(cx);
match ready!(cancellation) {
Some(request_id) => {
if let Some(in_flight_data) = self
.as_mut()
.project()
.in_flight_requests
.remove(&request_id)
{
self.as_mut().project().in_flight_requests.compact(0.1);
debug!("[{}] Removed request.", in_flight_data.ctx.trace_id());
return Poll::Ready(Some(Ok((in_flight_data.ctx, request_id))));
}
}
None => return Poll::Ready(None),
}
}
}
fn write_request(
mut self: Pin<&mut Self>,
dispatch_request: DispatchRequest<Req, Resp>,
) -> io::Result<()> {
let request_id = dispatch_request.request_id;
let request = ClientMessage::Request(Request {
id: request_id,
message: dispatch_request.request,
context: context::Context {
deadline: dispatch_request.ctx.deadline,
trace_context: dispatch_request.ctx.trace_context,
},
});
self.as_mut().project().transport.start_send(request)?;
self.as_mut().project().in_flight_requests.insert(
request_id,
InFlightData {
ctx: dispatch_request.ctx,
response_completion: dispatch_request.response_completion,
},
);
Ok(())
}
fn write_cancel(
mut self: Pin<&mut Self>,
context: context::Context,
request_id: u64,
) -> io::Result<()> {
let trace_id = *context.trace_id();
let cancel = ClientMessage::Cancel {
trace_context: context.trace_context,
request_id,
};
self.as_mut().project().transport.start_send(cancel)?;
trace!("[{}] Cancel message sent.", trace_id);
Ok(())
}
/// Sends a server response to the client task that initiated the associated request.
fn complete(mut self: Pin<&mut Self>, response: Response<Resp>) -> bool {
if let Some(in_flight_data) = self
.as_mut()
.project()
.in_flight_requests
.remove(&response.request_id)
{
self.as_mut().project().in_flight_requests.compact(0.1);
trace!("[{}] Received response.", in_flight_data.ctx.trace_id());
let _ = in_flight_data.response_completion.send(response);
return true;
}
debug!(
"No in-flight request found for request_id = {}.",
response.request_id
);
// If the response completion was absent, then the request was already canceled.
false
}
}
impl<Req, Resp, C> Future for RequestDispatch<Req, Resp, C>
where
C: Transport<ClientMessage<Req>, Response<Resp>>,
{
type Output = io::Result<()>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
loop {
match (self.as_mut().pump_read(cx)?, self.as_mut().pump_write(cx)?) {
(read, Poll::Ready(None)) => {
if self.as_mut().project().in_flight_requests.is_empty() {
info!("Shutdown: write half closed, and no requests in flight.");
return Poll::Ready(Ok(()));
}
info!(
"Shutdown: write half closed, and {} requests in flight.",
self.as_mut().project().in_flight_requests.len()
);
match read {
Poll::Ready(Some(())) => continue,
_ => return Poll::Pending,
}
}
(Poll::Ready(Some(())), _) | (_, Poll::Ready(Some(()))) => {}
_ => return Poll::Pending,
}
}
}
}
/// A server-bound request sent from a [`Channel`] to request dispatch, which will then manage
/// the lifecycle of the request.
#[derive(Debug)]
struct DispatchRequest<Req, Resp> {
ctx: context::Context,
request_id: u64,
request: Req,
response_completion: oneshot::Sender<Response<Resp>>,
}
#[derive(Debug)]
struct InFlightData<Resp> {
ctx: context::Context,
response_completion: oneshot::Sender<Response<Resp>>,
}
/// Sends request cancellation signals.
#[derive(Debug, Clone)]
struct RequestCancellation(mpsc::UnboundedSender<u64>);
/// A stream of IDs of requests that have been canceled.
#[derive(Debug)]
struct CanceledRequests(mpsc::UnboundedReceiver<u64>);
/// Returns a channel to send request cancellation messages.
fn cancellations() -> (RequestCancellation, CanceledRequests) {
// Unbounded because messages are sent in the drop fn. This is fine, because it's still
// bounded by the number of in-flight requests. Additionally, each request has a clone
// of the sender, so the bounded channel would have the same behavior,
// since it guarantees a slot.
let (tx, rx) = mpsc::unbounded();
(RequestCancellation(tx), CanceledRequests(rx))
}
impl RequestCancellation {
/// Cancels the request with ID `request_id`.
fn cancel(&mut self, request_id: u64) {
let _ = self.0.unbounded_send(request_id);
}
}
impl Stream for CanceledRequests {
type Item = u64;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<u64>> {
self.0.poll_next_unpin(cx)
}
}
#[pin_project]
#[derive(Debug)]
#[must_use = "futures do nothing unless polled"]
struct MapErrConnectionReset<Fut> {
#[pin]
future: Fut,
finished: Option<()>,
}
impl<Fut> MapErrConnectionReset<Fut> {
fn new(future: Fut) -> MapErrConnectionReset<Fut> {
MapErrConnectionReset {
future,
finished: Some(()),
}
}
}
impl<Fut> Future for MapErrConnectionReset<Fut>
where
Fut: TryFuture,
{
type Output = io::Result<Fut::Ok>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.as_mut().project().future.try_poll(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(result) => {
self.project().finished.take().expect(
"MapErrConnectionReset must not be polled after it returned `Poll::Ready`",
);
Poll::Ready(result.map_err(|_| io::Error::from(io::ErrorKind::ConnectionReset)))
}
}
}
}
#[pin_project]
#[derive(Debug)]
#[must_use = "futures do nothing unless polled"]
struct MapOkDispatchResponse<Fut, Resp> {
#[pin]
future: Fut,
response: Option<DispatchResponse<Resp>>,
}
impl<Fut, Resp> MapOkDispatchResponse<Fut, Resp> {
fn new(future: Fut, response: DispatchResponse<Resp>) -> MapOkDispatchResponse<Fut, Resp> {
MapOkDispatchResponse {
future,
response: Some(response),
}
}
}
impl<Fut, Resp> Future for MapOkDispatchResponse<Fut, Resp>
where
Fut: TryFuture,
{
type Output = Result<DispatchResponse<Resp>, Fut::Error>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.as_mut().project().future.try_poll(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(result) => {
let response = self
.as_mut()
.project()
.response
.take()
.expect("MapOk must not be polled after it returned `Poll::Ready`");
Poll::Ready(result.map(|_| response))
}
}
}
}
#[pin_project]
#[derive(Debug)]
#[must_use = "futures do nothing unless polled"]
struct AndThenIdent<Fut1, Fut2> {
#[pin]
try_chain: TryChain<Fut1, Fut2>,
}
impl<Fut1, Fut2> AndThenIdent<Fut1, Fut2>
where
Fut1: TryFuture<Ok = Fut2>,
Fut2: TryFuture,
{
/// Creates a new `Then`.
fn new(future: Fut1) -> AndThenIdent<Fut1, Fut2> {
AndThenIdent {
try_chain: TryChain::new(future),
}
}
}
impl<Fut1, Fut2> Future for AndThenIdent<Fut1, Fut2>
where
Fut1: TryFuture<Ok = Fut2>,
Fut2: TryFuture<Error = Fut1::Error>,
{
type Output = Result<Fut2::Ok, Fut2::Error>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.project().try_chain.poll(cx, |result| match result {
Ok(ok) => TryChainAction::Future(ok),
Err(err) => TryChainAction::Output(Err(err)),
})
}
}
#[pin_project(project = TryChainProj)]
#[must_use = "futures do nothing unless polled"]
#[derive(Debug)]
enum TryChain<Fut1, Fut2> {
First(#[pin] Fut1),
Second(#[pin] Fut2),
Empty,
}
enum TryChainAction<Fut2>
where
Fut2: TryFuture,
{
Future(Fut2),
Output(Result<Fut2::Ok, Fut2::Error>),
}
impl<Fut1, Fut2> TryChain<Fut1, Fut2>
where
Fut1: TryFuture<Ok = Fut2>,
Fut2: TryFuture,
{
fn new(fut1: Fut1) -> TryChain<Fut1, Fut2> {
TryChain::First(fut1)
}
fn poll<F>(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
f: F,
) -> Poll<Result<Fut2::Ok, Fut2::Error>>
where
F: FnOnce(Result<Fut1::Ok, Fut1::Error>) -> TryChainAction<Fut2>,
{
let mut f = Some(f);
loop {
let output = match self.as_mut().project() {
TryChainProj::First(fut1) => {
// Poll the first future
match fut1.try_poll(cx) {
Poll::Pending => return Poll::Pending,
Poll::Ready(output) => output,
}
}
TryChainProj::Second(fut2) => {
// Poll the second future
return fut2.try_poll(cx);
}
TryChainProj::Empty => {
panic!("future must not be polled after it returned `Poll::Ready`");
}
};
self.set(TryChain::Empty); // Drop fut1
let f = f.take().unwrap();
match f(output) {
TryChainAction::Future(fut2) => self.set(TryChain::Second(fut2)),
TryChainAction::Output(output) => return Poll::Ready(output),
}
}
}
}
#[cfg(test)]
mod tests {
use super::{
cancellations, CanceledRequests, Channel, DispatchResponse, RequestCancellation,
RequestDispatch,
};
use crate::{
client::Config,
context,
transport::{self, channel::UnboundedChannel},
ClientMessage, Response,
};
use fnv::FnvHashMap;
use futures::{
channel::{mpsc, oneshot},
prelude::*,
task::*,
};
use std::{pin::Pin, sync::atomic::AtomicU64, sync::Arc};
#[tokio::test(threaded_scheduler)]
async fn dispatch_response_cancels_on_drop() {
let (cancellation, mut canceled_requests) = cancellations();
let (_, response) = oneshot::channel();
drop(DispatchResponse::<u32> {
response,
cancellation,
complete: false,
request_id: 3,
ctx: context::current(),
});
// resp's drop() is run, which should send a cancel message.
assert_eq!(canceled_requests.0.try_next().unwrap(), Some(3));
}
#[tokio::test(threaded_scheduler)]
async fn stage_request() {
let (mut dispatch, mut channel, _server_channel) = set_up();
let dispatch = Pin::new(&mut dispatch);
let cx = &mut Context::from_waker(&noop_waker_ref());
let _resp = send_request(&mut channel, "hi").await;
let req = dispatch.poll_next_request(cx).ready();
assert!(req.is_some());
let req = req.unwrap();
assert_eq!(req.request_id, 0);
assert_eq!(req.request, "hi".to_string());
}
// Regression test for https://github.com/google/tarpc/issues/220
#[tokio::test(threaded_scheduler)]
async fn stage_request_channel_dropped_doesnt_panic() {
let (mut dispatch, mut channel, mut server_channel) = set_up();
let mut dispatch = Pin::new(&mut dispatch);
let cx = &mut Context::from_waker(&noop_waker_ref());
let _ = send_request(&mut channel, "hi").await;
drop(channel);
assert!(dispatch.as_mut().poll(cx).is_ready());
send_response(
&mut server_channel,
Response {
request_id: 0,
message: Ok("hello".into()),
},
)
.await;
dispatch.await.unwrap();
}
#[tokio::test(threaded_scheduler)]
async fn stage_request_response_future_dropped_is_canceled_before_sending() {
let (mut dispatch, mut channel, _server_channel) = set_up();
let dispatch = Pin::new(&mut dispatch);
let cx = &mut Context::from_waker(&noop_waker_ref());
let _ = send_request(&mut channel, "hi").await;
// Drop the channel so polling returns none if no requests are currently ready.
drop(channel);
// Test that a request future dropped before it's processed by dispatch will cause the request
// to not be added to the in-flight request map.
assert!(dispatch.poll_next_request(cx).ready().is_none());
}
#[tokio::test(threaded_scheduler)]
async fn stage_request_response_future_dropped_is_canceled_after_sending() {
let (mut dispatch, mut channel, _server_channel) = set_up();
let cx = &mut Context::from_waker(&noop_waker_ref());
let mut dispatch = Pin::new(&mut dispatch);
let req = send_request(&mut channel, "hi").await;
assert!(dispatch.as_mut().pump_write(cx).ready().is_some());
assert!(!dispatch.as_mut().project().in_flight_requests.is_empty());
// Test that a request future dropped after it's processed by dispatch will cause the request
// to be removed from the in-flight request map.
drop(req);
if let Poll::Ready(Some(_)) = dispatch.as_mut().poll_next_cancellation(cx).unwrap() {
// ok
} else {
panic!("Expected request to be cancelled")
};
assert!(dispatch.project().in_flight_requests.is_empty());
}
#[tokio::test(threaded_scheduler)]
async fn stage_request_response_closed_skipped() {
let (mut dispatch, mut channel, _server_channel) = set_up();
let dispatch = Pin::new(&mut dispatch);
let cx = &mut Context::from_waker(&noop_waker_ref());
// Test that a request future that's closed its receiver but not yet canceled its request --
// i.e. still in `drop fn` -- will cause the request to not be added to the in-flight request
// map.
let mut resp = send_request(&mut channel, "hi").await;
resp.response.close();
assert!(dispatch.poll_next_request(cx).is_pending());
}
fn set_up() -> (
RequestDispatch<String, String, UnboundedChannel<Response<String>, ClientMessage<String>>>,
Channel<String, String>,
UnboundedChannel<ClientMessage<String>, Response<String>>,
) {
let _ = env_logger::try_init();
let (to_dispatch, pending_requests) = mpsc::channel(1);
let (cancel_tx, canceled_requests) = mpsc::unbounded();
let (client_channel, server_channel) = transport::channel::unbounded();
let dispatch = RequestDispatch::<String, String, _> {
transport: client_channel.fuse(),
pending_requests: pending_requests.fuse(),
canceled_requests: CanceledRequests(canceled_requests).fuse(),
in_flight_requests: FnvHashMap::default(),
config: Config::default(),
};
let cancellation = RequestCancellation(cancel_tx);
let channel = Channel {
to_dispatch,
cancellation,
next_request_id: Arc::new(AtomicU64::new(0)),
};
(dispatch, channel, server_channel)
}
async fn send_request(
channel: &mut Channel<String, String>,
request: &str,
) -> DispatchResponse<String> {
channel
.send(context::current(), request.to_string())
.await
.unwrap()
}
async fn send_response(
channel: &mut UnboundedChannel<ClientMessage<String>, Response<String>>,
response: Response<String>,
) {
channel.send(response).await.unwrap();
}
trait PollTest {
type T;
fn unwrap(self) -> Poll<Self::T>;
fn ready(self) -> Self::T;
}
impl<T, E> PollTest for Poll<Option<Result<T, E>>>
where
E: ::std::fmt::Display,
{
type T = Option<T>;
fn unwrap(self) -> Poll<Option<T>> {
match self {
Poll::Ready(Some(Ok(t))) => Poll::Ready(Some(t)),
Poll::Ready(None) => Poll::Ready(None),
Poll::Ready(Some(Err(e))) => panic!(e.to_string()),
Poll::Pending => Poll::Pending,
}
}
fn ready(self) -> Option<T> {
match self {
Poll::Ready(Some(Ok(t))) => Some(t),
Poll::Ready(None) => None,
Poll::Ready(Some(Err(e))) => panic!(e.to_string()),
Poll::Pending => panic!("Pending"),
}
}
}
}

153
tarpc/src/rpc/client/mod.rs Normal file
View File

@@ -0,0 +1,153 @@
// Copyright 2018 Google LLC
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
//! Provides a client that connects to a server and sends multiplexed requests.
use crate::context;
use futures::prelude::*;
use std::io;
/// Provides a [`Client`] backed by a transport.
pub mod channel;
pub use channel::{new, Channel};
/// Sends multiplexed requests to, and receives responses from, a server.
pub trait Client<'a, Req> {
/// The response type.
type Response;
/// The future response.
type Future: Future<Output = io::Result<Self::Response>> + 'a;
/// Initiates a request, sending it to the dispatch task.
///
/// Returns a [`Future`] that resolves to this client and the future response
/// once the request is successfully enqueued.
///
/// [`Future`]: futures::Future
fn call(&'a mut self, ctx: context::Context, request: Req) -> Self::Future;
/// Returns a Client that applies a post-processing function to the returned response.
fn map_response<F, R>(self, f: F) -> MapResponse<Self, F>
where
F: FnMut(Self::Response) -> R,
Self: Sized,
{
MapResponse { inner: self, f }
}
/// Returns a Client that applies a pre-processing function to the request.
fn with_request<F, Req2>(self, f: F) -> WithRequest<Self, F>
where
F: FnMut(Req2) -> Req,
Self: Sized,
{
WithRequest { inner: self, f }
}
}
/// A Client that applies a function to the returned response.
#[derive(Clone, Debug)]
pub struct MapResponse<C, F> {
inner: C,
f: F,
}
impl<'a, C, F, Req, Resp, Resp2> Client<'a, Req> for MapResponse<C, F>
where
C: Client<'a, Req, Response = Resp>,
F: FnMut(Resp) -> Resp2 + 'a,
{
type Response = Resp2;
type Future = futures::future::MapOk<<C as Client<'a, Req>>::Future, &'a mut F>;
fn call(&'a mut self, ctx: context::Context, request: Req) -> Self::Future {
self.inner.call(ctx, request).map_ok(&mut self.f)
}
}
/// A Client that applies a pre-processing function to the request.
#[derive(Clone, Debug)]
pub struct WithRequest<C, F> {
inner: C,
f: F,
}
impl<'a, C, F, Req, Req2, Resp> Client<'a, Req2> for WithRequest<C, F>
where
C: Client<'a, Req, Response = Resp>,
F: FnMut(Req2) -> Req,
{
type Response = Resp;
type Future = <C as Client<'a, Req>>::Future;
fn call(&'a mut self, ctx: context::Context, request: Req2) -> Self::Future {
self.inner.call(ctx, (self.f)(request))
}
}
impl<'a, Req, Resp> Client<'a, Req> for Channel<Req, Resp>
where
Req: 'a,
Resp: 'a,
{
type Response = Resp;
type Future = channel::Call<'a, Req, Resp>;
fn call(&'a mut self, ctx: context::Context, request: Req) -> channel::Call<'a, Req, Resp> {
self.call(ctx, request)
}
}
/// Settings that control the behavior of the client.
#[derive(Clone, Debug)]
#[non_exhaustive]
pub struct Config {
/// The number of requests that can be in flight at once.
/// `max_in_flight_requests` controls the size of the map used by the client
/// for storing pending requests.
pub max_in_flight_requests: usize,
/// The number of requests that can be buffered client-side before being sent.
/// `pending_requests_buffer` controls the size of the channel clients use
/// to communicate with the request dispatch task.
pub pending_request_buffer: usize,
}
impl Default for Config {
fn default() -> Self {
Config {
max_in_flight_requests: 1_000,
pending_request_buffer: 100,
}
}
}
/// A channel and dispatch pair. The dispatch drives the sending and receiving of requests
/// and must be polled continuously or spawned.
#[derive(Debug)]
pub struct NewClient<C, D> {
/// The new client.
pub client: C,
/// The client's dispatch.
pub dispatch: D,
}
impl<C, D> NewClient<C, D>
where
D: Future<Output = io::Result<()>> + Send + 'static,
{
/// Helper method to spawn the dispatch on the default executor.
#[cfg(feature = "tokio1")]
pub fn spawn(self) -> io::Result<C> {
use log::error;
let dispatch = self
.dispatch
.unwrap_or_else(move |e| error!("Connection broken: {}", e));
tokio::spawn(dispatch);
Ok(self.client)
}
}

View File

@@ -7,8 +7,8 @@
//! Provides a request context that carries a deadline and trace context. This context is sent from
//! client to server and is used by the server to enforce response deadlines.
use crate::trace::{self, TraceId};
use std::time::{Duration, SystemTime};
use trace::{self, TraceId};
/// A request context that carries request-scoped information like deadlines and trace information.
/// It is sent from client to server and is used by the server to enforce response deadlines.
@@ -17,9 +17,19 @@ use trace::{self, TraceId};
/// be different for each request in scope.
#[derive(Clone, Copy, Debug)]
#[non_exhaustive]
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
pub struct Context {
/// When the client expects the request to be complete by. The server should cancel the request
/// if it is not complete by this time.
#[cfg_attr(
feature = "serde1",
serde(serialize_with = "crate::util::serde::serialize_epoch_secs")
)]
#[cfg_attr(
feature = "serde1",
serde(deserialize_with = "crate::util::serde::deserialize_epoch_secs")
)]
#[cfg_attr(feature = "serde1", serde(default = "ten_seconds_from_now"))]
pub deadline: SystemTime,
/// Uniquely identifies requests originating from the same source.
/// When a service handles a request by making requests itself, those requests should
@@ -28,6 +38,11 @@ pub struct Context {
pub trace_context: trace::Context,
}
#[cfg(feature = "serde1")]
fn ten_seconds_from_now() -> SystemTime {
SystemTime::now() + Duration::from_secs(10)
}
/// Returns the context for the current request, or a default Context if no request is active.
// TODO: populate Context with request-scoped data, with default fallbacks.
pub fn current() -> Context {

View File

@@ -4,23 +4,6 @@
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
#![feature(
const_fn,
non_exhaustive,
integer_atomics,
try_trait,
nll,
futures_api,
pin,
arbitrary_self_types,
await_macro,
async_await,
generators,
optin_builtin_traits,
generator_trait,
gen_future,
decl_macro,
)]
#![deny(missing_docs, missing_debug_implementations)]
//! An RPC framework providing client and server.
@@ -47,34 +30,16 @@ pub mod server;
pub mod transport;
pub(crate) mod util;
pub use crate::{client::Client, server::Server, transport::Transport};
pub use crate::{client::Client, server::Server, trace, transport::sealed::Transport};
use futures::{Future, task::{Spawn, SpawnExt, SpawnError}};
use std::{cell::RefCell, io, sync::Once, time::SystemTime};
use futures::task::*;
use std::{io, time::SystemTime};
/// A message from a client to a server.
#[derive(Debug)]
#[cfg_attr(
feature = "serde1",
derive(serde::Serialize, serde::Deserialize)
)]
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
#[non_exhaustive]
pub struct ClientMessage<T> {
/// The trace context associates the message with a specific chain of causally-related actions,
/// possibly orchestrated across many distributed systems.
pub trace_context: trace::Context,
/// The message payload.
pub message: ClientMessageKind<T>,
}
/// Different messages that can be sent from a client to a server.
#[derive(Debug)]
#[cfg_attr(
feature = "serde1",
derive(serde::Serialize, serde::Deserialize)
)]
#[non_exhaustive]
pub enum ClientMessageKind<T> {
pub enum ClientMessage<T> {
/// A request initiated by a user. The server responds to a request by invoking a
/// service-provided request handler. The handler completes with a [`response`](Response), which
/// the server sends back to the client.
@@ -87,43 +52,32 @@ pub enum ClientMessageKind<T> {
/// not be canceled, because the framework layer does not
/// know about them.
Cancel {
/// The trace context associates the message with a specific chain of causally-related actions,
/// possibly orchestrated across many distributed systems.
#[cfg_attr(feature = "serde", serde(default))]
trace_context: trace::Context,
/// The ID of the request to cancel.
request_id: u64,
},
}
/// A request from a client to a server.
#[derive(Debug)]
#[cfg_attr(
feature = "serde1",
derive(serde::Serialize, serde::Deserialize)
)]
#[derive(Clone, Copy, Debug)]
#[non_exhaustive]
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
pub struct Request<T> {
/// Trace context, deadline, and other cross-cutting concerns.
pub context: context::Context,
/// Uniquely identifies the request across all requests sent over a single channel.
pub id: u64,
/// The request body.
pub message: T,
/// When the client expects the request to be complete by. The server will cancel the request
/// if it is not complete by this time.
#[cfg_attr(
feature = "serde1",
serde(serialize_with = "util::serde::serialize_epoch_secs")
)]
#[cfg_attr(
feature = "serde1",
serde(deserialize_with = "util::serde::deserialize_epoch_secs")
)]
pub deadline: SystemTime,
}
/// A response from a server to a client.
#[derive(Debug, PartialEq, Eq)]
#[cfg_attr(
feature = "serde1",
derive(serde::Serialize, serde::Deserialize)
)]
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
#[non_exhaustive]
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
pub struct Response<T> {
/// The ID of the request being responded to.
pub request_id: u64,
@@ -132,12 +86,9 @@ pub struct Response<T> {
}
/// An error response from a server to a client.
#[derive(Debug, PartialEq, Eq)]
#[cfg_attr(
feature = "serde1",
derive(serde::Serialize, serde::Deserialize)
)]
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
#[non_exhaustive]
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
pub struct ServerError {
#[cfg_attr(
feature = "serde1",
@@ -162,54 +113,8 @@ impl From<ServerError> for io::Error {
impl<T> Request<T> {
/// Returns the deadline for this request.
pub fn deadline(&self) -> &SystemTime {
&self.deadline
&self.context.deadline
}
}
static INIT: Once = Once::new();
static mut SEED_SPAWN: Option<Box<dyn CloneSpawn>> = None;
thread_local! {
static SPAWN: RefCell<Box<dyn CloneSpawn>> = {
unsafe {
// INIT must always be called before accessing SPAWN.
// Otherwise, accessing SPAWN can trigger undefined behavior due to race conditions.
INIT.call_once(|| {});
RefCell::new(SEED_SPAWN.clone().expect("init() must be called."))
}
};
}
/// Initializes the RPC library with a mechanism to spawn futures on the user's runtime.
/// Client stubs and servers both use the initialized spawn.
///
/// Init only has an effect the first time it is called. If called previously, successive calls to
/// init are noops.
pub fn init(spawn: impl Spawn + Clone + 'static) {
unsafe {
INIT.call_once(|| {
SEED_SPAWN = Some(Box::new(spawn));
});
}
}
pub(crate) fn spawn(future: impl Future<Output = ()> + Send + 'static) -> Result<(), SpawnError> {
SPAWN.with(|spawn| {
spawn.borrow_mut().spawn(future)
})
}
trait CloneSpawn: Spawn {
fn box_clone(&self) -> Box<dyn CloneSpawn>;
}
impl Clone for Box<dyn CloneSpawn> {
fn clone(&self) -> Self {
self.box_clone()
}
}
impl<S: Spawn + Clone + 'static> CloneSpawn for S {
fn box_clone(&self) -> Box<dyn CloneSpawn> {
Box::new(self.clone())
}
}
pub(crate) type PollIo<T> = Poll<Option<io::Result<T>>>;

View File

@@ -0,0 +1,471 @@
// Copyright 2018 Google LLC
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
use crate::{
server::{self, Channel},
util::Compact,
};
use fnv::FnvHashMap;
use futures::{channel::mpsc, future::AbortRegistration, prelude::*, ready, stream::Fuse, task::*};
use log::{debug, info, trace};
use pin_project::pin_project;
use std::sync::{Arc, Weak};
use std::{
collections::hash_map::Entry, convert::TryInto, fmt, hash::Hash, marker::Unpin, pin::Pin,
};
/// A single-threaded filter that drops channels based on per-key limits.
#[pin_project]
#[derive(Debug)]
pub struct ChannelFilter<S, K, F>
where
K: Eq + Hash,
{
#[pin]
listener: Fuse<S>,
channels_per_key: u32,
#[pin]
dropped_keys: mpsc::UnboundedReceiver<K>,
#[pin]
dropped_keys_tx: mpsc::UnboundedSender<K>,
key_counts: FnvHashMap<K, Weak<Tracker<K>>>,
keymaker: F,
}
/// A channel that is tracked by a ChannelFilter.
#[pin_project]
#[derive(Debug)]
pub struct TrackedChannel<C, K> {
#[pin]
inner: C,
tracker: Arc<Tracker<K>>,
}
#[derive(Debug)]
struct Tracker<K> {
key: Option<K>,
dropped_keys: mpsc::UnboundedSender<K>,
}
impl<K> Drop for Tracker<K> {
fn drop(&mut self) {
// Don't care if the listener is dropped.
let _ = self.dropped_keys.unbounded_send(self.key.take().unwrap());
}
}
impl<C, K> Stream for TrackedChannel<C, K>
where
C: Stream,
{
type Item = <C as Stream>::Item;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
self.channel().poll_next(cx)
}
}
impl<C, I, K> Sink<I> for TrackedChannel<C, K>
where
C: Sink<I>,
{
type Error = C::Error;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.channel().poll_ready(cx)
}
fn start_send(self: Pin<&mut Self>, item: I) -> Result<(), Self::Error> {
self.channel().start_send(item)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.channel().poll_flush(cx)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.channel().poll_close(cx)
}
}
impl<C, K> AsRef<C> for TrackedChannel<C, K> {
fn as_ref(&self) -> &C {
&self.inner
}
}
impl<C, K> Channel for TrackedChannel<C, K>
where
C: Channel,
{
type Req = C::Req;
type Resp = C::Resp;
fn config(&self) -> &server::Config {
self.inner.config()
}
fn in_flight_requests(self: Pin<&mut Self>) -> usize {
self.project().inner.in_flight_requests()
}
fn start_request(self: Pin<&mut Self>, request_id: u64) -> AbortRegistration {
self.project().inner.start_request(request_id)
}
}
impl<C, K> TrackedChannel<C, K> {
/// Returns the inner channel.
pub fn get_ref(&self) -> &C {
&self.inner
}
/// Returns the pinned inner channel.
fn channel<'a>(self: Pin<&'a mut Self>) -> Pin<&'a mut C> {
self.project().inner
}
}
impl<S, K, F> ChannelFilter<S, K, F>
where
K: Eq + Hash,
S: Stream,
F: Fn(&S::Item) -> K,
{
/// Sheds new channels to stay under configured limits.
pub(crate) fn new(listener: S, channels_per_key: u32, keymaker: F) -> Self {
let (dropped_keys_tx, dropped_keys) = mpsc::unbounded();
ChannelFilter {
listener: listener.fuse(),
channels_per_key,
dropped_keys,
dropped_keys_tx,
key_counts: FnvHashMap::default(),
keymaker,
}
}
}
impl<S, K, F> ChannelFilter<S, K, F>
where
S: Stream,
K: fmt::Display + Eq + Hash + Clone + Unpin,
F: Fn(&S::Item) -> K,
{
fn handle_new_channel(
mut self: Pin<&mut Self>,
stream: S::Item,
) -> Result<TrackedChannel<S::Item, K>, K> {
let key = (self.as_mut().keymaker)(&stream);
let tracker = self.as_mut().increment_channels_for_key(key.clone())?;
trace!(
"[{}] Opening channel ({}/{}) channels for key.",
key,
Arc::strong_count(&tracker),
self.as_mut().project().channels_per_key
);
Ok(TrackedChannel {
tracker,
inner: stream,
})
}
fn increment_channels_for_key(mut self: Pin<&mut Self>, key: K) -> Result<Arc<Tracker<K>>, K> {
let channels_per_key = self.channels_per_key;
let dropped_keys = self.dropped_keys_tx.clone();
let key_counts = &mut self.as_mut().project().key_counts;
match key_counts.entry(key.clone()) {
Entry::Vacant(vacant) => {
let tracker = Arc::new(Tracker {
key: Some(key),
dropped_keys,
});
vacant.insert(Arc::downgrade(&tracker));
Ok(tracker)
}
Entry::Occupied(mut o) => {
let count = o.get().strong_count();
if count >= channels_per_key.try_into().unwrap() {
info!(
"[{}] Opened max channels from key ({}/{}).",
key, count, channels_per_key
);
Err(key)
} else {
Ok(o.get().upgrade().unwrap_or_else(|| {
let tracker = Arc::new(Tracker {
key: Some(key),
dropped_keys,
});
*o.get_mut() = Arc::downgrade(&tracker);
tracker
}))
}
}
}
}
fn poll_listener(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<TrackedChannel<S::Item, K>, K>>> {
match ready!(self.as_mut().project().listener.poll_next_unpin(cx)) {
Some(codec) => Poll::Ready(Some(self.handle_new_channel(codec))),
None => Poll::Ready(None),
}
}
fn poll_closed_channels(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
match ready!(self.as_mut().project().dropped_keys.poll_next_unpin(cx)) {
Some(key) => {
debug!("All channels dropped for key [{}]", key);
self.as_mut().project().key_counts.remove(&key);
self.as_mut().project().key_counts.compact(0.1);
Poll::Ready(())
}
None => unreachable!("Holding a copy of closed_channels and didn't close it."),
}
}
}
impl<S, K, F> Stream for ChannelFilter<S, K, F>
where
S: Stream,
K: fmt::Display + Eq + Hash + Clone + Unpin,
F: Fn(&S::Item) -> K,
{
type Item = TrackedChannel<S::Item, K>;
fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<TrackedChannel<S::Item, K>>> {
loop {
match (
self.as_mut().poll_listener(cx),
self.as_mut().poll_closed_channels(cx),
) {
(Poll::Ready(Some(Ok(channel))), _) => {
return Poll::Ready(Some(channel));
}
(Poll::Ready(Some(Err(_))), _) => {
continue;
}
(_, Poll::Ready(())) => continue,
(Poll::Pending, Poll::Pending) => return Poll::Pending,
(Poll::Ready(None), Poll::Pending) => {
trace!("Shutting down listener.");
return Poll::Ready(None);
}
}
}
}
}
#[cfg(test)]
fn ctx() -> Context<'static> {
use futures::task::*;
Context::from_waker(&noop_waker_ref())
}
#[test]
fn tracker_drop() {
use assert_matches::assert_matches;
let (tx, mut rx) = mpsc::unbounded();
Tracker {
key: Some(1),
dropped_keys: tx,
};
assert_matches!(rx.try_next(), Ok(Some(1)));
}
#[test]
fn tracked_channel_stream() {
use assert_matches::assert_matches;
use pin_utils::pin_mut;
let (chan_tx, chan) = mpsc::unbounded();
let (dropped_keys, _) = mpsc::unbounded();
let channel = TrackedChannel {
inner: chan,
tracker: Arc::new(Tracker {
key: Some(1),
dropped_keys,
}),
};
chan_tx.unbounded_send("test").unwrap();
pin_mut!(channel);
assert_matches!(channel.poll_next(&mut ctx()), Poll::Ready(Some("test")));
}
#[test]
fn tracked_channel_sink() {
use assert_matches::assert_matches;
use pin_utils::pin_mut;
let (chan, mut chan_rx) = mpsc::unbounded();
let (dropped_keys, _) = mpsc::unbounded();
let channel = TrackedChannel {
inner: chan,
tracker: Arc::new(Tracker {
key: Some(1),
dropped_keys,
}),
};
pin_mut!(channel);
assert_matches!(channel.as_mut().poll_ready(&mut ctx()), Poll::Ready(Ok(())));
assert_matches!(channel.as_mut().start_send("test"), Ok(()));
assert_matches!(channel.as_mut().poll_flush(&mut ctx()), Poll::Ready(Ok(())));
assert_matches!(chan_rx.try_next(), Ok(Some("test")));
}
#[test]
fn channel_filter_increment_channels_for_key() {
use assert_matches::assert_matches;
use pin_utils::pin_mut;
struct TestChannel {
key: &'static str,
}
let (_, listener) = mpsc::unbounded();
let filter = ChannelFilter::new(listener, 2, |chan: &TestChannel| chan.key);
pin_mut!(filter);
let tracker1 = filter.as_mut().increment_channels_for_key("key").unwrap();
assert_eq!(Arc::strong_count(&tracker1), 1);
let tracker2 = filter.as_mut().increment_channels_for_key("key").unwrap();
assert_eq!(Arc::strong_count(&tracker1), 2);
assert_matches!(filter.increment_channels_for_key("key"), Err("key"));
drop(tracker2);
assert_eq!(Arc::strong_count(&tracker1), 1);
}
#[test]
fn channel_filter_handle_new_channel() {
use assert_matches::assert_matches;
use pin_utils::pin_mut;
#[derive(Debug)]
struct TestChannel {
key: &'static str,
}
let (_, listener) = mpsc::unbounded();
let filter = ChannelFilter::new(listener, 2, |chan: &TestChannel| chan.key);
pin_mut!(filter);
let channel1 = filter
.as_mut()
.handle_new_channel(TestChannel { key: "key" })
.unwrap();
assert_eq!(Arc::strong_count(&channel1.tracker), 1);
let channel2 = filter
.as_mut()
.handle_new_channel(TestChannel { key: "key" })
.unwrap();
assert_eq!(Arc::strong_count(&channel1.tracker), 2);
assert_matches!(
filter.handle_new_channel(TestChannel { key: "key" }),
Err("key")
);
drop(channel2);
assert_eq!(Arc::strong_count(&channel1.tracker), 1);
}
#[test]
fn channel_filter_poll_listener() {
use assert_matches::assert_matches;
use pin_utils::pin_mut;
#[derive(Debug)]
struct TestChannel {
key: &'static str,
}
let (new_channels, listener) = mpsc::unbounded();
let filter = ChannelFilter::new(listener, 2, |chan: &TestChannel| chan.key);
pin_mut!(filter);
new_channels
.unbounded_send(TestChannel { key: "key" })
.unwrap();
let channel1 =
assert_matches!(filter.as_mut().poll_listener(&mut ctx()), Poll::Ready(Some(Ok(c))) => c);
assert_eq!(Arc::strong_count(&channel1.tracker), 1);
new_channels
.unbounded_send(TestChannel { key: "key" })
.unwrap();
let _channel2 =
assert_matches!(filter.as_mut().poll_listener(&mut ctx()), Poll::Ready(Some(Ok(c))) => c);
assert_eq!(Arc::strong_count(&channel1.tracker), 2);
new_channels
.unbounded_send(TestChannel { key: "key" })
.unwrap();
let key =
assert_matches!(filter.as_mut().poll_listener(&mut ctx()), Poll::Ready(Some(Err(k))) => k);
assert_eq!(key, "key");
assert_eq!(Arc::strong_count(&channel1.tracker), 2);
}
#[test]
fn channel_filter_poll_closed_channels() {
use assert_matches::assert_matches;
use pin_utils::pin_mut;
#[derive(Debug)]
struct TestChannel {
key: &'static str,
}
let (new_channels, listener) = mpsc::unbounded();
let filter = ChannelFilter::new(listener, 2, |chan: &TestChannel| chan.key);
pin_mut!(filter);
new_channels
.unbounded_send(TestChannel { key: "key" })
.unwrap();
let channel =
assert_matches!(filter.as_mut().poll_listener(&mut ctx()), Poll::Ready(Some(Ok(c))) => c);
assert_eq!(filter.key_counts.len(), 1);
drop(channel);
assert_matches!(
filter.as_mut().poll_closed_channels(&mut ctx()),
Poll::Ready(())
);
assert!(filter.key_counts.is_empty());
}
#[test]
fn channel_filter_stream() {
use assert_matches::assert_matches;
use pin_utils::pin_mut;
#[derive(Debug)]
struct TestChannel {
key: &'static str,
}
let (new_channels, listener) = mpsc::unbounded();
let filter = ChannelFilter::new(listener, 2, |chan: &TestChannel| chan.key);
pin_mut!(filter);
new_channels
.unbounded_send(TestChannel { key: "key" })
.unwrap();
let channel = assert_matches!(filter.as_mut().poll_next(&mut ctx()), Poll::Ready(Some(c)) => c);
assert_eq!(filter.key_counts.len(), 1);
drop(channel);
assert_matches!(filter.as_mut().poll_next(&mut ctx()), Poll::Pending);
assert!(filter.key_counts.is_empty());
}

697
tarpc/src/rpc/server/mod.rs Normal file
View File

@@ -0,0 +1,697 @@
// Copyright 2018 Google LLC
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
//! Provides a server that concurrently handles many connections sending multiplexed requests.
use crate::{
context, trace, util::Compact, util::TimeUntil, ClientMessage, PollIo, Request, Response,
ServerError, Transport,
};
use fnv::FnvHashMap;
use futures::{
channel::mpsc,
future::{AbortHandle, AbortRegistration, Abortable},
prelude::*,
ready,
stream::Fuse,
task::*,
};
use humantime::format_rfc3339;
use log::{debug, trace};
use pin_project::pin_project;
use std::{fmt, hash::Hash, io, marker::PhantomData, pin::Pin, time::SystemTime};
use tokio::time::Timeout;
mod filter;
#[cfg(test)]
mod testing;
mod throttle;
pub use self::{
filter::ChannelFilter,
throttle::{Throttler, ThrottlerStream},
};
/// Manages clients, serving multiplexed requests over each connection.
#[derive(Debug)]
pub struct Server<Req, Resp> {
config: Config,
ghost: PhantomData<(Req, Resp)>,
}
impl<Req, Resp> Default for Server<Req, Resp> {
fn default() -> Self {
new(Config::default())
}
}
/// Settings that control the behavior of the server.
#[derive(Clone, Debug)]
pub struct Config {
/// The number of responses per client that can be buffered server-side before being sent.
/// `pending_response_buffer` controls the buffer size of the channel that a server's
/// response tasks use to send responses to the client handler task.
pub pending_response_buffer: usize,
}
impl Default for Config {
fn default() -> Self {
Config {
pending_response_buffer: 100,
}
}
}
impl Config {
/// Returns a channel backed by `transport` and configured with `self`.
pub fn channel<Req, Resp, T>(self, transport: T) -> BaseChannel<Req, Resp, T>
where
T: Transport<Response<Resp>, ClientMessage<Req>>,
{
BaseChannel::new(self, transport)
}
}
/// Returns a new server with configuration specified `config`.
pub fn new<Req, Resp>(config: Config) -> Server<Req, Resp> {
Server {
config,
ghost: PhantomData,
}
}
impl<Req, Resp> Server<Req, Resp> {
/// Returns the config for this server.
pub fn config(&self) -> &Config {
&self.config
}
/// Returns a stream of server channels.
pub fn incoming<S, T>(self, listener: S) -> impl Stream<Item = BaseChannel<Req, Resp, T>>
where
S: Stream<Item = T>,
T: Transport<Response<Resp>, ClientMessage<Req>>,
{
listener.map(move |t| BaseChannel::new(self.config.clone(), t))
}
}
/// Basically a Fn(Req) -> impl Future<Output = Resp>;
pub trait Serve<Req>: Sized + Clone {
/// Type of response.
type Resp;
/// Type of response future.
type Fut: Future<Output = Self::Resp>;
/// Responds to a single request.
fn serve(self, ctx: context::Context, req: Req) -> Self::Fut;
}
impl<Req, Resp, Fut, F> Serve<Req> for F
where
F: FnOnce(context::Context, Req) -> Fut + Clone,
Fut: Future<Output = Resp>,
{
type Resp = Resp;
type Fut = Fut;
fn serve(self, ctx: context::Context, req: Req) -> Self::Fut {
self(ctx, req)
}
}
/// A utility trait enabling a stream to fluently chain a request handler.
pub trait Handler<C>
where
Self: Sized + Stream<Item = C>,
C: Channel,
{
/// Enforces channel per-key limits.
fn max_channels_per_key<K, KF>(self, n: u32, keymaker: KF) -> filter::ChannelFilter<Self, K, KF>
where
K: fmt::Display + Eq + Hash + Clone + Unpin,
KF: Fn(&C) -> K,
{
ChannelFilter::new(self, n, keymaker)
}
/// Caps the number of concurrent requests per channel.
fn max_concurrent_requests_per_channel(self, n: usize) -> ThrottlerStream<Self> {
ThrottlerStream::new(self, n)
}
/// Responds to all requests with `server`.
#[cfg(feature = "tokio1")]
fn respond_with<S>(self, server: S) -> Running<Self, S>
where
S: Serve<C::Req, Resp = C::Resp>,
{
Running {
incoming: self,
server,
}
}
}
impl<S, C> Handler<C> for S
where
S: Sized + Stream<Item = C>,
C: Channel,
{
}
/// BaseChannel lifts a Transport to a Channel by tracking in-flight requests.
#[pin_project]
#[derive(Debug)]
pub struct BaseChannel<Req, Resp, T> {
config: Config,
/// Writes responses to the wire and reads requests off the wire.
#[pin]
transport: Fuse<T>,
/// Number of requests currently being responded to.
in_flight_requests: FnvHashMap<u64, AbortHandle>,
/// Types the request and response.
ghost: PhantomData<(Req, Resp)>,
}
impl<Req, Resp, T> BaseChannel<Req, Resp, T>
where
T: Transport<Response<Resp>, ClientMessage<Req>>,
{
/// Creates a new channel backed by `transport` and configured with `config`.
pub fn new(config: Config, transport: T) -> Self {
BaseChannel {
config,
transport: transport.fuse(),
in_flight_requests: FnvHashMap::default(),
ghost: PhantomData,
}
}
/// Creates a new channel backed by `transport` and configured with the defaults.
pub fn with_defaults(transport: T) -> Self {
Self::new(Config::default(), transport)
}
/// Returns the inner transport.
pub fn get_ref(&self) -> &T {
self.transport.get_ref()
}
fn cancel_request(mut self: Pin<&mut Self>, trace_context: &trace::Context, request_id: u64) {
// It's possible the request was already completed, so it's fine
// if this is None.
if let Some(cancel_handle) = self
.as_mut()
.project()
.in_flight_requests
.remove(&request_id)
{
self.as_mut().project().in_flight_requests.compact(0.1);
cancel_handle.abort();
let remaining = self.as_mut().project().in_flight_requests.len();
trace!(
"[{}] Request canceled. In-flight requests = {}",
trace_context.trace_id,
remaining,
);
} else {
trace!(
"[{}] Received cancellation, but response handler \
is already complete.",
trace_context.trace_id,
);
}
}
}
/// The server end of an open connection with a client, streaming in requests from, and sinking
/// responses to, the client.
///
/// Channels are free to somewhat rely on the assumption that all in-flight requests are eventually
/// either [cancelled](BaseChannel::cancel_request) or [responded to](Sink::start_send). Safety cannot
/// rely on this assumption, but it is best for `Channel` users to always account for all outstanding
/// requests.
pub trait Channel
where
Self: Transport<Response<<Self as Channel>::Resp>, Request<<Self as Channel>::Req>>,
{
/// Type of request item.
type Req;
/// Type of response sink item.
type Resp;
/// Configuration of the channel.
fn config(&self) -> &Config;
/// Returns the number of in-flight requests over this channel.
fn in_flight_requests(self: Pin<&mut Self>) -> usize;
/// Caps the number of concurrent requests.
fn max_concurrent_requests(self, n: usize) -> Throttler<Self>
where
Self: Sized,
{
Throttler::new(self, n)
}
/// Tells the Channel that request with ID `request_id` is being handled.
/// The request will be tracked until a response with the same ID is sent
/// to the Channel.
fn start_request(self: Pin<&mut Self>, request_id: u64) -> AbortRegistration;
/// Respond to requests coming over the channel with `f`. Returns a future that drives the
/// responses and resolves when the connection is closed.
fn respond_with<S>(self, server: S) -> ClientHandler<Self, S>
where
S: Serve<Self::Req, Resp = Self::Resp>,
Self: Sized,
{
let (responses_tx, responses) = mpsc::channel(self.config().pending_response_buffer);
let responses = responses.fuse();
ClientHandler {
channel: self,
server,
pending_responses: responses,
responses_tx,
}
}
}
impl<Req, Resp, T> Stream for BaseChannel<Req, Resp, T>
where
T: Transport<Response<Resp>, ClientMessage<Req>>,
{
type Item = io::Result<Request<Req>>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
loop {
match ready!(self.as_mut().project().transport.poll_next(cx)?) {
Some(message) => match message {
ClientMessage::Request(request) => {
return Poll::Ready(Some(Ok(request)));
}
ClientMessage::Cancel {
trace_context,
request_id,
} => {
self.as_mut().cancel_request(&trace_context, request_id);
}
},
None => return Poll::Ready(None),
}
}
}
}
impl<Req, Resp, T> Sink<Response<Resp>> for BaseChannel<Req, Resp, T>
where
T: Transport<Response<Resp>, ClientMessage<Req>>,
{
type Error = io::Error;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.project().transport.poll_ready(cx)
}
fn start_send(mut self: Pin<&mut Self>, response: Response<Resp>) -> Result<(), Self::Error> {
if self
.as_mut()
.project()
.in_flight_requests
.remove(&response.request_id)
.is_some()
{
self.as_mut().project().in_flight_requests.compact(0.1);
}
self.project().transport.start_send(response)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.project().transport.poll_flush(cx)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.project().transport.poll_close(cx)
}
}
impl<Req, Resp, T> AsRef<T> for BaseChannel<Req, Resp, T> {
fn as_ref(&self) -> &T {
self.transport.get_ref()
}
}
impl<Req, Resp, T> Channel for BaseChannel<Req, Resp, T>
where
T: Transport<Response<Resp>, ClientMessage<Req>>,
{
type Req = Req;
type Resp = Resp;
fn config(&self) -> &Config {
&self.config
}
fn in_flight_requests(mut self: Pin<&mut Self>) -> usize {
self.as_mut().project().in_flight_requests.len()
}
fn start_request(self: Pin<&mut Self>, request_id: u64) -> AbortRegistration {
let (abort_handle, abort_registration) = AbortHandle::new_pair();
assert!(self
.project()
.in_flight_requests
.insert(request_id, abort_handle)
.is_none());
abort_registration
}
}
/// A running handler serving all requests coming over a channel.
#[pin_project]
#[derive(Debug)]
pub struct ClientHandler<C, S>
where
C: Channel,
{
#[pin]
channel: C,
/// Responses waiting to be written to the wire.
#[pin]
pending_responses: Fuse<mpsc::Receiver<(context::Context, Response<C::Resp>)>>,
/// Handed out to request handlers to fan in responses.
#[pin]
responses_tx: mpsc::Sender<(context::Context, Response<C::Resp>)>,
/// Server
server: S,
}
impl<C, S> ClientHandler<C, S>
where
C: Channel,
S: Serve<C::Req, Resp = C::Resp>,
{
fn pump_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> PollIo<RequestHandler<S::Fut, C::Resp>> {
match ready!(self.as_mut().project().channel.poll_next(cx)?) {
Some(request) => Poll::Ready(Some(Ok(self.handle_request(request)))),
None => Poll::Ready(None),
}
}
fn pump_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
read_half_closed: bool,
) -> PollIo<()> {
match self.as_mut().poll_next_response(cx)? {
Poll::Ready(Some((ctx, response))) => {
trace!(
"[{}] Staging response. In-flight requests = {}.",
ctx.trace_id(),
self.as_mut().project().channel.in_flight_requests(),
);
self.as_mut().project().channel.start_send(response)?;
Poll::Ready(Some(Ok(())))
}
Poll::Ready(None) => {
// Shutdown can't be done before we finish pumping out remaining responses.
ready!(self.as_mut().project().channel.poll_flush(cx)?);
Poll::Ready(None)
}
Poll::Pending => {
// No more requests to process, so flush any requests buffered in the transport.
ready!(self.as_mut().project().channel.poll_flush(cx)?);
// Being here means there are no staged requests and all written responses are
// fully flushed. So, if the read half is closed and there are no in-flight
// requests, then we can close the write half.
if read_half_closed && self.as_mut().project().channel.in_flight_requests() == 0 {
Poll::Ready(None)
} else {
Poll::Pending
}
}
}
}
fn poll_next_response(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> PollIo<(context::Context, Response<C::Resp>)> {
// Ensure there's room to write a response.
while let Poll::Pending = self.as_mut().project().channel.poll_ready(cx)? {
ready!(self.as_mut().project().channel.poll_flush(cx)?);
}
match ready!(self.as_mut().project().pending_responses.poll_next(cx)) {
Some((ctx, response)) => Poll::Ready(Some(Ok((ctx, response)))),
None => {
// This branch likely won't happen, since the ClientHandler is holding a Sender.
Poll::Ready(None)
}
}
}
fn handle_request(
mut self: Pin<&mut Self>,
request: Request<C::Req>,
) -> RequestHandler<S::Fut, C::Resp> {
let request_id = request.id;
let deadline = request.context.deadline;
let timeout = deadline.time_until();
trace!(
"[{}] Received request with deadline {} (timeout {:?}).",
request.context.trace_id(),
format_rfc3339(deadline),
timeout,
);
let ctx = request.context;
let request = request.message;
let response = self.as_mut().project().server.clone().serve(ctx, request);
let response = Resp {
state: RespState::PollResp,
request_id,
ctx,
deadline,
f: tokio::time::timeout(timeout, response),
response: None,
response_tx: self.as_mut().project().responses_tx.clone(),
};
let abort_registration = self.as_mut().project().channel.start_request(request_id);
RequestHandler {
resp: Abortable::new(response, abort_registration),
}
}
}
/// A future fulfilling a single client request.
#[pin_project]
#[derive(Debug)]
pub struct RequestHandler<F, R> {
#[pin]
resp: Abortable<Resp<F, R>>,
}
impl<F, R> Future for RequestHandler<F, R>
where
F: Future<Output = R>,
{
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
let _ = ready!(self.project().resp.poll(cx));
Poll::Ready(())
}
}
#[pin_project]
#[derive(Debug)]
struct Resp<F, R> {
state: RespState,
request_id: u64,
ctx: context::Context,
deadline: SystemTime,
#[pin]
f: Timeout<F>,
response: Option<Response<R>>,
#[pin]
response_tx: mpsc::Sender<(context::Context, Response<R>)>,
}
#[derive(Debug)]
#[allow(clippy::enum_variant_names)]
enum RespState {
PollResp,
PollReady,
PollFlush,
}
impl<F, R> Future for Resp<F, R>
where
F: Future<Output = R>,
{
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
loop {
match self.as_mut().project().state {
RespState::PollResp => {
let result = ready!(self.as_mut().project().f.poll(cx));
*self.as_mut().project().response = Some(Response {
request_id: self.request_id,
message: match result {
Ok(message) => Ok(message),
Err(tokio::time::Elapsed { .. }) => {
debug!(
"[{}] Response did not complete before deadline of {}s.",
self.ctx.trace_id(),
format_rfc3339(self.deadline)
);
// No point in responding, since the client will have dropped the
// request.
Err(ServerError {
kind: io::ErrorKind::TimedOut,
detail: Some(format!(
"Response did not complete before deadline of {}s.",
format_rfc3339(self.deadline)
)),
})
}
},
});
*self.as_mut().project().state = RespState::PollReady;
}
RespState::PollReady => {
let ready = ready!(self.as_mut().project().response_tx.poll_ready(cx));
if ready.is_err() {
return Poll::Ready(());
}
let resp = (self.ctx, self.as_mut().project().response.take().unwrap());
if self
.as_mut()
.project()
.response_tx
.start_send(resp)
.is_err()
{
return Poll::Ready(());
}
*self.as_mut().project().state = RespState::PollFlush;
}
RespState::PollFlush => {
let ready = ready!(self.as_mut().project().response_tx.poll_flush(cx));
if ready.is_err() {
return Poll::Ready(());
}
return Poll::Ready(());
}
}
}
}
}
impl<C, S> Stream for ClientHandler<C, S>
where
C: Channel,
S: Serve<C::Req, Resp = C::Resp>,
{
type Item = io::Result<RequestHandler<S::Fut, C::Resp>>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
loop {
let read = self.as_mut().pump_read(cx)?;
let read_closed = if let Poll::Ready(None) = read {
true
} else {
false
};
match (read, self.as_mut().pump_write(cx, read_closed)?) {
(Poll::Ready(None), Poll::Ready(None)) => {
return Poll::Ready(None);
}
(Poll::Ready(Some(request_handler)), _) => {
return Poll::Ready(Some(Ok(request_handler)));
}
(_, Poll::Ready(Some(()))) => {}
_ => {
return Poll::Pending;
}
}
}
}
}
// Send + 'static execution helper methods.
impl<C, S> ClientHandler<C, S>
where
C: Channel + 'static,
C::Req: Send + 'static,
C::Resp: Send + 'static,
S: Serve<C::Req, Resp = C::Resp> + Send + 'static,
S::Fut: Send + 'static,
{
/// Runs the client handler until completion by spawning each
/// request handler onto the default executor.
#[cfg(feature = "tokio1")]
pub fn execute(self) -> impl Future<Output = ()> {
use log::info;
self.try_for_each(|request_handler| async {
tokio::spawn(request_handler);
Ok(())
})
.unwrap_or_else(|e| info!("ClientHandler errored out: {}", e))
}
}
/// A future that drives the server by spawning channels and request handlers on the default
/// executor.
#[pin_project]
#[derive(Debug)]
#[cfg(feature = "tokio1")]
pub struct Running<St, Se> {
#[pin]
incoming: St,
server: Se,
}
#[cfg(feature = "tokio1")]
impl<St, C, Se> Future for Running<St, Se>
where
St: Sized + Stream<Item = C>,
C: Channel + Send + 'static,
C::Req: Send + 'static,
C::Resp: Send + 'static,
Se: Serve<C::Req, Resp = C::Resp> + Send + 'static + Clone,
Se::Fut: Send + 'static,
{
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
use log::info;
while let Some(channel) = ready!(self.as_mut().project().incoming.poll_next(cx)) {
tokio::spawn(
channel
.respond_with(self.as_mut().project().server.clone())
.execute(),
);
}
info!("Server shutting down.");
Poll::Ready(())
}
}

View File

@@ -0,0 +1,123 @@
use crate::server::{Channel, Config};
use crate::{context, Request, Response};
use fnv::FnvHashSet;
use futures::{
future::{AbortHandle, AbortRegistration},
task::*,
Sink, Stream,
};
use pin_project::pin_project;
use std::collections::VecDeque;
use std::io;
use std::pin::Pin;
use std::time::SystemTime;
#[pin_project]
pub(crate) struct FakeChannel<In, Out> {
#[pin]
pub stream: VecDeque<In>,
#[pin]
pub sink: VecDeque<Out>,
pub config: Config,
pub in_flight_requests: FnvHashSet<u64>,
}
impl<In, Out> Stream for FakeChannel<In, Out>
where
In: Unpin,
{
type Item = In;
fn poll_next(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<Option<Self::Item>> {
Poll::Ready(self.project().stream.pop_front())
}
}
impl<In, Resp> Sink<Response<Resp>> for FakeChannel<In, Response<Resp>> {
type Error = io::Error;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.project().sink.poll_ready(cx).map_err(|e| match e {})
}
fn start_send(mut self: Pin<&mut Self>, response: Response<Resp>) -> Result<(), Self::Error> {
self.as_mut()
.project()
.in_flight_requests
.remove(&response.request_id);
self.project()
.sink
.start_send(response)
.map_err(|e| match e {})
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.project().sink.poll_flush(cx).map_err(|e| match e {})
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.project().sink.poll_close(cx).map_err(|e| match e {})
}
}
impl<Req, Resp> Channel for FakeChannel<io::Result<Request<Req>>, Response<Resp>>
where
Req: Unpin,
{
type Req = Req;
type Resp = Resp;
fn config(&self) -> &Config {
&self.config
}
fn in_flight_requests(self: Pin<&mut Self>) -> usize {
self.in_flight_requests.len()
}
fn start_request(self: Pin<&mut Self>, id: u64) -> AbortRegistration {
self.project().in_flight_requests.insert(id);
AbortHandle::new_pair().1
}
}
impl<Req, Resp> FakeChannel<io::Result<Request<Req>>, Response<Resp>> {
pub fn push_req(&mut self, id: u64, message: Req) {
self.stream.push_back(Ok(Request {
context: context::Context {
deadline: SystemTime::UNIX_EPOCH,
trace_context: Default::default(),
},
id,
message,
}));
}
}
impl FakeChannel<(), ()> {
pub fn default<Req, Resp>() -> FakeChannel<io::Result<Request<Req>>, Response<Resp>> {
FakeChannel {
stream: Default::default(),
sink: Default::default(),
config: Default::default(),
in_flight_requests: Default::default(),
}
}
}
pub trait PollExt {
fn is_done(&self) -> bool;
}
impl<T> PollExt for Poll<Option<T>> {
fn is_done(&self) -> bool {
match self {
Poll::Ready(None) => true,
_ => false,
}
}
}
pub fn cx() -> Context<'static> {
Context::from_waker(&noop_waker_ref())
}

View File

@@ -0,0 +1,322 @@
use super::{Channel, Config};
use crate::{Response, ServerError};
use futures::{future::AbortRegistration, prelude::*, ready, task::*};
use log::debug;
use pin_project::pin_project;
use std::{io, pin::Pin};
/// A [`Channel`] that limits the number of concurrent
/// requests by throttling.
#[pin_project]
#[derive(Debug)]
pub struct Throttler<C> {
max_in_flight_requests: usize,
#[pin]
inner: C,
}
impl<C> Throttler<C> {
/// Returns the inner channel.
pub fn get_ref(&self) -> &C {
&self.inner
}
}
impl<C> Throttler<C>
where
C: Channel,
{
/// Returns a new `Throttler` that wraps the given channel and limits concurrent requests to
/// `max_in_flight_requests`.
pub fn new(inner: C, max_in_flight_requests: usize) -> Self {
Throttler {
inner,
max_in_flight_requests,
}
}
}
impl<C> Stream for Throttler<C>
where
C: Channel,
{
type Item = <C as Stream>::Item;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
while self.as_mut().in_flight_requests() >= *self.as_mut().project().max_in_flight_requests
{
ready!(self.as_mut().project().inner.poll_ready(cx)?);
match ready!(self.as_mut().project().inner.poll_next(cx)?) {
Some(request) => {
debug!(
"[{}] Client has reached in-flight request limit ({}/{}).",
request.context.trace_id(),
self.as_mut().in_flight_requests(),
self.as_mut().project().max_in_flight_requests,
);
self.as_mut().start_send(Response {
request_id: request.id,
message: Err(ServerError {
kind: io::ErrorKind::WouldBlock,
detail: Some("Server throttled the request.".into()),
}),
})?;
}
None => return Poll::Ready(None),
}
}
self.project().inner.poll_next(cx)
}
}
impl<C> Sink<Response<<C as Channel>::Resp>> for Throttler<C>
where
C: Channel,
{
type Error = io::Error;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.project().inner.poll_ready(cx)
}
fn start_send(self: Pin<&mut Self>, item: Response<<C as Channel>::Resp>) -> io::Result<()> {
self.project().inner.start_send(item)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
self.project().inner.poll_flush(cx)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
self.project().inner.poll_close(cx)
}
}
impl<C> AsRef<C> for Throttler<C> {
fn as_ref(&self) -> &C {
&self.inner
}
}
impl<C> Channel for Throttler<C>
where
C: Channel,
{
type Req = <C as Channel>::Req;
type Resp = <C as Channel>::Resp;
fn in_flight_requests(self: Pin<&mut Self>) -> usize {
self.project().inner.in_flight_requests()
}
fn config(&self) -> &Config {
self.inner.config()
}
fn start_request(self: Pin<&mut Self>, request_id: u64) -> AbortRegistration {
self.project().inner.start_request(request_id)
}
}
/// A stream of throttling channels.
#[pin_project]
#[derive(Debug)]
pub struct ThrottlerStream<S> {
#[pin]
inner: S,
max_in_flight_requests: usize,
}
impl<S> ThrottlerStream<S>
where
S: Stream,
<S as Stream>::Item: Channel,
{
pub(crate) fn new(inner: S, max_in_flight_requests: usize) -> Self {
Self {
inner,
max_in_flight_requests,
}
}
}
impl<S> Stream for ThrottlerStream<S>
where
S: Stream,
<S as Stream>::Item: Channel,
{
type Item = Throttler<<S as Stream>::Item>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
match ready!(self.as_mut().project().inner.poll_next(cx)) {
Some(channel) => Poll::Ready(Some(Throttler::new(
channel,
*self.project().max_in_flight_requests,
))),
None => Poll::Ready(None),
}
}
}
#[cfg(test)]
use super::testing::{self, FakeChannel, PollExt};
#[cfg(test)]
use crate::Request;
#[cfg(test)]
use pin_utils::pin_mut;
#[cfg(test)]
use std::marker::PhantomData;
#[test]
fn throttler_in_flight_requests() {
let throttler = Throttler {
max_in_flight_requests: 0,
inner: FakeChannel::default::<isize, isize>(),
};
pin_mut!(throttler);
for i in 0..5 {
throttler.inner.in_flight_requests.insert(i);
}
assert_eq!(throttler.as_mut().in_flight_requests(), 5);
}
#[test]
fn throttler_start_request() {
let throttler = Throttler {
max_in_flight_requests: 0,
inner: FakeChannel::default::<isize, isize>(),
};
pin_mut!(throttler);
throttler.as_mut().start_request(1);
assert_eq!(throttler.inner.in_flight_requests.len(), 1);
}
#[test]
fn throttler_poll_next_done() {
let throttler = Throttler {
max_in_flight_requests: 0,
inner: FakeChannel::default::<isize, isize>(),
};
pin_mut!(throttler);
assert!(throttler.as_mut().poll_next(&mut testing::cx()).is_done());
}
#[test]
fn throttler_poll_next_some() -> io::Result<()> {
let throttler = Throttler {
max_in_flight_requests: 1,
inner: FakeChannel::default::<isize, isize>(),
};
pin_mut!(throttler);
throttler.inner.push_req(0, 1);
assert!(throttler.as_mut().poll_ready(&mut testing::cx()).is_ready());
assert_eq!(
throttler
.as_mut()
.poll_next(&mut testing::cx())?
.map(|r| r.map(|r| (r.id, r.message))),
Poll::Ready(Some((0, 1)))
);
Ok(())
}
#[test]
fn throttler_poll_next_throttled() {
let throttler = Throttler {
max_in_flight_requests: 0,
inner: FakeChannel::default::<isize, isize>(),
};
pin_mut!(throttler);
throttler.inner.push_req(1, 1);
assert!(throttler.as_mut().poll_next(&mut testing::cx()).is_done());
assert_eq!(throttler.inner.sink.len(), 1);
let resp = throttler.inner.sink.get(0).unwrap();
assert_eq!(resp.request_id, 1);
assert!(resp.message.is_err());
}
#[test]
fn throttler_poll_next_throttled_sink_not_ready() {
let throttler = Throttler {
max_in_flight_requests: 0,
inner: PendingSink::default::<isize, isize>(),
};
pin_mut!(throttler);
assert!(throttler.poll_next(&mut testing::cx()).is_pending());
struct PendingSink<In, Out> {
ghost: PhantomData<fn(Out) -> In>,
}
impl PendingSink<(), ()> {
pub fn default<Req, Resp>() -> PendingSink<io::Result<Request<Req>>, Response<Resp>> {
PendingSink { ghost: PhantomData }
}
}
impl<In, Out> Stream for PendingSink<In, Out> {
type Item = In;
fn poll_next(self: Pin<&mut Self>, _: &mut Context) -> Poll<Option<Self::Item>> {
unimplemented!()
}
}
impl<In, Out> Sink<Out> for PendingSink<In, Out> {
type Error = io::Error;
fn poll_ready(self: Pin<&mut Self>, _: &mut Context) -> Poll<Result<(), Self::Error>> {
Poll::Pending
}
fn start_send(self: Pin<&mut Self>, _: Out) -> Result<(), Self::Error> {
Err(io::Error::from(io::ErrorKind::WouldBlock))
}
fn poll_flush(self: Pin<&mut Self>, _: &mut Context) -> Poll<Result<(), Self::Error>> {
Poll::Pending
}
fn poll_close(self: Pin<&mut Self>, _: &mut Context) -> Poll<Result<(), Self::Error>> {
Poll::Pending
}
}
impl<Req, Resp> Channel for PendingSink<io::Result<Request<Req>>, Response<Resp>> {
type Req = Req;
type Resp = Resp;
fn config(&self) -> &Config {
unimplemented!()
}
fn in_flight_requests(self: Pin<&mut Self>) -> usize {
0
}
fn start_request(self: Pin<&mut Self>, _: u64) -> AbortRegistration {
unimplemented!()
}
}
}
#[test]
fn throttler_start_send() {
let throttler = Throttler {
max_in_flight_requests: 0,
inner: FakeChannel::default::<isize, isize>(),
};
pin_mut!(throttler);
throttler.inner.in_flight_requests.insert(0);
throttler
.as_mut()
.start_send(Response {
request_id: 0,
message: Ok(1),
})
.unwrap();
assert!(throttler.inner.in_flight_requests.is_empty());
assert_eq!(
throttler.inner.sink.get(0),
Some(&Response {
request_id: 0,
message: Ok(1),
})
);
}

View File

@@ -0,0 +1,123 @@
// Copyright 2018 Google LLC
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
//! Transports backed by in-memory channels.
use crate::PollIo;
use futures::{channel::mpsc, task::*, Sink, Stream};
use pin_project::pin_project;
use std::io;
use std::pin::Pin;
/// Returns two unbounded channel peers. Each [`Stream`] yields items sent through the other's
/// [`Sink`].
pub fn unbounded<SinkItem, Item>() -> (
UnboundedChannel<SinkItem, Item>,
UnboundedChannel<Item, SinkItem>,
) {
let (tx1, rx2) = mpsc::unbounded();
let (tx2, rx1) = mpsc::unbounded();
(
UnboundedChannel { tx: tx1, rx: rx1 },
UnboundedChannel { tx: tx2, rx: rx2 },
)
}
/// A bi-directional channel backed by an [`UnboundedSender`](mpsc::UnboundedSender)
/// and [`UnboundedReceiver`](mpsc::UnboundedReceiver).
#[pin_project]
#[derive(Debug)]
pub struct UnboundedChannel<Item, SinkItem> {
#[pin]
rx: mpsc::UnboundedReceiver<Item>,
#[pin]
tx: mpsc::UnboundedSender<SinkItem>,
}
impl<Item, SinkItem> Stream for UnboundedChannel<Item, SinkItem> {
type Item = Result<Item, io::Error>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> PollIo<Item> {
self.project().rx.poll_next(cx).map(|option| option.map(Ok))
}
}
impl<Item, SinkItem> Sink<SinkItem> for UnboundedChannel<Item, SinkItem> {
type Error = io::Error;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.project()
.tx
.poll_ready(cx)
.map_err(|_| io::Error::from(io::ErrorKind::NotConnected))
}
fn start_send(self: Pin<&mut Self>, item: SinkItem) -> io::Result<()> {
self.project()
.tx
.start_send(item)
.map_err(|_| io::Error::from(io::ErrorKind::NotConnected))
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.project()
.tx
.poll_flush(cx)
.map_err(|_| io::Error::from(io::ErrorKind::NotConnected))
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.project()
.tx
.poll_close(cx)
.map_err(|_| io::Error::from(io::ErrorKind::NotConnected))
}
}
#[cfg(test)]
mod tests {
use crate::{
client, context,
server::{Handler, Server},
transport,
};
use assert_matches::assert_matches;
use futures::{prelude::*, stream};
use log::trace;
use std::io;
#[cfg(feature = "tokio1")]
#[tokio::test(threaded_scheduler)]
async fn integration() -> io::Result<()> {
let _ = env_logger::try_init();
let (client_channel, server_channel) = transport::channel::unbounded();
tokio::spawn(
Server::default()
.incoming(stream::once(future::ready(server_channel)))
.respond_with(|_ctx, request: String| {
future::ready(request.parse::<u64>().map_err(|_| {
io::Error::new(
io::ErrorKind::InvalidInput,
format!("{:?} is not an int", request),
)
}))
}),
);
let mut client = client::new(client::Config::default(), client_channel).spawn()?;
let response1 = client.call(context::current(), "123".into()).await?;
let response2 = client.call(context::current(), "abc".into()).await?;
trace!("response1: {:?}, response2: {:?}", response1, response2);
assert_matches!(response1, Ok(123));
assert_matches!(response2, Err(ref e) if e.kind() == io::ErrorKind::InvalidInput);
Ok(())
}
}

View File

@@ -0,0 +1,30 @@
// Copyright 2018 Google LLC
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
//! Provides a [`Transport`] trait as well as implementations.
//!
//! The rpc crate is transport- and protocol-agnostic. Any transport that impls [`Transport`]
//! can be plugged in, using whatever protocol it wants.
use futures::prelude::*;
use std::io;
pub mod channel;
pub(crate) mod sealed {
use super::*;
/// A bidirectional stream ([`Sink`] + [`Stream`]) of messages.
pub trait Transport<SinkItem, Item>:
Stream<Item = io::Result<Item>> + Sink<SinkItem, Error = io::Error>
{
}
impl<T, SinkItem, Item> Transport<SinkItem, Item> for T where
T: Stream<Item = io::Result<Item>> + Sink<SinkItem, Error = io::Error> + ?Sized
{
}
}

View File

@@ -10,18 +10,17 @@ use std::{
time::{Duration, SystemTime},
};
pub mod deadline_compat;
#[cfg(feature = "serde")]
pub mod serde;
/// Types that can be represented by a [`Duration`].
pub trait AsDuration {
fn as_duration(&self) -> Duration;
/// Extension trait for [SystemTimes](SystemTime) in the future, i.e. deadlines.
pub trait TimeUntil {
/// How much time from now until this time is reached.
fn time_until(&self) -> Duration;
}
impl AsDuration for SystemTime {
/// Duration of 0 if self is earlier than [`SystemTime::now`].
fn as_duration(&self) -> Duration {
impl TimeUntil for SystemTime {
fn time_until(&self) -> Duration {
self.duration_since(SystemTime::now()).unwrap_or_default()
}
}
@@ -38,9 +37,11 @@ where
H: BuildHasher,
{
fn compact(&mut self, usage_ratio_threshold: f64) {
let usage_ratio = self.len() as f64 / self.capacity() as f64;
if usage_ratio < usage_ratio_threshold {
self.shrink_to_fit();
if self.capacity() > 1000 {
let usage_ratio = self.len() as f64 / self.capacity() as f64;
if usage_ratio < usage_ratio_threshold {
self.shrink_to_fit();
}
}
}
}

View File

@@ -31,6 +31,7 @@ where
}
/// Serializes [`io::ErrorKind`] as a `u32`.
#[allow(clippy::trivially_copy_pass_by_ref)] // Exact fn signature required by serde derive
pub fn serialize_io_error_kind_as_u32<S>(
kind: &io::ErrorKind,
serializer: S,
@@ -59,7 +60,8 @@ where
Other => 16,
UnexpectedEof => 17,
_ => 16,
}.serialize(serializer)
}
.serialize(serializer)
}
/// Deserializes [`io::ErrorKind`] from a `u32`.

View File

@@ -0,0 +1,326 @@
// Copyright 2019 Google LLC
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
//! A generic Serde-based `Transport` that can serialize anything supported by `tokio-serde` via any medium that implements `AsyncRead` and `AsyncWrite`.
#![deny(missing_docs)]
use futures::{prelude::*, task::*};
use pin_project::pin_project;
use serde::{Deserialize, Serialize};
use std::{error::Error, io, pin::Pin};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_serde::{Framed as SerdeFramed, *};
use tokio_util::codec::{length_delimited::LengthDelimitedCodec, Framed};
/// A transport that serializes to, and deserializes from, a [`TcpStream`].
#[pin_project]
pub struct Transport<S, Item, SinkItem, Codec> {
#[pin]
inner: SerdeFramed<Framed<S, LengthDelimitedCodec>, Item, SinkItem, Codec>,
}
impl<S, Item, SinkItem, Codec, CodecError> Stream for Transport<S, Item, SinkItem, Codec>
where
S: AsyncWrite + AsyncRead,
Item: for<'a> Deserialize<'a>,
Codec: Deserializer<Item>,
CodecError: Into<Box<dyn std::error::Error + Send + Sync>>,
SerdeFramed<Framed<S, LengthDelimitedCodec>, Item, SinkItem, Codec>:
Stream<Item = Result<Item, CodecError>>,
{
type Item = io::Result<Item>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<io::Result<Item>>> {
match self.project().inner.poll_next(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(None) => Poll::Ready(None),
Poll::Ready(Some(Ok::<_, CodecError>(next))) => Poll::Ready(Some(Ok(next))),
Poll::Ready(Some(Err::<_, CodecError>(e))) => {
Poll::Ready(Some(Err(io::Error::new(io::ErrorKind::Other, e))))
}
}
}
}
impl<S, Item, SinkItem, Codec, CodecError> Sink<SinkItem> for Transport<S, Item, SinkItem, Codec>
where
S: AsyncWrite,
SinkItem: Serialize,
Codec: Serializer<SinkItem>,
CodecError: Into<Box<dyn Error + Send + Sync>>,
SerdeFramed<Framed<S, LengthDelimitedCodec>, Item, SinkItem, Codec>:
Sink<SinkItem, Error = CodecError>,
{
type Error = io::Error;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
convert(self.project().inner.poll_ready(cx))
}
fn start_send(self: Pin<&mut Self>, item: SinkItem) -> io::Result<()> {
self.project()
.inner
.start_send(item)
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
convert(self.project().inner.poll_flush(cx))
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
convert(self.project().inner.poll_close(cx))
}
}
fn convert<E: Into<Box<dyn Error + Send + Sync>>>(
poll: Poll<Result<(), E>>,
) -> Poll<io::Result<()>> {
poll.map(|ready| ready.map_err(|e| io::Error::new(io::ErrorKind::Other, e)))
}
impl<S, Item, SinkItem, Codec> From<(S, Codec)> for Transport<S, Item, SinkItem, Codec>
where
S: AsyncWrite + AsyncRead,
Item: for<'de> Deserialize<'de>,
SinkItem: Serialize,
Codec: Serializer<SinkItem> + Deserializer<Item>,
{
fn from((inner, codec): (S, Codec)) -> Self {
Transport {
inner: SerdeFramed::new(Framed::new(inner, LengthDelimitedCodec::new()), codec),
}
}
}
#[cfg(feature = "tcp")]
#[cfg_attr(docsrs, doc(cfg(feature = "tcp")))]
/// TCP support for generic transport using Tokio.
pub mod tcp {
use {
super::*,
futures::ready,
std::{marker::PhantomData, net::SocketAddr},
tokio::net::{TcpListener, TcpStream, ToSocketAddrs},
};
mod private {
use super::*;
pub trait Sealed {}
impl<Item, SinkItem, Codec> Sealed for Transport<TcpStream, Item, SinkItem, Codec> {}
}
impl<Item, SinkItem, Codec> Transport<TcpStream, Item, SinkItem, Codec> {
/// Returns the peer address of the underlying TcpStream.
pub fn peer_addr(&self) -> io::Result<SocketAddr> {
self.inner.get_ref().get_ref().peer_addr()
}
/// Returns the local address of the underlying TcpStream.
pub fn local_addr(&self) -> io::Result<SocketAddr> {
self.inner.get_ref().get_ref().local_addr()
}
}
/// Returns a new JSON transport that reads from and writes to `io`.
pub fn new<Item, SinkItem, Codec>(
io: TcpStream,
codec: Codec,
) -> Transport<TcpStream, Item, SinkItem, Codec>
where
Item: for<'de> Deserialize<'de>,
SinkItem: Serialize,
Codec: Serializer<SinkItem> + Deserializer<Item>,
{
Transport::from((io, codec))
}
/// Connects to `addr`, wrapping the connection in a JSON transport.
pub async fn connect<A, Item, SinkItem, Codec>(
addr: A,
codec: Codec,
) -> io::Result<Transport<TcpStream, Item, SinkItem, Codec>>
where
A: ToSocketAddrs,
Item: for<'de> Deserialize<'de>,
SinkItem: Serialize,
Codec: Serializer<SinkItem> + Deserializer<Item>,
{
Ok(new(TcpStream::connect(addr).await?, codec))
}
/// Listens on `addr`, wrapping accepted connections in JSON transports.
pub async fn listen<A, Item, SinkItem, Codec, CodecFn>(
addr: A,
codec_fn: CodecFn,
) -> io::Result<Incoming<Item, SinkItem, Codec, CodecFn>>
where
A: ToSocketAddrs,
Item: for<'de> Deserialize<'de>,
Codec: Serializer<SinkItem> + Deserializer<Item>,
CodecFn: Fn() -> Codec,
{
let listener = TcpListener::bind(addr).await?;
let local_addr = listener.local_addr()?;
Ok(Incoming {
listener,
codec_fn,
local_addr,
ghost: PhantomData,
})
}
/// A [`TcpListener`] that wraps connections in JSON transports.
#[pin_project]
#[derive(Debug)]
pub struct Incoming<Item, SinkItem, Codec, CodecFn> {
listener: TcpListener,
local_addr: SocketAddr,
codec_fn: CodecFn,
ghost: PhantomData<(Item, SinkItem, Codec)>,
}
impl<Item, SinkItem, Codec, CodecFn> Incoming<Item, SinkItem, Codec, CodecFn> {
/// Returns the address being listened on.
pub fn local_addr(&self) -> SocketAddr {
self.local_addr
}
}
impl<Item, SinkItem, Codec, CodecFn> Stream for Incoming<Item, SinkItem, Codec, CodecFn>
where
Item: for<'de> Deserialize<'de>,
SinkItem: Serialize,
Codec: Serializer<SinkItem> + Deserializer<Item>,
CodecFn: Fn() -> Codec,
{
type Item = io::Result<Transport<TcpStream, Item, SinkItem, Codec>>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let next =
ready!(Pin::new(&mut self.as_mut().project().listener.incoming()).poll_next(cx)?);
Poll::Ready(next.map(|conn| Ok(new(conn, (self.codec_fn)()))))
}
}
}
#[cfg(test)]
mod tests {
use super::Transport;
use assert_matches::assert_matches;
use futures::{task::*, Sink, Stream};
use pin_utils::pin_mut;
use std::{
io::{self, Cursor},
pin::Pin,
};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_serde::formats::SymmetricalJson;
fn ctx() -> Context<'static> {
Context::from_waker(&noop_waker_ref())
}
#[test]
fn test_stream() {
struct TestIo(Cursor<&'static [u8]>);
impl AsyncRead for TestIo {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
AsyncRead::poll_read(Pin::new(self.0.get_mut()), cx, buf)
}
}
impl AsyncWrite for TestIo {
fn poll_write(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
_buf: &[u8],
) -> Poll<io::Result<usize>> {
unreachable!()
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
unreachable!()
}
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
unreachable!()
}
}
let data = b"\x00\x00\x00\x18\"Test one, check check.\"";
let transport = Transport::from((
TestIo(Cursor::new(data)),
SymmetricalJson::<String>::default(),
));
pin_mut!(transport);
assert_matches!(
transport.poll_next(&mut ctx()),
Poll::Ready(Some(Ok(ref s))) if s == "Test one, check check.");
}
#[test]
fn test_sink() {
struct TestIo<'a>(&'a mut Vec<u8>);
impl<'a> AsyncRead for TestIo<'a> {
fn poll_read(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
_buf: &mut [u8],
) -> Poll<io::Result<usize>> {
unreachable!()
}
}
impl<'a> AsyncWrite for TestIo<'a> {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
AsyncWrite::poll_write(Pin::new(&mut *self.0), cx, buf)
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
AsyncWrite::poll_flush(Pin::new(&mut *self.0), cx)
}
fn poll_shutdown(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<io::Result<()>> {
AsyncWrite::poll_shutdown(Pin::new(&mut *self.0), cx)
}
}
let mut writer = vec![];
let transport =
Transport::from((TestIo(&mut writer), SymmetricalJson::<String>::default()));
pin_mut!(transport);
assert_matches!(
transport.as_mut().poll_ready(&mut ctx()),
Poll::Ready(Ok(()))
);
assert_matches!(
transport
.as_mut()
.start_send("Test one, check check.".into()),
Ok(())
);
assert_matches!(transport.poll_flush(&mut ctx()), Poll::Ready(Ok(())));
assert_eq!(writer, b"\x00\x00\x00\x18\"Test one, check check.\"");
}
}

View File

@@ -26,11 +26,8 @@ use std::{
///
/// Consists of a span identifying an event, an optional parent span identifying a causal event
/// that triggered the current span, and a trace with which all related spans are associated.
#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)]
#[cfg_attr(
feature = "serde",
derive(serde::Serialize, serde::Deserialize)
)]
#[derive(Debug, Default, PartialEq, Eq, Hash, Clone, Copy)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct Context {
/// An identifier of the trace associated with the current context. A trace ID is typically
/// created at a root span and passed along through all causal events.
@@ -49,19 +46,13 @@ pub struct Context {
/// A 128-bit UUID identifying a trace. All spans caused by the same originating span share the
/// same trace ID.
#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)]
#[cfg_attr(
feature = "serde",
derive(serde::Serialize, serde::Deserialize)
)]
#[derive(Debug, Default, PartialEq, Eq, Hash, Clone, Copy)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct TraceId(u128);
/// A 64-bit identifier of a span within a trace. The identifier is unique within the span's trace.
#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)]
#[cfg_attr(
feature = "serde",
derive(serde::Serialize, serde::Deserialize)
)]
#[derive(Debug, Default, PartialEq, Eq, Hash, Clone, Copy)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct SpanId(u64);
impl Context {
@@ -80,7 +71,7 @@ impl TraceId {
/// Returns a random trace ID that can be assumed to be globally unique if `rng` generates
/// actually-random numbers.
pub fn random<R: Rng>(rng: &mut R) -> Self {
TraceId((rng.next_u64() as u128) << mem::size_of::<u64>() | rng.next_u64() as u128)
TraceId(u128::from(rng.next_u64()) << mem::size_of::<u64>() | u128::from(rng.next_u64()))
}
}

View File

@@ -1,131 +0,0 @@
// Copyright 2018 Google LLC
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
#![feature(
test,
arbitrary_self_types,
pin,
integer_atomics,
futures_api,
generators,
await_macro,
async_await,
proc_macro_hygiene,
)]
extern crate test;
use self::test::stats::Stats;
use futures::{compat::TokioDefaultSpawner, future, prelude::*};
use rpc::{
client, context,
server::{self, Handler, Server},
};
use std::{
io,
time::{Duration, Instant},
};
mod ack {
tarpc::service! {
rpc ack();
}
}
#[derive(Clone)]
struct Serve;
impl ack::Service for Serve {
type AckFut = future::Ready<()>;
fn ack(&self, _: context::Context) -> Self::AckFut {
future::ready(())
}
}
async fn bench() -> io::Result<()> {
let listener = bincode_transport::listen(&"0.0.0.0:0".parse().unwrap())?;
let addr = listener.local_addr();
tokio_executor::spawn(
Server::new(server::Config::default())
.incoming(listener)
.take(1)
.respond_with(ack::serve(Serve))
.unit_error()
.boxed()
.compat()
);
let conn = await!(bincode_transport::connect(&addr))?;
let mut client = await!(ack::new_stub(client::Config::default(), conn))?;
let total = 10_000usize;
let mut successful = 0u32;
let mut unsuccessful = 0u32;
let mut durations = vec![];
for _ in 1..=total {
let now = Instant::now();
let response = await!(client.ack(context::current()));
let elapsed = now.elapsed();
match response {
Ok(_) => successful += 1,
Err(_) => unsuccessful += 1,
};
durations.push(elapsed);
}
let durations_nanos = durations
.iter()
.map(|duration| duration.as_secs() as f64 * 1E9 + duration.subsec_nanos() as f64)
.collect::<Vec<_>>();
let (lower, median, upper) = durations_nanos.quartiles();
println!("Of {:?} runs:", durations_nanos.len());
println!("\tSuccessful: {:?}", successful);
println!("\tUnsuccessful: {:?}", unsuccessful);
println!(
"\tMean: {:?}",
Duration::from_nanos(durations_nanos.mean() as u64)
);
println!("\tMedian: {:?}", Duration::from_nanos(median as u64));
println!(
"\tStd Dev: {:?}",
Duration::from_nanos(durations_nanos.std_dev() as u64)
);
println!(
"\tMin: {:?}",
Duration::from_nanos(durations_nanos.min() as u64)
);
println!(
"\tMax: {:?}",
Duration::from_nanos(durations_nanos.max() as u64)
);
println!(
"\tQuartiles: ({:?}, {:?}, {:?})",
Duration::from_nanos(lower as u64),
Duration::from_nanos(median as u64),
Duration::from_nanos(upper as u64)
);
println!("done");
Ok(())
}
#[test]
fn bench_small_packet() {
env_logger::init();
tarpc::init(TokioDefaultSpawner);
tokio::run(
bench()
.map_err(|e| panic!(e.to_string()))
.boxed()
.compat(),
)
}

View File

@@ -0,0 +1,112 @@
use assert_matches::assert_matches;
use futures::{
future::{ready, Ready},
prelude::*,
};
use std::io;
use tarpc::{
client::{self},
context, serde_transport,
server::{self, BaseChannel, Channel, Handler},
transport::channel,
};
use tokio_serde::formats::Json;
#[tarpc_plugins::service]
trait Service {
async fn add(x: i32, y: i32) -> i32;
async fn hey(name: String) -> String;
}
#[derive(Clone)]
struct Server;
impl Service for Server {
type AddFut = Ready<i32>;
fn add(self, _: context::Context, x: i32, y: i32) -> Self::AddFut {
ready(x + y)
}
type HeyFut = Ready<String>;
fn hey(self, _: context::Context, name: String) -> Self::HeyFut {
ready(format!("Hey, {}.", name))
}
}
#[tokio::test(threaded_scheduler)]
async fn sequential() -> io::Result<()> {
let _ = env_logger::try_init();
let (tx, rx) = channel::unbounded();
tokio::spawn(
BaseChannel::new(server::Config::default(), rx)
.respond_with(Server.serve())
.execute(),
);
let mut client = ServiceClient::new(client::Config::default(), tx).spawn()?;
assert_matches!(client.add(context::current(), 1, 2).await, Ok(3));
assert_matches!(
client.hey(context::current(), "Tim".into()).await,
Ok(ref s) if s == "Hey, Tim.");
Ok(())
}
#[cfg(feature = "serde1")]
#[tokio::test(threaded_scheduler)]
async fn serde() -> io::Result<()> {
let _ = env_logger::try_init();
let transport = serde_transport::tcp::listen("localhost:56789", Json::default).await?;
let addr = transport.local_addr();
tokio::spawn(
tarpc::Server::default()
.incoming(transport.take(1).filter_map(|r| async { r.ok() }))
.respond_with(Server.serve()),
);
let transport = serde_transport::tcp::connect(addr, Json::default()).await?;
let mut client = ServiceClient::new(client::Config::default(), transport).spawn()?;
assert_matches!(client.add(context::current(), 1, 2).await, Ok(3));
assert_matches!(
client.hey(context::current(), "Tim".to_string()).await,
Ok(ref s) if s == "Hey, Tim."
);
Ok(())
}
#[tokio::test(threaded_scheduler)]
async fn concurrent() -> io::Result<()> {
let _ = env_logger::try_init();
let (tx, rx) = channel::unbounded();
tokio::spawn(
tarpc::Server::default()
.incoming(stream::once(ready(rx)))
.respond_with(Server.serve()),
);
let client = ServiceClient::new(client::Config::default(), tx).spawn()?;
let mut c = client.clone();
let req1 = c.add(context::current(), 1, 2);
let mut c = client.clone();
let req2 = c.add(context::current(), 3, 4);
let mut c = client.clone();
let req3 = c.hey(context::current(), "Tim".to_string());
assert_matches!(req1.await, Ok(3));
assert_matches!(req2.await, Ok(7));
assert_matches!(req3.await, Ok(ref s) if s == "Hey, Tim.");
Ok(())
}

View File

@@ -1,21 +0,0 @@
[package]
name = "tarpc-trace"
version = "0.1.0"
authors = ["tikue <tikue@google.com>"]
edition = '2018'
license = "MIT"
documentation = "https://docs.rs/tarpc-trace"
homepage = "https://github.com/google/tarpc"
repository = "https://github.com/google/tarpc"
keywords = ["rpc", "network", "server", "api", "tls"]
categories = ["asynchronous", "network-programming"]
readme = "../README.md"
description = "foundations for tracing in tarpc"
[dependencies]
rand = "0.5"
[dependencies.serde]
version = "1.0"
optional = true
features = ["derive"]

View File

@@ -1 +0,0 @@
edition = "Edition2018"