76 Commits

Author SHA1 Message Date
Tim Kuehn
bed85e2827 Prepare release of v0.32.0 2023-03-24 15:04:06 -07:00
Bruno
93f3880025 Return transport errors to the caller (#399)
* Make client::InFlightRequests generic over result.

Previously, InFlightRequests required the client response type to be a
server response. However, this prevented injection of non-server
responses: for example, if the client fails to send a request, it should
complete the request with an IO error rather than a server error.

* Gracefully handle client-side send errors.

Previously, a client channel would immediately disconnect when
encountering an error in Transport::try_send. One kind of error that can
occur in try_send is message validation, e.g. validating a message is
not larger than a configured frame size. The problem with shutting down
the client immediately is that debuggability suffers: it can be hard to
understand what caused the client to fail. Also, these errors are not
always fatal, as with frame size limits, so complete shutdown was
extreme.

By bubbling up errors, it's now possible for the caller to
programmatically handle them. For example, the error could be walked
via anyhow::Error:

```
    2023-01-10T02:49:32.528939Z  WARN client: the client failed to send the request

    Caused by:
        0: could not write to the transport
        1: frame size too big
```

* Some follow-up work: right now, read errors will bubble up to all pending RPCs. However, on the write side, only `start_send` bubbles up. `poll_ready`, `poll_flush`, and `poll_close` do not propagate back to pending RPCs. This is probably okay in most circumstances, because fatal write errors likely coincide with fatal read errors, which *do* propagate back to clients. But it might still be worth unifying this logic.

---------

Co-authored-by: Tim Kuehn <tikue@google.com>
2023-03-24 14:31:25 -07:00
cguentherTUChemnitz
878f594d5b Feature/tls over tcp example (#398)
Example tarpc service that encodes messages with bincode written to a TLS-over-TCP transport.

Certs were generated with openssl 3 using https://github.com/rustls/rustls/tree/main/test-ca

New dependencies:
- tokio-rustls to set up the TLS connections
- rustls-pemfile to load certs from .pem files
2023-03-22 10:35:21 -07:00
Tim Kuehn
aa9bbad109 Fix compile_fail tests for formatting changes on stable 2023-03-17 10:07:54 -07:00
Tim Kuehn
7e872ce925 Remove bad mem::forget usage.
mem::forget is a dangerous tool, and it was being used carelessly for
things that have safer alternatives. There was at least one bug where a
cloned tokio::sync::mpsc::UnboundedSender used for request cancellation
was being leaked on every successful server response, so its refcounts
were never decremented. Because these are atomic refcounts, they'll wrap
around rather than overflow when reaching the maximum value, so I don't
believe this could lead to panics or unsoundness.
2022-11-23 18:01:12 -08:00
Tim Kuehn
62541b709d Replace actions-rs 2022-11-23 16:39:29 -08:00
Tim Kuehn
8c43f94fb6 Remove unused Sealed trait 2022-11-17 00:57:56 -08:00
Tim Kuehn
7fa4e5064d Ignore clippy false positive 2022-11-13 00:25:07 -08:00
Tim Kuehn
94db7610bb Require a static lifetime for request_name. 2022-11-05 11:43:03 -07:00
Tim Kuehn
0c08d5e8ca Prepare release of v0.31.0 2022-11-03 13:29:46 -07:00
Tim Kuehn
75b15fe2aa Address clippy lint 2022-10-07 10:51:45 -07:00
Tim Kuehn
863a08d87e In example-service, print the port the server is listened on.
This is helpful when passing starting the server with --port 0.
2022-10-06 20:58:54 -07:00
Tim Kuehn
49ba8f8b1b Zero-pad the random number suffix of TempPathBufs.
This way, the hex number is always 16 digits, which is helpful for test
verification as well as simple consistency.
2022-10-03 18:50:50 -07:00
Kevin K
d832209da3 feat: Unix domain sockets with serde transports (#380)
* adds support for Unix Domain Socket generic transports
* adds a TempPathBuf that lives in temp and is removed on drop
2022-10-03 18:07:29 -07:00
royrustdev
584426d414 fix clippy warnings #378 2022-09-19 23:26:07 -07:00
royrustdev
50eb80c883 reference latest tarpc version in readme 2022-09-19 21:58:21 -07:00
royrustdev
1f0c80d8c9 bump github actions 2022-09-15 11:17:58 -07:00
Tim Kuehn
99bf3e62a3 Prepare release of 0.30.0 2022-08-12 16:08:33 -07:00
Tim Kuehn
68863e3db0 Remove Channel::request_cancellation.
This trait fn returns a private type, which means it's useless for
anyone using the Channel.

Instead, add an inert (now-public) ResponseGuard to TrackedRequest that,
when taken out of the ManuallyDrop, ensures a Channel's request state is
cleaned up. It's preferable to make ResponseGuard public instead of
RequestCancellations because it's a smaller API surface (no public
methods, just a Drop fn) and harder to misuse, because it is already
associated with the correct request ID to cancel.
2022-08-12 16:08:33 -07:00
Tim Kuehn
453ba1c074 Lower log level of log in the RPC callpath 2022-08-12 09:04:47 -07:00
Makro
e3eac1b4f5 Add LICENSE files to crates (#372) 2022-08-10 17:11:50 -07:00
kkharji
0e102288a5 feat: re-export used packages (#371)
## Problem
Library users might get stuck with or ran into issues while using tarpc because of incompatible third party libraries. in particular, tokio_serde and tokio_util.

## Solution
This PR does the following:

1. re-export tokio_serde as part of feature serde-transport, because the end user imports it to use some serde-transport APIs.
2. Update third library packages to latest release and fix resulting issues from that.

## Important Notes
tokio_util 7.3 DelayQueue::poll_expired API changed [0] therefore, InFlightRequests::poll_expired now returns Poll<Option<u64>>

[0] https://docs.rs/tokio-util/latest/tokio_util/time/delay_queue/struct.DelayQueue.html#method.poll_expired
2022-07-15 10:14:49 -07:00
Tim Kuehn
4c8ba41b2f #[allow(unstable_name_collisions)] for .ready()
.ready() is being added to std, but in the meantime, I don't want to stop using PollTest::ready.
2022-06-07 01:29:14 -07:00
Tim Kuehn
946c627579 Remove unused field 2022-06-07 01:29:14 -07:00
Tim Kuehn
104dd71bba Clean up Channel request data more reliably.
When an InFlightRequest is dropped before response completion, request
data in the Channel persists until either the request expires or the
client cancels the request. In rare cases, requests with very large
deadlines could clog up the Channel long after request processing
ceases.

This commit adds a drop hook to InFlightRequest so that if it is dropped
before execution completes, a cancellation message is sent to the
Channel so that it can clean up the associated request data.

This only works for when using `InFlightRequest::execute` or
`Channel::execute`. However, users of raw `Channel` have access
to the `RequestCancellation` handle via `Channel::request_cancellation`,
so they can implement a similar method if they wish to manually clean up
request data.

Note that once a Channel's request data is cleaned up, that request can
never be responded to, even if a response is produced afterward.

Fixes https://github.com/google/tarpc/issues/314
2022-06-07 01:29:04 -07:00
Tim Kuehn
012c481861 Move cancellation types into a dedicated module.
Cancellation utilities could be useful for both client and server code.
2022-06-05 18:54:52 -07:00
Tim Kuehn
dc12bd09aa Annotate types that impl Future with #[must_use].
These types do nothing unless polled / .awaited.
Annotating them with #[must_use] helps prevent a common class of coding errors.

Fixes https://github.com/google/tarpc/issues/368.
2022-06-05 18:54:52 -07:00
Tim Kuehn
2594ea8ce9 Prepare release of 0.29.0 2022-06-05 15:26:33 -07:00
Tim Kuehn
839b87e394 Serialize RPC deadline as a Duration.
Duration was previously serialized as SystemTime. However, absolute
times run into problems with clock skew: if the remote machine's clock
is too far in the future, the RPC deadline will be exceeded before
request processing can begin. Conversely, if the remote machine's clock
is too far in the past, the RPC deadline will not be enforced.

By converting the absolute deadline to a relative duration, clock skew
is no longer relevant, as the remote machine will convert the deadline
into a time relative to its own clock. This mirrors how the gRPC HTTP2
protocol includes a Timeout in the request headers [0] but the SDK uses
timestamps [1]. Keeping the absolute time in the core APIs maintains all
the benefits of today, namely, natural deadline propagation between RPC
hops when using the current context.

This serialization strategy means that, generally, the remote machine's
deadline will be slightly in the future compared to the local machine.
Depending on network transfer latencies, this could be microseconds to
milliseconds, or worse in the worst case. Because the deadline is not
intended for high-precision scenarios, I don't view this is as
problematic.

Because this change only affects the serialization layer, local
transports that bypass serialization are not affected.

[0] https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md
[1] https://grpc.io/blog/deadlines/#setting-a-deadline
2022-05-26 15:18:49 -07:00
Tim Kuehn
57d0638a99 Add rpc.deadline tag to Opentelemetry traces. 2022-05-26 15:18:49 -07:00
Tim Kuehn
a3a6404a30 Prepare release of 0.28.0 2022-04-06 22:07:07 -07:00
Tim Kuehn
b36eac80b1 Bump minimum rust version to 1.58.0 2022-04-06 21:53:56 -07:00
Bruno
d7070e4bc3 Update opentelemetry and related dependencies (#362) 2022-04-03 14:09:14 -07:00
Tim Kuehn
b5d1828308 Use captured identifiers in format strings.
This was stabilized in Rust 1.58.0: https://blog.rust-lang.org/2022/01/13/Rust-1.58.0.html
2022-01-13 15:00:44 -08:00
Zak Cutner
92cfe63c4f Use single-threaded Tokio runtime (#360) 2022-01-06 20:59:51 -08:00
Tim Kuehn
839a2f067c Update example to latest version of Clap 2021-12-27 22:56:17 -08:00
David Kleingeld
b5d593488c Derive more traits for RpcError (#359)
Makes RpcError derive Clone, PartialEq, Eq, Hash, Serialize and Deserialize.
2021-12-27 22:00:58 -08:00
Tim Kuehn
eea38b8bf4 Simplify code with const assert!.
The code that prevents compilation on systems where usize is larger than
u64 previously used a const index-out-of-bounds trick. That code can now
be replaced with assert!, as const panic! has landed in 1.57.0 stable.
2021-12-03 15:20:33 -08:00
Shi Yan
70493c15f4 Fix a compiling issue of the official example (#358)
Fix a compiling issue of the official example because of the following error :

```
error[E0599]: the method `execute` exists for struct `BaseChannel<_, _, UnboundedChannel<ClientMessage<_>, Response<_>>>`, but its trait bounds were not satisfied
  --> src/main.rs:39:25
   |
39 |     tokio::spawn(server.execute(HelloServer.serve()));
   |                         ^^^^^^^ method cannot be called on `BaseChannel<_, _, UnboundedChannel<ClientMessage<_>, Response<_>>>` due to unsatisfied trait bounds
   |
   = note: the following trait bounds were not satisfied:
           `<&BaseChannel<_, _, UnboundedChannel<ClientMessage<_>, Response<_>>> as futures::Stream>::Item = _`
           which is required by `&BaseChannel<_, _, UnboundedChannel<ClientMessage<_>, Response<_>>>: tarpc::server::incoming::Incoming<_>`
           `&BaseChannel<_, _, UnboundedChannel<ClientMessage<_>, Response<_>>>: futures::Stream`
           which is required by `&BaseChannel<_, _, UnboundedChannel<ClientMessage<_>, Response<_>>>: tarpc::server::incoming::Incoming<_>`
   = help: items from traits can only be used if the trait is in scope
help: the following trait is implemented but not in scope; perhaps add a `use` for it:
   |
1  | use tarpc::server::Channel;
   |
```

See https://github.com/google/tarpc/pull/358#issuecomment-981953193 for the root cause.
2021-11-29 17:01:16 -08:00
baptiste0928
f7c5d6a7c3 Fix example-service (#355)
Fixes the compilation of the example-service crate (the Clap trait has been renamed Parser in clap-rs/clap@d840d56).
2021-11-15 08:39:04 -08:00
Scott Kirkpatrick
98c5d2a18b Re-add typo fixes (#353)
The typo fixes that were added by commit b5d9aaa
were accidentally reverted by commit 1e680e3, this
will add them back
2021-11-08 10:07:21 -08:00
Tim Kuehn
46b534f7c6 Use HashMap::shrink_to in impl of Comapct::compact. 2021-10-21 17:03:57 -07:00
Tim Kuehn
42b4fc52b1 Set rust-version to 1.56 2021-10-21 16:08:15 -07:00
Tim Kuehn
350dbcdad0 Upgrade to Rust 2021! 2021-10-21 14:10:21 -07:00
Tim Kuehn
b1b4461d89 Prepare release of 0.27.2 2021-10-08 22:31:56 -07:00
Tim Kuehn
f694b7573a Close TcpStream when client disconnects.
An attempt at a clean shutdown helps the server to drop its connections
more quickly.

Testing this uncovered a latent bug in DelayQueue wherein `poll_expired`
yields `Pending` when empty. A workaround was added to
`InFlightRequests::poll_expired`: check if there are actually any
outstanding requests before calling `DelayQueue::poll_expired`.
2021-10-08 22:13:24 -07:00
Tim Kuehn
1e680e3a5a Fix typos in docs.
Fixes https://github.com/google/tarpc/issues/352.
2021-10-08 19:19:50 -07:00
Tim Kuehn
2591d21e94 Update release notes to mention io::Error = 2021-09-23 13:57:43 -07:00
Tim Kuehn
6632f68d95 Prepare for 0.27 release 2021-09-22 15:41:34 -07:00
Dmitry Kakurin
25985ad56a Update README.md (#350)
Fixed 2 typos
2021-09-01 17:58:49 -07:00
Tim Kuehn
d6a24e9420 Address Clippy lint 2021-08-24 12:40:18 -07:00
Tim Kuehn
281a78f3c7 Add tokio-serde-bincode feature 2021-08-24 12:37:57 -07:00
Julian Tescher
a0787d0091 Update to opentelemetry 0.16.x (#349) 2021-08-17 00:00:07 -04:00
Frederik-Baetens
d2acba0e8a add serde-transport-json feature flag (#346)
In general, it should be possible to use, or at least import all functionality of a library, when having only that library in your cargo.toml.
2021-05-06 08:41:57 -07:00
Tim Kuehn
ea7b6763c4 Refactor server module.
In the interest of the user's attention, some ancillary APIs have been
moved to new submodules:

- server::limits contains what was previously called Throttler and
  ChannelFilter. Both of those names were very generic, when the methods
  applied by these types were very specific (and also simplistic). Renames
  have occurred:
  - ThrottlerStream => MaxRequestsPerChannel
  - Throttler => MaxRequests
  - ChannelFilter => MaxChannelsPerKey
- server::incoming contains the Incoming trait.
- server::tokio contains the tokio-specific helper types.

The 5 structs and 1 enum remaining in the base server module are all
core to the functioning of the server.
2021-04-21 17:05:49 -07:00
Tim Kuehn
eb67c540b9 Use more structured errors in client. 2021-04-21 14:54:45 -07:00
Tim Kuehn
4151d0abd3 Move Span creation into BaseChannel.
It's important for Channel decorators, like Throttler, to have access to
the Span. This means that the BaseChannel becomes responsible for
starting its own requests. Actually, this simplifies the integration for
the Channel users, as they can assume any yielded requests are already
tracked.

This entails the following breaking changes:

- removed trait method Channel::start_request as it is now done
  internally.
2021-04-21 14:54:45 -07:00
Tim Kuehn
d0c11a6efa Change RPC error type from io::Error => RpcError.
Becaue tarpc is a library, not an application, it should strive to
use structured errors in its API so that users have maximal flexibility
in how they handle errors. io::Error makes that hard, because it is a
kitchen-sink error type.

RPCs in particular only have 3 classes of errors:

- The connection breaks.
- The request expires.
- The server decides not to process the request.

(Of course, RPCs themselves can have application-specific errors, but
from the perspective of the RPC library, those can be classified as
successful responsees).
2021-04-20 18:29:55 -07:00
Tim Kuehn
82c4da1743 Prepare release of v0.26.2 2021-04-20 11:28:15 -07:00
Tim Kuehn
0a15e0b75c Rustdoc: link RPC futures to their methods. 2021-04-20 11:25:26 -07:00
Tim Kuehn
0b315c29bf It's not currently possible to document the enum variants, which means
projects that #[deny(missing_docs)] wouldn't compile if using tarpc
services.
2021-04-20 09:01:39 -07:00
Tim Kuehn
56f09bf61f Fix log that's split across lines. 2021-04-17 17:15:16 -07:00
Tim Kuehn
6d82e82419 Fix formatting 2021-04-16 16:51:21 -07:00
Tim Kuehn
9bebaf814a Address clippy lint 2021-04-14 17:49:27 -07:00
Tim Kuehn
5f4d6e6008 Prepare release of v0.26.0 2021-04-14 17:08:44 -07:00
Tim Kuehn
07d07d7ba3 Remove tracing_appender as it does not support build target mipsel-unknown-linux-gnu 2021-04-01 19:37:02 -07:00
Tim Kuehn
a41bbf65b2 Use rustfmt instead of cargo fmt so that diff is only printed once 2021-04-01 17:24:34 -07:00
Tim Kuehn
21e2f7ca62 Tear out requirement that Transport's error type is io::Error. 2021-04-01 17:24:34 -07:00
Tim Kuehn
7b7c182411 Instrument tarpc with tracing.
tarpc is now instrumented with tracing primitives extended with
OpenTelemetry traces. Using a compatible tracing-opentelemetry
subscriber like Jaeger, each RPC can be traced through the client,
server, amd other dependencies downstream of the server. Even for
applications not connected to a distributed tracing collector, the
instrumentation can also be ingested by regular loggers like env_logger.

 # Breaking Changes

 ## Logging

Logged events are now structured using tracing. For applications using a
logger and not a tracing subscriber, these logs may look different or
contain information in a less consumable manner. The easiest solution is
to add a tracing subscriber that logs to stdout, such as
tracing_subscriber::fmt.

 ##  Context

- Context no longer has parent_span, which was actually never needed,
  because the context sent in an RPC is inherently the parent context.
  For purposes of distributed tracing, the client side of the RPC has all
  necessary information to link the span to its parent; the server side
  need do nothing more than export the (trace ID, span ID) tuple.
- Context has a new field, SamplingDecision, which has two variants,
  Sampled and Unsampled. This field can be used by downstream systems to
  determine whether a trace needs to be exported. If the parent span is
  sampled, the expectation is that all child spans be exported, as well;
  to do otherwise could result in lossy traces being exported. Note that
  if an Openetelemetry tracing subscriber is not installed, the fallback
  context will still be used, but the Context's sampling decision will
  always be inherited by the parent Context's sampling decision.
- Context::scope has been removed. Context propagation is now done via
  tracing's task-local spans. Spans can be propagated across tasks via
  Span::in_scope. When a service receives a request, it attaches an
  Opentelemetry context to the local Span created before request handling,
  and this context contains the request deadline. This span-local deadline
  is retrieved by Context::current, but it cannot be modified so that
  future Context::current calls contain a different deadline. However, the
  deadline in the context passed into an RPC call will override it, so
  users can retrieve the current context and then modify the deadline
  field, as has been historically possible.
- Context propgation precedence changes: when an RPC is initiated, the
  current Span's Opentelemetry context takes precedence over the trace
  context passed into the RPC method. If there is no current Span, then
  the trace context argument is used as it has been historically. Note
  that Opentelemetry context propagation requires an Opentelemetry
  tracing subscriber to be installed.

 ## Server

- The server::Channel trait now has an additional required associated
  type and method which returns the underlying transport. This makes it
  more ergonomic for users to retrieve transport-specific information,
  like IP Address. BaseChannel implements Channel::transport by returning
  the underlying transport, and channel decorators like Throttler just
  delegate to the Channel::transport method of the wrapped channel.

 # References

[1] https://github.com/tokio-rs/tracing
[2] https://opentelemetry.io
[3] https://github.com/open-telemetry/opentelemetry-rust/tree/main/opentelemetry-jaeger
[4] https://github.com/env-logger-rs/env_logger
2021-04-01 17:24:34 -07:00
Ben Ludewig
db0c778ead Serialize u128 TraceId as LE bytes (#344) 2021-03-30 08:41:19 -07:00
Tim Kuehn
c3efb83ac1 Add more context to errors returned by serde transport 2021-03-28 20:03:03 -07:00
Tim Kuehn
3d7b0171fe Fix cargo fmt portion of pre-commit 2021-03-26 19:39:56 -07:00
oblique
c191ff5b2e Do not enable tokio-serde/json by default (#345) 2021-03-26 18:22:44 -07:00
Tim Kuehn
90bc7f741d Fix up imports 2021-03-17 12:44:39 -07:00
Kitsu
d3f6c01df2 Reduce required tokio features (#343)
* Move async tests behind cfg-ed mod
* Use explicit tokio features for the example
* Use only relative crate path for example dependency
2021-03-17 12:30:18 -07:00
Tim Kuehn
c6450521e6 Add method to run a future in the current context.
Previously, `Context::current` would always return a new context. Now,
it uses tokio task-local data to look for the current context. Tokio
task locals are not actually tied to a tokio executor; instead, they
provide data scoped to a future.

The basic pattern is:

```rust
let ctx = Context::new_root();
ctx.scope(async {
    let ctx2 = context::current();
    assert_eq!(ctx2.trace_context.span_id, ctx.trace_context.span_id);
});
```

`server::InFlightRequest::execute` uses `Context::scope` to set the
current context before executing a request, so calls to
`context::current` in request handlers will return the context provided
by the client. This does not propagate to new spawned tasks. To
propagate the client context to child tasks, the following pattern will
work:

```rust
tokio::spawn(context::current().scope(async { /* do work here */ }));
```

This commit also introduces a breaking change to Context serialization.
Previously, the deadline only serialized second-level precision. Now, it
provides full fidelity serialization to the nanosecond.
2021-03-13 16:05:02 -08:00
59 changed files with 3714 additions and 1838 deletions

View File

@@ -14,99 +14,57 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Cancel previous
uses: styfle/cancel-workflow-action@0.7.0
uses: styfle/cancel-workflow-action@0.10.0
with:
access_token: ${{ github.token }}
- uses: actions/checkout@v1
- uses: actions-rs/toolchain@v1
- uses: actions/checkout@v3
- uses: dtolnay/rust-toolchain@stable
with:
profile: minimal
toolchain: stable
target: mipsel-unknown-linux-gnu
override: true
- uses: actions-rs/cargo@v1
with:
command: check
args: --all-features
- uses: actions-rs/cargo@v1
with:
command: check
args: --all-features --target mipsel-unknown-linux-gnu
targets: mipsel-unknown-linux-gnu
- run: cargo check --all-features
- run: cargo check --all-features --target mipsel-unknown-linux-gnu
test:
name: Test Suite
runs-on: ubuntu-latest
steps:
- name: Cancel previous
uses: styfle/cancel-workflow-action@0.7.0
uses: styfle/cancel-workflow-action@0.10.0
with:
access_token: ${{ github.token }}
- uses: actions/checkout@v1
- uses: actions-rs/toolchain@v1
with:
profile: minimal
toolchain: stable
override: true
- uses: actions-rs/cargo@v1
with:
command: test
- uses: actions-rs/cargo@v1
with:
command: test
args: --manifest-path tarpc/Cargo.toml --features serde1
- uses: actions-rs/cargo@v1
with:
command: test
args: --manifest-path tarpc/Cargo.toml --features tokio1
- uses: actions-rs/cargo@v1
with:
command: test
args: --manifest-path tarpc/Cargo.toml --features serde-transport
- uses: actions-rs/cargo@v1
with:
command: test
args: --manifest-path tarpc/Cargo.toml --features tcp
- uses: actions-rs/cargo@v1
with:
command: test
args: --all-features
- uses: actions/checkout@v3
- uses: dtolnay/rust-toolchain@stable
- run: cargo test
- run: cargo test --manifest-path tarpc/Cargo.toml --features serde1
- run: cargo test --manifest-path tarpc/Cargo.toml --features tokio1
- run: cargo test --manifest-path tarpc/Cargo.toml --features serde-transport
- run: cargo test --manifest-path tarpc/Cargo.toml --features tcp
- run: cargo test --all-features
fmt:
name: Rustfmt
runs-on: ubuntu-latest
steps:
- name: Cancel previous
uses: styfle/cancel-workflow-action@0.7.0
uses: styfle/cancel-workflow-action@0.10.0
with:
access_token: ${{ github.token }}
- uses: actions/checkout@v1
- uses: actions-rs/toolchain@v1
- uses: actions/checkout@v3
- uses: dtolnay/rust-toolchain@stable
with:
profile: minimal
toolchain: stable
override: true
- run: rustup component add rustfmt
- uses: actions-rs/cargo@v1
with:
command: fmt
args: --all -- --check
components: rustfmt
- run: cargo fmt --all -- --check
clippy:
name: Clippy
runs-on: ubuntu-latest
steps:
- name: Cancel previous
uses: styfle/cancel-workflow-action@0.7.0
uses: styfle/cancel-workflow-action@0.10.0
with:
access_token: ${{ github.token }}
- uses: actions/checkout@v1
- uses: actions-rs/toolchain@v1
- uses: actions/checkout@v3
- uses: dtolnay/rust-toolchain@stable
with:
profile: minimal
toolchain: stable
override: true
- run: rustup component add clippy
- uses: actions-rs/cargo@v1
with:
command: clippy
args: --all-features -- -D warnings
components: clippy
- run: cargo clippy --all-features -- -D warnings

View File

@@ -1,7 +1,11 @@
[workspace]
resolver = "2"
members = [
"example-service",
"tarpc",
"plugins",
]
[profile.dev]
split-debuginfo = "unpacked"

View File

@@ -40,7 +40,7 @@ rather than in a separate language such as .proto. This means there's no separat
process, and no context switching between different languages.
Some other features of tarpc:
- Pluggable transport: any type impling `Stream<Item = Request> + Sink<Response>` can be
- Pluggable transport: any type implementing `Stream<Item = Request> + Sink<Response>` can be
used as a transport to connect the client and server.
- `Send + 'static` optional: if the transport doesn't require it, neither does tarpc!
- Cascading cancellation: dropping a request will send a cancellation message to the server.
@@ -51,6 +51,14 @@ Some other features of tarpc:
requests sent by the server that use the request context will propagate the request deadline.
For example, if a server is handling a request with a 10s deadline, does 2s of work, then
sends a request to another server, that server will see an 8s deadline.
- Distributed tracing: tarpc is instrumented with
[tracing](https://github.com/tokio-rs/tracing) primitives extended with
[OpenTelemetry](https://opentelemetry.io/) traces. Using a compatible tracing subscriber like
[Jaeger](https://github.com/open-telemetry/opentelemetry-rust/tree/main/opentelemetry-jaeger),
each RPC can be traced through the client, server, and other dependencies downstream of the
server. Even for applications not connected to a distributed tracing collector, the
instrumentation can also be ingested by regular loggers like
[env_logger](https://github.com/env-logger-rs/env_logger/).
- Serde serialization: enabling the `serde1` Cargo feature will make service requests and
responses `Serialize + Deserialize`. It's entirely optional, though: in-memory transports can
be used, as well, so the price of serialization doesn't have to be paid when it's not needed.
@@ -59,7 +67,7 @@ Some other features of tarpc:
Add to your `Cargo.toml` dependencies:
```toml
tarpc = "0.25"
tarpc = "0.32"
```
The `tarpc::service` attribute expands to a collection of items that form an rpc service.
@@ -72,8 +80,9 @@ This example uses [tokio](https://tokio.rs), so add the following dependencies t
your `Cargo.toml`:
```toml
futures = "1.0"
tarpc = { version = "0.25", features = ["tokio1"] }
anyhow = "1.0"
futures = "0.3"
tarpc = { version = "0.31", features = ["tokio1"] }
tokio = { version = "1.0", features = ["macros"] }
```
@@ -91,9 +100,8 @@ use futures::{
};
use tarpc::{
client, context,
server::{self, Incoming},
server::{self, incoming::Incoming, Channel},
};
use std::io;
// This is the service definition. It looks a lot like a trait definition.
// It defines one RPC, hello, which takes one arg, name, and returns a String.
@@ -120,7 +128,7 @@ impl World for HelloServer {
type HelloFut = Ready<String>;
fn hello(self, _: context::Context, name: String) -> Self::HelloFut {
future::ready(format!("Hello, {}!", name))
future::ready(format!("Hello, {name}!"))
}
}
```
@@ -132,7 +140,7 @@ available behind the `tcp` feature.
```rust
#[tokio::main]
async fn main() -> io::Result<()> {
async fn main() -> anyhow::Result<()> {
let (client_transport, server_transport) = tarpc::transport::channel::unbounded();
let server = server::BaseChannel::with_defaults(server_transport);
@@ -140,14 +148,14 @@ async fn main() -> io::Result<()> {
// WorldClient is generated by the #[tarpc::service] attribute. It has a constructor `new`
// that takes a config and any Transport as input.
let mut client = WorldClient::new(client::Config::default(), client_transport).spawn()?;
let mut client = WorldClient::new(client::Config::default(), client_transport).spawn();
// The client has an RPC method for each RPC defined in the annotated trait. It takes the same
// args as defined, with the addition of a Context, which is always the first arg. The Context
// specifies a deadline and trace information which can be helpful in debugging requests.
let hello = client.hello(context::current(), "Stim".to_string()).await?;
println!("{}", hello);
println!("{hello}");
Ok(())
}

View File

@@ -1,3 +1,177 @@
## 0.32.0 (2023-03-24)
### Breaking Changes
- As part of a fix to return more channel errors in RPC results, a few error types have changed:
0. `client::RpcError::Disconnected` was split into the following errors:
- Shutdown: the client was shutdown, either intentionally or due to an error. If due to an
error, pending RPCs should see the more specific errors below.
- Send: an RPC message failed to send over the transport. Only the RPC that failed to be sent
will see this error.
- Receive: a fatal error occurred while receiving from the transport. All in-flight RPCs will
receive this error.
0. `client::ChannelError` and `server::ChannelError` are unified in `tarpc::ChannelError`.
Previously, server transport errors would not indicate during which activity the transport
error occurred. Now, just like the client already was, it will be specific: reading, readying,
sending, flushing, or closing.
## 0.31.0 (2022-11-03)
### New Features
This release adds Unix Domain Sockets to the `serde_transport` module.
To use it, enable the "unix" feature. See the docs for more information.
## 0.30.0 (2022-08-12)
### Breaking Changes
- Some types that impl Future are now annotated with `#[must_use]`. Code that previously created
these types but did not use them will now receive a warning. Code that disallows warnings will
receive a compilation error.
### Fixes
- Servers will more reliably clean up request state for requests with long deadlines when response
processing is aborted without sending a response.
### Other Changes
- `TrackedRequest` now contains a response guard that can be used to ensure state cleanup for
aborted requests. (This was already handled automatically by `InFlightRequests`).
- When the feature serde-transport is enabled, the crate tokio_serde is now re-exported.
## 0.29.0 (2022-05-26)
### Breaking Changes
`Context.deadline` is now serialized as a Duration. This prevents clock skew from affecting deadline
behavior. For more details see https://github.com/google/tarpc/pull/367 and its [related
issue](https://github.com/google/tarpc/issues/366).
## 0.28.0 (2022-04-06)
### Breaking Changes
- The minimum supported Rust version has increased to 1.58.0.
- The version of opentelemetry depended on by tarpc has increased to 0.17.0.
## 0.27.2 (2021-10-08)
### Fixes
Clients will now close their transport before dropping it. An attempt at a clean shutdown can help
the server drop its connections more quickly.
## 0.27.1 (2021-09-22)
### Breaking Changes
#### RPC error type is changing
RPC return types are changing from `Result<Response, io::Error>` to `Result<Response,
tarpc::client::RpcError>`.
Becaue tarpc is a library, not an application, it should strive to
use structured errors in its API so that users have maximal flexibility
in how they handle errors. io::Error makes that hard, because it is a
kitchen-sink error type.
RPCs in particular only have 3 classes of errors:
- The connection breaks.
- The request expires.
- The server decides not to process the request.
RPC responses can also contain application-specific errors, but from the
perspective of the RPC library, those are opaque to the framework, classified
as successful responsees.
### Open Telemetry
The Opentelemetry dependency is updated to version 0.16.x.
## 0.27.0 (2021-09-22)
This version was yanked due to tarpc-plugins version mismatches.
## 0.26.0 (2021-04-14)
### New Features
#### Tracing
tarpc is now instrumented with tracing primitives extended with
OpenTelemetry traces. Using a compatible tracing-opentelemetry
subscriber like Jaeger, each RPC can be traced through the client,
server, amd other dependencies downstream of the server. Even for
applications not connected to a distributed tracing collector, the
instrumentation can also be ingested by regular loggers like env_logger.
### Breaking Changes
#### Logging
Logged events are now structured using tracing. For applications using a
logger and not a tracing subscriber, these logs may look different or
contain information in a less consumable manner. The easiest solution is
to add a tracing subscriber that logs to stdout, such as
tracing_subscriber::fmt.
#### Context
- Context no longer has parent_span, which was actually never needed,
because the context sent in an RPC is inherently the parent context.
For purposes of distributed tracing, the client side of the RPC has all
necessary information to link the span to its parent; the server side
need do nothing more than export the (trace ID, span ID) tuple.
- Context has a new field, SamplingDecision, which has two variants,
Sampled and Unsampled. This field can be used by downstream systems to
determine whether a trace needs to be exported. If the parent span is
sampled, the expectation is that all child spans be exported, as well;
to do otherwise could result in lossy traces being exported. Note that
if an Openetelemetry tracing subscriber is not installed, the fallback
context will still be used, but the Context's sampling decision will
always be inherited by the parent Context's sampling decision.
- Context::scope has been removed. Context propagation is now done via
tracing's task-local spans. Spans can be propagated across tasks via
Span::in_scope. When a service receives a request, it attaches an
Opentelemetry context to the local Span created before request handling,
and this context contains the request deadline. This span-local deadline
is retrieved by Context::current, but it cannot be modified so that
future Context::current calls contain a different deadline. However, the
deadline in the context passed into an RPC call will override it, so
users can retrieve the current context and then modify the deadline
field, as has been historically possible.
- Context propgation precedence changes: when an RPC is initiated, the
current Span's Opentelemetry context takes precedence over the trace
context passed into the RPC method. If there is no current Span, then
the trace context argument is used as it has been historically. Note
that Opentelemetry context propagation requires an Opentelemetry
tracing subscriber to be installed.
#### Server
- The server::Channel trait now has an additional required associated
type and method which returns the underlying transport. This makes it
more ergonomic for users to retrieve transport-specific information,
like IP Address. BaseChannel implements Channel::transport by returning
the underlying transport, and channel decorators like Throttler just
delegate to the Channel::transport method of the wrapped channel.
#### Client
- NewClient::spawn no longer returns a result, as spawn can't fail.
### References
1. https://github.com/tokio-rs/tracing
2. https://opentelemetry.io
3. https://github.com/open-telemetry/opentelemetry-rust/tree/main/opentelemetry-jaeger
4. https://github.com/env-logger-rs/env_logger
## 0.25.0 (2021-03-10)
### Breaking Changes
@@ -170,7 +344,7 @@ nameable futures and will just be boxing the return type anyway. This macro does
### Breaking Changes
- Enums had _non_exhaustive fields replaced with the #[non_exhaustive] attribute.
- Enums had `_non_exhaustive` fields replaced with the #[non_exhaustive] attribute.
### Bug Fixes

View File

@@ -1,8 +1,9 @@
[package]
name = "tarpc-example-service"
version = "0.9.0"
version = "0.14.0"
rust-version = "1.56"
authors = ["Tim Kuehn <tikue@google.com>"]
edition = "2018"
edition = "2021"
license = "MIT"
documentation = "https://docs.rs/tarpc-example-service"
homepage = "https://github.com/google/tarpc"
@@ -13,12 +14,18 @@ readme = "../README.md"
description = "An example server built on tarpc."
[dependencies]
clap = "2.33"
env_logger = "0.8"
anyhow = "1.0"
clap = { version = "3.0.0-rc.9", features = ["derive"] }
log = "0.4"
futures = "0.3"
serde = { version = "1.0" }
tarpc = { version = "0.25", path = "../tarpc", features = ["full"] }
tokio = { version = "1", features = ["full"] }
opentelemetry = { version = "0.17", features = ["rt-tokio"] }
opentelemetry-jaeger = { version = "0.16", features = ["rt-tokio"] }
rand = "0.8"
tarpc = { version = "0.32", path = "../tarpc", features = ["full"] }
tokio = { version = "1", features = ["macros", "net", "rt-multi-thread"] }
tracing = { version = "0.1" }
tracing-opentelemetry = "0.17"
tracing-subscriber = {version = "0.3", features = ["env-filter"]}
[lib]
name = "service"

View File

@@ -4,57 +4,53 @@
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
use clap::{App, Arg};
use std::{io, net::SocketAddr};
use clap::Parser;
use service::{init_tracing, WorldClient};
use std::{net::SocketAddr, time::Duration};
use tarpc::{client, context, tokio_serde::formats::Json};
use tokio::time::sleep;
use tracing::Instrument;
#[derive(Parser)]
struct Flags {
/// Sets the server address to connect to.
#[clap(long)]
server_addr: SocketAddr,
/// Sets the name to say hello to.
#[clap(long)]
name: String,
}
#[tokio::main]
async fn main() -> io::Result<()> {
env_logger::init();
async fn main() -> anyhow::Result<()> {
let flags = Flags::parse();
init_tracing("Tarpc Example Client")?;
let flags = App::new("Hello Client")
.version("0.1")
.author("Tim <tikue@google.com>")
.about("Say hello!")
.arg(
Arg::with_name("server_addr")
.long("server_addr")
.value_name("ADDRESS")
.help("Sets the server address to connect to.")
.required(true)
.takes_value(true),
)
.arg(
Arg::with_name("name")
.short("n")
.long("name")
.value_name("STRING")
.help("Sets the name to say hello to.")
.required(true)
.takes_value(true),
)
.get_matches();
let server_addr = flags.value_of("server_addr").unwrap();
let server_addr = server_addr
.parse::<SocketAddr>()
.unwrap_or_else(|e| panic!(r#"--server_addr value "{}" invalid: {}"#, server_addr, e));
let name = flags.value_of("name").unwrap().into();
let mut transport = tarpc::serde_transport::tcp::connect(server_addr, Json::default);
let mut transport = tarpc::serde_transport::tcp::connect(flags.server_addr, Json::default);
transport.config_mut().max_frame_length(usize::MAX);
// WorldClient is generated by the service attribute. It has a constructor `new` that takes a
// config and any Transport as input.
let client = service::WorldClient::new(client::Config::default(), transport.await?).spawn()?;
let client = WorldClient::new(client::Config::default(), transport.await?).spawn();
// The client has an RPC method for each RPC defined in the annotated trait. It takes the same
// args as defined, with the addition of a Context, which is always the first arg. The Context
// specifies a deadline and trace information which can be helpful in debugging requests.
let hello = client.hello(context::current(), name).await?;
let hello = async move {
// Send the request twice, just to be safe! ;)
tokio::select! {
hello1 = client.hello(context::current(), format!("{}1", flags.name)) => { hello1 }
hello2 = client.hello(context::current(), format!("{}2", flags.name)) => { hello2 }
}
}
.instrument(tracing::info_span!("Two Hellos"))
.await;
println!("{}", hello);
match hello {
Ok(hello) => tracing::info!("{hello:?}"),
Err(e) => tracing::warn!("{:?}", anyhow::Error::from(e)),
}
// Let the background span processor finish.
sleep(Duration::from_micros(1)).await;
opentelemetry::global::shutdown_tracer_provider();
Ok(())
}

View File

@@ -4,6 +4,9 @@
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
use std::env;
use tracing_subscriber::{fmt::format::FmtSpan, prelude::*};
/// This is the service definition. It looks a lot like a trait definition.
/// It defines one RPC, hello, which takes one arg, name, and returns a String.
#[tarpc::service]
@@ -11,3 +14,21 @@ pub trait World {
/// Returns a greeting for name.
async fn hello(name: String) -> String;
}
/// Initializes an OpenTelemetry tracing subscriber with a Jaeger backend.
pub fn init_tracing(service_name: &str) -> anyhow::Result<()> {
env::set_var("OTEL_BSP_MAX_EXPORT_BATCH_SIZE", "12");
let tracer = opentelemetry_jaeger::new_pipeline()
.with_service_name(service_name)
.with_max_packet_size(2usize.pow(13))
.install_batch(opentelemetry::runtime::Tokio)?;
tracing_subscriber::registry()
.with(tracing_subscriber::EnvFilter::from_default_env())
.with(tracing_subscriber::fmt::layer().with_span_events(FmtSpan::NEW | FmtSpan::CLOSE))
.with(tracing_opentelemetry::layer().with_tracer(tracer))
.try_init()?;
Ok(())
}

View File

@@ -4,18 +4,30 @@
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
use clap::{App, Arg};
use clap::Parser;
use futures::{future, prelude::*};
use service::World;
use rand::{
distributions::{Distribution, Uniform},
thread_rng,
};
use service::{init_tracing, World};
use std::{
io,
net::{IpAddr, SocketAddr},
net::{IpAddr, Ipv6Addr, SocketAddr},
time::Duration,
};
use tarpc::{
context,
server::{self, Channel, Incoming},
server::{self, incoming::Incoming, Channel},
tokio_serde::formats::Json,
};
use tokio::time;
#[derive(Parser)]
struct Flags {
/// Sets the port number to listen on.
#[clap(long)]
port: u16,
}
// This is the type that implements the generated World trait. It is the business logic
// and is used to start the server.
@@ -25,51 +37,36 @@ struct HelloServer(SocketAddr);
#[tarpc::server]
impl World for HelloServer {
async fn hello(self, _: context::Context, name: String) -> String {
format!("Hello, {}! You are connected from {:?}.", name, self.0)
let sleep_time =
Duration::from_millis(Uniform::new_inclusive(1, 10).sample(&mut thread_rng()));
time::sleep(sleep_time).await;
format!("Hello, {name}! You are connected from {}", self.0)
}
}
#[tokio::main]
async fn main() -> io::Result<()> {
env_logger::init();
async fn main() -> anyhow::Result<()> {
let flags = Flags::parse();
init_tracing("Tarpc Example Server")?;
let flags = App::new("Hello Server")
.version("0.1")
.author("Tim <tikue@google.com>")
.about("Say hello!")
.arg(
Arg::with_name("port")
.short("p")
.long("port")
.value_name("NUMBER")
.help("Sets the port number to listen on")
.required(true)
.takes_value(true),
)
.get_matches();
let port = flags.value_of("port").unwrap();
let port = port
.parse()
.unwrap_or_else(|e| panic!(r#"--port value "{}" invalid: {}"#, port, e));
let server_addr = (IpAddr::from([0, 0, 0, 0]), port);
let server_addr = (IpAddr::V6(Ipv6Addr::LOCALHOST), flags.port);
// JSON transport is provided by the json_transport tarpc module. It makes it easy
// to start up a serde-powered json serialization strategy over TCP.
let mut listener = tarpc::serde_transport::tcp::listen(&server_addr, Json::default).await?;
tracing::info!("Listening on port {}", listener.local_addr().port());
listener.config_mut().max_frame_length(usize::MAX);
listener
// Ignore accept errors.
.filter_map(|r| future::ready(r.ok()))
.map(server::BaseChannel::with_defaults)
// Limit channels to 1 per IP.
.max_channels_per_key(1, |t| t.as_ref().peer_addr().unwrap().ip())
.max_channels_per_key(1, |t| t.transport().peer_addr().unwrap().ip())
// serve is generated by the service attribute. It takes as input any type implementing
// the generated World trait.
.map(|channel| {
let server = HelloServer(channel.as_ref().as_ref().peer_addr().unwrap());
channel.requests().execute(server.serve())
let server = HelloServer(channel.transport().peer_addr().unwrap());
channel.execute(server.serve())
})
// Max 10 channels.
.buffer_unordered(10)

View File

@@ -67,7 +67,7 @@ else
fi
printf "${PREFIX} Checking for rustfmt ... "
command -v cargo fmt &>/dev/null
command -v rustfmt &>/dev/null
if [ $? == 0 ]; then
printf "${SUCCESS}\n"
else
@@ -93,19 +93,19 @@ diff=""
for file in $(git diff --name-only --cached);
do
if [ ${file: -3} == ".rs" ]; then
diff="$diff$(cargo fmt -- --check $file)"
diff="$diff$(rustfmt --edition 2018 --check $file)"
if [ $? != 0 ]; then
FMTRESULT=1
fi
fi
done
if grep --quiet "^[-+]" <<< "$diff"; then
FMTRESULT=1
fi
if [ "${TARPC_SKIP_RUSTFMT}" == 1 ]; then
printf "${SKIPPED}\n"$?
elif [ ${FMTRESULT} != 0 ]; then
FAILED=1
printf "${FAILURE}\n"
echo "$diff" | sed 's/Using rustfmt config file.*$/d/'
echo "$diff"
else
printf "${SUCCESS}\n"
fi

View File

@@ -1,8 +1,9 @@
[package]
name = "tarpc-plugins"
version = "0.10.0"
version = "0.12.0"
rust-version = "1.56"
authors = ["Adam Wright <adam.austin.wright@gmail.com>", "Tim Kuehn <timothy.j.kuehn@gmail.com>"]
edition = "2018"
edition = "2021"
license = "MIT"
documentation = "https://docs.rs/tarpc-plugins"
homepage = "https://github.com/google/tarpc"
@@ -30,4 +31,4 @@ proc-macro = true
assert-type-eq = "0.1.0"
futures = "0.3"
serde = { version = "1.0", features = ["derive"] }
tarpc = { path = "../tarpc" }
tarpc = { path = "../tarpc", features = ["serde1"] }

9
plugins/LICENSE Normal file
View File

@@ -0,0 +1,9 @@
The MIT License (MIT)
Copyright 2016 Google Inc. All Rights Reserved.
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

View File

@@ -83,7 +83,7 @@ impl Parse for Service {
ident_errors,
syn::Error::new(
rpc.ident.span(),
format!("method name conflicts with generated fn `{}::serve`", ident)
format!("method name conflicts with generated fn `{ident}::serve`")
)
);
}
@@ -267,18 +267,25 @@ pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream {
None
};
let methods = rpcs.iter().map(|rpc| &rpc.ident).collect::<Vec<_>>();
let request_names = methods
.iter()
.map(|m| format!("{ident}.{m}"))
.collect::<Vec<_>>();
ServiceGenerator {
response_fut_name,
service_ident: ident,
server_ident: &format_ident!("Serve{}", ident),
response_fut_ident: &Ident::new(&response_fut_name, ident.span()),
response_fut_ident: &Ident::new(response_fut_name, ident.span()),
client_ident: &format_ident!("{}Client", ident),
request_ident: &format_ident!("{}Request", ident),
response_ident: &format_ident!("{}Response", ident),
vis,
args,
method_attrs: &rpcs.iter().map(|rpc| &*rpc.attrs).collect::<Vec<_>>(),
method_idents: &rpcs.iter().map(|rpc| &rpc.ident).collect::<Vec<_>>(),
method_idents: &methods,
request_names: &request_names,
attrs,
rpcs,
return_types: &rpcs
@@ -299,7 +306,7 @@ pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream {
.collect::<Vec<_>>(),
future_types: &camel_case_fn_names
.iter()
.map(|name| parse_str(&format!("{}Fut", name)).unwrap())
.map(|name| parse_str(&format!("{name}Fut")).unwrap())
.collect::<Vec<_>>(),
derive_serialize: derive_serialize.as_ref(),
}
@@ -399,14 +406,10 @@ fn verify_types_were_provided(
) -> syn::Result<()> {
let mut result = Ok(());
for (method, expected) in expected {
if provided
.iter()
.find(|typedecl| typedecl.ident == expected)
.is_none()
{
if !provided.iter().any(|typedecl| typedecl.ident == expected) {
let mut e = syn::Error::new(
span,
format!("not all trait items implemented, missing: `{}`", expected),
format!("not all trait items implemented, missing: `{expected}`"),
);
let fn_span = method.sig.fn_token.span();
e.extend(syn::Error::new(
@@ -441,6 +444,7 @@ struct ServiceGenerator<'a> {
camel_case_idents: &'a [Ident],
future_types: &'a [Type],
method_idents: &'a [&'a Ident],
request_names: &'a [String],
method_attrs: &'a [&'a [Attribute]],
args: &'a [&'a [PatType]],
return_types: &'a [&'a Type],
@@ -475,7 +479,7 @@ impl<'a> ServiceGenerator<'a> {
),
output,
)| {
let ty_doc = format!("The response future returned by {}.", ident);
let ty_doc = format!("The response future returned by [`{service_ident}::{ident}`].");
quote! {
#[doc = #ty_doc]
type #future_type: std::future::Future<Output = #output>;
@@ -524,6 +528,7 @@ impl<'a> ServiceGenerator<'a> {
camel_case_idents,
arg_pats,
method_idents,
request_names,
..
} = self;
@@ -534,6 +539,16 @@ impl<'a> ServiceGenerator<'a> {
type Resp = #response_ident;
type Fut = #response_fut_ident<S>;
fn method(&self, req: &#request_ident) -> Option<&'static str> {
Some(match req {
#(
#request_ident::#camel_case_idents{..} => {
#request_names
}
)*
})
}
fn serve(self, ctx: tarpc::context::Context, req: #request_ident) -> Self::Fut {
match req {
#(
@@ -563,6 +578,7 @@ impl<'a> ServiceGenerator<'a> {
quote! {
/// The request sent over the wire from the client to the server.
#[allow(missing_docs)]
#[derive(Debug)]
#derive_serialize
#vis enum #request_ident {
@@ -583,6 +599,7 @@ impl<'a> ServiceGenerator<'a> {
quote! {
/// The response sent over the wire from the server to the client.
#[allow(missing_docs)]
#[derive(Debug)]
#derive_serialize
#vis enum #response_ident {
@@ -603,6 +620,7 @@ impl<'a> ServiceGenerator<'a> {
quote! {
/// A future resolving to a server response.
#[allow(missing_docs)]
#vis enum #response_fut_ident<S: #service_ident> {
#( #camel_case_idents(<S as #service_ident>::#future_types) ),*
}
@@ -714,6 +732,7 @@ impl<'a> ServiceGenerator<'a> {
method_attrs,
vis,
method_idents,
request_names,
args,
return_types,
arg_pats,
@@ -727,9 +746,9 @@ impl<'a> ServiceGenerator<'a> {
#[allow(unused)]
#( #method_attrs )*
#vis fn #method_idents(&self, ctx: tarpc::context::Context, #( #args ),*)
-> impl std::future::Future<Output = std::io::Result<#return_types>> + '_ {
-> impl std::future::Future<Output = Result<#return_types, tarpc::client::RpcError>> + '_ {
let request = #request_ident::#camel_case_idents { #( #arg_pats ),* };
let resp = self.0.call(ctx, request);
let resp = self.0.call(ctx, #request_names, request);
async move {
match resp.await? {
#response_ident::#camel_case_idents(msg) => std::result::Result::Ok(msg),

View File

@@ -1,26 +1,41 @@
[package]
name = "tarpc"
version = "0.25.1"
authors = ["Adam Wright <adam.austin.wright@gmail.com>", "Tim Kuehn <timothy.j.kuehn@gmail.com>"]
edition = "2018"
version = "0.32.0"
rust-version = "1.58.0"
authors = [
"Adam Wright <adam.austin.wright@gmail.com>",
"Tim Kuehn <timothy.j.kuehn@gmail.com>",
]
edition = "2021"
license = "MIT"
documentation = "https://docs.rs/tarpc"
homepage = "https://github.com/google/tarpc"
repository = "https://github.com/google/tarpc"
keywords = ["rpc", "network", "server", "api", "microservices"]
categories = ["asynchronous", "network-programming"]
readme = "../README.md"
readme = "README.md"
description = "An RPC framework for Rust with a focus on ease of use."
[features]
default = []
serde1 = ["tarpc-plugins/serde1", "serde", "serde/derive"]
tokio1 = ["tokio/rt-multi-thread"]
serde-transport = ["serde1", "tokio1", "tokio-serde/json", "tokio-util/codec"]
tokio1 = ["tokio/rt"]
serde-transport = ["serde1", "tokio1", "tokio-serde", "tokio-util/codec"]
serde-transport-json = ["tokio-serde/json"]
serde-transport-bincode = ["tokio-serde/bincode"]
tcp = ["tokio/net"]
unix = ["tokio/net"]
full = ["serde1", "tokio1", "serde-transport", "tcp"]
full = [
"serde1",
"tokio1",
"serde-transport",
"serde-transport-json",
"serde-transport-bincode",
"tcp",
"unix",
]
[badges]
travis-ci = { repository = "google/tarpc" }
@@ -30,29 +45,41 @@ anyhow = "1.0"
fnv = "1.0"
futures = "0.3"
humantime = "2.0"
log = "0.4"
pin-project = "1.0"
rand = "0.7"
rand = "0.8"
serde = { optional = true, version = "1.0", features = ["derive"] }
static_assertions = "1.1.0"
tarpc-plugins = { path = "../plugins", version = "0.10" }
tarpc-plugins = { path = "../plugins", version = "0.12" }
thiserror = "1.0"
tokio = { version = "1", features = ["time"] }
tokio-util = { version = "0.6.3", features = ["time"] }
tokio-util = { version = "0.7.3", features = ["time"] }
tokio-serde = { optional = true, version = "0.8" }
tracing = { version = "0.1", default-features = false, features = [
"attributes",
"log",
] }
tracing-opentelemetry = { version = "0.17.2", default-features = false }
opentelemetry = { version = "0.17.0", default-features = false }
[dev-dependencies]
assert_matches = "1.4"
bincode = "1.3"
bytes = { version = "1", features = ["serde"] }
env_logger = "0.8"
flate2 = "1.0"
futures-test = "0.3"
log = "0.4"
opentelemetry = { version = "0.17.0", default-features = false, features = [
"rt-tokio",
] }
opentelemetry-jaeger = { version = "0.16.0", features = ["rt-tokio"] }
pin-utils = "0.1.0-alpha"
serde_bytes = "0.11"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
tokio = { version = "1", features = ["full", "test-util"] }
tokio-serde = { version = "0.8", features = ["json", "bincode"] }
trybuild = "1.0"
tokio-rustls = "0.23"
rustls-pemfile = "1.0"
[package.metadata.docs.rs]
all-features = true
@@ -63,7 +90,7 @@ name = "compression"
required-features = ["serde-transport", "tcp"]
[[example]]
name = "server_calling_server"
name = "tracing"
required-features = ["full"]
[[example]]
@@ -78,6 +105,10 @@ required-features = ["full"]
name = "custom_transport"
required-features = ["serde1", "tokio1", "serde-transport"]
[[example]]
name = "tls_over_tcp"
required-features = ["full"]
[[test]]
name = "service_functional"
required-features = ["serde-transport"]

9
tarpc/LICENSE Normal file
View File

@@ -0,0 +1,9 @@
The MIT License (MIT)
Copyright 2016 Google Inc. All Rights Reserved.
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

View File

@@ -0,0 +1,11 @@
-----BEGIN CERTIFICATE-----
MIIBlDCCAUagAwIBAgICAxUwBQYDK2VwMC4xLDAqBgNVBAMMI3Bvbnl0b3duIEVk
RFNBIGxldmVsIDIgaW50ZXJtZWRpYXRlMB4XDTIzMDMxNzEwMTEwNFoXDTI4MDkw
NjEwMTEwNFowGjEYMBYGA1UEAwwPcG9ueXRvd24gY2xpZW50MCowBQYDK2VwAyEA
NTKuLume19IhJfEFd/5OZUuYDKZH6xvy4AGver17OoejgZswgZgwDAYDVR0TAQH/
BAIwADALBgNVHQ8EBAMCBsAwFgYDVR0lAQH/BAwwCgYIKwYBBQUHAwIwHQYDVR0O
BBYEFDjdrlMu4tyw5MHtbg7WnzSGRBpFMEQGA1UdIwQ9MDuAFHIl7fHKWP6/l8FE
fI2YEIM3oHxKoSCkHjAcMRowGAYDVQQDDBFwb255dG93biBFZERTQSBDQYIBezAF
BgMrZXADQQCaahfj/QLxoCOpvl6y0ZQ9CpojPqBnxV3460j5nUOp040Va2MpF137
izCBY7LwgUE/YG6E+kH30G4jMEnqVEYK
-----END CERTIFICATE-----

View File

@@ -0,0 +1,19 @@
-----BEGIN CERTIFICATE-----
MIIBeDCCASqgAwIBAgIBezAFBgMrZXAwHDEaMBgGA1UEAwwRcG9ueXRvd24gRWRE
U0EgQ0EwHhcNMjMwMzE3MTAxMTA0WhcNMzMwMzE0MTAxMTA0WjAuMSwwKgYDVQQD
DCNwb255dG93biBFZERTQSBsZXZlbCAyIGludGVybWVkaWF0ZTAqMAUGAytlcAMh
AEFsAexz4x2R4k4+PnTbvRVn0r3F/qw/zVnNBxfGcoEpo38wfTAdBgNVHQ4EFgQU
ciXt8cpY/r+XwUR8jZgQgzegfEowIAYDVR0lAQH/BBYwFAYIKwYBBQUHAwEGCCsG
AQUFBwMCMAwGA1UdEwQFMAMBAf8wCwYDVR0PBAQDAgH+MB8GA1UdIwQYMBaAFKYU
oLdKeY7mp7QgMZKrkVtSWYBKMAUGAytlcANBAHVpNpCV8nu4fkH3Smikx5A9qtHc
zgLIyp+wrF1a4YSa6sfTvuQmJd5aF23OXgq5grCOPXtdpHO50Mx5Qy74zQg=
-----END CERTIFICATE-----
-----BEGIN CERTIFICATE-----
MIIBTDCB/6ADAgECAhRZLuF0TWjDs/31OO8VeKHkNIJQaDAFBgMrZXAwHDEaMBgG
A1UEAwwRcG9ueXRvd24gRWREU0EgQ0EwHhcNMjMwMzE3MTAxMTA0WhcNMzMwMzE0
MTAxMTA0WjAcMRowGAYDVQQDDBFwb255dG93biBFZERTQSBDQTAqMAUGAytlcAMh
ABRPZ4TiuBE8CqAFByZvqpMo/unjnnryfG2AkkWGXpa3o1MwUTAdBgNVHQ4EFgQU
phSgt0p5juantCAxkquRW1JZgEowHwYDVR0jBBgwFoAUphSgt0p5juantCAxkquR
W1JZgEowDwYDVR0TAQH/BAUwAwEB/zAFBgMrZXADQQB29o8erJA0/a8/xOHilOCC
t/s5wPHHnS5NSKx/m2N2nRn3zPxEnETlrAmGulJoeKOx8OblwmPi9rBT2K+QY2UB
-----END CERTIFICATE-----

View File

@@ -0,0 +1,3 @@
-----BEGIN PRIVATE KEY-----
MC4CAQAwBQYDK2VwBCIEIIJX9ThTHpVS1SNZb6HP4myg4fRInIVGunTRdgnc+weH
-----END PRIVATE KEY-----

View File

@@ -0,0 +1,12 @@
-----BEGIN CERTIFICATE-----
MIIBuDCCAWqgAwIBAgICAcgwBQYDK2VwMC4xLDAqBgNVBAMMI3Bvbnl0b3duIEVk
RFNBIGxldmVsIDIgaW50ZXJtZWRpYXRlMB4XDTIzMDMxNzEwMTEwNFoXDTI4MDkw
NjEwMTEwNFowGTEXMBUGA1UEAwwOdGVzdHNlcnZlci5jb20wKjAFBgMrZXADIQDc
RLl3/N2tPoWnzBV3noVn/oheEl8IUtiY11Vg/QXTUKOBwDCBvTAMBgNVHRMBAf8E
AjAAMAsGA1UdDwQEAwIGwDAdBgNVHQ4EFgQUk7U2mnxedNWBAH84BsNy5si3ZQow
RAYDVR0jBD0wO4AUciXt8cpY/r+XwUR8jZgQgzegfEqhIKQeMBwxGjAYBgNVBAMM
EXBvbnl0b3duIEVkRFNBIENBggF7MDsGA1UdEQQ0MDKCDnRlc3RzZXJ2ZXIuY29t
ghVzZWNvbmQudGVzdHNlcnZlci5jb22CCWxvY2FsaG9zdDAFBgMrZXADQQCFWIcF
9FiztCuUNzgXDNu5kshuflt0RjkjWpGlWzQjGoYM2IvYhNVPeqnCiY92gqwDSBtq
amD2TBup4eNUCsQB
-----END CERTIFICATE-----

View File

@@ -0,0 +1,19 @@
-----BEGIN CERTIFICATE-----
MIIBeDCCASqgAwIBAgIBezAFBgMrZXAwHDEaMBgGA1UEAwwRcG9ueXRvd24gRWRE
U0EgQ0EwHhcNMjMwMzE3MTAxMTA0WhcNMzMwMzE0MTAxMTA0WjAuMSwwKgYDVQQD
DCNwb255dG93biBFZERTQSBsZXZlbCAyIGludGVybWVkaWF0ZTAqMAUGAytlcAMh
AEFsAexz4x2R4k4+PnTbvRVn0r3F/qw/zVnNBxfGcoEpo38wfTAdBgNVHQ4EFgQU
ciXt8cpY/r+XwUR8jZgQgzegfEowIAYDVR0lAQH/BBYwFAYIKwYBBQUHAwEGCCsG
AQUFBwMCMAwGA1UdEwQFMAMBAf8wCwYDVR0PBAQDAgH+MB8GA1UdIwQYMBaAFKYU
oLdKeY7mp7QgMZKrkVtSWYBKMAUGAytlcANBAHVpNpCV8nu4fkH3Smikx5A9qtHc
zgLIyp+wrF1a4YSa6sfTvuQmJd5aF23OXgq5grCOPXtdpHO50Mx5Qy74zQg=
-----END CERTIFICATE-----
-----BEGIN CERTIFICATE-----
MIIBTDCB/6ADAgECAhRZLuF0TWjDs/31OO8VeKHkNIJQaDAFBgMrZXAwHDEaMBgG
A1UEAwwRcG9ueXRvd24gRWREU0EgQ0EwHhcNMjMwMzE3MTAxMTA0WhcNMzMwMzE0
MTAxMTA0WjAcMRowGAYDVQQDDBFwb255dG93biBFZERTQSBDQTAqMAUGAytlcAMh
ABRPZ4TiuBE8CqAFByZvqpMo/unjnnryfG2AkkWGXpa3o1MwUTAdBgNVHQ4EFgQU
phSgt0p5juantCAxkquRW1JZgEowHwYDVR0jBBgwFoAUphSgt0p5juantCAxkquR
W1JZgEowDwYDVR0TAQH/BAUwAwEB/zAFBgMrZXADQQB29o8erJA0/a8/xOHilOCC
t/s5wPHHnS5NSKx/m2N2nRn3zPxEnETlrAmGulJoeKOx8OblwmPi9rBT2K+QY2UB
-----END CERTIFICATE-----

View File

@@ -0,0 +1,3 @@
-----BEGIN PRIVATE KEY-----
MC4CAQAwBQYDK2VwBCIEIMU6xGVe8JTpZ3bN/wajHfw6pEHt0Rd7wPBxds9eEFy2
-----END PRIVATE KEY-----

View File

@@ -7,8 +7,8 @@ use tarpc::{
client, context,
serde_transport::tcp,
server::{BaseChannel, Channel},
tokio_serde::formats::Bincode,
};
use tokio_serde::formats::Bincode;
/// Type of compression that should be enabled on the request. The transport is free to ignore this.
#[derive(Debug, PartialEq, Eq, Clone, Copy, Deserialize, Serialize)]
@@ -54,7 +54,7 @@ where
if algorithm != CompressionAlgorithm::Deflate {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("Compression algorithm {:?} not supported", algorithm),
format!("Compression algorithm {algorithm:?} not supported"),
));
}
let mut deflater = DeflateDecoder::new(payload.as_slice());
@@ -102,7 +102,7 @@ struct HelloServer;
#[tarpc::server]
impl World for HelloServer {
async fn hello(self, _: context::Context, name: String) -> String {
format!("Hey, {}!", name)
format!("Hey, {name}!")
}
}
@@ -118,7 +118,7 @@ async fn main() -> anyhow::Result<()> {
});
let transport = tcp::connect(addr, Bincode::default).await?;
let client = WorldClient::new(client::Config::default(), add_compression(transport)).spawn()?;
let client = WorldClient::new(client::Config::default(), add_compression(transport)).spawn();
println!(
"{}",

View File

@@ -1,10 +1,9 @@
use futures::future;
use tarpc::context::Context;
use tarpc::serde_transport as transport;
use tarpc::server::{BaseChannel, Channel};
use tarpc::tokio_serde::formats::Bincode;
use tarpc::tokio_util::codec::length_delimited::LengthDelimitedCodec;
use tokio::net::{UnixListener, UnixStream};
use tokio_serde::formats::Bincode;
use tokio_util::codec::length_delimited::LengthDelimitedCodec;
#[tarpc::service]
pub trait PingService {
@@ -14,16 +13,13 @@ pub trait PingService {
#[derive(Clone)]
struct Service;
#[tarpc::server]
impl PingService for Service {
type PingFut = future::Ready<()>;
fn ping(self, _: Context) -> Self::PingFut {
future::ready(())
}
async fn ping(self, _: Context) {}
}
#[tokio::main]
async fn main() -> std::io::Result<()> {
async fn main() -> anyhow::Result<()> {
let bind_addr = "/tmp/tarpc_on_unix_example.sock";
let _ = std::fs::remove_file(bind_addr);
@@ -44,7 +40,9 @@ async fn main() -> std::io::Result<()> {
let conn = UnixStream::connect(bind_addr).await?;
let transport = transport::new(codec_builder.new_framed(conn), Bincode::default());
PingServiceClient::new(Default::default(), transport)
.spawn()?
.spawn()
.ping(tarpc::context::current())
.await
.await?;
Ok(())
}

View File

@@ -38,10 +38,11 @@ use futures::{
future::{self, AbortHandle},
prelude::*,
};
use log::info;
use publisher::Publisher as _;
use std::{
collections::HashMap,
env,
error::Error,
io,
net::SocketAddr,
sync::{Arc, Mutex, RwLock},
@@ -51,9 +52,11 @@ use tarpc::{
client, context,
serde_transport::tcp,
server::{self, Channel},
tokio_serde::formats::Json,
};
use tokio::net::ToSocketAddrs;
use tokio_serde::formats::Json;
use tracing::info;
use tracing_subscriber::prelude::*;
pub mod subscriber {
#[tarpc::service]
@@ -83,10 +86,7 @@ impl subscriber::Subscriber for Subscriber {
}
async fn receive(self, _: context::Context, topic: String, message: String) {
info!(
"[{}] received message on topic '{}': {}",
self.local_addr, topic, message
);
info!(local_addr = %self.local_addr, %topic, %message, "ReceivedMessage")
}
}
@@ -120,7 +120,7 @@ impl Subscriber {
let (handler, abort_handle) = future::abortable(handler.execute(subscriber.serve()));
tokio::spawn(async move {
match handler.await {
Ok(()) | Err(future::Aborted) => info!("[{}] subscriber shutdown.", local_addr),
Ok(()) | Err(future::Aborted) => info!(?local_addr, "subscriber shutdown."),
}
});
Ok(SubscriberHandle(abort_handle))
@@ -129,7 +129,6 @@ impl Subscriber {
#[derive(Debug)]
struct Subscription {
subscriber: subscriber::SubscriberClient,
topics: Vec<String>,
}
@@ -153,13 +152,13 @@ impl Publisher {
subscriptions: self.clone().start_subscription_manager().await?,
};
info!("[{}] listening for publishers.", publisher_addrs.publisher);
info!(publisher_addr = %publisher_addrs.publisher, "listening for publishers.",);
tokio::spawn(async move {
// Because this is just an example, we know there will only be one publisher. In more
// realistic code, this would be a loop to continually accept new publisher
// connections.
let publisher = connecting_publishers.next().await.unwrap().unwrap();
info!("[{}] publisher connected.", publisher.peer_addr().unwrap());
info!(publisher.peer_addr = ?publisher.peer_addr(), "publisher connected.");
server::BaseChannel::with_defaults(publisher)
.execute(self.serve())
@@ -174,7 +173,7 @@ impl Publisher {
.await?
.filter_map(|r| future::ready(r.ok()));
let new_subscriber_addr = connecting_subscribers.get_ref().local_addr();
info!("[{}] listening for subscribers.", new_subscriber_addr);
info!(?new_subscriber_addr, "listening for subscribers.");
tokio::spawn(async move {
while let Some(conn) = connecting_subscribers.next().await {
@@ -210,12 +209,11 @@ impl Publisher {
self.clients.lock().unwrap().insert(
subscriber_addr,
Subscription {
subscriber: subscriber.clone(),
topics: topics.clone(),
},
);
info!("[{}] subscribed to topics: {:?}", subscriber_addr, topics);
info!(%subscriber_addr, ?topics, "subscribed to new topics");
let mut subscriptions = self.subscriptions.write().unwrap();
for topic in topics {
subscriptions
@@ -226,18 +224,18 @@ impl Publisher {
}
}
fn start_subscriber_gc(
fn start_subscriber_gc<E: Error>(
self,
subscriber_addr: SocketAddr,
client_dispatch: impl Future<Output = anyhow::Result<()>> + Send + 'static,
client_dispatch: impl Future<Output = Result<(), E>> + Send + 'static,
subscriber_ready: oneshot::Receiver<()>,
) {
tokio::spawn(async move {
if let Err(e) = client_dispatch.await {
info!(
"[{}] subscriber connection broken: {:?}",
subscriber_addr, e
)
%subscriber_addr,
error = %e,
"subscriber connection broken");
}
// Don't clean up the subscriber until initialization is done.
let _ = subscriber_ready.await;
@@ -281,13 +279,29 @@ impl publisher::Publisher for Publisher {
}
}
/// Initializes an OpenTelemetry tracing subscriber with a Jaeger backend.
fn init_tracing(service_name: &str) -> anyhow::Result<()> {
env::set_var("OTEL_BSP_MAX_EXPORT_BATCH_SIZE", "12");
let tracer = opentelemetry_jaeger::new_pipeline()
.with_service_name(service_name)
.with_max_packet_size(2usize.pow(13))
.install_batch(opentelemetry::runtime::Tokio)?;
tracing_subscriber::registry()
.with(tracing_subscriber::filter::EnvFilter::from_default_env())
.with(tracing_subscriber::fmt::layer())
.with(tracing_opentelemetry::layer().with_tracer(tracer))
.try_init()?;
Ok(())
}
#[tokio::main]
async fn main() -> anyhow::Result<()> {
env_logger::init();
init_tracing("Pub/Sub")?;
let clients = Arc::new(Mutex::new(HashMap::new()));
let addrs = Publisher {
clients,
clients: Arc::new(Mutex::new(HashMap::new())),
subscriptions: Arc::new(RwLock::new(HashMap::new())),
}
.start()
@@ -309,7 +323,7 @@ async fn main() -> anyhow::Result<()> {
client::Config::default(),
tcp::connect(addrs.publisher, Json::default).await?,
)
.spawn()?;
.spawn();
publisher
.publish(context::current(), "calculus".into(), "sqrt(2)".into())
@@ -337,6 +351,7 @@ async fn main() -> anyhow::Result<()> {
)
.await?;
opentelemetry::global::shutdown_tracer_provider();
info!("done.");
Ok(())

View File

@@ -5,7 +5,6 @@
// https://opensource.org/licenses/MIT.
use futures::future::{self, Ready};
use std::io;
use tarpc::{
client, context,
server::{self, Channel},
@@ -30,12 +29,12 @@ impl World for HelloServer {
type HelloFut = Ready<String>;
fn hello(self, _: context::Context, name: String) -> Self::HelloFut {
future::ready(format!("Hello, {}!", name))
future::ready(format!("Hello, {name}!"))
}
}
#[tokio::main]
async fn main() -> io::Result<()> {
async fn main() -> anyhow::Result<()> {
let (client_transport, server_transport) = tarpc::transport::channel::unbounded();
let server = server::BaseChannel::with_defaults(server_transport);
@@ -43,14 +42,14 @@ async fn main() -> io::Result<()> {
// WorldClient is generated by the #[tarpc::service] attribute. It has a constructor `new`
// that takes a config and any Transport as input.
let client = WorldClient::new(client::Config::default(), client_transport).spawn()?;
let client = WorldClient::new(client::Config::default(), client_transport).spawn();
// The client has an RPC method for each RPC defined in the annotated trait. It takes the same
// args as defined, with the addition of a Context, which is always the first arg. The Context
// specifies a deadline and trace information which can be helpful in debugging requests.
let hello = client.hello(context::current(), "Stim".to_string()).await?;
println!("{}", hello);
println!("{hello}");
Ok(())
}

View File

@@ -0,0 +1,152 @@
use rustls_pemfile::certs;
use std::io::{BufReader, Cursor};
use std::net::{IpAddr, Ipv4Addr};
use tokio_rustls::rustls::server::AllowAnyAuthenticatedClient;
use std::sync::Arc;
use tokio::net::TcpListener;
use tokio::net::TcpStream;
use tokio_rustls::rustls::{self, Certificate, OwnedTrustAnchor, RootCertStore};
use tokio_rustls::{webpki, TlsAcceptor, TlsConnector};
use tarpc::context::Context;
use tarpc::serde_transport as transport;
use tarpc::server::{BaseChannel, Channel};
use tarpc::tokio_serde::formats::Bincode;
use tarpc::tokio_util::codec::length_delimited::LengthDelimitedCodec;
#[tarpc::service]
pub trait PingService {
async fn ping() -> String;
}
#[derive(Clone)]
struct Service;
#[tarpc::server]
impl PingService for Service {
async fn ping(self, _: Context) -> String {
"🔒".to_owned()
}
}
// certs were generated with openssl 3 https://github.com/rustls/rustls/tree/main/test-ca
// used on client-side for server tls
const END_CHAIN: &[u8] = include_bytes!("certs/eddsa/end.chain");
// used on client-side for client-auth
const CLIENT_PRIVATEKEY_CLIENT_AUTH: &str = include_str!("certs/eddsa/client.key");
const CLIENT_CERT_CLIENT_AUTH: &str = include_str!("certs/eddsa/client.cert");
// used on server-side for server tls
const END_CERT: &str = include_str!("certs/eddsa/end.cert");
const END_PRIVATEKEY: &str = include_str!("certs/eddsa/end.key");
// used on server-side for client-auth
const CLIENT_CHAIN_CLIENT_AUTH: &str = include_str!("certs/eddsa/client.chain");
pub fn load_private_key(key: &str) -> rustls::PrivateKey {
let mut reader = BufReader::new(Cursor::new(key));
loop {
match rustls_pemfile::read_one(&mut reader).expect("cannot parse private key .pem file") {
Some(rustls_pemfile::Item::RSAKey(key)) => return rustls::PrivateKey(key),
Some(rustls_pemfile::Item::PKCS8Key(key)) => return rustls::PrivateKey(key),
Some(rustls_pemfile::Item::ECKey(key)) => return rustls::PrivateKey(key),
None => break,
_ => {}
}
}
panic!("no keys found in {:?} (encrypted keys not supported)", key);
}
#[tokio::main]
async fn main() -> anyhow::Result<()> {
// -------------------- start here to setup tls tcp tokio stream --------------------------
// ref certs and loading from: https://github.com/tokio-rs/tls/blob/master/tokio-rustls/tests/test.rs
// ref basic tls server setup from: https://github.com/tokio-rs/tls/blob/master/tokio-rustls/examples/server/src/main.rs
let cert = certs(&mut BufReader::new(Cursor::new(END_CERT)))
.unwrap()
.into_iter()
.map(rustls::Certificate)
.collect();
let key = load_private_key(END_PRIVATEKEY);
let server_addr = (IpAddr::V4(Ipv4Addr::LOCALHOST), 5000);
// ------------- server side client_auth cert loading start
let roots: Vec<Certificate> = certs(&mut BufReader::new(Cursor::new(CLIENT_CHAIN_CLIENT_AUTH)))
.unwrap()
.into_iter()
.map(rustls::Certificate)
.collect();
let mut client_auth_roots = RootCertStore::empty();
for root in roots {
client_auth_roots.add(&root).unwrap();
}
let client_auth = AllowAnyAuthenticatedClient::new(client_auth_roots);
// ------------- server side client_auth cert loading end
let config = rustls::ServerConfig::builder()
.with_safe_defaults()
.with_client_cert_verifier(client_auth) // use .with_no_client_auth() instead if you don't want client-auth
.with_single_cert(cert, key)
.unwrap();
let acceptor = TlsAcceptor::from(Arc::new(config));
let listener = TcpListener::bind(&server_addr).await.unwrap();
let codec_builder = LengthDelimitedCodec::builder();
// ref ./custom_transport.rs server side
tokio::spawn(async move {
loop {
let (stream, _peer_addr) = listener.accept().await.unwrap();
let acceptor = acceptor.clone();
let tls_stream = acceptor.accept(stream).await.unwrap();
let framed = codec_builder.new_framed(tls_stream);
let transport = transport::new(framed, Bincode::default());
let fut = BaseChannel::with_defaults(transport).execute(Service.serve());
tokio::spawn(fut);
}
});
// ---------------------- client connection ---------------------
// cert loading from: https://github.com/tokio-rs/tls/blob/357bc562483dcf04c1f8d08bd1a831b144bf7d4c/tokio-rustls/tests/test.rs#L113
// tls client connection from https://github.com/tokio-rs/tls/blob/master/tokio-rustls/examples/client/src/main.rs
let chain = certs(&mut std::io::Cursor::new(END_CHAIN)).unwrap();
let mut root_store = rustls::RootCertStore::empty();
root_store.add_server_trust_anchors(chain.iter().map(|cert| {
let ta = webpki::TrustAnchor::try_from_cert_der(&cert[..]).unwrap();
OwnedTrustAnchor::from_subject_spki_name_constraints(
ta.subject,
ta.spki,
ta.name_constraints,
)
}));
let client_auth_private_key = load_private_key(CLIENT_PRIVATEKEY_CLIENT_AUTH);
let client_auth_certs: Vec<Certificate> =
certs(&mut BufReader::new(Cursor::new(CLIENT_CERT_CLIENT_AUTH)))
.unwrap()
.into_iter()
.map(rustls::Certificate)
.collect();
let config = rustls::ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(root_store)
.with_single_cert(client_auth_certs, client_auth_private_key)?; // use .with_no_client_auth() instead if you don't want client-auth
let domain = rustls::ServerName::try_from("localhost")?;
let connector = TlsConnector::from(Arc::new(config));
let stream = TcpStream::connect(server_addr).await?;
let stream = connector.connect(domain, stream).await?;
let transport = transport::new(codec_builder.new_framed(stream), Bincode::default());
let answer = PingServiceClient::new(Default::default(), transport)
.spawn()
.ping(tarpc::context::current())
.await?;
println!("ping answer: {answer}");
Ok(())
}

View File

@@ -6,12 +6,12 @@
use crate::{add::Add as AddService, double::Double as DoubleService};
use futures::{future, prelude::*};
use std::io;
use tarpc::{
client, context,
server::{BaseChannel, Incoming},
server::{incoming::Incoming, BaseChannel},
tokio_serde::formats::Json,
};
use tokio_serde::formats::Json;
use tracing_subscriber::prelude::*;
pub mod add {
#[tarpc::service]
@@ -54,9 +54,25 @@ impl DoubleService for DoubleServer {
}
}
fn init_tracing(service_name: &str) -> anyhow::Result<()> {
let tracer = opentelemetry_jaeger::new_pipeline()
.with_service_name(service_name)
.with_auto_split_batch(true)
.with_max_packet_size(2usize.pow(13))
.install_batch(opentelemetry::runtime::Tokio)?;
tracing_subscriber::registry()
.with(tracing_subscriber::EnvFilter::from_default_env())
.with(tracing_subscriber::fmt::layer())
.with(tracing_opentelemetry::layer().with_tracer(tracer))
.try_init()?;
Ok(())
}
#[tokio::main]
async fn main() -> io::Result<()> {
env_logger::init();
async fn main() -> anyhow::Result<()> {
init_tracing("tarpc_tracing_example")?;
let add_listener = tarpc::serde_transport::tcp::listen("localhost:0", Json::default)
.await?
@@ -69,7 +85,7 @@ async fn main() -> io::Result<()> {
tokio::spawn(add_server);
let to_add_server = tarpc::serde_transport::tcp::connect(addr, Json::default).await?;
let add_client = add::AddClient::new(client::Config::default(), to_add_server).spawn()?;
let add_client = add::AddClient::new(client::Config::default(), to_add_server).spawn();
let double_listener = tarpc::serde_transport::tcp::listen("localhost:0", Json::default)
.await?
@@ -83,10 +99,14 @@ async fn main() -> io::Result<()> {
let to_double_server = tarpc::serde_transport::tcp::connect(addr, Json::default).await?;
let double_client =
double::DoubleClient::new(client::Config::default(), to_double_server).spawn()?;
double::DoubleClient::new(client::Config::default(), to_double_server).spawn();
for i in 1..=5 {
eprintln!("{:?}", double_client.double(context::current(), i).await?);
let ctx = context::current();
for _ in 1..=5 {
tracing::info!("{:?}", double_client.double(ctx, 1).await?);
}
opentelemetry::global::shutdown_tracer_provider();
Ok(())
}

View File

@@ -0,0 +1,49 @@
use futures::{prelude::*, task::*};
use std::pin::Pin;
use tokio::sync::mpsc;
/// Sends request cancellation signals.
#[derive(Debug, Clone)]
pub struct RequestCancellation(mpsc::UnboundedSender<u64>);
/// A stream of IDs of requests that have been canceled.
#[derive(Debug)]
pub struct CanceledRequests(mpsc::UnboundedReceiver<u64>);
/// Returns a channel to send request cancellation messages.
pub fn cancellations() -> (RequestCancellation, CanceledRequests) {
// Unbounded because messages are sent in the drop fn. This is fine, because it's still
// bounded by the number of in-flight requests.
let (tx, rx) = mpsc::unbounded_channel();
(RequestCancellation(tx), CanceledRequests(rx))
}
impl RequestCancellation {
/// Cancels the request with ID `request_id`.
///
/// No validation is done of `request_id`. There is no way to know if the request id provided
/// corresponds to a request actually tracked by the backing channel. `RequestCancellation` is
/// a one-way communication channel.
///
/// Once request data is cleaned up, a response will never be received by the client. This is
/// useful primarily when request processing ends prematurely for requests with long deadlines
/// which would otherwise continue to be tracked by the backing channel—a kind of leak.
pub fn cancel(&self, request_id: u64) {
let _ = self.0.send(request_id);
}
}
impl CanceledRequests {
/// Polls for a cancelled request.
pub fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<u64>> {
self.0.poll_recv(cx)
}
}
impl Stream for CanceledRequests {
type Item = u64;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<u64>> {
self.poll_recv(cx)
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,18 +1,15 @@
use crate::{
context,
util::{Compact, TimeUntil},
PollIo, Response, ServerError,
};
use fnv::FnvHashMap;
use futures::ready;
use log::{debug, trace};
use std::{
collections::hash_map,
io,
task::{Context, Poll},
};
use tokio::sync::oneshot;
use tokio_util::time::delay_queue::{self, DelayQueue};
use tracing::Span;
/// Requests already written to the wire that haven't yet received responses.
#[derive(Debug)]
@@ -31,9 +28,10 @@ impl<Resp> Default for InFlightRequests<Resp> {
}
#[derive(Debug)]
struct RequestData<Resp> {
struct RequestData<Res> {
ctx: context::Context,
response_completion: oneshot::Sender<Response<Resp>>,
span: Span,
response_completion: oneshot::Sender<Res>,
/// The key to remove the timer for the request's deadline.
deadline_key: delay_queue::Key,
}
@@ -43,7 +41,7 @@ struct RequestData<Resp> {
#[derive(Debug)]
pub struct AlreadyExistsError;
impl<Resp> InFlightRequests<Resp> {
impl<Res> InFlightRequests<Res> {
/// Returns the number of in-flight requests.
pub fn len(&self) -> usize {
self.request_data.len()
@@ -59,20 +57,16 @@ impl<Resp> InFlightRequests<Resp> {
&mut self,
request_id: u64,
ctx: context::Context,
response_completion: oneshot::Sender<Response<Resp>>,
span: Span,
response_completion: oneshot::Sender<Res>,
) -> Result<(), AlreadyExistsError> {
match self.request_data.entry(request_id) {
hash_map::Entry::Vacant(vacant) => {
let timeout = ctx.deadline.time_until();
trace!(
"[{}] Queuing request with timeout {:?}.",
ctx.trace_id(),
timeout,
);
let deadline_key = self.deadlines.insert(request_id, timeout);
vacant.insert(RequestData {
ctx,
span,
response_completion,
deadline_key,
});
@@ -83,33 +77,42 @@ impl<Resp> InFlightRequests<Resp> {
}
/// Removes a request without aborting. Returns true iff the request was found.
pub fn complete_request(&mut self, response: Response<Resp>) -> bool {
if let Some(request_data) = self.request_data.remove(&response.request_id) {
pub fn complete_request(&mut self, request_id: u64, result: Res) -> bool {
if let Some(request_data) = self.request_data.remove(&request_id) {
let _entered = request_data.span.enter();
tracing::info!("ReceiveResponse");
self.request_data.compact(0.1);
trace!("[{}] Received response.", request_data.ctx.trace_id());
self.deadlines.remove(&request_data.deadline_key);
request_data.complete(response);
let _ = request_data.response_completion.send(result);
return true;
}
debug!(
"No in-flight request found for request_id = {}.",
response.request_id
);
tracing::debug!("No in-flight request found for request_id = {request_id}.");
// If the response completion was absent, then the request was already canceled.
false
}
/// Completes all requests using the provided function.
/// Returns Spans for all completes requests.
pub fn complete_all_requests<'a>(
&'a mut self,
mut result: impl FnMut() -> Res + 'a,
) -> impl Iterator<Item = Span> + 'a {
self.deadlines.clear();
self.request_data.drain().map(move |(_, request_data)| {
let _ = request_data.response_completion.send(result());
request_data.span
})
}
/// Cancels a request without completing (typically used when a request handle was dropped
/// before the request completed).
pub fn cancel_request(&mut self, request_id: u64) -> Option<context::Context> {
pub fn cancel_request(&mut self, request_id: u64) -> Option<(context::Context, Span)> {
if let Some(request_data) = self.request_data.remove(&request_id) {
self.request_data.compact(0.1);
trace!("[{}] Cancelling request.", request_data.ctx.trace_id());
self.deadlines.remove(&request_data.deadline_key);
Some(request_data.ctx)
Some((request_data.ctx, request_data.span))
} else {
None
}
@@ -117,46 +120,20 @@ impl<Resp> InFlightRequests<Resp> {
/// Yields a request that has expired, completing it with a TimedOut error.
/// The caller should send cancellation messages for any yielded request ID.
pub fn poll_expired(&mut self, cx: &mut Context) -> PollIo<u64> {
Poll::Ready(match ready!(self.deadlines.poll_expired(cx)) {
Some(Ok(expired)) => {
let request_id = expired.into_inner();
if let Some(request_data) = self.request_data.remove(&request_id) {
self.request_data.compact(0.1);
request_data.complete(Self::deadline_exceeded_error(request_id));
}
Some(Ok(request_id))
pub fn poll_expired(
&mut self,
cx: &mut Context,
expired_error: impl Fn() -> Res,
) -> Poll<Option<u64>> {
self.deadlines.poll_expired(cx).map(|expired| {
let request_id = expired?.into_inner();
if let Some(request_data) = self.request_data.remove(&request_id) {
let _entered = request_data.span.enter();
tracing::error!("DeadlineExceeded");
self.request_data.compact(0.1);
let _ = request_data.response_completion.send(expired_error());
}
Some(Err(e)) => Some(Err(io::Error::new(io::ErrorKind::Other, e))),
None => None,
Some(request_id)
})
}
fn deadline_exceeded_error(request_id: u64) -> Response<Resp> {
Response {
request_id,
message: Err(ServerError {
kind: io::ErrorKind::TimedOut,
detail: Some("Client dropped expired request.".to_string()),
}),
}
}
}
/// When InFlightRequests is dropped, any outstanding requests are completed with a
/// deadline-exceeded error.
impl<Resp> Drop for InFlightRequests<Resp> {
fn drop(&mut self) {
let deadlines = &mut self.deadlines;
for (_, request_data) in self.request_data.drain() {
let expired = deadlines.remove(&request_data.deadline_key);
request_data.complete(Self::deadline_exceeded_error(expired.into_inner()));
}
}
}
impl<Resp> RequestData<Resp> {
fn complete(self, response: Response<Resp>) {
let _ = self.response_completion.send(response);
}
}

View File

@@ -8,8 +8,13 @@
//! client to server and is used by the server to enforce response deadlines.
use crate::trace::{self, TraceId};
use opentelemetry::trace::TraceContextExt;
use static_assertions::assert_impl_all;
use std::time::{Duration, SystemTime};
use std::{
convert::TryFrom,
time::{Duration, SystemTime},
};
use tracing_opentelemetry::OpenTelemetrySpanExt;
/// A request context that carries request-scoped information like deadlines and trace information.
/// It is sent from client to server and is used by the server to enforce response deadlines.
@@ -22,15 +27,9 @@ use std::time::{Duration, SystemTime};
pub struct Context {
/// When the client expects the request to be complete by. The server should cancel the request
/// if it is not complete by this time.
#[cfg_attr(
feature = "serde1",
serde(serialize_with = "crate::util::serde::serialize_epoch_secs")
)]
#[cfg_attr(
feature = "serde1",
serde(deserialize_with = "crate::util::serde::deserialize_epoch_secs")
)]
#[cfg_attr(feature = "serde1", serde(default = "ten_seconds_from_now"))]
// Serialized as a Duration to prevent clock skew issues.
#[cfg_attr(feature = "serde1", serde(with = "absolute_to_relative_time"))]
pub deadline: SystemTime,
/// Uniquely identifies requests originating from the same source.
/// When a service handles a request by making requests itself, those requests should
@@ -39,25 +38,115 @@ pub struct Context {
pub trace_context: trace::Context,
}
#[cfg(feature = "serde1")]
mod absolute_to_relative_time {
pub use serde::{Deserialize, Deserializer, Serialize, Serializer};
pub use std::time::{Duration, SystemTime};
pub fn serialize<S>(deadline: &SystemTime, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let deadline = deadline
.duration_since(SystemTime::now())
.unwrap_or(Duration::ZERO);
deadline.serialize(serializer)
}
pub fn deserialize<'de, D>(deserializer: D) -> Result<SystemTime, D::Error>
where
D: Deserializer<'de>,
{
let deadline = Duration::deserialize(deserializer)?;
Ok(SystemTime::now() + deadline)
}
#[cfg(test)]
#[derive(serde::Serialize, serde::Deserialize)]
struct AbsoluteToRelative(#[serde(with = "self")] SystemTime);
#[test]
fn test_serialize() {
let now = SystemTime::now();
let deadline = now + Duration::from_secs(10);
let serialized_deadline = bincode::serialize(&AbsoluteToRelative(deadline)).unwrap();
let deserialized_deadline: Duration = bincode::deserialize(&serialized_deadline).unwrap();
// TODO: how to avoid flakiness?
assert!(deserialized_deadline > Duration::from_secs(9));
}
#[test]
fn test_deserialize() {
let deadline = Duration::from_secs(10);
let serialized_deadline = bincode::serialize(&deadline).unwrap();
let AbsoluteToRelative(deserialized_deadline) =
bincode::deserialize(&serialized_deadline).unwrap();
// TODO: how to avoid flakiness?
assert!(deserialized_deadline > SystemTime::now() + Duration::from_secs(9));
}
}
assert_impl_all!(Context: Send, Sync);
#[cfg(feature = "serde1")]
fn ten_seconds_from_now() -> SystemTime {
SystemTime::now() + Duration::from_secs(10)
}
/// Returns the context for the current request, or a default Context if no request is active.
// TODO: populate Context with request-scoped data, with default fallbacks.
pub fn current() -> Context {
Context {
deadline: SystemTime::now() + Duration::from_secs(10),
trace_context: trace::Context::new_root(),
Context::current()
}
#[derive(Clone)]
struct Deadline(SystemTime);
impl Default for Deadline {
fn default() -> Self {
Self(ten_seconds_from_now())
}
}
impl Context {
/// Returns the context for the current request, or a default Context if no request is active.
pub fn current() -> Self {
let span = tracing::Span::current();
Self {
trace_context: trace::Context::try_from(&span)
.unwrap_or_else(|_| trace::Context::default()),
deadline: span
.context()
.get::<Deadline>()
.cloned()
.unwrap_or_default()
.0,
}
}
/// Returns the ID of the request-scoped trace.
pub fn trace_id(&self) -> &TraceId {
&self.trace_context.trace_id
}
}
/// An extension trait for [`tracing::Span`] for propagating tarpc Contexts.
pub(crate) trait SpanExt {
/// Sets the given context on this span. Newly-created spans will be children of the given
/// context's trace context.
fn set_context(&self, context: &Context);
}
impl SpanExt for tracing::Span {
fn set_context(&self, context: &Context) {
self.set_parent(
opentelemetry::Context::new()
.with_remote_span_context(opentelemetry::trace::SpanContext::new(
opentelemetry::trace::TraceId::from(context.trace_context.trace_id),
opentelemetry::trace::SpanId::from(context.trace_context.span_id),
opentelemetry::trace::TraceFlags::from(context.trace_context.sampling_decision),
true,
opentelemetry::trace::TraceState::default(),
))
.with_value(Deadline(context.deadline)),
);
}
}

View File

@@ -27,7 +27,7 @@
//! process, and no context switching between different languages.
//!
//! Some other features of tarpc:
//! - Pluggable transport: any type impling `Stream<Item = Request> + Sink<Response>` can be
//! - Pluggable transport: any type implementing `Stream<Item = Request> + Sink<Response>` can be
//! used as a transport to connect the client and server.
//! - `Send + 'static` optional: if the transport doesn't require it, neither does tarpc!
//! - Cascading cancellation: dropping a request will send a cancellation message to the server.
@@ -38,6 +38,14 @@
//! requests sent by the server that use the request context will propagate the request deadline.
//! For example, if a server is handling a request with a 10s deadline, does 2s of work, then
//! sends a request to another server, that server will see an 8s deadline.
//! - Distributed tracing: tarpc is instrumented with
//! [tracing](https://github.com/tokio-rs/tracing) primitives extended with
//! [OpenTelemetry](https://opentelemetry.io/) traces. Using a compatible tracing subscriber like
//! [Jaeger](https://github.com/open-telemetry/opentelemetry-rust/tree/main/opentelemetry-jaeger),
//! each RPC can be traced through the client, server, and other dependencies downstream of the
//! server. Even for applications not connected to a distributed tracing collector, the
//! instrumentation can also be ingested by regular loggers like
//! [env_logger](https://github.com/env-logger-rs/env_logger/).
//! - Serde serialization: enabling the `serde1` Cargo feature will make service requests and
//! responses `Serialize + Deserialize`. It's entirely optional, though: in-memory transports can
//! be used, as well, so the price of serialization doesn't have to be paid when it's not needed.
@@ -46,7 +54,7 @@
//! Add to your `Cargo.toml` dependencies:
//!
//! ```toml
//! tarpc = "0.25"
//! tarpc = "0.29"
//! ```
//!
//! The `tarpc::service` attribute expands to a collection of items that form an rpc service.
@@ -59,8 +67,9 @@
//! your `Cargo.toml`:
//!
//! ```toml
//! futures = "1.0"
//! tarpc = { version = "0.25", features = ["tokio1"] }
//! anyhow = "1.0"
//! futures = "0.3"
//! tarpc = { version = "0.29", features = ["tokio1"] }
//! tokio = { version = "1.0", features = ["macros"] }
//! ```
//!
@@ -79,9 +88,8 @@
//! };
//! use tarpc::{
//! client, context,
//! server::{self, Incoming},
//! server::{self, incoming::Incoming, Channel},
//! };
//! use std::io;
//!
//! // This is the service definition. It looks a lot like a trait definition.
//! // It defines one RPC, hello, which takes one arg, name, and returns a String.
@@ -103,9 +111,8 @@
//! # };
//! # use tarpc::{
//! # client, context,
//! # server::{self, Incoming},
//! # server::{self, incoming::Incoming},
//! # };
//! # use std::io;
//! # // This is the service definition. It looks a lot like a trait definition.
//! # // It defines one RPC, hello, which takes one arg, name, and returns a String.
//! # #[tarpc::service]
@@ -125,7 +132,7 @@
//! type HelloFut = Ready<String>;
//!
//! fn hello(self, _: context::Context, name: String) -> Self::HelloFut {
//! future::ready(format!("Hello, {}!", name))
//! future::ready(format!("Hello, {name}!"))
//! }
//! }
//! ```
@@ -145,7 +152,6 @@
//! # client, context,
//! # server::{self, Channel},
//! # };
//! # use std::io;
//! # // This is the service definition. It looks a lot like a trait definition.
//! # // It defines one RPC, hello, which takes one arg, name, and returns a String.
//! # #[tarpc::service]
@@ -162,14 +168,14 @@
//! # // an associated type representing the future output by the fn.
//! # type HelloFut = Ready<String>;
//! # fn hello(self, _: context::Context, name: String) -> Self::HelloFut {
//! # future::ready(format!("Hello, {}!", name))
//! # future::ready(format!("Hello, {name}!"))
//! # }
//! # }
//! # #[cfg(not(feature = "tokio1"))]
//! # fn main() {}
//! # #[cfg(feature = "tokio1")]
//! #[tokio::main]
//! async fn main() -> io::Result<()> {
//! async fn main() -> anyhow::Result<()> {
//! let (client_transport, server_transport) = tarpc::transport::channel::unbounded();
//!
//! let server = server::BaseChannel::with_defaults(server_transport);
@@ -177,14 +183,14 @@
//!
//! // WorldClient is generated by the #[tarpc::service] attribute. It has a constructor `new`
//! // that takes a config and any Transport as input.
//! let mut client = WorldClient::new(client::Config::default(), client_transport).spawn()?;
//! let mut client = WorldClient::new(client::Config::default(), client_transport).spawn();
//!
//! // The client has an RPC method for each RPC defined in the annotated trait. It takes the same
//! // args as defined, with the addition of a Context, which is always the first arg. The Context
//! // specifies a deadline and trace information which can be helpful in debugging requests.
//! let hello = client.hello(context::current(), "Stim".to_string()).await?;
//!
//! println!("{}", hello);
//! println!("{hello}");
//!
//! Ok(())
//! }
@@ -199,10 +205,11 @@
#![cfg_attr(docsrs, feature(doc_cfg))]
#[cfg(feature = "serde1")]
#[doc(hidden)]
pub use serde;
#[cfg(feature = "serde-transport")]
pub use tokio_serde;
pub use {tokio_serde, tokio_util};
#[cfg(feature = "serde-transport")]
#[cfg_attr(docsrs, doc(cfg(feature = "serde-transport")))]
@@ -257,7 +264,7 @@ pub use tarpc_plugins::service;
/// #[tarpc::server]
/// impl World for HelloServer {
/// async fn hello(self, _: context::Context, name: String) -> String {
/// format!("Hello, {}! You are connected from {:?}.", name, self.0)
/// format!("Hello, {name}! You are connected from {:?}.", self.0)
/// }
/// }
/// ```
@@ -283,7 +290,7 @@ pub use tarpc_plugins::service;
/// fn hello(self, _: context::Context, name: String) -> Pin<Box<dyn Future<Output = String>
/// + Send>> {
/// Box::pin(async move {
/// format!("Hello, {}! You are connected from {:?}.", name, self.0)
/// format!("Hello, {name}! You are connected from {:?}.", self.0)
/// })
/// }
/// }
@@ -293,6 +300,7 @@ pub use tarpc_plugins::service;
/// `async`, meaning that this should not break existing code.
pub use tarpc_plugins::server;
pub(crate) mod cancellations;
pub mod client;
pub mod context;
pub mod server;
@@ -303,7 +311,8 @@ pub use crate::transport::sealed::Transport;
use anyhow::Context as _;
use futures::task::*;
use std::{fmt::Display, io, time::SystemTime};
use std::sync::Arc;
use std::{error::Error, fmt::Display, io, time::SystemTime};
/// A message from a client to a server.
#[derive(Debug)]
@@ -355,8 +364,9 @@ pub struct Response<T> {
pub message: Result<T, ServerError>,
}
/// An error response from a server to a client.
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
/// An error indicating the server aborted the request early, e.g., due to request throttling.
#[derive(thiserror::Error, Clone, Debug, PartialEq, Eq, Hash)]
#[error("{kind:?}: {detail}")]
#[non_exhaustive]
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
pub struct ServerError {
@@ -371,13 +381,30 @@ pub struct ServerError {
/// The type of error that occurred to fail the request.
pub kind: io::ErrorKind,
/// A message describing more detail about the error that occurred.
pub detail: Option<String>,
pub detail: String,
}
impl From<ServerError> for io::Error {
fn from(e: ServerError) -> io::Error {
io::Error::new(e.kind, e.detail.unwrap_or_default())
}
/// Critical errors that result in a Channel disconnecting.
#[derive(thiserror::Error, Debug, PartialEq, Eq)]
pub enum ChannelError<E>
where
E: Error + Send + Sync + 'static,
{
/// Could not read from the transport.
#[error("could not read from the transport")]
Read(#[source] Arc<E>),
/// Could not ready the transport for writes.
#[error("could not ready the transport for writes")]
Ready(#[source] E),
/// Could not write to the transport.
#[error("could not write to the transport")]
Write(#[source] E),
/// Could not flush the transport.
#[error("could not flush the transport")]
Flush(#[source] E),
/// Could not close the write end of the transport.
#[error("could not close the write end of the transport")]
Close(#[source] E),
}
impl<T> Request<T> {
@@ -387,7 +414,6 @@ impl<T> Request<T> {
}
}
pub(crate) type PollIo<T> = Poll<Option<io::Result<T>>>;
pub(crate) trait PollContext<T> {
fn context<C>(self, context: C) -> Poll<Option<anyhow::Result<T>>>
where
@@ -399,7 +425,10 @@ pub(crate) trait PollContext<T> {
F: FnOnce() -> C;
}
impl<T> PollContext<T> for PollIo<T> {
impl<T, E> PollContext<T> for Poll<Option<Result<T, E>>>
where
E: Error + Send + Sync + 'static,
{
fn context<C>(self, context: C) -> Poll<Option<anyhow::Result<T>>>
where
C: Display + Send + Sync + 'static,

View File

@@ -42,14 +42,10 @@ where
type Item = io::Result<Item>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<io::Result<Item>>> {
match self.project().inner.poll_next(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(None) => Poll::Ready(None),
Poll::Ready(Some(Ok::<_, CodecError>(next))) => Poll::Ready(Some(Ok(next))),
Poll::Ready(Some(Err::<_, CodecError>(e))) => {
Poll::Ready(Some(Err(io::Error::new(io::ErrorKind::Other, e))))
}
}
self.project()
.inner
.poll_next(cx)
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))
}
}
@@ -65,7 +61,10 @@ where
type Error = io::Error;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
convert(self.project().inner.poll_ready(cx))
self.project()
.inner
.poll_ready(cx)
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))
}
fn start_send(self: Pin<&mut Self>, item: SinkItem) -> io::Result<()> {
@@ -76,20 +75,20 @@ where
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
convert(self.project().inner.poll_flush(cx))
self.project()
.inner
.poll_flush(cx)
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
convert(self.project().inner.poll_close(cx))
self.project()
.inner
.poll_close(cx)
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))
}
}
fn convert<E: Into<Box<dyn Error + Send + Sync>>>(
poll: Poll<Result<(), E>>,
) -> Poll<io::Result<()>> {
poll.map(|ready| ready.map_err(|e| io::Error::new(io::ErrorKind::Other, e)))
}
/// Constructs a new transport from a framed transport and a serialization codec.
pub fn new<S, Item, SinkItem, Codec>(
framed_io: Framed<S, LengthDelimitedCodec>,
@@ -130,14 +129,6 @@ pub mod tcp {
tokio_util::codec::length_delimited,
};
mod private {
use super::*;
pub trait Sealed {}
impl<Item, SinkItem, Codec> Sealed for Transport<TcpStream, Item, SinkItem, Codec> {}
}
impl<Item, SinkItem, Codec> Transport<TcpStream, Item, SinkItem, Codec> {
/// Returns the peer address of the underlying TcpStream.
pub fn peer_addr(&self) -> io::Result<SocketAddr> {
@@ -150,6 +141,7 @@ pub mod tcp {
}
/// A connection Future that also exposes the length-delimited framing config.
#[must_use]
#[pin_project]
pub struct Connect<T, Item, SinkItem, CodecFn> {
#[pin]
@@ -277,6 +269,270 @@ pub mod tcp {
}
}
#[cfg(all(unix, feature = "unix"))]
#[cfg_attr(docsrs, doc(cfg(all(unix, feature = "unix"))))]
/// Unix Domain Socket support for generic transport using Tokio.
pub mod unix {
use {
super::*,
futures::ready,
std::{marker::PhantomData, path::Path},
tokio::net::{unix::SocketAddr, UnixListener, UnixStream},
tokio_util::codec::length_delimited,
};
impl<Item, SinkItem, Codec> Transport<UnixStream, Item, SinkItem, Codec> {
/// Returns the socket address of the remote half of the underlying [`UnixStream`].
pub fn peer_addr(&self) -> io::Result<SocketAddr> {
self.inner.get_ref().get_ref().peer_addr()
}
/// Returns the socket address of the local half of the underlying [`UnixStream`].
pub fn local_addr(&self) -> io::Result<SocketAddr> {
self.inner.get_ref().get_ref().local_addr()
}
}
/// A connection Future that also exposes the length-delimited framing config.
#[must_use]
#[pin_project]
pub struct Connect<T, Item, SinkItem, CodecFn> {
#[pin]
inner: T,
codec_fn: CodecFn,
config: length_delimited::Builder,
ghost: PhantomData<(fn(SinkItem), fn() -> Item)>,
}
impl<T, Item, SinkItem, Codec, CodecFn> Future for Connect<T, Item, SinkItem, CodecFn>
where
T: Future<Output = io::Result<UnixStream>>,
Item: for<'de> Deserialize<'de>,
SinkItem: Serialize,
Codec: Serializer<SinkItem> + Deserializer<Item>,
CodecFn: Fn() -> Codec,
{
type Output = io::Result<Transport<UnixStream, Item, SinkItem, Codec>>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let io = ready!(self.as_mut().project().inner.poll(cx))?;
Poll::Ready(Ok(new(self.config.new_framed(io), (self.codec_fn)())))
}
}
impl<T, Item, SinkItem, CodecFn> Connect<T, Item, SinkItem, CodecFn> {
/// Returns an immutable reference to the length-delimited codec's config.
pub fn config(&self) -> &length_delimited::Builder {
&self.config
}
/// Returns a mutable reference to the length-delimited codec's config.
pub fn config_mut(&mut self) -> &mut length_delimited::Builder {
&mut self.config
}
}
/// Connects to socket named by `path`, wrapping the connection in a Unix Domain Socket
/// transport.
pub fn connect<P, Item, SinkItem, Codec, CodecFn>(
path: P,
codec_fn: CodecFn,
) -> Connect<impl Future<Output = io::Result<UnixStream>>, Item, SinkItem, CodecFn>
where
P: AsRef<Path>,
Item: for<'de> Deserialize<'de>,
SinkItem: Serialize,
Codec: Serializer<SinkItem> + Deserializer<Item>,
CodecFn: Fn() -> Codec,
{
Connect {
inner: UnixStream::connect(path),
codec_fn,
config: LengthDelimitedCodec::builder(),
ghost: PhantomData,
}
}
/// Listens on the socket named by `path`, wrapping accepted connections in Unix Domain Socket
/// transports.
pub async fn listen<P, Item, SinkItem, Codec, CodecFn>(
path: P,
codec_fn: CodecFn,
) -> io::Result<Incoming<Item, SinkItem, Codec, CodecFn>>
where
P: AsRef<Path>,
Item: for<'de> Deserialize<'de>,
Codec: Serializer<SinkItem> + Deserializer<Item>,
CodecFn: Fn() -> Codec,
{
let listener = UnixListener::bind(path)?;
let local_addr = listener.local_addr()?;
Ok(Incoming {
listener,
codec_fn,
local_addr,
config: LengthDelimitedCodec::builder(),
ghost: PhantomData,
})
}
/// A [`UnixListener`] that wraps connections in [transports](Transport).
#[pin_project]
#[derive(Debug)]
pub struct Incoming<Item, SinkItem, Codec, CodecFn> {
listener: UnixListener,
local_addr: SocketAddr,
codec_fn: CodecFn,
config: length_delimited::Builder,
ghost: PhantomData<(fn() -> Item, fn(SinkItem), Codec)>,
}
impl<Item, SinkItem, Codec, CodecFn> Incoming<Item, SinkItem, Codec, CodecFn> {
/// Returns the the socket address being listened on.
pub fn local_addr(&self) -> &SocketAddr {
&self.local_addr
}
/// Returns an immutable reference to the length-delimited codec's config.
pub fn config(&self) -> &length_delimited::Builder {
&self.config
}
/// Returns a mutable reference to the length-delimited codec's config.
pub fn config_mut(&mut self) -> &mut length_delimited::Builder {
&mut self.config
}
}
impl<Item, SinkItem, Codec, CodecFn> Stream for Incoming<Item, SinkItem, Codec, CodecFn>
where
Item: for<'de> Deserialize<'de>,
SinkItem: Serialize,
Codec: Serializer<SinkItem> + Deserializer<Item>,
CodecFn: Fn() -> Codec,
{
type Item = io::Result<Transport<UnixStream, Item, SinkItem, Codec>>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let conn: UnixStream = ready!(self.as_mut().project().listener.poll_accept(cx)?).0;
Poll::Ready(Some(Ok(new(
self.config.new_framed(conn),
(self.codec_fn)(),
))))
}
}
/// A temporary `PathBuf` that lives in `std::env::temp_dir` and is removed on drop.
pub struct TempPathBuf(std::path::PathBuf);
impl TempPathBuf {
/// A named socket that results in `<tempdir>/<name>`
pub fn new<S: AsRef<str>>(name: S) -> Self {
let mut sock = std::env::temp_dir();
sock.push(name.as_ref());
Self(sock)
}
/// Appends a random hex string to the socket name resulting in
/// `<tempdir>/<name>_<xxxxx>`
pub fn with_random<S: AsRef<str>>(name: S) -> Self {
Self::new(format!("{}_{:016x}", name.as_ref(), rand::random::<u64>()))
}
}
impl AsRef<std::path::Path> for TempPathBuf {
fn as_ref(&self) -> &std::path::Path {
self.0.as_path()
}
}
impl Drop for TempPathBuf {
fn drop(&mut self) {
// This will remove the file pointed to by this PathBuf if it exists, however Err's can
// be returned such as attempting to remove a non-existing file, or one which we don't
// have permission to remove. In these cases the Err is swallowed
let _ = std::fs::remove_file(&self.0);
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokio_serde::formats::SymmetricalJson;
#[test]
fn temp_path_buf_non_random() {
let sock = TempPathBuf::new("test");
let mut good = std::env::temp_dir();
good.push("test");
assert_eq!(sock.as_ref(), good);
assert_eq!(sock.as_ref().file_name().unwrap(), "test");
}
#[test]
fn temp_path_buf_random() {
let sock = TempPathBuf::with_random("test");
let good = std::env::temp_dir();
assert!(sock.as_ref().starts_with(good));
// Since there are 16 random characters we just assert the file_name has the right name
// and starts with the correct string 'test_'
// file name: test_xxxxxxxxxxxxxxxx
// test = 4
// _ = 1
// <hex> = 16
// total = 21
let fname = sock.as_ref().file_name().unwrap().to_string_lossy();
assert!(fname.starts_with("test_"));
assert_eq!(fname.len(), 21);
}
#[test]
fn temp_path_buf_non_existing() {
let sock = TempPathBuf::with_random("test");
let sock_path = std::path::PathBuf::from(sock.as_ref());
// No actual file has been created yet
assert!(!sock_path.exists());
// Should not panic
std::mem::drop(sock);
assert!(!sock_path.exists());
}
#[test]
fn temp_path_buf_existing_file() {
let sock = TempPathBuf::with_random("test");
let sock_path = std::path::PathBuf::from(sock.as_ref());
let _file = std::fs::File::create(&sock).unwrap();
assert!(sock_path.exists());
std::mem::drop(sock);
assert!(!sock_path.exists());
}
#[test]
fn temp_path_buf_preexisting_file() {
let mut pre_existing = std::env::temp_dir();
pre_existing.push("test");
let _file = std::fs::File::create(&pre_existing).unwrap();
let sock = TempPathBuf::new("test");
let sock_path = std::path::PathBuf::from(sock.as_ref());
assert!(sock_path.exists());
std::mem::drop(sock);
assert!(!sock_path.exists());
}
#[tokio::test]
async fn temp_path_buf_for_socket() {
let sock = TempPathBuf::with_random("test");
// Save path for testing after drop
let sock_path = std::path::PathBuf::from(sock.as_ref());
// create the actual socket
let _ = listen(&sock, SymmetricalJson::<String>::default).await;
assert!(sock_path.exists());
std::mem::drop(sock);
assert!(!sock_path.exists());
}
}
}
#[cfg(test)]
mod tests {
use super::Transport;
@@ -291,7 +547,7 @@ mod tests {
use tokio_serde::formats::SymmetricalJson;
fn ctx() -> Context<'static> {
Context::from_waker(&noop_waker_ref())
Context::from_waker(noop_waker_ref())
}
struct TestIo(Cursor<Vec<u8>>);
@@ -393,4 +649,24 @@ mod tests {
assert_matches!(transport.next().await, None);
Ok(())
}
#[cfg(all(unix, feature = "unix"))]
#[tokio::test]
async fn uds() -> io::Result<()> {
use super::unix;
use super::*;
let sock = unix::TempPathBuf::with_random("uds");
let mut listener = unix::listen(&sock, SymmetricalJson::<String>::default).await?;
tokio::spawn(async move {
let mut transport = listener.next().await.unwrap().unwrap();
let message = transport.next().await.unwrap().unwrap();
transport.send(message).await.unwrap();
});
let mut transport = unix::connect(&sock, SymmetricalJson::<String>::default).await?;
transport.send(String::from("test")).await?;
assert_matches!(transport.next().await, Some(Ok(s)) if s == "test");
assert_matches!(transport.next().await, None);
Ok(())
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,19 +1,13 @@
use crate::{
util::{Compact, TimeUntil},
PollIo,
};
use crate::util::{Compact, TimeUntil};
use fnv::FnvHashMap;
use futures::{
future::{AbortHandle, AbortRegistration},
ready,
};
use futures::future::{AbortHandle, AbortRegistration};
use std::{
collections::hash_map,
io,
task::{Context, Poll},
time::SystemTime,
};
use tokio_util::time::delay_queue::{self, DelayQueue};
use tracing::Span;
/// A data structure that tracks in-flight requests. It aborts requests,
/// either on demand or when a request deadline expires.
@@ -23,13 +17,15 @@ pub struct InFlightRequests {
deadlines: DelayQueue<u64>,
}
#[derive(Debug)]
/// Data needed to clean up a single in-flight request.
#[derive(Debug)]
struct RequestData {
/// Aborts the response handler for the associated request.
abort_handle: AbortHandle,
/// The key to remove the timer for the request's deadline.
deadline_key: delay_queue::Key,
/// The client span.
span: Span,
}
/// An error returned when a request attempted to start with the same ID as a request already
@@ -48,6 +44,7 @@ impl InFlightRequests {
&mut self,
request_id: u64,
deadline: SystemTime,
span: Span,
) -> Result<AbortRegistration, AlreadyExistsError> {
match self.request_data.entry(request_id) {
hash_map::Entry::Vacant(vacant) => {
@@ -57,6 +54,7 @@ impl InFlightRequests {
vacant.insert(RequestData {
abort_handle,
deadline_key,
span,
});
Ok(abort_registration)
}
@@ -66,12 +64,17 @@ impl InFlightRequests {
/// Cancels an in-flight request. Returns true iff the request was found.
pub fn cancel_request(&mut self, request_id: u64) -> bool {
if let Some(request_data) = self.request_data.remove(&request_id) {
if let Some(RequestData {
span,
abort_handle,
deadline_key,
}) = self.request_data.remove(&request_id)
{
let _entered = span.enter();
self.request_data.compact(0.1);
request_data.abort_handle.abort();
self.deadlines.remove(&request_data.deadline_key);
abort_handle.abort();
self.deadlines.remove(&deadline_key);
tracing::info!("ReceiveCancel");
true
} else {
false
@@ -80,30 +83,35 @@ impl InFlightRequests {
/// Removes a request without aborting. Returns true iff the request was found.
/// This method should be used when a response is being sent.
pub fn remove_request(&mut self, request_id: u64) -> bool {
pub fn remove_request(&mut self, request_id: u64) -> Option<Span> {
if let Some(request_data) = self.request_data.remove(&request_id) {
self.request_data.compact(0.1);
self.deadlines.remove(&request_data.deadline_key);
true
Some(request_data.span)
} else {
false
None
}
}
/// Yields a request that has expired, aborting any ongoing processing of that request.
pub fn poll_expired(&mut self, cx: &mut Context) -> PollIo<u64> {
Poll::Ready(match ready!(self.deadlines.poll_expired(cx)) {
Some(Ok(expired)) => {
if let Some(request_data) = self.request_data.remove(expired.get_ref()) {
self.request_data.compact(0.1);
request_data.abort_handle.abort();
}
Some(Ok(expired.into_inner()))
pub fn poll_expired(&mut self, cx: &mut Context) -> Poll<Option<u64>> {
if self.deadlines.is_empty() {
// TODO(https://github.com/tokio-rs/tokio/issues/4161)
// This is a workaround for DelayQueue not always treating this case correctly.
return Poll::Ready(None);
}
self.deadlines.poll_expired(cx).map(|expired| {
let expired = expired?;
if let Some(RequestData {
abort_handle, span, ..
}) = self.request_data.remove(expired.get_ref())
{
let _entered = span.enter();
self.request_data.compact(0.1);
abort_handle.abort();
tracing::error!("DeadlineExceeded");
}
Some(Err(e)) => Some(Err(io::Error::new(io::ErrorKind::Other, e))),
None => None,
Some(expired.into_inner())
})
}
}
@@ -118,75 +126,96 @@ impl Drop for InFlightRequests {
}
#[cfg(test)]
use {
assert_matches::assert_matches,
futures::{
mod tests {
use super::*;
use assert_matches::assert_matches;
use futures::{
future::{pending, Abortable},
FutureExt,
},
futures_test::task::noop_context,
};
};
use futures_test::task::noop_context;
#[tokio::test]
async fn start_request_increases_len() {
let mut in_flight_requests = InFlightRequests::default();
assert_eq!(in_flight_requests.len(), 0);
in_flight_requests
.start_request(0, SystemTime::now())
.unwrap();
assert_eq!(in_flight_requests.len(), 1);
}
#[tokio::test]
async fn polling_expired_aborts() {
let mut in_flight_requests = InFlightRequests::default();
let abort_registration = in_flight_requests
.start_request(0, SystemTime::now())
.unwrap();
let mut abortable_future = Box::new(Abortable::new(pending::<()>(), abort_registration));
tokio::time::pause();
tokio::time::advance(std::time::Duration::from_secs(1000)).await;
assert_matches!(
in_flight_requests.poll_expired(&mut noop_context()),
Poll::Ready(Some(Ok(_)))
);
assert_matches!(
abortable_future.poll_unpin(&mut noop_context()),
Poll::Ready(Err(_))
);
assert_eq!(in_flight_requests.len(), 0);
}
#[tokio::test]
async fn cancel_request_aborts() {
let mut in_flight_requests = InFlightRequests::default();
let abort_registration = in_flight_requests
.start_request(0, SystemTime::now())
.unwrap();
let mut abortable_future = Box::new(Abortable::new(pending::<()>(), abort_registration));
assert_eq!(in_flight_requests.cancel_request(0), true);
assert_matches!(
abortable_future.poll_unpin(&mut noop_context()),
Poll::Ready(Err(_))
);
assert_eq!(in_flight_requests.len(), 0);
}
#[tokio::test]
async fn remove_request_doesnt_abort() {
let mut in_flight_requests = InFlightRequests::default();
let abort_registration = in_flight_requests
.start_request(0, SystemTime::now())
.unwrap();
let mut abortable_future = Box::new(Abortable::new(pending::<()>(), abort_registration));
assert_eq!(in_flight_requests.remove_request(0), true);
assert_matches!(
abortable_future.poll_unpin(&mut noop_context()),
Poll::Pending
);
assert_eq!(in_flight_requests.len(), 0);
#[tokio::test]
async fn start_request_increases_len() {
let mut in_flight_requests = InFlightRequests::default();
assert_eq!(in_flight_requests.len(), 0);
in_flight_requests
.start_request(0, SystemTime::now(), Span::current())
.unwrap();
assert_eq!(in_flight_requests.len(), 1);
}
#[tokio::test]
async fn polling_expired_aborts() {
let mut in_flight_requests = InFlightRequests::default();
let abort_registration = in_flight_requests
.start_request(0, SystemTime::now(), Span::current())
.unwrap();
let mut abortable_future = Box::new(Abortable::new(pending::<()>(), abort_registration));
tokio::time::pause();
tokio::time::advance(std::time::Duration::from_secs(1000)).await;
assert_matches!(
in_flight_requests.poll_expired(&mut noop_context()),
Poll::Ready(Some(_))
);
assert_matches!(
abortable_future.poll_unpin(&mut noop_context()),
Poll::Ready(Err(_))
);
assert_eq!(in_flight_requests.len(), 0);
}
#[tokio::test]
async fn cancel_request_aborts() {
let mut in_flight_requests = InFlightRequests::default();
let abort_registration = in_flight_requests
.start_request(0, SystemTime::now(), Span::current())
.unwrap();
let mut abortable_future = Box::new(Abortable::new(pending::<()>(), abort_registration));
assert!(in_flight_requests.cancel_request(0));
assert_matches!(
abortable_future.poll_unpin(&mut noop_context()),
Poll::Ready(Err(_))
);
assert_eq!(in_flight_requests.len(), 0);
}
#[tokio::test]
async fn remove_request_doesnt_abort() {
let mut in_flight_requests = InFlightRequests::default();
assert!(in_flight_requests.deadlines.is_empty());
let abort_registration = in_flight_requests
.start_request(
0,
SystemTime::now() + std::time::Duration::from_secs(10),
Span::current(),
)
.unwrap();
let mut abortable_future = Box::new(Abortable::new(pending::<()>(), abort_registration));
// Precondition: Pending expiration
assert_matches!(
in_flight_requests.poll_expired(&mut noop_context()),
Poll::Pending
);
assert!(!in_flight_requests.deadlines.is_empty());
assert_matches!(in_flight_requests.remove_request(0), Some(_));
// Postcondition: No pending expirations
assert!(in_flight_requests.deadlines.is_empty());
assert_matches!(
in_flight_requests.poll_expired(&mut noop_context()),
Poll::Ready(None)
);
assert_matches!(
abortable_future.poll_unpin(&mut noop_context()),
Poll::Pending
);
assert_eq!(in_flight_requests.len(), 0);
}
}

View File

@@ -0,0 +1,49 @@
use super::{
limits::{channels_per_key::MaxChannelsPerKey, requests_per_channel::MaxRequestsPerChannel},
Channel,
};
use futures::prelude::*;
use std::{fmt, hash::Hash};
#[cfg(feature = "tokio1")]
use super::{tokio::TokioServerExecutor, Serve};
/// An extension trait for [streams](futures::prelude::Stream) of [`Channels`](Channel).
pub trait Incoming<C>
where
Self: Sized + Stream<Item = C>,
C: Channel,
{
/// Enforces channel per-key limits.
fn max_channels_per_key<K, KF>(self, n: u32, keymaker: KF) -> MaxChannelsPerKey<Self, K, KF>
where
K: fmt::Display + Eq + Hash + Clone + Unpin,
KF: Fn(&C) -> K,
{
MaxChannelsPerKey::new(self, n, keymaker)
}
/// Caps the number of concurrent requests per channel.
fn max_concurrent_requests_per_channel(self, n: usize) -> MaxRequestsPerChannel<Self> {
MaxRequestsPerChannel::new(self, n)
}
/// [Executes](Channel::execute) each incoming channel. Each channel will be handled
/// concurrently by spawning on tokio's default executor, and each request will be also
/// be spawned on tokio's default executor.
#[cfg(feature = "tokio1")]
#[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))]
fn execute<S>(self, serve: S) -> TokioServerExecutor<Self, S>
where
S: Serve<C::Req, Resp = C::Resp>,
{
TokioServerExecutor::new(self, serve)
}
}
impl<S, C> Incoming<C> for S
where
S: Sized + Stream<Item = C>,
C: Channel,
{
}

View File

@@ -0,0 +1,5 @@
/// Provides functionality to limit the number of active channels.
pub mod channels_per_key;
/// Provides a [channel](crate::server::Channel) that limits the number of in-flight requests.
pub mod requests_per_channel;

View File

@@ -9,20 +9,23 @@ use crate::{
util::Compact,
};
use fnv::FnvHashMap;
use futures::{future::AbortRegistration, prelude::*, ready, stream::Fuse, task::*};
use log::{debug, info, trace};
use futures::{prelude::*, ready, stream::Fuse, task::*};
use pin_project::pin_project;
use std::sync::{Arc, Weak};
use std::{
collections::hash_map::Entry, convert::TryFrom, fmt, hash::Hash, marker::Unpin, pin::Pin,
time::SystemTime,
};
use tokio::sync::mpsc;
use tracing::{debug, info, trace};
/// A single-threaded filter that drops channels based on per-key limits.
/// An [`Incoming`](crate::server::incoming::Incoming) stream that drops new channels based on
/// per-key limits.
///
/// The decision to drop a Channel is made once at the time the Channel materializes. Once a
/// Channel is yielded, it will not be prematurely dropped.
#[pin_project]
#[derive(Debug)]
pub struct ChannelFilter<S, K, F>
pub struct MaxChannelsPerKey<S, K, F>
where
K: Eq + Hash,
{
@@ -35,7 +38,7 @@ where
keymaker: F,
}
/// A channel that is tracked by a ChannelFilter.
/// A channel that is tracked by [`MaxChannelsPerKey`].
#[pin_project]
#[derive(Debug)]
pub struct TrackedChannel<C, K> {
@@ -103,6 +106,7 @@ where
{
type Req = C::Req;
type Resp = C::Resp;
type Transport = C::Transport;
fn config(&self) -> &server::Config {
self.inner.config()
@@ -112,12 +116,8 @@ where
self.inner.in_flight_requests()
}
fn start_request(
mut self: Pin<&mut Self>,
id: u64,
deadline: SystemTime,
) -> Result<AbortRegistration, super::in_flight_requests::AlreadyExistsError> {
self.inner_pin_mut().start_request(id, deadline)
fn transport(&self) -> &Self::Transport {
self.inner.transport()
}
}
@@ -133,7 +133,7 @@ impl<C, K> TrackedChannel<C, K> {
}
}
impl<S, K, F> ChannelFilter<S, K, F>
impl<S, K, F> MaxChannelsPerKey<S, K, F>
where
K: Eq + Hash,
S: Stream,
@@ -142,7 +142,7 @@ where
/// Sheds new channels to stay under configured limits.
pub(crate) fn new(listener: S, channels_per_key: u32, keymaker: F) -> Self {
let (dropped_keys_tx, dropped_keys) = mpsc::unbounded_channel();
ChannelFilter {
MaxChannelsPerKey {
listener: listener.fuse(),
channels_per_key,
dropped_keys,
@@ -153,7 +153,7 @@ where
}
}
impl<S, K, F> ChannelFilter<S, K, F>
impl<S, K, F> MaxChannelsPerKey<S, K, F>
where
S: Stream,
K: fmt::Display + Eq + Hash + Clone + Unpin,
@@ -171,11 +171,10 @@ where
let tracker = self.as_mut().increment_channels_for_key(key.clone())?;
trace!(
"[{}] Opening channel ({}/{}) channels for key.",
key,
Arc::strong_count(&tracker),
self.channels_per_key
);
channel_filter_key = %key,
open_channels = Arc::strong_count(&tracker),
max_open_channels = self.channels_per_key,
"Opening channel");
Ok(TrackedChannel {
tracker,
@@ -200,9 +199,10 @@ where
let count = o.get().strong_count();
if count >= TryFrom::try_from(*self_.channels_per_key).unwrap() {
info!(
"[{}] Opened max channels from key ({}/{}).",
key, count, self_.channels_per_key
);
channel_filter_key = %key,
open_channels = count,
max_open_channels = *self_.channels_per_key,
"At open channel limit");
Err(key)
} else {
Ok(o.get().upgrade().unwrap_or_else(|| {
@@ -233,7 +233,9 @@ where
let self_ = self.project();
match ready!(self_.dropped_keys.poll_recv(cx)) {
Some(key) => {
debug!("All channels dropped for key [{}]", key);
debug!(
channel_filter_key = %key,
"All channels dropped");
self_.key_counts.remove(&key);
self_.key_counts.compact(0.1);
Poll::Ready(())
@@ -243,7 +245,7 @@ where
}
}
impl<S, K, F> Stream for ChannelFilter<S, K, F>
impl<S, K, F> Stream for MaxChannelsPerKey<S, K, F>
where
S: Stream,
K: fmt::Display + Eq + Hash + Clone + Unpin,
@@ -280,7 +282,7 @@ where
fn ctx() -> Context<'static> {
use futures::task::*;
Context::from_waker(&noop_waker_ref())
Context::from_waker(noop_waker_ref())
}
#[test]
@@ -346,7 +348,7 @@ fn channel_filter_increment_channels_for_key() {
key: &'static str,
}
let (_, listener) = futures::channel::mpsc::unbounded();
let filter = ChannelFilter::new(listener, 2, |chan: &TestChannel| chan.key);
let filter = MaxChannelsPerKey::new(listener, 2, |chan: &TestChannel| chan.key);
pin_mut!(filter);
let tracker1 = filter.as_mut().increment_channels_for_key("key").unwrap();
assert_eq!(Arc::strong_count(&tracker1), 1);
@@ -367,7 +369,7 @@ fn channel_filter_handle_new_channel() {
key: &'static str,
}
let (_, listener) = futures::channel::mpsc::unbounded();
let filter = ChannelFilter::new(listener, 2, |chan: &TestChannel| chan.key);
let filter = MaxChannelsPerKey::new(listener, 2, |chan: &TestChannel| chan.key);
pin_mut!(filter);
let channel1 = filter
.as_mut()
@@ -399,7 +401,7 @@ fn channel_filter_poll_listener() {
key: &'static str,
}
let (new_channels, listener) = futures::channel::mpsc::unbounded();
let filter = ChannelFilter::new(listener, 2, |chan: &TestChannel| chan.key);
let filter = MaxChannelsPerKey::new(listener, 2, |chan: &TestChannel| chan.key);
pin_mut!(filter);
new_channels
@@ -435,7 +437,7 @@ fn channel_filter_poll_closed_channels() {
key: &'static str,
}
let (new_channels, listener) = futures::channel::mpsc::unbounded();
let filter = ChannelFilter::new(listener, 2, |chan: &TestChannel| chan.key);
let filter = MaxChannelsPerKey::new(listener, 2, |chan: &TestChannel| chan.key);
pin_mut!(filter);
new_channels
@@ -463,7 +465,7 @@ fn channel_filter_stream() {
key: &'static str,
}
let (new_channels, listener) = futures::channel::mpsc::unbounded();
let filter = ChannelFilter::new(listener, 2, |chan: &TestChannel| chan.key);
let filter = MaxChannelsPerKey::new(listener, 2, |chan: &TestChannel| chan.key);
pin_mut!(filter);
new_channels

View File

@@ -0,0 +1,349 @@
// Copyright 2020 Google LLC
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
use crate::{
server::{Channel, Config},
Response, ServerError,
};
use futures::{prelude::*, ready, task::*};
use pin_project::pin_project;
use std::{io, pin::Pin};
/// A [`Channel`] that limits the number of concurrent requests by throttling.
///
/// Note that this is a very basic throttling heuristic. It is easy to set a number that is too low
/// for the resources available to the server. For production use cases, a more advanced throttler
/// is likely needed.
#[pin_project]
#[derive(Debug)]
pub struct MaxRequests<C> {
max_in_flight_requests: usize,
#[pin]
inner: C,
}
impl<C> MaxRequests<C> {
/// Returns the inner channel.
pub fn get_ref(&self) -> &C {
&self.inner
}
}
impl<C> MaxRequests<C>
where
C: Channel,
{
/// Returns a new `MaxRequests` that wraps the given channel and limits concurrent requests to
/// `max_in_flight_requests`.
pub fn new(inner: C, max_in_flight_requests: usize) -> Self {
MaxRequests {
max_in_flight_requests,
inner,
}
}
}
impl<C> Stream for MaxRequests<C>
where
C: Channel,
{
type Item = <C as Stream>::Item;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
while self.as_mut().in_flight_requests() >= *self.as_mut().project().max_in_flight_requests
{
ready!(self.as_mut().project().inner.poll_ready(cx)?);
match ready!(self.as_mut().project().inner.poll_next(cx)?) {
Some(r) => {
let _entered = r.span.enter();
tracing::info!(
in_flight_requests = self.as_mut().in_flight_requests(),
"ThrottleRequest",
);
self.as_mut().start_send(Response {
request_id: r.request.id,
message: Err(ServerError {
kind: io::ErrorKind::WouldBlock,
detail: "server throttled the request.".into(),
}),
})?;
}
None => return Poll::Ready(None),
}
}
self.project().inner.poll_next(cx)
}
}
impl<C> Sink<Response<<C as Channel>::Resp>> for MaxRequests<C>
where
C: Channel,
{
type Error = C::Error;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.project().inner.poll_ready(cx)
}
fn start_send(
self: Pin<&mut Self>,
item: Response<<C as Channel>::Resp>,
) -> Result<(), Self::Error> {
self.project().inner.start_send(item)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.project().inner.poll_flush(cx)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.project().inner.poll_close(cx)
}
}
impl<C> AsRef<C> for MaxRequests<C> {
fn as_ref(&self) -> &C {
&self.inner
}
}
impl<C> Channel for MaxRequests<C>
where
C: Channel,
{
type Req = <C as Channel>::Req;
type Resp = <C as Channel>::Resp;
type Transport = <C as Channel>::Transport;
fn in_flight_requests(&self) -> usize {
self.inner.in_flight_requests()
}
fn config(&self) -> &Config {
self.inner.config()
}
fn transport(&self) -> &Self::Transport {
self.inner.transport()
}
}
/// An [`Incoming`](crate::server::incoming::Incoming) stream of channels that enforce limits on
/// the number of in-flight requests.
#[pin_project]
#[derive(Debug)]
pub struct MaxRequestsPerChannel<S> {
#[pin]
inner: S,
max_in_flight_requests: usize,
}
impl<S> MaxRequestsPerChannel<S>
where
S: Stream,
<S as Stream>::Item: Channel,
{
pub(crate) fn new(inner: S, max_in_flight_requests: usize) -> Self {
Self {
inner,
max_in_flight_requests,
}
}
}
impl<S> Stream for MaxRequestsPerChannel<S>
where
S: Stream,
<S as Stream>::Item: Channel,
{
type Item = MaxRequests<<S as Stream>::Item>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
match ready!(self.as_mut().project().inner.poll_next(cx)) {
Some(channel) => Poll::Ready(Some(MaxRequests::new(
channel,
*self.project().max_in_flight_requests,
))),
None => Poll::Ready(None),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::server::{
testing::{self, FakeChannel, PollExt},
TrackedRequest,
};
use pin_utils::pin_mut;
use std::{
marker::PhantomData,
time::{Duration, SystemTime},
};
use tracing::Span;
#[tokio::test]
async fn throttler_in_flight_requests() {
let throttler = MaxRequests {
max_in_flight_requests: 0,
inner: FakeChannel::default::<isize, isize>(),
};
pin_mut!(throttler);
for i in 0..5 {
throttler
.inner
.in_flight_requests
.start_request(
i,
SystemTime::now() + Duration::from_secs(1),
Span::current(),
)
.unwrap();
}
assert_eq!(throttler.as_mut().in_flight_requests(), 5);
}
#[test]
fn throttler_poll_next_done() {
let throttler = MaxRequests {
max_in_flight_requests: 0,
inner: FakeChannel::default::<isize, isize>(),
};
pin_mut!(throttler);
assert!(throttler.as_mut().poll_next(&mut testing::cx()).is_done());
}
#[test]
fn throttler_poll_next_some() -> io::Result<()> {
let throttler = MaxRequests {
max_in_flight_requests: 1,
inner: FakeChannel::default::<isize, isize>(),
};
pin_mut!(throttler);
throttler.inner.push_req(0, 1);
assert!(throttler.as_mut().poll_ready(&mut testing::cx()).is_ready());
assert_eq!(
throttler
.as_mut()
.poll_next(&mut testing::cx())?
.map(|r| r.map(|r| (r.request.id, r.request.message))),
Poll::Ready(Some((0, 1)))
);
Ok(())
}
#[test]
fn throttler_poll_next_throttled() {
let throttler = MaxRequests {
max_in_flight_requests: 0,
inner: FakeChannel::default::<isize, isize>(),
};
pin_mut!(throttler);
throttler.inner.push_req(1, 1);
assert!(throttler.as_mut().poll_next(&mut testing::cx()).is_done());
assert_eq!(throttler.inner.sink.len(), 1);
let resp = throttler.inner.sink.get(0).unwrap();
assert_eq!(resp.request_id, 1);
assert!(resp.message.is_err());
}
#[test]
fn throttler_poll_next_throttled_sink_not_ready() {
let throttler = MaxRequests {
max_in_flight_requests: 0,
inner: PendingSink::default::<isize, isize>(),
};
pin_mut!(throttler);
assert!(throttler.poll_next(&mut testing::cx()).is_pending());
struct PendingSink<In, Out> {
ghost: PhantomData<fn(Out) -> In>,
}
impl PendingSink<(), ()> {
pub fn default<Req, Resp>(
) -> PendingSink<io::Result<TrackedRequest<Req>>, Response<Resp>> {
PendingSink { ghost: PhantomData }
}
}
impl<In, Out> Stream for PendingSink<In, Out> {
type Item = In;
fn poll_next(self: Pin<&mut Self>, _: &mut Context) -> Poll<Option<Self::Item>> {
unimplemented!()
}
}
impl<In, Out> Sink<Out> for PendingSink<In, Out> {
type Error = io::Error;
fn poll_ready(self: Pin<&mut Self>, _: &mut Context) -> Poll<Result<(), Self::Error>> {
Poll::Pending
}
fn start_send(self: Pin<&mut Self>, _: Out) -> Result<(), Self::Error> {
Err(io::Error::from(io::ErrorKind::WouldBlock))
}
fn poll_flush(self: Pin<&mut Self>, _: &mut Context) -> Poll<Result<(), Self::Error>> {
Poll::Pending
}
fn poll_close(self: Pin<&mut Self>, _: &mut Context) -> Poll<Result<(), Self::Error>> {
Poll::Pending
}
}
impl<Req, Resp> Channel for PendingSink<io::Result<TrackedRequest<Req>>, Response<Resp>> {
type Req = Req;
type Resp = Resp;
type Transport = ();
fn config(&self) -> &Config {
unimplemented!()
}
fn in_flight_requests(&self) -> usize {
0
}
fn transport(&self) -> &() {
&()
}
}
}
#[tokio::test]
async fn throttler_start_send() {
let throttler = MaxRequests {
max_in_flight_requests: 0,
inner: FakeChannel::default::<isize, isize>(),
};
pin_mut!(throttler);
throttler
.inner
.in_flight_requests
.start_request(
0,
SystemTime::now() + Duration::from_secs(1),
Span::current(),
)
.unwrap();
throttler
.as_mut()
.start_send(Response {
request_id: 0,
message: Ok(1),
})
.unwrap();
assert_eq!(throttler.inner.in_flight_requests.len(), 0);
assert_eq!(
throttler.inner.sink.get(0),
Some(&Response {
request_id: 0,
message: Ok(1),
})
);
}
}

View File

@@ -5,16 +5,15 @@
// https://opensource.org/licenses/MIT.
use crate::{
cancellations::{cancellations, CanceledRequests, RequestCancellation},
context,
server::{Channel, Config},
server::{Channel, Config, ResponseGuard, TrackedRequest},
Request, Response,
};
use futures::{future::AbortRegistration, task::*, Sink, Stream};
use futures::{task::*, Sink, Stream};
use pin_project::pin_project;
use std::collections::VecDeque;
use std::io;
use std::pin::Pin;
use std::time::SystemTime;
use std::{collections::VecDeque, io, pin::Pin, time::SystemTime};
use tracing::Span;
#[pin_project]
pub(crate) struct FakeChannel<In, Out> {
@@ -24,6 +23,8 @@ pub(crate) struct FakeChannel<In, Out> {
pub sink: VecDeque<Out>,
pub config: Config,
pub in_flight_requests: super::in_flight_requests::InFlightRequests,
pub request_cancellation: RequestCancellation,
pub canceled_requests: CanceledRequests,
}
impl<In, Out> Stream for FakeChannel<In, Out>
@@ -64,12 +65,13 @@ impl<In, Resp> Sink<Response<Resp>> for FakeChannel<In, Response<Resp>> {
}
}
impl<Req, Resp> Channel for FakeChannel<io::Result<Request<Req>>, Response<Resp>>
impl<Req, Resp> Channel for FakeChannel<io::Result<TrackedRequest<Req>>, Response<Resp>>
where
Req: Unpin,
{
type Req = Req;
type Resp = Resp;
type Transport = ();
fn config(&self) -> &Config {
&self.config
@@ -79,37 +81,45 @@ where
self.in_flight_requests.len()
}
fn start_request(
self: Pin<&mut Self>,
id: u64,
deadline: SystemTime,
) -> Result<AbortRegistration, super::in_flight_requests::AlreadyExistsError> {
self.project()
.in_flight_requests
.start_request(id, deadline)
fn transport(&self) -> &() {
&()
}
}
impl<Req, Resp> FakeChannel<io::Result<Request<Req>>, Response<Resp>> {
impl<Req, Resp> FakeChannel<io::Result<TrackedRequest<Req>>, Response<Resp>> {
pub fn push_req(&mut self, id: u64, message: Req) {
self.stream.push_back(Ok(Request {
context: context::Context {
deadline: SystemTime::UNIX_EPOCH,
trace_context: Default::default(),
let (_, abort_registration) = futures::future::AbortHandle::new_pair();
let (request_cancellation, _) = cancellations();
self.stream.push_back(Ok(TrackedRequest {
request: Request {
context: context::Context {
deadline: SystemTime::UNIX_EPOCH,
trace_context: Default::default(),
},
id,
message,
},
abort_registration,
span: Span::none(),
response_guard: ResponseGuard {
request_cancellation,
request_id: id,
cancel: false,
},
id,
message,
}));
}
}
impl FakeChannel<(), ()> {
pub fn default<Req, Resp>() -> FakeChannel<io::Result<Request<Req>>, Response<Resp>> {
pub fn default<Req, Resp>() -> FakeChannel<io::Result<TrackedRequest<Req>>, Response<Resp>> {
let (request_cancellation, canceled_requests) = cancellations();
FakeChannel {
stream: Default::default(),
sink: Default::default(),
config: Default::default(),
in_flight_requests: Default::default(),
request_cancellation,
canceled_requests,
}
}
}
@@ -125,5 +135,5 @@ impl<T> PollExt for Poll<Option<T>> {
}
pub fn cx() -> Context<'static> {
Context::from_waker(&noop_waker_ref())
Context::from_waker(noop_waker_ref())
}

View File

@@ -1,347 +0,0 @@
// Copyright 2020 Google LLC
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
use super::{Channel, Config};
use crate::{Response, ServerError};
use futures::{future::AbortRegistration, prelude::*, ready, task::*};
use log::debug;
use pin_project::pin_project;
use std::{io, pin::Pin, time::SystemTime};
/// A [`Channel`] that limits the number of concurrent
/// requests by throttling.
#[pin_project]
#[derive(Debug)]
pub struct Throttler<C> {
max_in_flight_requests: usize,
#[pin]
inner: C,
}
impl<C> Throttler<C> {
/// Returns the inner channel.
pub fn get_ref(&self) -> &C {
&self.inner
}
}
impl<C> Throttler<C>
where
C: Channel,
{
/// Returns a new `Throttler` that wraps the given channel and limits concurrent requests to
/// `max_in_flight_requests`.
pub fn new(inner: C, max_in_flight_requests: usize) -> Self {
Throttler {
max_in_flight_requests,
inner,
}
}
}
impl<C> Stream for Throttler<C>
where
C: Channel,
{
type Item = <C as Stream>::Item;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
while self.as_mut().in_flight_requests() >= *self.as_mut().project().max_in_flight_requests
{
ready!(self.as_mut().project().inner.poll_ready(cx)?);
match ready!(self.as_mut().project().inner.poll_next(cx)?) {
Some(request) => {
debug!(
"[{}] Client has reached in-flight request limit ({}/{}).",
request.context.trace_id(),
self.as_mut().in_flight_requests(),
self.as_mut().project().max_in_flight_requests,
);
self.as_mut().start_send(Response {
request_id: request.id,
message: Err(ServerError {
kind: io::ErrorKind::WouldBlock,
detail: Some("Server throttled the request.".into()),
}),
})?;
}
None => return Poll::Ready(None),
}
}
self.project().inner.poll_next(cx)
}
}
impl<C> Sink<Response<<C as Channel>::Resp>> for Throttler<C>
where
C: Channel,
{
type Error = io::Error;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.project().inner.poll_ready(cx)
}
fn start_send(self: Pin<&mut Self>, item: Response<<C as Channel>::Resp>) -> io::Result<()> {
self.project().inner.start_send(item)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
self.project().inner.poll_flush(cx)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
self.project().inner.poll_close(cx)
}
}
impl<C> AsRef<C> for Throttler<C> {
fn as_ref(&self) -> &C {
&self.inner
}
}
impl<C> Channel for Throttler<C>
where
C: Channel,
{
type Req = <C as Channel>::Req;
type Resp = <C as Channel>::Resp;
fn in_flight_requests(&self) -> usize {
self.inner.in_flight_requests()
}
fn config(&self) -> &Config {
self.inner.config()
}
fn start_request(
self: Pin<&mut Self>,
id: u64,
deadline: SystemTime,
) -> Result<AbortRegistration, super::in_flight_requests::AlreadyExistsError> {
self.project().inner.start_request(id, deadline)
}
}
/// A stream of throttling channels.
#[pin_project]
#[derive(Debug)]
pub struct ThrottlerStream<S> {
#[pin]
inner: S,
max_in_flight_requests: usize,
}
impl<S> ThrottlerStream<S>
where
S: Stream,
<S as Stream>::Item: Channel,
{
pub(crate) fn new(inner: S, max_in_flight_requests: usize) -> Self {
Self {
inner,
max_in_flight_requests,
}
}
}
impl<S> Stream for ThrottlerStream<S>
where
S: Stream,
<S as Stream>::Item: Channel,
{
type Item = Throttler<<S as Stream>::Item>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
match ready!(self.as_mut().project().inner.poll_next(cx)) {
Some(channel) => Poll::Ready(Some(Throttler::new(
channel,
*self.project().max_in_flight_requests,
))),
None => Poll::Ready(None),
}
}
}
#[cfg(test)]
use super::testing::{self, FakeChannel, PollExt};
#[cfg(test)]
use crate::Request;
#[cfg(test)]
use pin_utils::pin_mut;
#[cfg(test)]
use std::{marker::PhantomData, time::Duration};
#[tokio::test]
async fn throttler_in_flight_requests() {
let throttler = Throttler {
max_in_flight_requests: 0,
inner: FakeChannel::default::<isize, isize>(),
};
pin_mut!(throttler);
for i in 0..5 {
throttler
.inner
.in_flight_requests
.start_request(i, SystemTime::now() + Duration::from_secs(1))
.unwrap();
}
assert_eq!(throttler.as_mut().in_flight_requests(), 5);
}
#[tokio::test]
async fn throttler_start_request() {
let throttler = Throttler {
max_in_flight_requests: 0,
inner: FakeChannel::default::<isize, isize>(),
};
pin_mut!(throttler);
throttler
.as_mut()
.start_request(1, SystemTime::now() + Duration::from_secs(1))
.unwrap();
assert_eq!(throttler.inner.in_flight_requests.len(), 1);
}
#[test]
fn throttler_poll_next_done() {
let throttler = Throttler {
max_in_flight_requests: 0,
inner: FakeChannel::default::<isize, isize>(),
};
pin_mut!(throttler);
assert!(throttler.as_mut().poll_next(&mut testing::cx()).is_done());
}
#[test]
fn throttler_poll_next_some() -> io::Result<()> {
let throttler = Throttler {
max_in_flight_requests: 1,
inner: FakeChannel::default::<isize, isize>(),
};
pin_mut!(throttler);
throttler.inner.push_req(0, 1);
assert!(throttler.as_mut().poll_ready(&mut testing::cx()).is_ready());
assert_eq!(
throttler
.as_mut()
.poll_next(&mut testing::cx())?
.map(|r| r.map(|r| (r.id, r.message))),
Poll::Ready(Some((0, 1)))
);
Ok(())
}
#[test]
fn throttler_poll_next_throttled() {
let throttler = Throttler {
max_in_flight_requests: 0,
inner: FakeChannel::default::<isize, isize>(),
};
pin_mut!(throttler);
throttler.inner.push_req(1, 1);
assert!(throttler.as_mut().poll_next(&mut testing::cx()).is_done());
assert_eq!(throttler.inner.sink.len(), 1);
let resp = throttler.inner.sink.get(0).unwrap();
assert_eq!(resp.request_id, 1);
assert!(resp.message.is_err());
}
#[test]
fn throttler_poll_next_throttled_sink_not_ready() {
let throttler = Throttler {
max_in_flight_requests: 0,
inner: PendingSink::default::<isize, isize>(),
};
pin_mut!(throttler);
assert!(throttler.poll_next(&mut testing::cx()).is_pending());
struct PendingSink<In, Out> {
ghost: PhantomData<fn(Out) -> In>,
}
impl PendingSink<(), ()> {
pub fn default<Req, Resp>() -> PendingSink<io::Result<Request<Req>>, Response<Resp>> {
PendingSink { ghost: PhantomData }
}
}
impl<In, Out> Stream for PendingSink<In, Out> {
type Item = In;
fn poll_next(self: Pin<&mut Self>, _: &mut Context) -> Poll<Option<Self::Item>> {
unimplemented!()
}
}
impl<In, Out> Sink<Out> for PendingSink<In, Out> {
type Error = io::Error;
fn poll_ready(self: Pin<&mut Self>, _: &mut Context) -> Poll<Result<(), Self::Error>> {
Poll::Pending
}
fn start_send(self: Pin<&mut Self>, _: Out) -> Result<(), Self::Error> {
Err(io::Error::from(io::ErrorKind::WouldBlock))
}
fn poll_flush(self: Pin<&mut Self>, _: &mut Context) -> Poll<Result<(), Self::Error>> {
Poll::Pending
}
fn poll_close(self: Pin<&mut Self>, _: &mut Context) -> Poll<Result<(), Self::Error>> {
Poll::Pending
}
}
impl<Req, Resp> Channel for PendingSink<io::Result<Request<Req>>, Response<Resp>> {
type Req = Req;
type Resp = Resp;
fn config(&self) -> &Config {
unimplemented!()
}
fn in_flight_requests(&self) -> usize {
0
}
fn start_request(
self: Pin<&mut Self>,
_id: u64,
_deadline: SystemTime,
) -> Result<AbortRegistration, super::in_flight_requests::AlreadyExistsError> {
unimplemented!()
}
}
}
#[tokio::test]
async fn throttler_start_send() {
let throttler = Throttler {
max_in_flight_requests: 0,
inner: FakeChannel::default::<isize, isize>(),
};
pin_mut!(throttler);
throttler
.inner
.in_flight_requests
.start_request(0, SystemTime::now() + Duration::from_secs(1))
.unwrap();
throttler
.as_mut()
.start_send(Response {
request_id: 0,
message: Ok(1),
})
.unwrap();
assert_eq!(throttler.inner.in_flight_requests.len(), 0);
assert_eq!(
throttler.inner.sink.get(0),
Some(&Response {
request_id: 0,
message: Ok(1),
})
);
}

113
tarpc/src/server/tokio.rs Normal file
View File

@@ -0,0 +1,113 @@
use super::{Channel, Requests, Serve};
use futures::{prelude::*, ready, task::*};
use pin_project::pin_project;
use std::pin::Pin;
/// A future that drives the server by [spawning](tokio::spawn) a [`TokioChannelExecutor`](TokioChannelExecutor)
/// for each new channel. Returned by
/// [`Incoming::execute`](crate::server::incoming::Incoming::execute).
#[must_use]
#[pin_project]
#[derive(Debug)]
pub struct TokioServerExecutor<T, S> {
#[pin]
inner: T,
serve: S,
}
impl<T, S> TokioServerExecutor<T, S> {
pub(crate) fn new(inner: T, serve: S) -> Self {
Self { inner, serve }
}
}
/// A future that drives the server by [spawning](tokio::spawn) each [response
/// handler](super::InFlightRequest::execute) on tokio's default executor. Returned by
/// [`Channel::execute`](crate::server::Channel::execute).
#[must_use]
#[pin_project]
#[derive(Debug)]
pub struct TokioChannelExecutor<T, S> {
#[pin]
inner: T,
serve: S,
}
impl<T, S> TokioServerExecutor<T, S> {
fn inner_pin_mut<'a>(self: &'a mut Pin<&mut Self>) -> Pin<&'a mut T> {
self.as_mut().project().inner
}
}
impl<T, S> TokioChannelExecutor<T, S> {
fn inner_pin_mut<'a>(self: &'a mut Pin<&mut Self>) -> Pin<&'a mut T> {
self.as_mut().project().inner
}
}
// Send + 'static execution helper methods.
impl<C> Requests<C>
where
C: Channel,
C::Req: Send + 'static,
C::Resp: Send + 'static,
{
/// Executes all requests using the given service function. Requests are handled concurrently
/// by [spawning](::tokio::spawn) each handler on tokio's default executor.
pub fn execute<S>(self, serve: S) -> TokioChannelExecutor<Self, S>
where
S: Serve<C::Req, Resp = C::Resp> + Send + 'static,
{
TokioChannelExecutor { inner: self, serve }
}
}
impl<St, C, Se> Future for TokioServerExecutor<St, Se>
where
St: Sized + Stream<Item = C>,
C: Channel + Send + 'static,
C::Req: Send + 'static,
C::Resp: Send + 'static,
Se: Serve<C::Req, Resp = C::Resp> + Send + 'static + Clone,
Se::Fut: Send,
{
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
while let Some(channel) = ready!(self.inner_pin_mut().poll_next(cx)) {
tokio::spawn(channel.execute(self.serve.clone()));
}
tracing::info!("Server shutting down.");
Poll::Ready(())
}
}
impl<C, S> Future for TokioChannelExecutor<Requests<C>, S>
where
C: Channel + 'static,
C::Req: Send + 'static,
C::Resp: Send + 'static,
S: Serve<C::Req, Resp = C::Resp> + Send + 'static + Clone,
S::Fut: Send,
{
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
while let Some(response_handler) = ready!(self.inner_pin_mut().poll_next(cx)) {
match response_handler {
Ok(resp) => {
let server = self.serve.clone();
tokio::spawn(async move {
resp.execute(server).await;
});
}
Err(e) => {
tracing::warn!("Requests stream errored out: {}", e);
break;
}
}
}
Poll::Ready(())
}
}

View File

@@ -16,11 +16,14 @@
//! This crate's design is based on [opencensus
//! tracing](https://opencensus.io/core-concepts/tracing/).
use opentelemetry::trace::TraceContextExt;
use rand::Rng;
use std::{
convert::TryFrom,
fmt::{self, Formatter},
mem,
num::{NonZeroU128, NonZeroU64},
};
use tracing_opentelemetry::OpenTelemetrySpanExt;
/// A context for tracing the execution of processes, distributed or otherwise.
///
@@ -36,33 +39,50 @@ pub struct Context {
/// before making an RPC, and the span ID is sent to the server. The server is free to create
/// its own spans, for which it sets the client's span as the parent span.
pub span_id: SpanId,
/// An identifier of the span that originated the current span. For example, if a server sends
/// an RPC in response to a client request that included a span, the server would create a span
/// for the RPC and set its parent to the span_id in the incoming request's context.
///
/// If `parent_id` is `None`, then this is a root context.
pub parent_id: Option<SpanId>,
/// Indicates whether a sampler has already decided whether or not to sample the trace
/// associated with the Context. If `sampling_decision` is None, then a decision has not yet
/// been made. Downstream samplers do not need to abide by "no sample" decisions--for example,
/// an upstream client may choose to never sample, which may not make sense for the client's
/// dependencies. On the other hand, if an upstream process has chosen to sample this trace,
/// then the downstream samplers are expected to respect that decision and also sample the
/// trace. Otherwise, the full trace would not be able to be reconstructed.
pub sampling_decision: SamplingDecision,
}
/// A 128-bit UUID identifying a trace. All spans caused by the same originating span share the
/// same trace ID.
#[derive(Debug, Default, PartialEq, Eq, Hash, Clone, Copy)]
#[derive(Default, PartialEq, Eq, Hash, Clone, Copy)]
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
pub struct TraceId(u128);
pub struct TraceId(#[cfg_attr(feature = "serde1", serde(with = "u128_serde"))] u128);
/// A 64-bit identifier of a span within a trace. The identifier is unique within the span's trace.
#[derive(Debug, Default, PartialEq, Eq, Hash, Clone, Copy)]
#[derive(Default, PartialEq, Eq, Hash, Clone, Copy)]
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
pub struct SpanId(u64);
/// Indicates whether a sampler has decided whether or not to sample the trace associated with the
/// Context. Downstream samplers do not need to abide by "no sample" decisions--for example, an
/// upstream client may choose to never sample, which may not make sense for the client's
/// dependencies. On the other hand, if an upstream process has chosen to sample this trace, then
/// the downstream samplers are expected to respect that decision and also sample the trace.
/// Otherwise, the full trace would not be able to be reconstructed reliably.
#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)]
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
#[repr(u8)]
pub enum SamplingDecision {
/// The associated span was sampled by its creating process. Child spans must also be sampled.
Sampled,
/// The associated span was not sampled by its creating process.
Unsampled,
}
impl Context {
/// Constructs a new root context. A root context is one with no parent span.
pub fn new_root() -> Self {
let rng = &mut rand::thread_rng();
Context {
trace_id: TraceId::random(rng),
span_id: SpanId::random(rng),
parent_id: None,
/// Constructs a new context with the trace ID and sampling decision inherited from the parent.
pub(crate) fn new_child(&self) -> Self {
Self {
trace_id: self.trace_id,
span_id: SpanId::random(&mut rand::thread_rng()),
sampling_decision: self.sampling_decision,
}
}
}
@@ -71,17 +91,128 @@ impl TraceId {
/// Returns a random trace ID that can be assumed to be globally unique if `rng` generates
/// actually-random numbers.
pub fn random<R: Rng>(rng: &mut R) -> Self {
TraceId(u128::from(rng.next_u64()) << mem::size_of::<u64>() | u128::from(rng.next_u64()))
TraceId(rng.gen::<NonZeroU128>().get())
}
/// Returns true iff the trace ID is 0.
pub fn is_none(&self) -> bool {
self.0 == 0
}
}
impl SpanId {
/// Returns a random span ID that can be assumed to be unique within a single trace.
pub fn random<R: Rng>(rng: &mut R) -> Self {
SpanId(rng.next_u64())
SpanId(rng.gen::<NonZeroU64>().get())
}
/// Returns true iff the span ID is 0.
pub fn is_none(&self) -> bool {
self.0 == 0
}
}
impl From<TraceId> for u128 {
fn from(trace_id: TraceId) -> Self {
trace_id.0
}
}
impl From<u128> for TraceId {
fn from(trace_id: u128) -> Self {
Self(trace_id)
}
}
impl From<SpanId> for u64 {
fn from(span_id: SpanId) -> Self {
span_id.0
}
}
impl From<u64> for SpanId {
fn from(span_id: u64) -> Self {
Self(span_id)
}
}
impl From<opentelemetry::trace::TraceId> for TraceId {
fn from(trace_id: opentelemetry::trace::TraceId) -> Self {
Self::from(u128::from_be_bytes(trace_id.to_bytes()))
}
}
impl From<TraceId> for opentelemetry::trace::TraceId {
fn from(trace_id: TraceId) -> Self {
Self::from_bytes(u128::from(trace_id).to_be_bytes())
}
}
impl From<opentelemetry::trace::SpanId> for SpanId {
fn from(span_id: opentelemetry::trace::SpanId) -> Self {
Self::from(u64::from_be_bytes(span_id.to_bytes()))
}
}
impl From<SpanId> for opentelemetry::trace::SpanId {
fn from(span_id: SpanId) -> Self {
Self::from_bytes(u64::from(span_id).to_be_bytes())
}
}
impl TryFrom<&tracing::Span> for Context {
type Error = NoActiveSpan;
fn try_from(span: &tracing::Span) -> Result<Self, NoActiveSpan> {
let context = span.context();
if context.has_active_span() {
Ok(Self::from(context.span()))
} else {
Err(NoActiveSpan)
}
}
}
impl From<opentelemetry::trace::SpanRef<'_>> for Context {
fn from(span: opentelemetry::trace::SpanRef<'_>) -> Self {
let otel_ctx = span.span_context();
Self {
trace_id: TraceId::from(otel_ctx.trace_id()),
span_id: SpanId::from(otel_ctx.span_id()),
sampling_decision: SamplingDecision::from(otel_ctx),
}
}
}
impl From<SamplingDecision> for opentelemetry::trace::TraceFlags {
fn from(decision: SamplingDecision) -> Self {
match decision {
SamplingDecision::Sampled => opentelemetry::trace::TraceFlags::SAMPLED,
SamplingDecision::Unsampled => opentelemetry::trace::TraceFlags::default(),
}
}
}
impl From<&opentelemetry::trace::SpanContext> for SamplingDecision {
fn from(context: &opentelemetry::trace::SpanContext) -> Self {
if context.is_sampled() {
SamplingDecision::Sampled
} else {
SamplingDecision::Unsampled
}
}
}
impl Default for SamplingDecision {
fn default() -> Self {
Self::Unsampled
}
}
/// Returned when a [`Context`] cannot be constructed from a [`Span`](tracing::Span).
#[derive(Debug)]
pub struct NoActiveSpan;
impl fmt::Display for TraceId {
fn fmt(&self, f: &mut Formatter) -> Result<(), fmt::Error> {
write!(f, "{:02x}", self.0)?;
@@ -89,9 +220,42 @@ impl fmt::Display for TraceId {
}
}
impl fmt::Debug for TraceId {
fn fmt(&self, f: &mut Formatter) -> Result<(), fmt::Error> {
write!(f, "{:02x}", self.0)?;
Ok(())
}
}
impl fmt::Display for SpanId {
fn fmt(&self, f: &mut Formatter) -> Result<(), fmt::Error> {
write!(f, "{:02x}", self.0)?;
Ok(())
}
}
impl fmt::Debug for SpanId {
fn fmt(&self, f: &mut Formatter) -> Result<(), fmt::Error> {
write!(f, "{:02x}", self.0)?;
Ok(())
}
}
#[cfg(feature = "serde1")]
mod u128_serde {
pub fn serialize<S>(u: &u128, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serde::Serialize::serialize(&u.to_le_bytes(), serializer)
}
pub fn deserialize<'de, D>(deserializer: D) -> Result<u128, D::Error>
where
D: serde::Deserializer<'de>,
{
Ok(u128::from_le_bytes(serde::Deserialize::deserialize(
deserializer,
)?))
}
}

View File

@@ -9,22 +9,32 @@
//! The rpc crate is transport- and protocol-agnostic. Any transport that impls [`Transport`](sealed::Transport)
//! can be plugged in, using whatever protocol it wants.
use futures::prelude::*;
use std::io;
pub mod channel;
pub(crate) mod sealed {
use super::*;
use futures::prelude::*;
use std::error::Error;
/// A bidirectional stream ([`Sink`] + [`Stream`]) of messages.
pub trait Transport<SinkItem, Item>:
Stream<Item = io::Result<Item>> + Sink<SinkItem, Error = io::Error>
pub trait Transport<SinkItem, Item>
where
Self: Stream<Item = Result<Item, <Self as Sink<SinkItem>>::Error>>,
Self: Sink<SinkItem, Error = <Self as Transport<SinkItem, Item>>::TransportError>,
<Self as Sink<SinkItem>>::Error: Error,
{
/// Associated type where clauses are not elaborated; this associated type allows users
/// bounding types by Transport to avoid having to explicitly add `T::Error: Error` to their
/// bounds.
type TransportError: Error + Send + Sync + 'static;
}
impl<T, SinkItem, Item> Transport<SinkItem, Item> for T where
T: Stream<Item = io::Result<Item>> + Sink<SinkItem, Error = io::Error> + ?Sized
impl<T, SinkItem, Item, E> Transport<SinkItem, Item> for T
where
T: ?Sized,
T: Stream<Item = Result<Item, E>>,
T: Sink<SinkItem, Error = E>,
T::Error: Error + Send + Sync + 'static,
{
type TransportError = E;
}
}

View File

@@ -6,13 +6,19 @@
//! Transports backed by in-memory channels.
use crate::PollIo;
use futures::{task::*, Sink, Stream};
use pin_project::pin_project;
use std::io;
use std::pin::Pin;
use std::{error::Error, pin::Pin};
use tokio::sync::mpsc;
/// Errors that occur in the sending or receiving of messages over a channel.
#[derive(thiserror::Error, Debug)]
pub enum ChannelError {
/// An error occurred sending over the channel.
#[error("an error occurred sending over the channel")]
Send(#[source] Box<dyn Error + Send + Sync + 'static>),
}
/// Returns two unbounded channel peers. Each [`Stream`] yields items sent through the other's
/// [`Sink`].
pub fn unbounded<SinkItem, Item>() -> (
@@ -36,28 +42,33 @@ pub struct UnboundedChannel<Item, SinkItem> {
}
impl<Item, SinkItem> Stream for UnboundedChannel<Item, SinkItem> {
type Item = Result<Item, io::Error>;
type Item = Result<Item, ChannelError>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> PollIo<Item> {
fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Item, ChannelError>>> {
self.rx.poll_recv(cx).map(|option| option.map(Ok))
}
}
impl<Item, SinkItem> Sink<SinkItem> for UnboundedChannel<Item, SinkItem> {
type Error = io::Error;
const CLOSED_MESSAGE: &str = "the channel is closed and cannot accept new items for sending";
fn poll_ready(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
impl<Item, SinkItem> Sink<SinkItem> for UnboundedChannel<Item, SinkItem> {
type Error = ChannelError;
fn poll_ready(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(if self.tx.is_closed() {
Err(io::Error::from(io::ErrorKind::NotConnected))
Err(ChannelError::Send(CLOSED_MESSAGE.into()))
} else {
Ok(())
})
}
fn start_send(self: Pin<&mut Self>, item: SinkItem) -> io::Result<()> {
fn start_send(self: Pin<&mut Self>, item: SinkItem) -> Result<(), Self::Error> {
self.tx
.send(item)
.map_err(|_| io::Error::from(io::ErrorKind::NotConnected))
.map_err(|_| ChannelError::Send(CLOSED_MESSAGE.into()))
}
fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
@@ -65,7 +76,7 @@ impl<Item, SinkItem> Sink<SinkItem> for UnboundedChannel<Item, SinkItem> {
Poll::Ready(Ok(()))
}
fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
// UnboundedSender can't initiate closure.
Poll::Ready(Ok(()))
}
@@ -93,52 +104,45 @@ pub struct Channel<Item, SinkItem> {
}
impl<Item, SinkItem> Stream for Channel<Item, SinkItem> {
type Item = Result<Item, io::Error>;
type Item = Result<Item, ChannelError>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> PollIo<Item> {
fn poll_next(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Item, ChannelError>>> {
self.project().rx.poll_next(cx).map(|option| option.map(Ok))
}
}
impl<Item, SinkItem> Sink<SinkItem> for Channel<Item, SinkItem> {
type Error = io::Error;
type Error = ChannelError;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.project()
.tx
.poll_ready(cx)
.map_err(convert_send_err_to_io)
.map_err(|e| ChannelError::Send(Box::new(e)))
}
fn start_send(self: Pin<&mut Self>, item: SinkItem) -> io::Result<()> {
fn start_send(self: Pin<&mut Self>, item: SinkItem) -> Result<(), Self::Error> {
self.project()
.tx
.start_send(item)
.map_err(convert_send_err_to_io)
.map_err(|e| ChannelError::Send(Box::new(e)))
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.project()
.tx
.poll_flush(cx)
.map_err(convert_send_err_to_io)
.map_err(|e| ChannelError::Send(Box::new(e)))
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.project()
.tx
.poll_close(cx)
.map_err(convert_send_err_to_io)
}
}
fn convert_send_err_to_io(e: futures::channel::mpsc::SendError) -> io::Error {
if e.is_disconnected() {
io::Error::from(io::ErrorKind::NotConnected)
} else if e.is_full() {
io::Error::from(io::ErrorKind::WouldBlock)
} else {
io::Error::new(io::ErrorKind::Other, e)
.map_err(|e| ChannelError::Send(Box::new(e)))
}
}
@@ -147,17 +151,27 @@ fn convert_send_err_to_io(e: futures::channel::mpsc::SendError) -> io::Error {
mod tests {
use crate::{
client, context,
server::{BaseChannel, Incoming},
transport,
server::{incoming::Incoming, BaseChannel},
transport::{
self,
channel::{Channel, UnboundedChannel},
},
};
use assert_matches::assert_matches;
use futures::{prelude::*, stream};
use log::trace;
use std::io;
use tracing::trace;
#[test]
fn ensure_is_transport() {
fn is_transport<SinkItem, Item, T: crate::Transport<SinkItem, Item>>() {}
is_transport::<(), (), UnboundedChannel<(), ()>>();
is_transport::<(), (), Channel<(), ()>>();
}
#[tokio::test]
async fn integration() -> io::Result<()> {
let _ = env_logger::try_init();
async fn integration() -> anyhow::Result<()> {
let _ = tracing_subscriber::fmt::try_init();
let (client_channel, server_channel) = transport::channel::unbounded();
tokio::spawn(
@@ -167,16 +181,16 @@ mod tests {
future::ready(request.parse::<u64>().map_err(|_| {
io::Error::new(
io::ErrorKind::InvalidInput,
format!("{:?} is not an int", request),
format!("{request:?} is not an int"),
)
}))
}),
);
let client = client::new(client::Config::default(), client_channel).spawn()?;
let client = client::new(client::Config::default(), client_channel).spawn();
let response1 = client.call(context::current(), "123".into()).await?;
let response2 = client.call(context::current(), "abc".into()).await?;
let response1 = client.call(context::current(), "", "123".into()).await?;
let response2 = client.call(context::current(), "", "abc".into()).await?;
trace!("response1: {:?}, response2: {:?}", response1, response2);

View File

@@ -38,11 +38,34 @@ where
H: BuildHasher,
{
fn compact(&mut self, usage_ratio_threshold: f64) {
if self.capacity() > 1000 {
let usage_ratio = self.len() as f64 / self.capacity() as f64;
if usage_ratio < usage_ratio_threshold {
self.shrink_to_fit();
}
}
let usage_ratio_threshold = usage_ratio_threshold.clamp(f64::MIN_POSITIVE, 1.);
let cap = f64::max(1000., self.len() as f64 / usage_ratio_threshold);
self.shrink_to(cap as usize);
}
}
#[test]
fn test_compact() {
let mut map = HashMap::with_capacity(2048);
assert_eq!(map.capacity(), 3584);
// Make usage ratio 25%
for i in 0..896 {
map.insert(format!("k{i}"), "v");
}
map.compact(-1.0);
assert_eq!(map.capacity(), 3584);
map.compact(0.25);
assert_eq!(map.capacity(), 3584);
map.compact(0.50);
assert_eq!(map.capacity(), 1792);
map.compact(1.0);
assert_eq!(map.capacity(), 1792);
map.compact(2.0);
assert_eq!(map.capacity(), 1792);
}

View File

@@ -5,31 +5,7 @@
// https://opensource.org/licenses/MIT.
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use std::{
io,
time::{Duration, SystemTime},
};
/// Serializes `system_time` as a `u64` equal to the number of seconds since the epoch.
pub fn serialize_epoch_secs<S>(system_time: &SystemTime, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
const ZERO_SECS: Duration = Duration::from_secs(0);
system_time
.duration_since(SystemTime::UNIX_EPOCH)
.unwrap_or(ZERO_SECS)
.as_secs() // Only care about second precision
.serialize(serializer)
}
/// Deserializes [`SystemTime`] from a `u64` equal to the number of seconds since the epoch.
pub fn deserialize_epoch_secs<'de, D>(deserializer: D) -> Result<SystemTime, D::Error>
where
D: Deserializer<'de>,
{
Ok(SystemTime::UNIX_EPOCH + Duration::from_secs(u64::deserialize(deserializer)?))
}
use std::io;
/// Serializes [`io::ErrorKind`] as a `u32`.
#[allow(clippy::trivially_copy_pass_by_ref)] // Exact fn signature required by serde derive

View File

@@ -2,4 +2,8 @@
fn ui() {
let t = trybuild::TestCases::new();
t.compile_fail("tests/compile_fail/*.rs");
#[cfg(feature = "tokio1")]
t.compile_fail("tests/compile_fail/tokio/*.rs");
#[cfg(all(feature = "serde-transport", feature = "tcp"))]
t.compile_fail("tests/compile_fail/serde_transport/*.rs");
}

View File

@@ -0,0 +1,15 @@
use tarpc::client;
#[tarpc::service]
trait World {
async fn hello(name: String) -> String;
}
fn main() {
let (client_transport, _) = tarpc::transport::channel::unbounded();
#[deny(unused_must_use)]
{
WorldClient::new(client::Config::default(), client_transport).dispatch;
}
}

View File

@@ -0,0 +1,11 @@
error: unused `RequestDispatch` that must be used
--> tests/compile_fail/must_use_request_dispatch.rs:13:9
|
13 | WorldClient::new(client::Config::default(), client_transport).dispatch;
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
note: the lint level is defined here
--> tests/compile_fail/must_use_request_dispatch.rs:11:12
|
11 | #[deny(unused_must_use)]
| ^^^^^^^^^^^^^^^

View File

@@ -0,0 +1,9 @@
use tarpc::serde_transport;
use tokio_serde::formats::Json;
fn main() {
#[deny(unused_must_use)]
{
serde_transport::tcp::connect::<_, (), (), _, _>("0.0.0.0:0", Json::default);
}
}

View File

@@ -0,0 +1,11 @@
error: unused `tarpc::serde_transport::tcp::Connect` that must be used
--> tests/compile_fail/serde_transport/must_use_tcp_connect.rs:7:9
|
7 | serde_transport::tcp::connect::<_, (), (), _, _>("0.0.0.0:0", Json::default);
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
note: the lint level is defined here
--> tests/compile_fail/serde_transport/must_use_tcp_connect.rs:5:12
|
5 | #[deny(unused_must_use)]
| ^^^^^^^^^^^^^^^

View File

@@ -7,8 +7,8 @@ struct HelloServer;
#[tarpc::server]
impl World for HelloServer {
fn hello(name: String) -> String {
format!("Hello, {}!", name)
fn hello(name: String) -> String {
format!("Hello, {name}!", name)
}
}

View File

@@ -7,5 +7,5 @@ error: not all trait items implemented, missing: `HelloFut`
error: hint: `#[tarpc::server]` only rewrites async fns, and `fn hello` is not async
--> $DIR/tarpc_server_missing_async.rs:10:5
|
10 | fn hello(name: String) -> String {
10 | fn hello(name: String) -> String {
| ^^

View File

@@ -0,0 +1,29 @@
use tarpc::{
context,
server::{self, Channel},
};
#[tarpc::service]
trait World {
async fn hello(name: String) -> String;
}
#[derive(Clone)]
struct HelloServer;
#[tarpc::server]
impl World for HelloServer {
async fn hello(self, _: context::Context, name: String) -> String {
format!("Hello, {name}!")
}
}
fn main() {
let (_, server_transport) = tarpc::transport::channel::unbounded();
let server = server::BaseChannel::with_defaults(server_transport);
#[deny(unused_must_use)]
{
server.execute(HelloServer.serve());
}
}

View File

@@ -0,0 +1,11 @@
error: unused `TokioChannelExecutor` that must be used
--> tests/compile_fail/tokio/must_use_channel_executor.rs:27:9
|
27 | server.execute(HelloServer.serve());
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
note: the lint level is defined here
--> tests/compile_fail/tokio/must_use_channel_executor.rs:25:12
|
25 | #[deny(unused_must_use)]
| ^^^^^^^^^^^^^^^

View File

@@ -0,0 +1,30 @@
use futures::stream::once;
use tarpc::{
context,
server::{self, incoming::Incoming},
};
#[tarpc::service]
trait World {
async fn hello(name: String) -> String;
}
#[derive(Clone)]
struct HelloServer;
#[tarpc::server]
impl World for HelloServer {
async fn hello(self, _: context::Context, name: String) -> String {
format!("Hello, {name}!")
}
}
fn main() {
let (_, server_transport) = tarpc::transport::channel::unbounded();
let server = once(async move { server::BaseChannel::with_defaults(server_transport) });
#[deny(unused_must_use)]
{
server.execute(HelloServer.serve());
}
}

View File

@@ -0,0 +1,11 @@
error: unused `TokioServerExecutor` that must be used
--> tests/compile_fail/tokio/must_use_server_executor.rs:28:9
|
28 | server.execute(HelloServer.serve());
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
note: the lint level is defined here
--> tests/compile_fail/tokio/must_use_server_executor.rs:26:12
|
26 | #[deny(unused_must_use)]
| ^^^^^^^^^^^^^^^

View File

@@ -1,14 +1,13 @@
use futures::prelude::*;
use std::io;
use tarpc::serde_transport;
use tarpc::{
client, context,
server::{BaseChannel, Incoming},
server::{incoming::Incoming, BaseChannel},
};
use tokio_serde::formats::Json;
#[tarpc::derive_serde]
#[derive(Debug, PartialEq)]
#[derive(Debug, PartialEq, Eq)]
pub enum TestData {
Black,
White,
@@ -33,7 +32,7 @@ impl ColorProtocol for ColorServer {
}
#[tokio::test]
async fn test_call() -> io::Result<()> {
async fn test_call() -> anyhow::Result<()> {
let transport = tarpc::serde_transport::tcp::listen("localhost:56797", Json::default).await?;
let addr = transport.local_addr();
tokio::spawn(
@@ -45,7 +44,7 @@ async fn test_call() -> io::Result<()> {
);
let transport = serde_transport::tcp::connect(addr, Json::default).await?;
let client = ColorProtocolClient::new(client::Config::default(), transport).spawn()?;
let client = ColorProtocolClient::new(client::Config::default(), transport).spawn();
let color = client
.get_opposite_color(context::current(), TestData::White)

View File

@@ -3,14 +3,11 @@ use futures::{
future::{join_all, ready, Ready},
prelude::*,
};
use std::{
io,
time::{Duration, SystemTime},
};
use std::time::{Duration, SystemTime};
use tarpc::{
client::{self},
context,
server::{self, BaseChannel, Channel, Incoming},
server::{self, incoming::Incoming, BaseChannel, Channel},
transport::channel,
};
use tokio::join;
@@ -34,13 +31,13 @@ impl Service for Server {
type HeyFut = Ready<String>;
fn hey(self, _: context::Context, name: String) -> Self::HeyFut {
ready(format!("Hey, {}.", name))
ready(format!("Hey, {name}."))
}
}
#[tokio::test]
async fn sequential() -> io::Result<()> {
let _ = env_logger::try_init();
async fn sequential() -> anyhow::Result<()> {
let _ = tracing_subscriber::fmt::try_init();
let (tx, rx) = channel::unbounded();
@@ -50,7 +47,7 @@ async fn sequential() -> io::Result<()> {
.execute(Server.serve()),
);
let client = ServiceClient::new(client::Config::default(), tx).spawn()?;
let client = ServiceClient::new(client::Config::default(), tx).spawn();
assert_matches!(client.add(context::current(), 1, 2).await, Ok(3));
assert_matches!(
@@ -61,7 +58,7 @@ async fn sequential() -> io::Result<()> {
}
#[tokio::test]
async fn dropped_channel_aborts_in_flight_requests() -> io::Result<()> {
async fn dropped_channel_aborts_in_flight_requests() -> anyhow::Result<()> {
#[tarpc_plugins::service]
trait Loop {
async fn r#loop();
@@ -82,16 +79,14 @@ async fn dropped_channel_aborts_in_flight_requests() -> io::Result<()> {
}
}
let _ = env_logger::try_init();
let _ = tracing_subscriber::fmt::try_init();
let (tx, rx) = channel::unbounded();
// Set up a client that initiates a long-lived request.
// The request will complete in error when the server drops the connection.
tokio::spawn(async move {
let client = LoopClient::new(client::Config::default(), tx)
.spawn()
.unwrap();
let client = LoopClient::new(client::Config::default(), tx).spawn();
let mut ctx = context::current();
ctx.deadline = SystemTime::now() + Duration::from_secs(60 * 60);
@@ -113,11 +108,11 @@ async fn dropped_channel_aborts_in_flight_requests() -> io::Result<()> {
#[cfg(all(feature = "serde-transport", feature = "tcp"))]
#[tokio::test]
async fn serde() -> io::Result<()> {
async fn serde_tcp() -> anyhow::Result<()> {
use tarpc::serde_transport;
use tokio_serde::formats::Json;
let _ = env_logger::try_init();
let _ = tracing_subscriber::fmt::try_init();
let transport = tarpc::serde_transport::tcp::listen("localhost:56789", Json::default).await?;
let addr = transport.local_addr();
@@ -130,7 +125,7 @@ async fn serde() -> io::Result<()> {
);
let transport = serde_transport::tcp::connect(addr, Json::default).await?;
let client = ServiceClient::new(client::Config::default(), transport).spawn()?;
let client = ServiceClient::new(client::Config::default(), transport).spawn();
assert_matches!(client.add(context::current(), 1, 2).await, Ok(3));
assert_matches!(
@@ -141,9 +136,40 @@ async fn serde() -> io::Result<()> {
Ok(())
}
#[cfg(all(feature = "serde-transport", feature = "unix", unix))]
#[tokio::test]
async fn concurrent() -> io::Result<()> {
let _ = env_logger::try_init();
async fn serde_uds() -> anyhow::Result<()> {
use tarpc::serde_transport;
use tokio_serde::formats::Json;
let _ = tracing_subscriber::fmt::try_init();
let sock = tarpc::serde_transport::unix::TempPathBuf::with_random("uds");
let transport = tarpc::serde_transport::unix::listen(&sock, Json::default).await?;
tokio::spawn(
transport
.take(1)
.filter_map(|r| async { r.ok() })
.map(BaseChannel::with_defaults)
.execute(Server.serve()),
);
let transport = serde_transport::unix::connect(&sock, Json::default).await?;
let client = ServiceClient::new(client::Config::default(), transport).spawn();
// Save results using socket so we can clean the socket even if our test assertions fail
let res1 = client.add(context::current(), 1, 2).await;
let res2 = client.hey(context::current(), "Tim".to_string()).await;
assert_matches!(res1, Ok(3));
assert_matches!(res2, Ok(ref s) if s == "Hey, Tim.");
Ok(())
}
#[tokio::test]
async fn concurrent() -> anyhow::Result<()> {
let _ = tracing_subscriber::fmt::try_init();
let (tx, rx) = channel::unbounded();
tokio::spawn(
@@ -152,7 +178,7 @@ async fn concurrent() -> io::Result<()> {
.execute(Server.serve()),
);
let client = ServiceClient::new(client::Config::default(), tx).spawn()?;
let client = ServiceClient::new(client::Config::default(), tx).spawn();
let req1 = client.add(context::current(), 1, 2);
let req2 = client.add(context::current(), 3, 4);
@@ -166,8 +192,8 @@ async fn concurrent() -> io::Result<()> {
}
#[tokio::test]
async fn concurrent_join() -> io::Result<()> {
let _ = env_logger::try_init();
async fn concurrent_join() -> anyhow::Result<()> {
let _ = tracing_subscriber::fmt::try_init();
let (tx, rx) = channel::unbounded();
tokio::spawn(
@@ -176,7 +202,7 @@ async fn concurrent_join() -> io::Result<()> {
.execute(Server.serve()),
);
let client = ServiceClient::new(client::Config::default(), tx).spawn()?;
let client = ServiceClient::new(client::Config::default(), tx).spawn();
let req1 = client.add(context::current(), 1, 2);
let req2 = client.add(context::current(), 3, 4);
@@ -191,8 +217,8 @@ async fn concurrent_join() -> io::Result<()> {
}
#[tokio::test]
async fn concurrent_join_all() -> io::Result<()> {
let _ = env_logger::try_init();
async fn concurrent_join_all() -> anyhow::Result<()> {
let _ = tracing_subscriber::fmt::try_init();
let (tx, rx) = channel::unbounded();
tokio::spawn(
@@ -201,7 +227,7 @@ async fn concurrent_join_all() -> io::Result<()> {
.execute(Server.serve()),
);
let client = ServiceClient::new(client::Config::default(), tx).spawn()?;
let client = ServiceClient::new(client::Config::default(), tx).spawn();
let req1 = client.add(context::current(), 1, 2);
let req2 = client.add(context::current(), 3, 4);
@@ -214,7 +240,7 @@ async fn concurrent_join_all() -> io::Result<()> {
}
#[tokio::test]
async fn counter() -> io::Result<()> {
async fn counter() -> anyhow::Result<()> {
#[tarpc::service]
trait Counter {
async fn count() -> u32;
@@ -241,7 +267,7 @@ async fn counter() -> io::Result<()> {
}
});
let client = CounterClient::new(client::Config::default(), tx).spawn()?;
let client = CounterClient::new(client::Config::default(), tx).spawn();
assert_matches!(client.count(context::current()).await, Ok(1));
assert_matches!(client.count(context::current()).await, Ok(2));