306 Commits

Author SHA1 Message Date
Tim Kuehn
a6758fd1f9 BeforeRequest hook chaining.
It's unintuitive that serve.before(hook1).before(hook2) executes in
reverse order, with hook2 going before hook1. With BeforeRequestList,
users can write `before().then(hook1).then(hook2).serving(serve)`, and
it will run hook1, then hook2, then the service fn.
2023-12-29 23:13:06 -08:00
Tim Kuehn
2c241cc809 Fix readme and release notes. 2023-12-29 21:03:09 -08:00
Tim Kuehn
263ef8a897 Prepare for v0.34.0 release 2023-12-29 20:52:37 -08:00
Kevin Ge
d50290a21c Add README for example app 2023-12-29 20:32:00 -08:00
Kevin Ge
26988cb833 Update OpenTelemetry packages in example app 2023-12-29 20:32:00 -08:00
Tim Kuehn
6cf18a1caf Rewrite traits to use async-fn-in-trait.
- Stub
- BeforeRequest
- AfterRequest

Also removed the last remaining usage of an unstable feature,
iter_intersperse.
2023-12-29 13:52:05 -08:00
Tim Kuehn
84932df9b4 Return span in InFlightRequests::complete_request.
Rather than returning a bool, return the Span associated with the
request. This gives RequestDispatch more flexibility to annotate the
request span.
2023-12-29 13:52:05 -08:00
Tim Kuehn
8dc3711a80 Use async fn in generated traits!!
The major breaking change is that Channel::execute no longer internally
spawns RPC handlers, because it is no longer possible to place a Send
bound on the return type of Serve::serve. Instead, Channel::execute
returns a stream of RPC handler futures.

Service authors can reproduce the old behavior by spawning each response
handler (the compiler knows whether or not the futures can be spawned;
it's just that the bounds can't be expressed generically):

    channel.execute(server.serve())
           .for_each(|rpc| { tokio::spawn(rpc); })
2023-12-29 13:52:05 -08:00
Tim Kuehn
7c5afa97bb Add request hooks to the Serve trait.
This allows plugging in horizontal functionality, such as authorization,
throttling, or latency recording, that should run before and/or after
execution of every request, regardless of the request type.

The tracing example is updated to show off both client stubs as well as
server hooks.

As part of this change, there were some changes to the Serve trait:

1. Serve's output type is now a Result<Response, ServerError>..
   Serve previously did not allow returning ServerErrors, which
   prevented using Serve for horizontal functionality like throttling or
   auth. Now, Serve's output type is Result<Resp, ServerError>, making
   Serve a more natural integration point for horizontal capabilities.
2. Serve's generic Request type changed to an associated type. The
   primary benefit of the generic type is that it allows one type to
   impl a trait multiple times (for example, u64 impls TryFrom<usize>,
   TryFrom<u128>, etc.). In the case of Serve impls, while it is
   theoretically possible to contrive a type that could serve multiple
   request types, in practice I don't expect that to be needed.  Most
   users will use the Serve impl generated by #[tarpc::service], which
   only ever serves one type of request.
2023-12-29 13:52:05 -08:00
Tim Kuehn
324df5cd15 Add back the Client trait, renamed Stub.
Also adds a Client stub trait alias for each generated service.

Now that generic associated types are stable, it's almost possible to
define a trait for Channel that works with async fns on stable. `impl
trait in type aliases` is still necessary (and unstable), but we're
getting closer.

As a proof of concept, three more implementations of Stub are implemented;

1. A load balancer that round-robins requests between different stubs.
2. A load balancer that selects a stub based on a request hash, so that
   the same requests go to the same stubs.
3. A stub that retries requests based on a configurable policy.

   The "serde/rc" feature is added to the "full" feature because the Retry
   stub wraps the request in an Arc, so that the request is reusable for
   multiple calls.

   Server implementors commonly need to operate generically across all
   services or request types. For example, a server throttler may want to
   return errors telling clients to back off, which is not specific to any
   one service.
2023-12-29 13:52:05 -08:00
Guillaume Charmetant
3264979993 Fix warnings in README's example 2023-11-16 09:54:54 -08:00
Guillaume Charmetant
dd63fb59bf Fix tokio dep in the README's example
Add missing tokio feature in the example's dependencies.
2023-11-16 09:54:54 -08:00
Tim Kuehn
f4db8cc5b4 Address clippy lints 2023-11-16 00:00:27 -08:00
Tim Kuehn
e9ba350496 Update must-use UI tests 2023-11-16 00:00:27 -08:00
Tim Kuehn
e6d779e70b Remove mipsel workflow.
Mipsel was downgraded to tier 3, which broke this workflow.
https://github.com/rust-lang/compiler-team/issues/648
2023-11-16 00:00:27 -08:00
Izumi Raine
ce5f8cfb0c Simplify TLS example (#404) 2023-04-14 12:33:22 -07:00
Tim Kuehn
4b69dc8db5 Prepare release of v0.33.0 2023-04-03 11:03:55 -07:00
Bruno
866db2a2cd Bump opentelemetry to 0.18 (#401) 2023-04-03 10:38:17 -07:00
Tim Kuehn
bed85e2827 Prepare release of v0.32.0 2023-03-24 15:04:06 -07:00
Bruno
93f3880025 Return transport errors to the caller (#399)
* Make client::InFlightRequests generic over result.

Previously, InFlightRequests required the client response type to be a
server response. However, this prevented injection of non-server
responses: for example, if the client fails to send a request, it should
complete the request with an IO error rather than a server error.

* Gracefully handle client-side send errors.

Previously, a client channel would immediately disconnect when
encountering an error in Transport::try_send. One kind of error that can
occur in try_send is message validation, e.g. validating a message is
not larger than a configured frame size. The problem with shutting down
the client immediately is that debuggability suffers: it can be hard to
understand what caused the client to fail. Also, these errors are not
always fatal, as with frame size limits, so complete shutdown was
extreme.

By bubbling up errors, it's now possible for the caller to
programmatically handle them. For example, the error could be walked
via anyhow::Error:

```
    2023-01-10T02:49:32.528939Z  WARN client: the client failed to send the request

    Caused by:
        0: could not write to the transport
        1: frame size too big
```

* Some follow-up work: right now, read errors will bubble up to all pending RPCs. However, on the write side, only `start_send` bubbles up. `poll_ready`, `poll_flush`, and `poll_close` do not propagate back to pending RPCs. This is probably okay in most circumstances, because fatal write errors likely coincide with fatal read errors, which *do* propagate back to clients. But it might still be worth unifying this logic.

---------

Co-authored-by: Tim Kuehn <tikue@google.com>
2023-03-24 14:31:25 -07:00
cguentherTUChemnitz
878f594d5b Feature/tls over tcp example (#398)
Example tarpc service that encodes messages with bincode written to a TLS-over-TCP transport.

Certs were generated with openssl 3 using https://github.com/rustls/rustls/tree/main/test-ca

New dependencies:
- tokio-rustls to set up the TLS connections
- rustls-pemfile to load certs from .pem files
2023-03-22 10:35:21 -07:00
Tim Kuehn
aa9bbad109 Fix compile_fail tests for formatting changes on stable 2023-03-17 10:07:54 -07:00
Tim Kuehn
7e872ce925 Remove bad mem::forget usage.
mem::forget is a dangerous tool, and it was being used carelessly for
things that have safer alternatives. There was at least one bug where a
cloned tokio::sync::mpsc::UnboundedSender used for request cancellation
was being leaked on every successful server response, so its refcounts
were never decremented. Because these are atomic refcounts, they'll wrap
around rather than overflow when reaching the maximum value, so I don't
believe this could lead to panics or unsoundness.
2022-11-23 18:01:12 -08:00
Tim Kuehn
62541b709d Replace actions-rs 2022-11-23 16:39:29 -08:00
Tim Kuehn
8c43f94fb6 Remove unused Sealed trait 2022-11-17 00:57:56 -08:00
Tim Kuehn
7fa4e5064d Ignore clippy false positive 2022-11-13 00:25:07 -08:00
Tim Kuehn
94db7610bb Require a static lifetime for request_name. 2022-11-05 11:43:03 -07:00
Tim Kuehn
0c08d5e8ca Prepare release of v0.31.0 2022-11-03 13:29:46 -07:00
Tim Kuehn
75b15fe2aa Address clippy lint 2022-10-07 10:51:45 -07:00
Tim Kuehn
863a08d87e In example-service, print the port the server is listened on.
This is helpful when passing starting the server with --port 0.
2022-10-06 20:58:54 -07:00
Tim Kuehn
49ba8f8b1b Zero-pad the random number suffix of TempPathBufs.
This way, the hex number is always 16 digits, which is helpful for test
verification as well as simple consistency.
2022-10-03 18:50:50 -07:00
Kevin K
d832209da3 feat: Unix domain sockets with serde transports (#380)
* adds support for Unix Domain Socket generic transports
* adds a TempPathBuf that lives in temp and is removed on drop
2022-10-03 18:07:29 -07:00
royrustdev
584426d414 fix clippy warnings #378 2022-09-19 23:26:07 -07:00
royrustdev
50eb80c883 reference latest tarpc version in readme 2022-09-19 21:58:21 -07:00
royrustdev
1f0c80d8c9 bump github actions 2022-09-15 11:17:58 -07:00
Tim Kuehn
99bf3e62a3 Prepare release of 0.30.0 2022-08-12 16:08:33 -07:00
Tim Kuehn
68863e3db0 Remove Channel::request_cancellation.
This trait fn returns a private type, which means it's useless for
anyone using the Channel.

Instead, add an inert (now-public) ResponseGuard to TrackedRequest that,
when taken out of the ManuallyDrop, ensures a Channel's request state is
cleaned up. It's preferable to make ResponseGuard public instead of
RequestCancellations because it's a smaller API surface (no public
methods, just a Drop fn) and harder to misuse, because it is already
associated with the correct request ID to cancel.
2022-08-12 16:08:33 -07:00
Tim Kuehn
453ba1c074 Lower log level of log in the RPC callpath 2022-08-12 09:04:47 -07:00
Makro
e3eac1b4f5 Add LICENSE files to crates (#372) 2022-08-10 17:11:50 -07:00
kkharji
0e102288a5 feat: re-export used packages (#371)
## Problem
Library users might get stuck with or ran into issues while using tarpc because of incompatible third party libraries. in particular, tokio_serde and tokio_util.

## Solution
This PR does the following:

1. re-export tokio_serde as part of feature serde-transport, because the end user imports it to use some serde-transport APIs.
2. Update third library packages to latest release and fix resulting issues from that.

## Important Notes
tokio_util 7.3 DelayQueue::poll_expired API changed [0] therefore, InFlightRequests::poll_expired now returns Poll<Option<u64>>

[0] https://docs.rs/tokio-util/latest/tokio_util/time/delay_queue/struct.DelayQueue.html#method.poll_expired
2022-07-15 10:14:49 -07:00
Tim Kuehn
4c8ba41b2f #[allow(unstable_name_collisions)] for .ready()
.ready() is being added to std, but in the meantime, I don't want to stop using PollTest::ready.
2022-06-07 01:29:14 -07:00
Tim Kuehn
946c627579 Remove unused field 2022-06-07 01:29:14 -07:00
Tim Kuehn
104dd71bba Clean up Channel request data more reliably.
When an InFlightRequest is dropped before response completion, request
data in the Channel persists until either the request expires or the
client cancels the request. In rare cases, requests with very large
deadlines could clog up the Channel long after request processing
ceases.

This commit adds a drop hook to InFlightRequest so that if it is dropped
before execution completes, a cancellation message is sent to the
Channel so that it can clean up the associated request data.

This only works for when using `InFlightRequest::execute` or
`Channel::execute`. However, users of raw `Channel` have access
to the `RequestCancellation` handle via `Channel::request_cancellation`,
so they can implement a similar method if they wish to manually clean up
request data.

Note that once a Channel's request data is cleaned up, that request can
never be responded to, even if a response is produced afterward.

Fixes https://github.com/google/tarpc/issues/314
2022-06-07 01:29:04 -07:00
Tim Kuehn
012c481861 Move cancellation types into a dedicated module.
Cancellation utilities could be useful for both client and server code.
2022-06-05 18:54:52 -07:00
Tim Kuehn
dc12bd09aa Annotate types that impl Future with #[must_use].
These types do nothing unless polled / .awaited.
Annotating them with #[must_use] helps prevent a common class of coding errors.

Fixes https://github.com/google/tarpc/issues/368.
2022-06-05 18:54:52 -07:00
Tim Kuehn
2594ea8ce9 Prepare release of 0.29.0 2022-06-05 15:26:33 -07:00
Tim Kuehn
839b87e394 Serialize RPC deadline as a Duration.
Duration was previously serialized as SystemTime. However, absolute
times run into problems with clock skew: if the remote machine's clock
is too far in the future, the RPC deadline will be exceeded before
request processing can begin. Conversely, if the remote machine's clock
is too far in the past, the RPC deadline will not be enforced.

By converting the absolute deadline to a relative duration, clock skew
is no longer relevant, as the remote machine will convert the deadline
into a time relative to its own clock. This mirrors how the gRPC HTTP2
protocol includes a Timeout in the request headers [0] but the SDK uses
timestamps [1]. Keeping the absolute time in the core APIs maintains all
the benefits of today, namely, natural deadline propagation between RPC
hops when using the current context.

This serialization strategy means that, generally, the remote machine's
deadline will be slightly in the future compared to the local machine.
Depending on network transfer latencies, this could be microseconds to
milliseconds, or worse in the worst case. Because the deadline is not
intended for high-precision scenarios, I don't view this is as
problematic.

Because this change only affects the serialization layer, local
transports that bypass serialization are not affected.

[0] https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md
[1] https://grpc.io/blog/deadlines/#setting-a-deadline
2022-05-26 15:18:49 -07:00
Tim Kuehn
57d0638a99 Add rpc.deadline tag to Opentelemetry traces. 2022-05-26 15:18:49 -07:00
Tim Kuehn
a3a6404a30 Prepare release of 0.28.0 2022-04-06 22:07:07 -07:00
Tim Kuehn
b36eac80b1 Bump minimum rust version to 1.58.0 2022-04-06 21:53:56 -07:00
Bruno
d7070e4bc3 Update opentelemetry and related dependencies (#362) 2022-04-03 14:09:14 -07:00
Tim Kuehn
b5d1828308 Use captured identifiers in format strings.
This was stabilized in Rust 1.58.0: https://blog.rust-lang.org/2022/01/13/Rust-1.58.0.html
2022-01-13 15:00:44 -08:00
Zak Cutner
92cfe63c4f Use single-threaded Tokio runtime (#360) 2022-01-06 20:59:51 -08:00
Tim Kuehn
839a2f067c Update example to latest version of Clap 2021-12-27 22:56:17 -08:00
David Kleingeld
b5d593488c Derive more traits for RpcError (#359)
Makes RpcError derive Clone, PartialEq, Eq, Hash, Serialize and Deserialize.
2021-12-27 22:00:58 -08:00
Tim Kuehn
eea38b8bf4 Simplify code with const assert!.
The code that prevents compilation on systems where usize is larger than
u64 previously used a const index-out-of-bounds trick. That code can now
be replaced with assert!, as const panic! has landed in 1.57.0 stable.
2021-12-03 15:20:33 -08:00
Shi Yan
70493c15f4 Fix a compiling issue of the official example (#358)
Fix a compiling issue of the official example because of the following error :

```
error[E0599]: the method `execute` exists for struct `BaseChannel<_, _, UnboundedChannel<ClientMessage<_>, Response<_>>>`, but its trait bounds were not satisfied
  --> src/main.rs:39:25
   |
39 |     tokio::spawn(server.execute(HelloServer.serve()));
   |                         ^^^^^^^ method cannot be called on `BaseChannel<_, _, UnboundedChannel<ClientMessage<_>, Response<_>>>` due to unsatisfied trait bounds
   |
   = note: the following trait bounds were not satisfied:
           `<&BaseChannel<_, _, UnboundedChannel<ClientMessage<_>, Response<_>>> as futures::Stream>::Item = _`
           which is required by `&BaseChannel<_, _, UnboundedChannel<ClientMessage<_>, Response<_>>>: tarpc::server::incoming::Incoming<_>`
           `&BaseChannel<_, _, UnboundedChannel<ClientMessage<_>, Response<_>>>: futures::Stream`
           which is required by `&BaseChannel<_, _, UnboundedChannel<ClientMessage<_>, Response<_>>>: tarpc::server::incoming::Incoming<_>`
   = help: items from traits can only be used if the trait is in scope
help: the following trait is implemented but not in scope; perhaps add a `use` for it:
   |
1  | use tarpc::server::Channel;
   |
```

See https://github.com/google/tarpc/pull/358#issuecomment-981953193 for the root cause.
2021-11-29 17:01:16 -08:00
baptiste0928
f7c5d6a7c3 Fix example-service (#355)
Fixes the compilation of the example-service crate (the Clap trait has been renamed Parser in clap-rs/clap@d840d56).
2021-11-15 08:39:04 -08:00
Scott Kirkpatrick
98c5d2a18b Re-add typo fixes (#353)
The typo fixes that were added by commit b5d9aaa
were accidentally reverted by commit 1e680e3, this
will add them back
2021-11-08 10:07:21 -08:00
Tim Kuehn
46b534f7c6 Use HashMap::shrink_to in impl of Comapct::compact. 2021-10-21 17:03:57 -07:00
Tim Kuehn
42b4fc52b1 Set rust-version to 1.56 2021-10-21 16:08:15 -07:00
Tim Kuehn
350dbcdad0 Upgrade to Rust 2021! 2021-10-21 14:10:21 -07:00
Tim Kuehn
b1b4461d89 Prepare release of 0.27.2 2021-10-08 22:31:56 -07:00
Tim Kuehn
f694b7573a Close TcpStream when client disconnects.
An attempt at a clean shutdown helps the server to drop its connections
more quickly.

Testing this uncovered a latent bug in DelayQueue wherein `poll_expired`
yields `Pending` when empty. A workaround was added to
`InFlightRequests::poll_expired`: check if there are actually any
outstanding requests before calling `DelayQueue::poll_expired`.
2021-10-08 22:13:24 -07:00
Tim Kuehn
1e680e3a5a Fix typos in docs.
Fixes https://github.com/google/tarpc/issues/352.
2021-10-08 19:19:50 -07:00
Tim Kuehn
2591d21e94 Update release notes to mention io::Error = 2021-09-23 13:57:43 -07:00
Tim Kuehn
6632f68d95 Prepare for 0.27 release 2021-09-22 15:41:34 -07:00
Dmitry Kakurin
25985ad56a Update README.md (#350)
Fixed 2 typos
2021-09-01 17:58:49 -07:00
Tim Kuehn
d6a24e9420 Address Clippy lint 2021-08-24 12:40:18 -07:00
Tim Kuehn
281a78f3c7 Add tokio-serde-bincode feature 2021-08-24 12:37:57 -07:00
Julian Tescher
a0787d0091 Update to opentelemetry 0.16.x (#349) 2021-08-17 00:00:07 -04:00
Frederik-Baetens
d2acba0e8a add serde-transport-json feature flag (#346)
In general, it should be possible to use, or at least import all functionality of a library, when having only that library in your cargo.toml.
2021-05-06 08:41:57 -07:00
Tim Kuehn
ea7b6763c4 Refactor server module.
In the interest of the user's attention, some ancillary APIs have been
moved to new submodules:

- server::limits contains what was previously called Throttler and
  ChannelFilter. Both of those names were very generic, when the methods
  applied by these types were very specific (and also simplistic). Renames
  have occurred:
  - ThrottlerStream => MaxRequestsPerChannel
  - Throttler => MaxRequests
  - ChannelFilter => MaxChannelsPerKey
- server::incoming contains the Incoming trait.
- server::tokio contains the tokio-specific helper types.

The 5 structs and 1 enum remaining in the base server module are all
core to the functioning of the server.
2021-04-21 17:05:49 -07:00
Tim Kuehn
eb67c540b9 Use more structured errors in client. 2021-04-21 14:54:45 -07:00
Tim Kuehn
4151d0abd3 Move Span creation into BaseChannel.
It's important for Channel decorators, like Throttler, to have access to
the Span. This means that the BaseChannel becomes responsible for
starting its own requests. Actually, this simplifies the integration for
the Channel users, as they can assume any yielded requests are already
tracked.

This entails the following breaking changes:

- removed trait method Channel::start_request as it is now done
  internally.
2021-04-21 14:54:45 -07:00
Tim Kuehn
d0c11a6efa Change RPC error type from io::Error => RpcError.
Becaue tarpc is a library, not an application, it should strive to
use structured errors in its API so that users have maximal flexibility
in how they handle errors. io::Error makes that hard, because it is a
kitchen-sink error type.

RPCs in particular only have 3 classes of errors:

- The connection breaks.
- The request expires.
- The server decides not to process the request.

(Of course, RPCs themselves can have application-specific errors, but
from the perspective of the RPC library, those can be classified as
successful responsees).
2021-04-20 18:29:55 -07:00
Tim Kuehn
82c4da1743 Prepare release of v0.26.2 2021-04-20 11:28:15 -07:00
Tim Kuehn
0a15e0b75c Rustdoc: link RPC futures to their methods. 2021-04-20 11:25:26 -07:00
Tim Kuehn
0b315c29bf It's not currently possible to document the enum variants, which means
projects that #[deny(missing_docs)] wouldn't compile if using tarpc
services.
2021-04-20 09:01:39 -07:00
Tim Kuehn
56f09bf61f Fix log that's split across lines. 2021-04-17 17:15:16 -07:00
Tim Kuehn
6d82e82419 Fix formatting 2021-04-16 16:51:21 -07:00
Tim Kuehn
9bebaf814a Address clippy lint 2021-04-14 17:49:27 -07:00
Tim Kuehn
5f4d6e6008 Prepare release of v0.26.0 2021-04-14 17:08:44 -07:00
Tim Kuehn
07d07d7ba3 Remove tracing_appender as it does not support build target mipsel-unknown-linux-gnu 2021-04-01 19:37:02 -07:00
Tim Kuehn
a41bbf65b2 Use rustfmt instead of cargo fmt so that diff is only printed once 2021-04-01 17:24:34 -07:00
Tim Kuehn
21e2f7ca62 Tear out requirement that Transport's error type is io::Error. 2021-04-01 17:24:34 -07:00
Tim Kuehn
7b7c182411 Instrument tarpc with tracing.
tarpc is now instrumented with tracing primitives extended with
OpenTelemetry traces. Using a compatible tracing-opentelemetry
subscriber like Jaeger, each RPC can be traced through the client,
server, amd other dependencies downstream of the server. Even for
applications not connected to a distributed tracing collector, the
instrumentation can also be ingested by regular loggers like env_logger.

 # Breaking Changes

 ## Logging

Logged events are now structured using tracing. For applications using a
logger and not a tracing subscriber, these logs may look different or
contain information in a less consumable manner. The easiest solution is
to add a tracing subscriber that logs to stdout, such as
tracing_subscriber::fmt.

 ##  Context

- Context no longer has parent_span, which was actually never needed,
  because the context sent in an RPC is inherently the parent context.
  For purposes of distributed tracing, the client side of the RPC has all
  necessary information to link the span to its parent; the server side
  need do nothing more than export the (trace ID, span ID) tuple.
- Context has a new field, SamplingDecision, which has two variants,
  Sampled and Unsampled. This field can be used by downstream systems to
  determine whether a trace needs to be exported. If the parent span is
  sampled, the expectation is that all child spans be exported, as well;
  to do otherwise could result in lossy traces being exported. Note that
  if an Openetelemetry tracing subscriber is not installed, the fallback
  context will still be used, but the Context's sampling decision will
  always be inherited by the parent Context's sampling decision.
- Context::scope has been removed. Context propagation is now done via
  tracing's task-local spans. Spans can be propagated across tasks via
  Span::in_scope. When a service receives a request, it attaches an
  Opentelemetry context to the local Span created before request handling,
  and this context contains the request deadline. This span-local deadline
  is retrieved by Context::current, but it cannot be modified so that
  future Context::current calls contain a different deadline. However, the
  deadline in the context passed into an RPC call will override it, so
  users can retrieve the current context and then modify the deadline
  field, as has been historically possible.
- Context propgation precedence changes: when an RPC is initiated, the
  current Span's Opentelemetry context takes precedence over the trace
  context passed into the RPC method. If there is no current Span, then
  the trace context argument is used as it has been historically. Note
  that Opentelemetry context propagation requires an Opentelemetry
  tracing subscriber to be installed.

 ## Server

- The server::Channel trait now has an additional required associated
  type and method which returns the underlying transport. This makes it
  more ergonomic for users to retrieve transport-specific information,
  like IP Address. BaseChannel implements Channel::transport by returning
  the underlying transport, and channel decorators like Throttler just
  delegate to the Channel::transport method of the wrapped channel.

 # References

[1] https://github.com/tokio-rs/tracing
[2] https://opentelemetry.io
[3] https://github.com/open-telemetry/opentelemetry-rust/tree/main/opentelemetry-jaeger
[4] https://github.com/env-logger-rs/env_logger
2021-04-01 17:24:34 -07:00
Ben Ludewig
db0c778ead Serialize u128 TraceId as LE bytes (#344) 2021-03-30 08:41:19 -07:00
Tim Kuehn
c3efb83ac1 Add more context to errors returned by serde transport 2021-03-28 20:03:03 -07:00
Tim Kuehn
3d7b0171fe Fix cargo fmt portion of pre-commit 2021-03-26 19:39:56 -07:00
oblique
c191ff5b2e Do not enable tokio-serde/json by default (#345) 2021-03-26 18:22:44 -07:00
Tim Kuehn
90bc7f741d Fix up imports 2021-03-17 12:44:39 -07:00
Kitsu
d3f6c01df2 Reduce required tokio features (#343)
* Move async tests behind cfg-ed mod
* Use explicit tokio features for the example
* Use only relative crate path for example dependency
2021-03-17 12:30:18 -07:00
Tim Kuehn
c6450521e6 Add method to run a future in the current context.
Previously, `Context::current` would always return a new context. Now,
it uses tokio task-local data to look for the current context. Tokio
task locals are not actually tied to a tokio executor; instead, they
provide data scoped to a future.

The basic pattern is:

```rust
let ctx = Context::new_root();
ctx.scope(async {
    let ctx2 = context::current();
    assert_eq!(ctx2.trace_context.span_id, ctx.trace_context.span_id);
});
```

`server::InFlightRequest::execute` uses `Context::scope` to set the
current context before executing a request, so calls to
`context::current` in request handlers will return the context provided
by the client. This does not propagate to new spawned tasks. To
propagate the client context to child tasks, the following pattern will
work:

```rust
tokio::spawn(context::current().scope(async { /* do work here */ }));
```

This commit also introduces a breaking change to Context serialization.
Previously, the deadline only serialized second-level precision. Now, it
provides full fidelity serialization to the nanosecond.
2021-03-13 16:05:02 -08:00
Tim Kuehn
1da6bcec57 Prepare v0.25 release 2021-03-10 20:00:25 -08:00
Seth Vargo
75a5591158 Improve Actions hygiene
👋 hello there! I'm a fellow Googler who works on projects that leverage GitHub Actions for CI/CD. Recently I noticed a large increase in our queue time, and I've tracked it down to the [limit of 180 concurrent jobs](https://docs.github.com/en/actions/reference/usage-limits-billing-and-administration) for an organization. To help be better citizens, I'm proposing changes across a few repositories that will reduce GitHub Actions hours and consumption. I hope these changes are reasonable and I'm happy to talk through them in more detail.

- Only run GitHub Actions for pushes and PRs against the main branch of the repository. If your team uses a forking model, this change will not affect you. If your team pushes branches to the repository directly, this changes actions to only run against the primary branches or if you open a Pull Request against a primary branch.

- For long-running jobs (especially tests), I added the "Cancel previous" workflow. This is very helpful to prevent a large queue backlog when you are doing rapid development and pushing multiple commits. Without this, GitHub Actions' default behavior is to run all actions on all commits.

There are other changes you could make, depending on your project (but I'm not an expert):

- If you have tests that should only run when a subset of code changes, consider gating your workflow to particular file paths. For example, we have some jobs that do Terraform linting, but [they only run when Terraform files are changed](c4f59fee71/.github/workflows/terraform.yml (L3-L11)).

Hopefully these changes are not too controversial and also hopefully you can see how this would reduce actions consumption to be good citizens to fellow Googlers. If you have any questions, feel free to respond here or ping me on chat. Thank you!
2021-03-10 17:31:13 -08:00
Tim Kuehn
9462aad3bf Improve test coverage of serde_transport 2021-03-10 12:37:55 -08:00
Tim Kuehn
0964fc51ff Add transport::channel::bounded.
This is like transport::channel::unbounded but with a fixed buffer size.
2021-03-08 23:10:12 -08:00
Tim Kuehn
27aacab432 Alternate polling expired and new requests.
Previously, there were two loops:

- Expired in-flight requests are polled until Pending.
- New requests are polled until Pending.

Now there is one loop that alternates between polling expired requests
and new requests. This way, neither type of action can face starvation.
2021-03-08 23:00:39 -08:00
Tim Kuehn
3feb465ad3 Clean up some server documentation. 2021-03-08 15:43:23 -08:00
Tim Kuehn
66cdc99ae0 Factor out ensure_writeable methods.
There is some important logic that is easy to overlook in the client and
server channels: streams of data to write to the transport should not be
polled until the transport is known to be ready to buffer a message. In
the case that a transport's buffer is full, it needs to be flushed to
make room for more messages.

Without this logic, start_send() could return an error when the buffer
is full, which would cause the entire Channel to error out.

Due to the importance of this logic, it's now factored out into its own
method that's easier to understand: fn ensure_writeable. There is one in
the client module and and one in the server module.
2021-03-08 11:36:20 -08:00
Tim Kuehn
66419db6fd Don't send a deadline-exceeded response.
The deadline-exceeded response was largely redundant, because the client
shouldn't normally be waiting for such a response, anyway -- the normal
client will automatically remove the in-flight request when it reaches
the deadline.

This also allows for internalizing the expiration+cleanup logic entirely
within BaseChannel, without having it leak into the Channel trait and
requiring action taken by the Requests struct.
2021-03-07 23:49:31 -08:00
Tim Kuehn
72d5dbba89 Cleanup wrap-up.
- Remove unnecessary Sync and Clone bounds.
- Merge client and client::channel modules.
- Run cargo clippy in the pre-push hook.
- Put DispatchResponse.cancellation in an Option.  Previously, the
  cancellation logic looked to see if `complete == true`, but it's a bit
  less error prone to put the Cancellation in an Option, so that the
  request can't accidentally be cancelled.
- Remove some unnecessary pins/projections.
- Clean up docs a bit. rustdoc had some warnings that are now gone.
2021-03-07 22:29:03 -08:00
Tim Kuehn
e75193c191 Client RPCs now take &self.
This required the breaking change of removing the Client trait. The
intent of the Client trait was to facilitate the decorator pattern by
allowing users to create their own Clients that added behavior on top of
the base client. Unfortunately, this trait had become a maintenance
burden, consistently causing issues with lifetimes and the lack of
generic associated types. Specifically, it meant that Client impls could
not use async fns, which is no longer tenable today.
2021-03-07 17:41:29 -08:00
Tim Kuehn
ce4fd49161 Centralize client-side request deadline handling.
Before this commit, each request future had its own timeout and would
communicate to the client Channel when a request was no longer being
listened to. Now, instead, the Channel tracks deadlines of in-flight
requests and completes requests with deadline-exceeded errors when they
expire.

This should be functionally equivalent to the previous way. It just cuts
down on the amount of two-way processing required. Unfortunately,
dropping a response future early still requires the client to send a
cancellation message to the Channel.
2021-03-07 15:31:17 -08:00
Tim Kuehn
3c978c5bf6 Handle deadlines in BaseChannel.
Before this commit, deadlines were handled by a timeout future that
wrapped each request handler. However, request handlers can be dropped
before sending a response back to the channel, so they can't be relied
on for channel state cleanup. Additionally, clients can't be relied on
to send cancellation messages. It was therefore theoretically possible
for pathological behaviors to cause an unbounded growth in orphan
request data in the Channel.

With this change, as long as requests sent have reasonable deadlines,
then the channel will be able to clean itself up. It is still possible
for requests to be sent with very large deadlines, which would prevent
the channel from cleaning itself up.
2021-03-07 04:05:33 -08:00
Tim Kuehn
6f419e9a9a Refactor server module to be easier to understand.
1. Renames

Some of the items in this module were renamed to be less generic:

- Handler => Incoming
- ClientHandler => Requests
- ResponseHandler => InFlightRequest
- Channel::{respond_with => requests}

In the case of Handler: handler of *what*? Now it's a bit clearer that
this is a stream of Channels (aka *incoming* connections).

Similarly, ClientHandler was a stream of requests over a single
connection. Hopefully Requests better reflects that.

ResponseHandler was renamed InFlightRequest because it no longer
contains the serving function. Instead, it is just the request, plus
the response channel and an abort hook. As a result of this,
Channel::respond_with underwent a big change: it used to take the
serving function and return a ClientHandler; now it has been renamed
Channel::requests and does not take any args.

2. Execute methods

All methods thats actually result in responses being generated
have been consolidated into methods named `execute`:

- InFlightRequest::execute returns a future that completes when a
  response has been generated and sent to the server Channel.
- Requests::execute automatically spawns response handlers for all
  requests over a single channel.
- Channel::execute is a convenience for `channel.requests().execute()`.
- Incoming::execute automatically spawns response handlers for all
  requests over all channels.

3. Removal of Server.

server::Server was removed, as it provided no value over the Incoming/Channel
abstractions. Additionally, server::new was removed, since it just
returned a Server.
2021-03-06 20:20:48 -08:00
Tim Kuehn
b3eb8d0b7a Move items in the rpc module to the top level.
The rpc module doesn't carry its weight. The whole darn project is RPC related!
2021-03-06 15:05:10 -08:00
Tim Kuehn
3b422eb179 Abort all in-flight requests when dropping BaseChannel.
Fixes #341
2021-01-24 17:57:44 -08:00
Michael Zimmermann
4b513bad73 fix clippy::needless_lifetimes
warning: explicit lifetimes given in parameter types where they could be elided (or replaced with `'_` if needed by type declaration)
   --> tarpc/src/rpc/server/filter.rs:127:5
    |
127 |     fn channel<'a>(self: Pin<&'a mut Self>) -> Pin<&'a mut C> {
    |     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    |
    = note: `#[warn(clippy::needless_lifetimes)]` on by default
    = help: for further information visit https://rust-lang.github.io/rust-clippy/master/index.html#needless_lifetimes

warning: 1 warning emitted
2021-01-20 23:27:50 -08:00
Michael Zimmermann
e71e17866d github actions: cargo-check mipsel-unknown-linux-gnu 2021-01-20 23:27:50 -08:00
Michael Zimmermann
7e3fbec077 example-service: set max frame length to usize::MAX
I don't know what the intention was behind using u32::MAX + 1 but since the
argument's type is usize this is the only giant value that makes sense to me.
2021-01-20 23:27:50 -08:00
Michael Zimmermann
e4bc5e8e32 use AtomicUsize instead of AtomicU64
- it's more portable (some architectures like MIPS don't support AtomicU64)
- for most 64bit architectures usize should be 64bit as well
- for most users even 32bit would probably be enough because:
  - it's tied to the connection(for streaming sockets)
  - the ID wraps and by the time that happens, all previous requests would have
    timed out unless you send a lot of requests and have a ton of RAM
2021-01-20 23:27:50 -08:00
Tim Kuehn
bc982c5584 Prepare release of v0.24.1 2020-12-28 15:42:11 -08:00
Logan Magee
d440e12c19 Bump tokio to 1.0 (#337)
Co-authored-by: Artem Vorotnikov <artem@vorotnikov.me>
2020-12-23 22:49:02 -08:00
Frederik-Baetens
bc8128af69 add serde derivation alias macro (#333) 2020-11-13 14:36:59 -08:00
Tim Kuehn
1d87c14262 Fix github actions config - take 3 2020-11-12 12:33:10 -08:00
Tim Kuehn
ca929c2178 Fix github actions config - take 2 2020-11-12 12:24:46 -08:00
Tim Kuehn
569039734b Fix github actions config 2020-11-12 12:13:10 -08:00
Tim Kuehn
3d43310e6a Make 'cargo test' succeed again 2020-11-12 11:59:39 -08:00
Tim Kuehn
d21cbddb0d Cargo test should pass without features enabled 2020-11-12 11:57:08 -08:00
Frederik-Baetens
25aa857edf Reexport/tokio serde (#332)
Re-export tokio_serde when the serde-transport feature is enabled.
2020-11-09 12:56:46 -08:00
Frederik-Baetens
0bb2e2bbbe re-export serde (#330)
* re-export serde
* make serde re-export dependent on serde1 feature flag
* update missing_async compile test case
2020-11-09 11:42:28 -08:00
chansuke
dc376343d6 Remove #[derive(Debug)] from library structs (#327)
* Remove `#[derive(Debug)]` from library structs
* Add manual debug impl for backward compatibility
2020-11-04 11:24:57 -08:00
Artem Vorotnikov
2e7d1f8a88 Bump dependencies (#328) 2020-10-31 09:43:40 -07:00
Tim Kuehn
6314591c65 Add tokio's macros feature to readme example's dependencies 2020-10-30 17:29:14 -07:00
Tim Kuehn
7dd7494420 Prepare v0.23.1 release 2020-10-29 18:54:35 -07:00
Tim Kuehn
6c10e3649f Fix tokio required features 2020-10-29 18:53:04 -07:00
Tim Kuehn
4c6dee13d2 cargo fmt 2020-10-29 00:44:15 -07:00
Bernardo Meurer
e45abe953a tarpc: enable tokio's time feature (#325) 2020-10-29 00:43:38 -07:00
Tim Kuehn
dec3e491b5 Fix unused import 2020-10-27 15:52:11 -07:00
Kitsu
6ce341cf79 Add example for custom transport usage (#322) 2020-10-23 14:28:26 -07:00
Tim Kuehn
b9868250f8 Prepare release of v0.23.0 2020-10-19 11:12:43 -07:00
Urhengulas
a3f1064efe Cargo.toml: Clean + update dependencies 2020-10-18 16:03:04 -07:00
Johann Hemmann
026083d653 Bump tokio from 0.2 to 0.3 (#319)
# Bump `tokio` from 0.2 to 0.3

* `Cargo.toml`:
    * bump `tokio` from 0.2 to 0.3
    * bump `tokio-util` from 0.3 to 0.4
    * remove feature `time` from `tokio`
    * fix alphabetical order of dependencies
* `tarpc::rpc`:
    * `client, server`: `tokio::time::Elapsed` -> `tokio::time::error::Elapsed`
    * `client, transport`, `::tests`: Fix `#[tokio::test]` macro usage
* `tarpc::serde_transport`:
    * `TcpListener.incoming().poll_next(...)` -> `TcpListener.poll_accept(...)`
      -> https://github.com/tokio-rs/tokio/discussions/2983
    * Adapt `AsyncRead`, `AsynWrite` implements in tests
* `README.md`, `tarpc::lib`: Adapt tokio version in docs

# Satisfy clippy

* replace `match`-statements with `matches!(...)`-macro
2020-10-17 17:33:08 -07:00
Tim Kuehn
d27f341bde Prepare release of v0.22.0 2020-08-19 18:35:36 -07:00
Tim Kuehn
2264ebecfc Remove serde_transport::tcp::connect_with.
Instead, serde_transport::tcp::connect returns a future named Connect
that has methods to directly access the framing config. This is
consistent with how serde_transport::tcp::listen returns a future with
methods to access the framing config. In addition to this consistency,
it reduces the API surface and provides a simpler user transition from
"zero config" to "some config".
2020-08-19 17:51:53 -07:00
Tim Kuehn
3207affb4a Update pre-commit for changes to cargo fmt.
--write-mode is now --check.
2020-08-19 17:51:20 -07:00
Andre B. Reis
0602afd50c Make connect() and connect_with() take a FnOnce for the codec (#315) 2020-08-19 16:15:26 -07:00
Tim Kuehn
4343e12217 Fix incorrect documentation 2020-08-18 02:58:11 -07:00
Tim Kuehn
7fda862fb8 Run cargo fmt 2020-08-18 02:55:24 -07:00
Tim Kuehn
aa7b875b1a Expose framing config in serde_transport. 2020-08-18 02:47:41 -07:00
Tim Kuehn
54d6e0e3b6 Add license headers 2020-08-04 17:33:41 -07:00
Tim Kuehn
bea3b442aa Move mod.rs files up one directory.
It's easier in IDEs if the files aren't all named the same.
2020-08-04 17:25:53 -07:00
Tim Kuehn
954a2502e7 Remove duplicate rustdoc 2020-08-02 22:24:09 -07:00
Tim Kuehn
e3f34917c5 Prepare v0.21.1 2020-08-02 21:34:13 -07:00
Tim Kuehn
f65dd05949 Enable documentation for optional features on docs.rs 2020-08-02 20:57:21 -07:00
Tim Kuehn
240c436b34 Ensure Context is Sync. 2020-08-01 14:01:07 -07:00
Tim Kuehn
c9803688cc Ensure Context is Send. 2020-08-01 13:49:25 -07:00
Tim Kuehn
4987094483 Compression example.
Follow-up work: some extension points would be useful allow enabling compression on a per-request basis.

Fixes https://github.com/google/tarpc/issues/200
2020-08-01 13:45:16 -07:00
Tim Kuehn
ff55080193 Minor refactor 2020-07-30 13:11:13 -07:00
Tim Kuehn
258193c932 PubSub example needs to populate the subscription topics. 2020-07-30 11:14:13 -07:00
Tim Kuehn
67823ef5de Get rid of sleeps in PubSub example. 2020-07-30 01:27:31 -07:00
Tim Kuehn
a671457243 Add topics to PubSub example 2020-07-29 22:51:04 -07:00
Tim Kuehn
cf654549da Add documentation to PubSub example. 2020-07-29 18:05:35 -07:00
Tim Kuehn
6a01e32a2d Shut down client dispatch immediately when read half of transport is closed.
Clients can't receive any responses when the read half is closed, which means they can't verify if their requests were served. Therefore, there is no point in writing further requests after the read half is closed.
2020-07-29 13:50:42 -07:00
Tim Kuehn
e6597fab03 Add some error context to client dispatch.
I'm taking this opportunity to experiment with anyhow. So far, results are promising. It was a bit hard to use with Poll<Option<Result<T, E>>> types, so I added a crate-internal helper trait for that.
2020-07-29 12:07:07 -07:00
Tim Kuehn
ebd245a93d Rewrite pubsub example to have the subscriber connect to the publisher.
Fixes https://github.com/google/tarpc/issues/313
2020-07-28 22:10:17 -07:00
Tim Kuehn
3ebc3b5845 Add accessor fns.
- ClientHandler::get_pin_channel
- BaseChannel::get_pin_ref
- serde_transport::Transport::get_ref
2020-07-28 21:27:36 -07:00
Tim Kuehn
0e5973109d Make docs.rs document feature-gated public items. 2020-07-28 19:43:43 -07:00
Tim Kuehn
5f02d7383a Add tests for correct diagnostic output from proc macro-generated compiler errors. 2020-07-27 01:17:06 -07:00
Tim Kuehn
2bae148529 Address clippy lints 2020-07-27 00:04:45 -07:00
Tim Kuehn
42a2e03aab Add better diagnostics for missing 'async' in impls using #[tarpc::server] 2020-07-26 23:47:48 -07:00
Tim Kuehn
b566d0c646 Use #[tarpc::server] in example-service 2020-07-26 18:26:41 -07:00
Jon Cinque
b359f16767 Add concurrent tests using join and join_all
These tests are essentially copies of the `concurrent` test,
specifically using `join` and `join_all`.  Note that for the `join_all`
example to work, all of the `Client` clones must be created before *any*
requests are added, otherwise there will be a lifetime problem with the
second request, saying that second client, `c2`, is still borrowed when
`req1` is dropped.  It would require a larger redesign to fix this
issue.
2020-07-24 09:51:05 -07:00
Greg Fitzgerald
f8681ab134 Migrate examples to tarpc::server 2020-07-22 14:03:23 -07:00
Tim Kuehn
7e521768ab Prepare for v0.21.0 release. 2020-06-26 20:05:02 -07:00
Tim Kuehn
e9b1e7d101 Use #[non_exhaustive] in lieu of _NonExhaustive enum variant. 2020-06-26 19:47:20 -07:00
Taiki Endo
f0322fb892 Remove uses of pin_project::project attribute
pin-project will deprecate the project attribute due to some unfixable
limitations.

Refs: https://github.com/taiki-e/pin-project/issues/225
2020-06-05 20:34:44 -07:00
Patrick Elsen
617daebb88 Add tarpc::server proc-macro as syntactic sugar for async methods. (#302)
The tarpc::server proc-macro can be used to annotate implementations of
services to turn async functions into the proper declarations needed
for tarpc to be able to call them.

This uses the assert_type_eq crate to check that the transformations
applied by the tarpc::server proc macro are correct and lead to code
that compiles.
2020-05-16 10:25:25 -07:00
Tim Kuehn
a11d4fff58 Remove raii_counter 2020-04-22 02:13:02 -07:00
Tim
bf42a04d83 Move the request timeout so that it surrounds the entire call, not just the response future. (#295)
* Move the request timeout so that it surrounds the entire call, not just the response future.

This will enable the timeout earlier, so that a backlog in the outbound request buffer can not cause requests to stall indefinitely.

* Run cargo fmt
2020-02-25 14:42:40 -08:00
Tim Kuehn
06528d6953 Fix clippy lint. 2019-12-19 12:28:26 -08:00
Tim Kuehn
9f00395746 Replace _non_exhaustive fields with #[non_exhaustive] attribute.
The attribute landed on stable rust (1.40.0) today.

Fixes https://github.com/google/tarpc/issues/275
2019-12-19 12:14:34 -08:00
Tim Kuehn
e0674cd57f Make pre-push run on rust stable. 2019-12-19 12:06:06 -08:00
Tim Kuehn
7e49bd9ee7 Clean up badges a bit. 2019-12-16 13:21:00 -08:00
Tim Kuehn
8a1baa9c4e Remove usage of unsafe in rpc::client::channel.
pin_project is actually able to handle the complexities of enum Futures.
2019-12-16 11:10:57 -08:00
Oleg Nosov
31c713d188 Allow raw identifiers + fixed naming + place all code generation methods in impl (#291)
Allows defining services using raw identifiers like:

```rust
pub mod service {
    #[tarpc::service]
    pub trait r#trait {
        async fn r#fn(x: i32) -> Result<u8, String>;
    }
}
```

Also:

- Refactored names (ident -> type)
- All code generation methods placed in impl
2019-12-12 10:13:57 -08:00
Tim Kuehn
d905bc1591 Prepare for tarpc release v0.20.0 2019-12-11 20:47:56 -08:00
Tim Kuehn
7f946c7f83 Make tokio a hard dependency.
Fixes #289
2019-12-11 20:08:36 -08:00
Tim Kuehn
36cfdb6c6f Fix tokio dependency for example-service 2019-12-11 20:01:06 -08:00
Tim Kuehn
dbabe9774f Clean up proc macro code to make clippy happy.
I made a separate TokenStream-returning fn for each item in the previously-huge quote block.
The service fn now primarily performs the duty of creating idents and joining all the TokenStreams.
2019-12-11 17:20:03 -08:00
Tim Kuehn
deb041b8d3 Replace travis-ci badge with github CI workflow badge 2019-12-11 12:54:56 -08:00
Oleg Nosov
85d49477f5 Updated and simplified macros (#290)
* syn updated to latest version
* quote updated to latest version
* proc-macro-2 updated to latest version
* Performance improvements
* Don't create unnecessary TokenStreams for output types
2019-12-11 12:28:24 -08:00
Tim Kuehn
45af6ccdeb Workaround for pubsub example hanging.
The publisher client isn't being dropped when the async fn returns. It
could potentially be something strange in the ThreadPool executor.
2019-12-07 22:01:41 -08:00
Tim Kuehn
917c0c5e2d Use tokio::time::delay_for in lieu of thread::sleep. 2019-12-07 21:28:45 -08:00
Artem Vorotnikov
bbbd43e282 Unify serde transports.
This PR obsoletes the JSON and Bincode transports and instead introduces a unified transport that
is generic over any tokio-serde serialization format as well as AsyncRead + AsyncWrite medium.
This comes with a slight hit for usability (having to manually specify the underlying transport
and codec), but it can be alleviated by making custom freestanding connect and listen fns.
2019-12-07 20:58:08 -08:00
Artem Vorotnikov
f945392b5a Use tokio/stream feature for json-transport 2019-12-07 09:54:33 -08:00
Artem Vorotnikov
f4060779e4 Add GitHub workflow 2019-12-05 20:13:14 -08:00
Artem Vorotnikov
7cc8d9640b Fix clippy warnings 2019-12-05 17:39:53 -08:00
Artem Vorotnikov
7f871f03ef Improve Travis configuration (#282)
* Improve Travis configuration

* Replace 0.0.0.0 with localhost in tests
2019-11-28 14:06:35 -08:00
Artem Vorotnikov
709b966150 Update to Tokio 0.2 and futures 0.3 (#277) 2019-11-27 19:53:44 -08:00
Artem Vorotnikov
5e19b79aa4 Unite most of tarpc into a single crate 2019-11-26 13:08:18 -08:00
Tim Kuehn
6eb806907a Replace Gitter badge with Discord badge. 2019-11-22 14:28:24 -08:00
Tim Kuehn
8250ca31ff Remove --no-default-features from pre-push hook.
It seemingly doesn't work at the root of a virtual workspace. Not sure if this is new behavior or just a new explicit error message.
2019-11-15 17:19:08 -08:00
Tim Kuehn
7cd776143b Fix typo 2019-11-15 17:12:00 -08:00
Artem Vorotnikov
5f6c3d7d98 Port to pin-project 2019-10-09 14:12:24 -07:00
Artem Vorotnikov
915fe3ed4e Use the JSON transport in examples 2019-10-08 19:18:49 -07:00
Artem Vorotnikov
d8c7b9feb2 JSON transport: use Tokio resolver for connect() 2019-10-08 18:03:25 -07:00
Artem Vorotnikov
5ab3866d96 Add Unpin note 2019-10-08 17:15:17 -07:00
Artem Vorotnikov
184ea42033 Upgrade json-transport to Tokio 0.2 2019-10-08 17:15:17 -07:00
Artem Vorotnikov
014c209b8e Do not serialize _non_exhaustive field 2019-10-03 13:09:26 -07:00
Artem Vorotnikov
e91005855c Remove remaining feature flags 2019-10-02 13:07:37 -07:00
Artem Vorotnikov
46bcc0f559 tokio 0.2.0-alpha.4 2019-08-30 09:29:18 -07:00
Artem Vorotnikov
61322ebf41 Clippy fixes 2019-08-29 11:34:38 -07:00
Artem Vorotnikov
db0c9c4182 Cut type_alias_impl_trait feature flag 2019-08-29 11:34:38 -07:00
Artem Vorotnikov
9ee3011687 Update to Tokio 0.3.0-alpha.3 2019-08-29 11:34:38 -07:00
Artem Vorotnikov
5aa4a2cef6 tokio 0.2.0-alpha.2 2019-08-19 23:13:06 -07:00
Artem Vorotnikov
f38a172523 Format code with rustfmt 2019-08-19 13:20:21 -07:00
Tim Kuehn
66dbca80b2 Add missing feature, "compat", back to json-transport dependency on futures-preview. 2019-08-14 09:16:44 -07:00
Tim
61377dd4ff Fix comment in example service
It referred to bincode instead of json.
2019-08-14 08:32:49 -07:00
Tim
cd03f3ff8c Don't mention 'static optional in readme
This isn't supported by the service attribute.
2019-08-13 08:49:11 -07:00
Tim Kuehn
9479963773 Don't enable serde1 by default. I forgot it gives bad compile errors to people who don't have serde in their Cargo.toml. 2019-08-09 01:21:31 -07:00
Tim Kuehn
f974533bf7 Use real crate names rather than internal aliases. It's less confusing for people reading examples. 2019-08-09 01:16:06 -07:00
Tim Kuehn
d560ac6197 Update to the latest rustc nightly. 2019-08-09 01:08:20 -07:00
Tim Kuehn
1cdff15412 Fix needless verbosity in readme 2019-08-09 00:50:06 -07:00
Tim Kuehn
f8ba7d9f4e Make tokio1 serde1 default features 2019-08-08 22:06:09 -07:00
Tim Kuehn
41c1aafaf7 Update tokio to v0.2.0-alpha.1
As part of this, I made an optional tokio feature which, when enabled,
adds utility functions that spawn on the default tokio executor. This
allows for the removal of the runtime crate.

On the one hand, this makes the spawning utils slightly less generic. On
the other hand:

- The fns are just helpers and are easily rewritten by the user.
- Tokio is the clear dominant futures executor, so most people will just
  use these versions.
2019-08-08 21:53:36 -07:00
Tim Kuehn
75d1e877be Update README to talk about deadlines a bit more precisely. 2019-08-08 20:45:37 -07:00
Tim Kuehn
88e1cf558b Generate README.md from cargo readme 2019-08-08 20:31:04 -07:00
Tim Kuehn
50879d2acb Don't bake in Send + 'static.
Send + 'static was baked in to make it possible to spawn futures onto
the default executor. We can accomplish the same thing by offering
helper fns that do the spawning while not requiring it for the rest of
the functionality.

Fixes https://github.com/google/tarpc/issues/212
2019-08-07 13:39:48 -07:00
Tim
13cb14a119 Merge pull request #248 from tikue/service-idents
With this change, the service definitions don't need to be isolated in their own modules.

Given:

```rust
#[tarpc::service]
trait World { ... }
```

Before this would generate the following items
------
- `trait World`
- `fn serve`
- `struct Client`
- `fn new_stub`

`// Implementation details below`
- `enum Request`
- `enum Response`
- `enum ResponseFut`

And now these items
------
- `trait World {    ...    fn serve }`
- `struct WorldClient ... impl WorldClient {    ...    async fn new }`

`// Implementation details below`
- `enum WorldRequest`
- `enum WorldResponse`
- `enum WorldResponseFut`
- `struct ServeWorld` (new manual closure impl because you can't use impl Trait in trait fns)
```
2019-08-05 12:23:35 -07:00
Tim Kuehn
22ef6b7800 Choose a slightly less obvious name for Serve impl.
To hopefully avoid most collisions.
2019-07-30 21:46:16 -07:00
Tim Kuehn
e48e6dfe67 Add nice error message for ident collisions 2019-07-30 21:31:22 -07:00
Tim Kuehn
1b58914d59 Move generated functions under their corresponding items.
- fn serve -> Service::serve
- fn new_stub -> Client::new

This allows the generated function names to remain consistent across
service definitions while preventing collisions.
2019-07-30 20:45:58 -07:00
Tim Kuehn
2f24842b2d Add service name to generated items.
With this change, the service definitions don't need to be isolated in their own modules.
2019-07-30 00:52:30 -07:00
Tim Kuehn
5c485fe608 Add some tests for snake to camel case conversion. 2019-07-30 00:52:30 -07:00
Tim Kuehn
b0319e7db9 Remove macros.rs 2019-07-30 00:51:29 -07:00
Tim Kuehn
a4d9581888 Remove service_registry example 2019-07-29 23:17:08 -07:00
Tim Kuehn
fb5022b1c0 cargo fmt 2019-07-29 22:08:53 -07:00
Tim Kuehn
abb0b5b3ac Rewrite to use proc_macro_attribute 2019-07-29 22:04:04 -07:00
Artem Vorotnikov
49f2641e3c Port to runtime crate 2019-07-29 08:36:06 -07:00
Tim
650c60fe44 Merge pull request #246 from google/rustfmt
Reformat all code using rustfmt
2019-07-22 17:53:48 -07:00
Artem Vorotnikov
1d0bbcb36c Reformat all code using rustfmt 2019-07-23 03:44:16 +03:00
Tim Kuehn
c456ad7fa5 Fix typo 2019-07-22 14:15:27 -07:00
Tim Kuehn
537446a5c9 Remove use of unstable feature 'arbitrary_self_types'.
Turns out, this actually wasn't needed, with some minor refactoring.
2019-07-19 00:48:59 -07:00
Tim Kuehn
94b5b2c431 Add tests for rpc/server/filter.rs 2019-07-16 21:48:11 -07:00
Tim Kuehn
9863433fea Remove unstable feature 'async_closure' 2019-07-16 11:17:18 -07:00
Tim Kuehn
9a27465a25 Remove use of unstable feature 'try_trait' 2019-07-16 11:08:53 -07:00
Tim Kuehn
263cfe1435 Remove unused unstable feature 'integer_atomics' 2019-07-16 10:27:59 -07:00
Tim
6ae5302a70 Merge pull request #240 from tikue/filter-refactor 2019-07-15 23:04:20 -07:00
Tim Kuehn
c67b7283e7 Move bench outside crate. 2019-07-15 22:43:58 -07:00
Tim Kuehn
7b6e98da7b Replace transport integration tests with unit tests.
I want 'cargo test' to run faster.
2019-07-15 22:40:58 -07:00
Tim Kuehn
15b65fa20f Replace usage of Once and unsafe code with once_cell crate. 2019-07-15 20:05:10 -07:00
Tim Kuehn
372900173a Merge origin/master => tikue/filter-refactor 2019-07-15 19:04:56 -07:00
Tim Kuehn
1089415451 Make server methods more composable.
-- Connection Limits

The problem with having ConnectionFilter default-enabled is elaborated on in https://github.com/google/tarpc/issues/217. The gist of it is not all servers want a policy based on `SocketAddr`. This PR allows customizing the behavior of ConnectionFilter, at the cost of not having it enabled by default. However, enabling it is as simple as one line:

incoming.max_channels_per_key(10, ip_addr)

The second argument is a key function that takes the user-chosen transport and returns some hashable, equatable, cloneable key. In the above example, it returns an `IpAddr`.

This also allows the `Transport` trait to have the addr fns removed, which means it has become simply an alias for `Stream + Sink`.

-- Per-Channel Request Throttling

With respect to Channel's throttling behavior, the same argument applies. There isn't a one size fits all solution to throttling requests, and the policy applied by tarpc is just one of potentially many solutions. As such, `Channel` is now a trait that offers a few combinators, one of which is throttling:

channel.max_concurrent_requests(10).respond_with(serve(Server))

This functionality is also available on the existing `Handler` trait, which applies it to all incoming channels and can be used in tandem with connection limits:

incoming
    .max_channels_per_key(10, ip_addr)
    .max_concurrent_requests_per_channel(10).respond_with(serve(Server))

-- Global Request Throttling

I've entirely removed the overall request limit enforced across all channels. This functionality is easily gotten back via [`StreamExt::buffer_unordered`](https://rust-lang-nursery.github.io/futures-api-docs/0.3.0-alpha.1/futures/stream/trait.StreamExt.html#method.buffer_unordered), with the difference being that the previous behavior allowed you to spawn channels onto different threads, whereas `buffer_unordered ` means the `Channels` are handled on a single thread (the per-request handlers are still spawned). Considering the existing options, I don't believe that the benefit provided by this functionality held its own.
2019-07-15 19:01:46 -07:00
Tim Kuehn
8dbeeff0eb Fix some lint warnings 2019-07-15 18:21:11 -07:00
iovxw
85312d430c Update to futures-preview 0.3.0-alpha.17 (#238)
* Update to futures-preview 0.3.0-alpha.17

* Update feature gate

async_closure was moved out from async_await
2019-07-15 18:20:19 -07:00
Adam Wright
9843af9e00 Reflow some text in the readme (#239) 2019-07-15 17:53:56 -07:00
Tim Kuehn
a6bd423ef0 Remove use of external crate 'libtest'. 2019-07-15 17:52:27 -07:00
Kevin Ji
146496d08c README: Use the SVG Travis badge (#236) 2019-06-08 10:31:08 -07:00
Tim Kuehn
b562051c38 Bump tarpc-lib to 0.6.1 to fix request cancellation issue. 2019-05-22 01:33:00 -07:00
Tim Kuehn
fe164ca368 Fix bug where expired request wasn't propagating cancellation.
DispatchResponse was incorrectly marking itself as complete even when
expiring without receiving a response. This can cause a chain of
deleterious effects:

- Request cancellation won't propagate when request timers expire.
- Which causes client dispatch to have an inconsistent in-flight request
  map containing stale IDs.
- Which can cause clients to hang rather than exiting.
2019-05-22 01:29:01 -07:00
Artem Vorotnikov
950ad5187c Add JSON transport (#219) 2019-05-20 18:45:41 -07:00
Tim Kuehn
e6ab69c314 Keep commented-out code in each block so that rustdoc is happy. 2019-05-15 16:31:11 -07:00
Tim Kuehn
373dcbed57 Clarify dependencies required for README example
Fixes https://github.com/google/tarpc/issues/232
2019-05-15 15:40:25 -07:00
Tim Kuehn
ce9c057b1b Remove await!() macro from readme 2019-05-13 10:16:25 -07:00
Tim Kuehn
6745cee72c Bump tarpc to v0.18.0 2019-05-11 13:00:35 -07:00
Artem Vorotnikov
31abea18b3 Update to futures-preview 0.3.0-alpha.16 (#230) 2019-05-11 15:18:52 -04:00
Tim Kuehn
593ac135ce Remove stable features from doc examples 2019-04-30 13:18:39 -07:00
Tim Kuehn
05a924d27f Bump tarpc version to 0.17.0 2019-04-30 13:01:45 -07:00
Artem Vorotnikov
af9d71ed0d Bump futures to 0.3.0-alpha.15 (#226) 2019-04-28 20:13:06 -07:00
Tim Kuehn
9b90f6ae51 Bump to v0.16.0 2019-04-16 10:46:53 -07:00
Tim
bbfc8ac352 Merge pull request #216 from vorot93/futures-master
* Use upstream sink compat shims
* Port to new Sink trait introduced in e101c891f04aba34ee29c6a8cd8321563c7e0161
* rustfmt
* Port to std::task::Context
* Add Google license header to bincode-transport/src/compat.rs
* Remove compat for it is no longer needed
* future::join as freestanding function
* Simplify dependencies
* Depend on futures-preview 0.3.0-alpha.14
* Fix infinite recursion
2019-04-16 08:43:10 -07:00
Tim
ad86a967ba Fix infinite recursion 2019-04-16 18:27:42 +03:00
Artem Vorotnikov
58a0eced19 Depend on futures-preview 0.3.0-alpha.14 2019-04-15 21:16:20 +03:00
Artem Vorotnikov
46fffd13e7 Simplify dependencies 2019-04-15 21:14:25 +03:00
Artem Vorotnikov
6c8d4be462 future::join as freestanding function 2019-04-15 20:30:04 +03:00
Artem Vorotnikov
e3a517bf0d Remove compat and transmute for they are no longer needed 2019-04-15 20:24:09 +03:00
Artem Vorotnikov
f4e22bdc2e Port to std::task::Context 2019-04-15 20:22:15 +03:00
Artem Vorotnikov
46f56fbdc0 Add Google license header to bincode-transport/src/compat.rs 2019-04-15 20:22:15 +03:00
Artem Vorotnikov
8665655592 Fix test client breakage by 9100ea46f997f24d4bc8c1764d0fe3ff8226ad2a 2019-04-15 20:22:15 +03:00
Artem Vorotnikov
4569d26d81 rustfmt 2019-04-15 20:22:15 +03:00
Artem Vorotnikov
b8b92ddb5f Workaround for stack overflow caused by 2a95710db0e2d85094938776ebb4f270bc389c41 2019-04-15 20:16:48 +03:00
Artem Vorotnikov
8dd3390876 Port to new Sink trait introduced in e101c891f04aba34ee29c6a8cd8321563c7e0161 2019-04-15 20:16:48 +03:00
Artem Vorotnikov
06c420b60c Use upstream sink compat shims 2019-04-15 20:16:48 +03:00
Artem Vorotnikov
a7fb4d22cc Switch to master branch of futures-preview 2019-04-15 20:16:48 +03:00
Tim
b1cd5f34e5 Don't panic in pump_write when a client is dropped and there are more calls to poll. (#221)
This can happen in cases where a response is being read and the client isn't around.

Fixes #220
2019-04-15 09:42:53 -07:00
Artem Vorotnikov
088e5f8f2c Remove deprecated feature from bincode dependency (#218) 2019-04-04 10:34:11 -07:00
Tim Kuehn
4e0be5b626 Publish tarpc v0.15.0 2019-03-26 21:13:41 -07:00
Artem Vorotnikov
5516034bbc Use libtest crate (#213) 2019-03-24 22:29:01 -07:00
Artem Vorotnikov
06544faa5a Update to futures 0.3.0-alpha.13 (#211) 2019-02-26 09:32:41 -08:00
Tim Kuehn
39737b720a Cargo fmt 2019-01-17 10:37:16 -08:00
Tim Kuehn
0f36985440 Update for latest changes to futures.
Fixes #209.
2019-01-17 10:37:03 -08:00
Tyler Bindon
959bb691cd Update regex to match diffs output by cargo fmt. (#208)
It appears the header of the diffs output by cargo fmt have changed. It now says "Diff in /blah/blah/blah.rs at line 99:" Matching on lines starting with + or - should be more future-proof against changes to the surroundings.
2018-12-09 01:59:35 -08:00
Tim
2a3162c5fa Cargo feature 'rename-dependency' is stabilized 2018-11-21 11:03:41 -08:00
Tim Kuehn
0cc976b729 cargo fmt 2018-11-06 17:01:27 -08:00
Tim Kuehn
4d2d3f24c6 Address Clippy lints 2018-11-06 17:00:15 -08:00
Tim Kuehn
2c7c64841f Add symlink tarpc/README.md -> README.md 2018-10-29 16:11:01 -07:00
Tim Kuehn
4ea142d0f3 Remove coverage badge.
It hasn't been updated in over 2 years.
2018-10-29 11:40:09 -07:00
Tim Kuehn
00751d2518 external_doc doesn't work with crates.io yet :( 2018-10-29 11:05:09 -07:00
Tim Kuehn
4394a52b65 Add doc tests to .travis.yml 2018-10-29 10:55:12 -07:00
Tim Kuehn
70938501d7 Use eternal_doc for tarpc package. This will ensure our README is always up-to-date. 2018-10-29 10:53:34 -07:00
Tim Kuehn
d5f5cf4300 Bump versions. 2018-10-29 10:43:41 -07:00
Tim Kuehn
e2c4164d8c Remove unused feature enablements from tarpc 2018-10-25 11:44:38 -07:00
Tim Kuehn
78124ef7a8 Cargo fmt 2018-10-25 11:44:18 -07:00
Tim Kuehn
096d354b7e Remove unused features 2018-10-25 11:41:08 -07:00
Tim
7ad0e4b070 Service registry (#204)
# Changes

## Client is now a trait
And `Channel<Req, Resp>` implements `Client<Req, Resp>`. Previously, `Client<Req, Resp>` was a thin wrapper around `Channel<Req, Resp>`.

This was changed to allow for mapping the request and response types. For example, you can take a `channel: Channel<Req, Resp>` and do:

```rust
channel
    .with_request(|req: Req2| -> Req { ... })
    .map_response(|resp: Resp| -> Resp2 { ... })
```

...which returns a type that implements `Client<Req2, Resp2>`.

### Why would you want to map request and response types?

The main benefit of this is that it enables creating different client types backed by the same channel. For example, you could run multiple clients multiplexing requests over a single `TcpStream`. I have a demo in `tarpc/examples/service_registry.rs` showing how you might do this with a bincode transport. I am considering factoring out the service registry portion of that to an actual library, because it's doing pretty cool stuff. For this PR, though, it'll just be part of the example.

## Client::new is now client::new

This is pretty minor, but necessary because async fns can't currently exist on traits. I changed `Server::new` to match this as well.

## Macro-generated Clients are generic over the backing Client.

This is a natural consequence of the above change. However, it is transparent to the user by keeping `Channel<Req, Resp>` as the default type for the `<C: Client>` type parameter. `new_stub` returns `Client<Channel<Req, Resp>>`, and other clients can be created via the `From` trait.

## example-service/ now has two binaries, one for client and one for server.

This serves as a "realistic" example of how one might set up a service. The other examples all run the client and server in the same binary, which isn't realistic in distributed systems use cases.

## `service!` trait fns take self by value.

Services are already cloned per request, so this just passes on that flexibility to the trait implementers.

# Open Questions

In the service registry example, multiple services are running on a single port, and thus multiple clients are sending requests over a single `TcpStream`. This has implications for throttling: [`max_in_flight_requests_per_connection`](https://github.com/google/tarpc/blob/master/rpc/src/server/mod.rs#L57-L60) will set a maximum for the sum of requests for all clients sharing a single connection. I think this is reasonable behavior, but users may expect this setting to act like `max_in_flight_requests_per_client`.

Fixes #103 #153 #205
2018-10-25 11:22:55 -07:00
Tim
64755d5329 Update futures 2018-10-19 11:19:25 -07:00
Tim Kuehn
3071422132 Helper fn to create transports 2018-10-18 00:24:26 -07:00
Tim Kuehn
8847330dbe impl From<S> for bincode::Transport<S> 2018-10-18 00:24:08 -07:00
Tim Kuehn
6d396520f4 Don't allow empty service invocations 2018-10-18 00:23:34 -07:00
Tim Kuehn
79a2f7fe2f Replace tokio-serde-bincode with async-bincode 2018-10-17 20:24:31 -07:00
Tim Kuehn
af66841f68 Remove keyword 2018-10-17 11:59:09 -07:00
Tim
1ab4cfdff9 Make Request and Resonse enums' docs public, because they show up in the serve fn. 2018-10-16 23:02:52 -07:00
Tim
f7e03eeeb7 Fix up readme 2018-10-16 22:28:57 -07:00
100 changed files with 9707 additions and 4562 deletions

67
.github/workflows/main.yml vendored Normal file
View File

@@ -0,0 +1,67 @@
on:
push:
branches:
- master
pull_request:
branches:
- master
name: Continuous integration
jobs:
check:
name: Check
runs-on: ubuntu-latest
steps:
- name: Cancel previous
uses: styfle/cancel-workflow-action@0.10.0
with:
access_token: ${{ github.token }}
- uses: actions/checkout@v3
- uses: dtolnay/rust-toolchain@stable
- run: cargo check --all-features
test:
name: Test Suite
runs-on: ubuntu-latest
steps:
- name: Cancel previous
uses: styfle/cancel-workflow-action@0.10.0
with:
access_token: ${{ github.token }}
- uses: actions/checkout@v3
- uses: dtolnay/rust-toolchain@stable
- run: cargo test
- run: cargo test --manifest-path tarpc/Cargo.toml --features serde1
- run: cargo test --manifest-path tarpc/Cargo.toml --features tokio1
- run: cargo test --manifest-path tarpc/Cargo.toml --features serde-transport
- run: cargo test --manifest-path tarpc/Cargo.toml --features tcp
- run: cargo test --all-features
fmt:
name: Rustfmt
runs-on: ubuntu-latest
steps:
- name: Cancel previous
uses: styfle/cancel-workflow-action@0.10.0
with:
access_token: ${{ github.token }}
- uses: actions/checkout@v3
- uses: dtolnay/rust-toolchain@stable
with:
components: rustfmt
- run: cargo fmt --all -- --check
clippy:
name: Clippy
runs-on: ubuntu-latest
steps:
- name: Cancel previous
uses: styfle/cancel-workflow-action@0.10.0
with:
access_token: ${{ github.token }}
- uses: actions/checkout@v3
- uses: dtolnay/rust-toolchain@stable
with:
components: clippy
- run: cargo clippy --all-features -- -D warnings

View File

@@ -1,12 +0,0 @@
language: rust
rust:
- nightly
sudo: false
cache: cargo
os:
- osx
- linux
script:
- cargo test --all --all-features

View File

@@ -1,10 +1,11 @@
[workspace]
resolver = "2"
members = [
"example-service",
"rpc",
"trace",
"bincode-transport",
"tarpc",
"plugins",
]
[profile.dev]
split-debuginfo = "unpacked"

183
README.md
View File

@@ -1,9 +1,20 @@
## tarpc: Tim & Adam's RPC lib
[![Travis-CI Status](https://travis-ci.org/google/tarpc.png?branch=master)](https://travis-ci.org/google/tarpc)
[![Coverage Status](https://coveralls.io/repos/github/google/tarpc/badge.svg?branch=master)](https://coveralls.io/github/google/tarpc?branch=master)
[![Software License](https://img.shields.io/badge/license-MIT-brightgreen.svg)](LICENSE)
[![Latest Version](https://img.shields.io/crates/v/tarpc.svg)](https://crates.io/crates/tarpc)
[![Join the chat at https://gitter.im/tarpc/Lobby](https://badges.gitter.im/tarpc/Lobby.svg)](https://gitter.im/tarpc/Lobby?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge)
[![Crates.io][crates-badge]][crates-url]
[![MIT licensed][mit-badge]][mit-url]
[![Build status][gh-actions-badge]][gh-actions-url]
[![Discord chat][discord-badge]][discord-url]
[crates-badge]: https://img.shields.io/crates/v/tarpc.svg
[crates-url]: https://crates.io/crates/tarpc
[mit-badge]: https://img.shields.io/badge/license-MIT-blue.svg
[mit-url]: LICENSE
[gh-actions-badge]: https://github.com/google/tarpc/workflows/Continuous%20integration/badge.svg
[gh-actions-url]: https://github.com/google/tarpc/actions?query=workflow%3A%22Continuous+integration%22
[discord-badge]: https://img.shields.io/discord/647529123996237854.svg?logo=discord&style=flat-square
[discord-url]: https://discord.gg/gXwpdSt
# tarpc
<!-- cargo-sync-readme start -->
*Disclaimer*: This is not an official Google product.
@@ -26,113 +37,119 @@ architectures. Two well-known ones are [gRPC](http://www.grpc.io) and
tarpc differentiates itself from other RPC frameworks by defining the schema in code,
rather than in a separate language such as .proto. This means there's no separate compilation
process, and no cognitive context switching between different languages. Additionally, it
works with the community-backed library serde: any serde-serializable type can be used as
arguments to tarpc fns.
process, and no context switching between different languages.
Some other features of tarpc:
- Pluggable transport: any type implementing `Stream<Item = Request> + Sink<Response>` can be
used as a transport to connect the client and server.
- `Send + 'static` optional: if the transport doesn't require it, neither does tarpc!
- Cascading cancellation: dropping a request will send a cancellation message to the server.
The server will cease any unfinished work on the request, subsequently cancelling any of its
own requests, repeating for the entire chain of transitive dependencies.
- Configurable deadlines and deadline propagation: request deadlines default to 10s if
unspecified. The server will automatically cease work when the deadline has passed. Any
requests sent by the server that use the request context will propagate the request deadline.
For example, if a server is handling a request with a 10s deadline, does 2s of work, then
sends a request to another server, that server will see an 8s deadline.
- Distributed tracing: tarpc is instrumented with
[tracing](https://github.com/tokio-rs/tracing) primitives extended with
[OpenTelemetry](https://opentelemetry.io/) traces. Using a compatible tracing subscriber like
[Jaeger](https://github.com/open-telemetry/opentelemetry-rust/tree/main/opentelemetry-jaeger),
each RPC can be traced through the client, server, and other dependencies downstream of the
server. Even for applications not connected to a distributed tracing collector, the
instrumentation can also be ingested by regular loggers like
[env_logger](https://github.com/env-logger-rs/env_logger/).
- Serde serialization: enabling the `serde1` Cargo feature will make service requests and
responses `Serialize + Deserialize`. It's entirely optional, though: in-memory transports can
be used, as well, so the price of serialization doesn't have to be paid when it's not needed.
## Usage
**NB**: *this example is for master. Are you looking for other
[versions](https://docs.rs/tarpc)?*
Add to your `Cargo.toml` dependencies:
```toml
tarpc = "0.12.0"
tarpc-plugins = "0.4.0"
tarpc = "0.34"
```
The `tarpc::service` attribute expands to a collection of items that form an rpc service.
These generated types make it easy and ergonomic to write servers with less boilerplate.
Simply implement the generated service trait, and you're off to the races!
The `service!` macro expands to a collection of items that form an
rpc service. In the above example, the macro is called within the
`hello_service` module. This module will contain a `Client` stub and `Service` trait. There is
These generated types make it easy and ergonomic to write servers without dealing with serialization
directly. Simply implement one of the generated traits, and you're off to the
races!
## Example
## Example:
This example uses [tokio](https://tokio.rs), so add the following dependencies to
your `Cargo.toml`:
Here's a small service.
```toml
anyhow = "1.0"
futures = "0.3"
tarpc = { version = "0.31", features = ["tokio1"] }
tokio = { version = "1.0", features = ["rt-multi-thread", "macros"] }
```
In the following example, we use an in-process channel for communication between
client and server. In real code, you will likely communicate over the network.
For a more real-world example, see [example-service](example-service).
First, let's set up the dependencies and service definition.
```rust
#![feature(plugin, futures_api, pin, arbitrary_self_types, await_macro, async_await)]
#![plugin(tarpc_plugins)]
use futures::{
compat::TokioDefaultSpawner,
future::{self, Ready},
prelude::*,
spawn,
};
use tarpc::rpc::{
use futures::future::{self, Ready};
use tarpc::{
client, context,
server::{self, Handler, Server},
server::{self, Channel},
};
use std::io;
// This is the service definition. It looks a lot like a trait definition.
// It defines one RPC, hello, which takes one arg, name, and returns a String.
tarpc::service! {
rpc hello(name: String) -> String;
#[tarpc::service]
trait World {
/// Returns a greeting for name.
async fn hello(name: String) -> String;
}
```
// This is the type that implements the generated Service trait. It is the business logic
This service definition generates a trait called `World`. Next we need to
implement it for our Server struct.
```rust
// This is the type that implements the generated World trait. It is the business logic
// and is used to start the server.
#[derive(Clone)]
struct HelloServer;
impl Service for HelloServer {
// Each defined rpc generates two items in the trait, a fn that serves the RPC, and
// an associated type representing the future output by the fn.
type HelloFut = Ready<String>;
fn hello(&self, _: context::Context, name: String) -> Self::HelloFut {
future::ready(format!("Hello, {}!", name))
impl World for HelloServer {
async fn hello(self, _: context::Context, name: String) -> String {
format!("Hello, {name}!")
}
}
```
async fn run() -> io::Result<()> {
// bincode_transport is provided by the associated crate bincode-transport. It makes it easy
// to start up a serde-powered bincode serialization strategy over TCP.
let transport = bincode_transport::listen(&"0.0.0.0:0".parse().unwrap())?;
let addr = transport.local_addr();
Lastly let's write our `main` that will start the server. While this example uses an
[in-process channel](transport::channel), tarpc also ships a generic [`serde_transport`]
behind the `serde-transport` feature, with additional [TCP](serde_transport::tcp) functionality
available behind the `tcp` feature.
// The server is configured with the defaults.
let server = Server::new(server::Config::default())
// Server can listen on any type that implements the Transport trait.
.incoming(transport)
// Close the stream after the client connects
.take(1)
// serve is generated by the service! macro. It takes as input any type implementing
// the generated Service trait.
.respond_with(serve(HelloServer));
```rust
#[tokio::main]
async fn main() -> anyhow::Result<()> {
let (client_transport, server_transport) = tarpc::transport::channel::unbounded();
spawn!(server).unwrap();
let server = server::BaseChannel::with_defaults(server_transport);
tokio::spawn(server.execute(HelloServer.serve()));
let transport = await!(bincode_transport::connect(&addr))?;
// WorldClient is generated by the #[tarpc::service] attribute. It has a constructor `new`
// that takes a config and any Transport as input.
let client = WorldClient::new(client::Config::default(), client_transport).spawn();
// new_stub is generated by the service! macro. Like Server, it takes a config and any
// Transport as input, and returns a Client, also generated by the macro.
// by the service mcro.
let mut client = await!(new_stub(client::Config::default(), transport));
// The client has an RPC method for each RPC defined in service!. It takes the same args
// as defined, with the addition of a Context, which is always the first arg. The Context
// The client has an RPC method for each RPC defined in the annotated trait. It takes the same
// args as defined, with the addition of a Context, which is always the first arg. The Context
// specifies a deadline and trace information which can be helpful in debugging requests.
let hello = await!(client.hello(context::current(), "Stim".to_string()))?;
let hello = client.hello(context::current(), "Stim".to_string()).await?;
println!("{}", hello);
println!("{hello}");
Ok(())
}
fn main() {
tokio::run(run()
.map_err(|e| eprintln!("Oh no: {}", e))
.boxed()
.compat(TokioDefaultSpawner),
);
}
```
## Service Documentation
@@ -140,12 +157,6 @@ fn main() {
Use `cargo doc` as you normally would to see the documentation created for all
items expanded by a `service!` invocation.
## Contributing
<!-- cargo-sync-readme end -->
To contribute to tarpc, please see [CONTRIBUTING](CONTRIBUTING.md).
## License
tarpc is distributed under the terms of the MIT license.
See [LICENSE](LICENSE) for details.
License: MIT

View File

@@ -1,6 +1,401 @@
## 0.34.0 (2023-12-29)
### Breaking Changes
- `#[tarpc::server]` is no more! Service traits now use async fns.
- `Channel::execute` no longer spawns request handlers. Async-fn-in-traits makes it impossible to
add a Send bound to the future returned by `Serve::serve`. Instead, `Channel::execute` returns a
stream of futures, where each future is a request handler. To achieve the former behavior:
```rust
channel.execute(server.serve())
.for_each(|rpc| { tokio::spawn(rpc); })
```
### New Features
- Request hooks are added to the serve trait, so that it's easy to hook in cross-cutting
functionality like throttling, authorization, etc.
- The Client trait is back! This makes it possible to hook in generic client functionality like load
balancing, retries, etc.
## 0.33.0 (2023-04-01)
### Breaking Changes
Opentelemetry dependency version increased to 0.18.
## 0.32.0 (2023-03-24)
### Breaking Changes
- As part of a fix to return more channel errors in RPC results, a few error types have changed:
0. `client::RpcError::Disconnected` was split into the following errors:
- Shutdown: the client was shutdown, either intentionally or due to an error. If due to an
error, pending RPCs should see the more specific errors below.
- Send: an RPC message failed to send over the transport. Only the RPC that failed to be sent
will see this error.
- Receive: a fatal error occurred while receiving from the transport. All in-flight RPCs will
receive this error.
0. `client::ChannelError` and `server::ChannelError` are unified in `tarpc::ChannelError`.
Previously, server transport errors would not indicate during which activity the transport
error occurred. Now, just like the client already was, it will be specific: reading, readying,
sending, flushing, or closing.
## 0.31.0 (2022-11-03)
### New Features
This release adds Unix Domain Sockets to the `serde_transport` module.
To use it, enable the "unix" feature. See the docs for more information.
## 0.30.0 (2022-08-12)
### Breaking Changes
- Some types that impl Future are now annotated with `#[must_use]`. Code that previously created
these types but did not use them will now receive a warning. Code that disallows warnings will
receive a compilation error.
### Fixes
- Servers will more reliably clean up request state for requests with long deadlines when response
processing is aborted without sending a response.
### Other Changes
- `TrackedRequest` now contains a response guard that can be used to ensure state cleanup for
aborted requests. (This was already handled automatically by `InFlightRequests`).
- When the feature serde-transport is enabled, the crate tokio_serde is now re-exported.
## 0.29.0 (2022-05-26)
### Breaking Changes
`Context.deadline` is now serialized as a Duration. This prevents clock skew from affecting deadline
behavior. For more details see https://github.com/google/tarpc/pull/367 and its [related
issue](https://github.com/google/tarpc/issues/366).
## 0.28.0 (2022-04-06)
### Breaking Changes
- The minimum supported Rust version has increased to 1.58.0.
- The version of opentelemetry depended on by tarpc has increased to 0.17.0.
## 0.27.2 (2021-10-08)
### Fixes
Clients will now close their transport before dropping it. An attempt at a clean shutdown can help
the server drop its connections more quickly.
## 0.27.1 (2021-09-22)
### Breaking Changes
#### RPC error type is changing
RPC return types are changing from `Result<Response, io::Error>` to `Result<Response,
tarpc::client::RpcError>`.
Becaue tarpc is a library, not an application, it should strive to
use structured errors in its API so that users have maximal flexibility
in how they handle errors. io::Error makes that hard, because it is a
kitchen-sink error type.
RPCs in particular only have 3 classes of errors:
- The connection breaks.
- The request expires.
- The server decides not to process the request.
RPC responses can also contain application-specific errors, but from the
perspective of the RPC library, those are opaque to the framework, classified
as successful responsees.
### Open Telemetry
The Opentelemetry dependency is updated to version 0.16.x.
## 0.27.0 (2021-09-22)
This version was yanked due to tarpc-plugins version mismatches.
## 0.26.0 (2021-04-14)
### New Features
#### Tracing
tarpc is now instrumented with tracing primitives extended with
OpenTelemetry traces. Using a compatible tracing-opentelemetry
subscriber like Jaeger, each RPC can be traced through the client,
server, amd other dependencies downstream of the server. Even for
applications not connected to a distributed tracing collector, the
instrumentation can also be ingested by regular loggers like env_logger.
### Breaking Changes
#### Logging
Logged events are now structured using tracing. For applications using a
logger and not a tracing subscriber, these logs may look different or
contain information in a less consumable manner. The easiest solution is
to add a tracing subscriber that logs to stdout, such as
tracing_subscriber::fmt.
#### Context
- Context no longer has parent_span, which was actually never needed,
because the context sent in an RPC is inherently the parent context.
For purposes of distributed tracing, the client side of the RPC has all
necessary information to link the span to its parent; the server side
need do nothing more than export the (trace ID, span ID) tuple.
- Context has a new field, SamplingDecision, which has two variants,
Sampled and Unsampled. This field can be used by downstream systems to
determine whether a trace needs to be exported. If the parent span is
sampled, the expectation is that all child spans be exported, as well;
to do otherwise could result in lossy traces being exported. Note that
if an Openetelemetry tracing subscriber is not installed, the fallback
context will still be used, but the Context's sampling decision will
always be inherited by the parent Context's sampling decision.
- Context::scope has been removed. Context propagation is now done via
tracing's task-local spans. Spans can be propagated across tasks via
Span::in_scope. When a service receives a request, it attaches an
Opentelemetry context to the local Span created before request handling,
and this context contains the request deadline. This span-local deadline
is retrieved by Context::current, but it cannot be modified so that
future Context::current calls contain a different deadline. However, the
deadline in the context passed into an RPC call will override it, so
users can retrieve the current context and then modify the deadline
field, as has been historically possible.
- Context propgation precedence changes: when an RPC is initiated, the
current Span's Opentelemetry context takes precedence over the trace
context passed into the RPC method. If there is no current Span, then
the trace context argument is used as it has been historically. Note
that Opentelemetry context propagation requires an Opentelemetry
tracing subscriber to be installed.
#### Server
- The server::Channel trait now has an additional required associated
type and method which returns the underlying transport. This makes it
more ergonomic for users to retrieve transport-specific information,
like IP Address. BaseChannel implements Channel::transport by returning
the underlying transport, and channel decorators like Throttler just
delegate to the Channel::transport method of the wrapped channel.
#### Client
- NewClient::spawn no longer returns a result, as spawn can't fail.
### References
1. https://github.com/tokio-rs/tracing
2. https://opentelemetry.io
3. https://github.com/open-telemetry/opentelemetry-rust/tree/main/opentelemetry-jaeger
4. https://github.com/env-logger-rs/env_logger
## 0.25.0 (2021-03-10)
### Breaking Changes
#### Major server module refactoring
1. Renames
Some of the items in this module were renamed to be less generic:
- Handler => Incoming
- ClientHandler => Requests
- ResponseHandler => InFlightRequest
- Channel::{respond_with => requests}
In the case of Handler: handler of *what*? Now it's a bit clearer that this is a stream of Channels
(aka *incoming* connections).
Similarly, ClientHandler was a stream of requests over a single connection. Hopefully Requests
better reflects that.
ResponseHandler was renamed InFlightRequest because it no longer contains the serving function.
Instead, it is just the request, plus the response channel and an abort hook. As a result of this,
Channel::respond_with underwent a big change: it used to take the serving function and return a
ClientHandler; now it has been renamed Channel::requests and does not take any args.
2. Execute methods
All methods thats actually result in responses being generated have been consolidated into methods
named `execute`:
- InFlightRequest::execute returns a future that completes when a response has been generated and
sent to the server Channel.
- Requests::execute automatically spawns response handlers for all requests over a single channel.
- Channel::execute is a convenience for `channel.requests().execute()`.
- Incoming::execute automatically spawns response handlers for all requests over all channels.
3. Removal of Server.
server::Server was removed, as it provided no value over the Incoming/Channel abstractions.
Additionally, server::new was removed, since it just returned a Server.
#### Client RPC methods now take &self
This required the breaking change of removing the Client trait. The intent of the Client trait was
to facilitate the decorator pattern by allowing users to create their own Clients that added
behavior on top of the base client. Unfortunately, this trait had become a maintenance burden,
consistently causing issues with lifetimes and the lack of generic associated types. Specifically,
it meant that Client impls could not use async fns, which is no longer tenable today, with channel
libraries moving to async fns.
#### Servers no longer send deadline-exceed responses.
The deadline-exceeded response was largely redundant, because the client
shouldn't normally be waiting for such a response, anyway -- the normal
client will automatically remove the in-flight request when it reaches
the deadline.
This also allows for internalizing the expiration+cleanup logic entirely
within BaseChannel, without having it leak into the Channel trait and
requiring action taken by the Requests struct.
#### Clients no longer send cancel messages when the request deadline is exceeded.
The server already knows when the request deadline was exceeded, so the client didn't need to inform
it.
### Fixes
- When a channel is dropped, all in-flight requests for that channel are now aborted.
## 0.24.1 (2020-12-28)
### Breaking Changes
Upgrades tokio to 1.0.
## 0.24.0 (2020-12-28)
This release was yanked.
## 0.23.0 (2020-10-19)
### Breaking Changes
Upgrades tokio to 0.3.
## 0.22.0 (2020-08-02)
This release adds some flexibility and consistency to `serde_transport`, with one new feature and
one small breaking change.
### New Features
`serde_transport::tcp` now exposes framing configuration on `connect()` and `listen()`. This is
useful if, for instance, you want to send requests or responses that are larger than the maximum
payload allowed by default:
```rust
let mut transport = tarpc::serde_transport::tcp::connect(server_addr, Json::default);
transport.config_mut().max_frame_length(4294967296);
let mut client = MyClient::new(client::Config::default(), transport.await?).spawn()?;
```
### Breaking Changes
The codec argument to `serde_transport::tcp::connect` changed from a Codec to impl Fn() -> Codec,
to be consistent with `serde_transport::tcp::listen`. While only one Codec is needed, more than one
person has been tripped up by the inconsistency between `connect` and `listen`. Unfortunately, the
compiler errors are not much help in this case, so it was decided to simply do the more intuitive
thing so that the compiler doesn't need to step in in the first place.
## 0.21.1 (2020-08-02)
### New Features
#### #[tarpc::server] diagnostics
When a service impl uses #[tarpc::server], only `async fn`s are re-written. This can lead to
confusing compiler errors about missing associated types:
```
error: not all trait items implemented, missing: `HelloFut`
--> $DIR/tarpc_server_missing_async.rs:9:1
|
9 | impl World for HelloServer {
| ^^^^
```
The proc macro now provides better diagnostics for this case:
```
error: not all trait items implemented, missing: `HelloFut`
--> $DIR/tarpc_server_missing_async.rs:9:1
|
9 | impl World for HelloServer {
| ^^^^
error: hint: `#[tarpc::server]` only rewrites async fns, and `fn hello` is not async
--> $DIR/tarpc_server_missing_async.rs:10:5
|
10 | fn hello(name: String) -> String {
| ^^
```
### Bug Fixes
#### Fixed client hanging when server shuts down
Previously, clients would ignore when the read half of the transport was closed, continuing to
write requests. This didn't make much sense, because without the ability to receive responses,
clients have no way to know if requests were actually processed by the server. It basically just
led to clients that would hang for a few seconds before shutting down. This has now been
corrected: clients will immediately shut down when the read-half of the transport is closed.
#### More docs.rs documentation
Previously, docs.rs only documented items enabled by default, notably leaving out documentation
for tokio and serde features. This has now been corrected: docs.rs should have documentation
for all optional features.
## 0.21.0 (2020-06-26)
### New Features
A new proc macro, `#[tarpc::server]` was added! This enables service impls to elide the boilerplate
of specifying associated types for each RPC. With the ubiquity of async-await, most code won't have
nameable futures and will just be boxing the return type anyway. This macro does that for you.
### Breaking Changes
- Enums had `_non_exhaustive` fields replaced with the #[non_exhaustive] attribute.
### Bug Fixes
- https://github.com/google/tarpc/issues/304
A race condition in code that limits number of connections per client caused occasional panics.
- https://github.com/google/tarpc/pull/295
Made request timeouts account for time spent in the outbound buffer. Previously, a large outbound
queue would lead to requests not timing out correctly.
## 0.20.0 (2019-12-11)
### Breaking Changes
1. tarpc has updated its tokio dependency to the latest 0.2 version.
2. The tarpc crates have been unified into just `tarpc`, with new Cargo features to enable
functionality.
- The bincode-transport and json-transport crates are deprecated and superseded by
the `serde_transport` module, which unifies much of the logic present in both crates.
## 0.13.0 (2018-10-16)
### Breaking Changes
### Breaking Changes
Version 0.13 marks a significant departure from previous versions of tarpc. The
API has changed significantly. The tokio-proto crate has been torn out and

View File

@@ -1,43 +0,0 @@
cargo-features = ["rename-dependency"]
[package]
name = "tarpc-bincode-transport"
version = "0.1.0"
authors = ["Tim Kuehn <tikue@google.com>"]
edition = '2018'
license = "MIT"
documentation = "https://docs.rs/tarpc-bincode-transport"
homepage = "https://github.com/google/tarpc"
repository = "https://github.com/google/tarpc"
keywords = ["rpc", "network", "bincode", "serde", "tarpc"]
categories = ["asynchronous", "network-programming"]
readme = "../README.md"
description = "A bincode-based transport for tarpc services."
[dependencies]
bincode = { version = "1.0", features = ["i128"] }
bytes = "0.4"
futures_legacy = { version = "0.1", package = "futures" }
pin-utils = "0.1.0-alpha.2"
rpc = { package = "tarpc-lib", version = "0.1", path = "../rpc", features = ["serde1"] }
serde = "1.0"
tokio = "0.1"
tokio-io = "0.1"
tokio-serde-bincode = "0.1"
tokio-tcp = "0.1"
tokio-serde = "0.2"
[target.'cfg(not(test))'.dependencies]
futures-preview = { version = "0.3.0-alpha.8", features = ["compat"] }
[dev-dependencies]
futures-preview = { version = "0.3.0-alpha.8", features = ["compat", "tokio-compat"] }
env_logger = "0.5"
humantime = "1.0"
log = "0.4"
rand = "0.5"
tokio = "0.1"
tokio-executor = "0.1"
tokio-reactor = "0.1"
tokio-serde = "0.2"
tokio-timer = "0.2"

View File

@@ -1 +0,0 @@
edition = "Edition2018"

View File

@@ -1,286 +0,0 @@
// Copyright 2018 Google LLC
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
//! A TCP [`Transport`] that serializes as bincode.
#![feature(
futures_api,
pin,
arbitrary_self_types,
underscore_imports,
await_macro,
async_await,
)]
#![deny(missing_docs, missing_debug_implementations)]
mod vendored;
use bytes::{Bytes, BytesMut};
use crate::vendored::tokio_serde_bincode::{IoErrorWrapper, ReadBincode, WriteBincode};
use futures::{
Poll,
compat::{Compat01As03, Future01CompatExt, Stream01CompatExt},
prelude::*,
ready, task,
};
use futures_legacy::{
executor::{
self as executor01, Notify as Notify01, NotifyHandle as NotifyHandle01,
UnsafeNotify as UnsafeNotify01,
},
sink::SinkMapErr as SinkMapErr01,
sink::With as With01,
stream::MapErr as MapErr01,
Async as Async01, AsyncSink as AsyncSink01, Sink as Sink01, Stream as Stream01,
};
use pin_utils::unsafe_pinned;
use serde::{Deserialize, Serialize};
use std::{fmt, io, marker::PhantomData, net::SocketAddr, pin::Pin, task::LocalWaker};
use tokio::codec::{Framed, LengthDelimitedCodec, length_delimited};
use tokio_tcp::{self, TcpListener, TcpStream};
/// Returns a new bincode transport that reads from and writes to `io`.
pub fn new<Item, SinkItem>(io: TcpStream) -> Transport<Item, SinkItem>
where
Item: for<'de> Deserialize<'de>,
SinkItem: Serialize,
{
let peer_addr = io.peer_addr();
let local_addr = io.local_addr();
let inner = length_delimited::Builder::new()
.max_frame_length(8_000_000)
.new_framed(io)
.map_err(IoErrorWrapper as _)
.sink_map_err(IoErrorWrapper as _)
.with(freeze as _);
let inner = WriteBincode::new(inner);
let inner = ReadBincode::new(inner);
Transport {
inner,
staged_item: None,
peer_addr,
local_addr,
}
}
fn freeze(bytes: BytesMut) -> Result<Bytes, IoErrorWrapper> {
Ok(bytes.freeze())
}
/// Connects to `addr`, wrapping the connection in a bincode transport.
pub async fn connect<Item, SinkItem>(addr: &SocketAddr) -> io::Result<Transport<Item, SinkItem>>
where
Item: for<'de> Deserialize<'de>,
SinkItem: Serialize,
{
let stream = await!(TcpStream::connect(addr).compat())?;
Ok(new(stream))
}
/// Listens on `addr`, wrapping accepted connections in bincode transports.
pub fn listen<Item, SinkItem>(addr: &SocketAddr) -> io::Result<Incoming<Item, SinkItem>>
where
Item: for<'de> Deserialize<'de>,
SinkItem: Serialize,
{
let listener = TcpListener::bind(addr)?;
let local_addr = listener.local_addr()?;
let incoming = listener.incoming().compat();
Ok(Incoming {
incoming,
local_addr,
ghost: PhantomData,
})
}
/// A [`TcpListener`] that wraps connections in bincode transports.
#[derive(Debug)]
pub struct Incoming<Item, SinkItem> {
incoming: Compat01As03<tokio_tcp::Incoming>,
local_addr: SocketAddr,
ghost: PhantomData<(Item, SinkItem)>,
}
impl<Item, SinkItem> Incoming<Item, SinkItem> {
unsafe_pinned!(incoming: Compat01As03<tokio_tcp::Incoming>);
/// Returns the address being listened on.
pub fn local_addr(&self) -> SocketAddr {
self.local_addr
}
}
impl<Item, SinkItem> Stream for Incoming<Item, SinkItem>
where
Item: for<'a> Deserialize<'a>,
SinkItem: Serialize,
{
type Item = io::Result<Transport<Item, SinkItem>>;
fn poll_next(mut self: Pin<&mut Self>, waker: &LocalWaker) -> Poll<Option<Self::Item>> {
let next = ready!(self.incoming().poll_next(waker)?);
Poll::Ready(next.map(|conn| Ok(new(conn))))
}
}
/// A transport that serializes to, and deserializes from, a [`TcpStream`].
pub struct Transport<Item, SinkItem> {
inner: ReadBincode<
WriteBincode<
With01<
SinkMapErr01<
MapErr01<
Framed<tokio_tcp::TcpStream, LengthDelimitedCodec>,
fn(std::io::Error) -> IoErrorWrapper,
>,
fn(std::io::Error) -> IoErrorWrapper,
>,
BytesMut,
fn(BytesMut) -> Result<Bytes, IoErrorWrapper>,
Result<Bytes, IoErrorWrapper>
>,
SinkItem,
>,
Item,
>,
staged_item: Option<SinkItem>,
peer_addr: io::Result<SocketAddr>,
local_addr: io::Result<SocketAddr>,
}
impl<Item, SinkItem> fmt::Debug for Transport<Item, SinkItem> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "Transport")
}
}
impl<Item, SinkItem> Stream for Transport<Item, SinkItem>
where
Item: for<'a> Deserialize<'a>,
{
type Item = io::Result<Item>;
fn poll_next(self: Pin<&mut Self>, waker: &LocalWaker) -> Poll<Option<io::Result<Item>>> {
unsafe {
let inner = &mut Pin::get_mut_unchecked(self).inner;
let mut compat = inner.compat();
let compat = Pin::new_unchecked(&mut compat);
match ready!(compat.poll_next(waker)) {
None => Poll::Ready(None),
Some(Ok(next)) => Poll::Ready(Some(Ok(next))),
Some(Err(e)) => Poll::Ready(Some(Err(e.0))),
}
}
}
}
impl<Item, SinkItem> Sink for Transport<Item, SinkItem>
where
SinkItem: Serialize,
{
type SinkItem = SinkItem;
type SinkError = io::Error;
fn start_send(self: Pin<&mut Self>, item: SinkItem) -> io::Result<()> {
let me = unsafe { Pin::get_mut_unchecked(self) };
assert!(me.staged_item.is_none());
me.staged_item = Some(item);
Ok(())
}
fn poll_ready(self: Pin<&mut Self>, waker: &LocalWaker) -> Poll<io::Result<()>> {
let notify = &WakerToHandle(waker);
executor01::with_notify(notify, 0, move || {
let me = unsafe { Pin::get_mut_unchecked(self) };
match me.staged_item.take() {
Some(staged_item) => match me.inner.start_send(staged_item)? {
AsyncSink01::Ready => Poll::Ready(Ok(())),
AsyncSink01::NotReady(item) => {
me.staged_item = Some(item);
Poll::Pending
}
},
None => Poll::Ready(Ok(())),
}
})
}
fn poll_flush(self: Pin<&mut Self>, waker: &LocalWaker) -> Poll<io::Result<()>> {
let notify = &WakerToHandle(waker);
executor01::with_notify(notify, 0, move || {
let me = unsafe { Pin::get_mut_unchecked(self) };
match me.inner.poll_complete()? {
Async01::Ready(()) => Poll::Ready(Ok(())),
Async01::NotReady => Poll::Pending,
}
})
}
fn poll_close(self: Pin<&mut Self>, waker: &LocalWaker) -> Poll<io::Result<()>> {
let notify = &WakerToHandle(waker);
executor01::with_notify(notify, 0, move || {
let me = unsafe { Pin::get_mut_unchecked(self) };
match me.inner.get_mut().close()? {
Async01::Ready(()) => Poll::Ready(Ok(())),
Async01::NotReady => Poll::Pending,
}
})
}
}
impl<Item, SinkItem> rpc::Transport for Transport<Item, SinkItem>
where
Item: for<'de> Deserialize<'de>,
SinkItem: Serialize,
{
type Item = Item;
type SinkItem = SinkItem;
fn peer_addr(&self) -> io::Result<SocketAddr> {
// TODO: should just access from the inner transport.
// https://github.com/alexcrichton/tokio-serde-bincode/issues/4
Ok(*self.peer_addr.as_ref().unwrap())
}
fn local_addr(&self) -> io::Result<SocketAddr> {
Ok(*self.local_addr.as_ref().unwrap())
}
}
#[derive(Clone, Debug)]
struct WakerToHandle<'a>(&'a LocalWaker);
#[derive(Debug)]
struct NotifyWaker(task::Waker);
impl Notify01 for NotifyWaker {
fn notify(&self, _: usize) {
self.0.wake();
}
}
unsafe impl UnsafeNotify01 for NotifyWaker {
unsafe fn clone_raw(&self) -> NotifyHandle01 {
let ptr = Box::new(NotifyWaker(self.0.clone()));
NotifyHandle01::new(Box::into_raw(ptr))
}
unsafe fn drop_raw(&self) {
let ptr: *const dyn UnsafeNotify01 = self;
drop(Box::from_raw(ptr as *mut dyn UnsafeNotify01));
}
}
impl<'a> From<WakerToHandle<'a>> for NotifyHandle01 {
fn from(handle: WakerToHandle<'a>) -> NotifyHandle01 {
unsafe { NotifyWaker(handle.0.clone().into_waker()).clone_raw() }
}
}

View File

@@ -1,7 +0,0 @@
// Copyright 2018 Google LLC
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
pub(crate) mod tokio_serde_bincode;

View File

@@ -1,224 +0,0 @@
//! `Stream` and `Sink` adaptors for serializing and deserializing values using
//! Bincode.
//!
//! This crate provides adaptors for going from a stream or sink of buffers
//! ([`Bytes`]) to a stream or sink of values by performing Bincode encoding or
//! decoding. It is expected that each yielded buffer contains a single
//! serialized Bincode value. The specific strategy by which this is done is left
//! up to the user. One option is to use using [`length_delimited`] from
//! [tokio-io].
//!
//! [`Bytes`]: https://docs.rs/bytes/0.4/bytes/struct.Bytes.html
//! [`length_delimited`]: http://alexcrichton.com/tokio-io/tokio_io/codec/length_delimited/index.html
//! [tokio-io]: http://github.com/alexcrichton/tokio-io
//! [examples]: https://github.com/carllerche/tokio-serde-json/tree/master/examples
#![allow(missing_debug_implementations)]
use bincode::Error;
use bytes::{Bytes, BytesMut};
use futures_legacy::{Poll, Sink, StartSend, Stream};
use serde::{Deserialize, Serialize};
use std::io;
use tokio_serde::{Deserializer, FramedRead, FramedWrite, Serializer};
use std::marker::PhantomData;
/// Adapts a stream of Bincode encoded buffers to a stream of values by
/// deserializing them.
///
/// `ReadBincode` implements `Stream` by polling the inner buffer stream and
/// deserializing the buffer as Bincode. It expects that each yielded buffer
/// represents a single Bincode value and does not contain any extra trailing
/// bytes.
pub(crate) struct ReadBincode<T, U> {
inner: FramedRead<T, U, Bincode<U>>,
}
/// Adapts a buffer sink to a value sink by serializing the values as Bincode.
///
/// `WriteBincode` implements `Sink` by serializing the submitted values to a
/// buffer. The buffer is then sent to the inner stream, which is responsible
/// for handling framing on the wire.
pub(crate) struct WriteBincode<T: Sink, U> {
inner: FramedWrite<T, U, Bincode<U>>,
}
struct Bincode<T> {
ghost: PhantomData<T>,
}
impl<T, U> ReadBincode<T, U>
where
T: Stream<Error = IoErrorWrapper>,
U: for<'de> Deserialize<'de>,
Bytes: From<T::Item>,
{
/// Creates a new `ReadBincode` with the given buffer stream.
pub fn new(inner: T) -> ReadBincode<T, U> {
let json = Bincode { ghost: PhantomData };
ReadBincode {
inner: FramedRead::new(inner, json),
}
}
}
impl<T, U> ReadBincode<T, U> {
/// Returns a mutable reference to the underlying stream wrapped by
/// `ReadBincode`.
///
/// Note that care should be taken to not tamper with the underlying stream
/// of data coming in as it may corrupt the stream of frames otherwise
/// being worked with.
pub fn get_mut(&mut self) -> &mut T {
self.inner.get_mut()
}
}
impl<T, U> Stream for ReadBincode<T, U>
where
T: Stream<Error = IoErrorWrapper>,
U: for<'de> Deserialize<'de>,
Bytes: From<T::Item>,
{
type Item = U;
type Error = <T as Stream>::Error;
fn poll(&mut self) -> Poll<Option<U>, Self::Error> {
self.inner.poll()
}
}
impl<T, U> Sink for ReadBincode<T, U>
where
T: Sink,
{
type SinkItem = T::SinkItem;
type SinkError = T::SinkError;
fn start_send(&mut self, item: T::SinkItem) -> StartSend<T::SinkItem, T::SinkError> {
self.get_mut().start_send(item)
}
fn poll_complete(&mut self) -> Poll<(), T::SinkError> {
self.get_mut().poll_complete()
}
fn close(&mut self) -> Poll<(), T::SinkError> {
self.get_mut().close()
}
}
pub(crate) struct IoErrorWrapper(pub io::Error);
impl From<Box<bincode::ErrorKind>> for IoErrorWrapper {
fn from(e: Box<bincode::ErrorKind>) -> Self {
IoErrorWrapper(match *e {
bincode::ErrorKind::Io(e) => e,
bincode::ErrorKind::InvalidUtf8Encoding(e) => {
io::Error::new(io::ErrorKind::InvalidInput, e)
}
bincode::ErrorKind::InvalidBoolEncoding(e) => {
io::Error::new(io::ErrorKind::InvalidInput, e.to_string())
}
bincode::ErrorKind::InvalidTagEncoding(e) => {
io::Error::new(io::ErrorKind::InvalidInput, e.to_string())
}
bincode::ErrorKind::InvalidCharEncoding => {
io::Error::new(io::ErrorKind::InvalidInput, "Invalid char encoding")
}
bincode::ErrorKind::DeserializeAnyNotSupported => {
io::Error::new(io::ErrorKind::InvalidInput, "Deserialize Any not supported")
}
bincode::ErrorKind::SizeLimit => {
io::Error::new(io::ErrorKind::InvalidInput, "Size limit exceeded")
}
bincode::ErrorKind::SequenceMustHaveLength => {
io::Error::new(io::ErrorKind::InvalidInput, "Sequence must have length")
}
bincode::ErrorKind::Custom(s) => io::Error::new(io::ErrorKind::Other, s),
})
}
}
impl From<IoErrorWrapper> for io::Error {
fn from(wrapper: IoErrorWrapper) -> io::Error {
wrapper.0
}
}
impl<T, U> WriteBincode<T, U>
where
T: Sink<SinkItem = BytesMut, SinkError = IoErrorWrapper>,
U: Serialize,
{
/// Creates a new `WriteBincode` with the given buffer sink.
pub fn new(inner: T) -> WriteBincode<T, U> {
let json = Bincode { ghost: PhantomData };
WriteBincode {
inner: FramedWrite::new(inner, json),
}
}
}
impl<T: Sink, U> WriteBincode<T, U> {
/// Returns a mutable reference to the underlying sink wrapped by
/// `WriteBincode`.
///
/// Note that care should be taken to not tamper with the underlying sink as
/// it may corrupt the sequence of frames otherwise being worked with.
pub fn get_mut(&mut self) -> &mut T {
self.inner.get_mut()
}
}
impl<T, U> Sink for WriteBincode<T, U>
where
T: Sink<SinkItem = BytesMut, SinkError = IoErrorWrapper>,
U: Serialize,
{
type SinkItem = U;
type SinkError = <T as Sink>::SinkError;
fn start_send(&mut self, item: U) -> StartSend<U, Self::SinkError> {
self.inner.start_send(item)
}
fn poll_complete(&mut self) -> Poll<(), Self::SinkError> {
self.inner.poll_complete()
}
fn close(&mut self) -> Poll<(), Self::SinkError> {
self.inner.poll_complete()
}
}
impl<T, U> Stream for WriteBincode<T, U>
where
T: Stream + Sink,
{
type Item = T::Item;
type Error = T::Error;
fn poll(&mut self) -> Poll<Option<T::Item>, T::Error> {
self.get_mut().poll()
}
}
impl<T> Deserializer<T> for Bincode<T>
where
T: for<'de> Deserialize<'de>,
{
type Error = Error;
fn deserialize(&mut self, src: &Bytes) -> Result<T, Error> {
bincode::deserialize(src)
}
}
impl<T: Serialize> Serializer<T> for Bincode<T> {
type Error = Error;
fn serialize(&mut self, item: &T) -> Result<BytesMut, Self::Error> {
bincode::serialize(item).map(Into::into)
}
}

View File

@@ -1,116 +0,0 @@
// Copyright 2018 Google LLC
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
//! Tests client/server control flow.
#![feature(
test,
integer_atomics,
futures_api,
generators,
await_macro,
async_await
)]
extern crate test;
use self::test::stats::Stats;
use futures::{compat::TokioDefaultSpawner, prelude::*};
use rpc::{
client::{self, Client},
context,
server::{self, Handler, Server},
};
use std::{
io,
time::{Duration, Instant},
};
async fn bench() -> io::Result<()> {
let listener = tarpc_bincode_transport::listen(&"0.0.0.0:0".parse().unwrap())?;
let addr = listener.local_addr();
tokio_executor::spawn(
Server::<u32, u32>::new(server::Config::default())
.incoming(listener)
.take(1)
.respond_with(|_ctx, request| futures::future::ready(Ok(request)))
.unit_error()
.boxed()
.compat()
);
let conn = await!(tarpc_bincode_transport::connect(&addr))?;
let client = &mut await!(Client::<u32, u32>::new(client::Config::default(), conn))?;
let total = 10_000usize;
let mut successful = 0u32;
let mut unsuccessful = 0u32;
let mut durations = vec![];
for _ in 1..=total {
let now = Instant::now();
let response = await!(client.call(context::current(), 0u32));
let elapsed = now.elapsed();
match response {
Ok(_) => successful += 1,
Err(_) => unsuccessful += 1,
};
durations.push(elapsed);
}
let durations_nanos = durations
.iter()
.map(|duration| duration.as_secs() as f64 * 1E9 + duration.subsec_nanos() as f64)
.collect::<Vec<_>>();
let (lower, median, upper) = durations_nanos.quartiles();
println!("Of {} runs:", durations_nanos.len());
println!("\tSuccessful: {}", successful);
println!("\tUnsuccessful: {}", unsuccessful);
println!(
"\tMean: {:?}",
Duration::from_nanos(durations_nanos.mean() as u64)
);
println!("\tMedian: {:?}", Duration::from_nanos(median as u64));
println!(
"\tStd Dev: {:?}",
Duration::from_nanos(durations_nanos.std_dev() as u64)
);
println!(
"\tMin: {:?}",
Duration::from_nanos(durations_nanos.min() as u64)
);
println!(
"\tMax: {:?}",
Duration::from_nanos(durations_nanos.max() as u64)
);
println!(
"\tQuartiles: ({:?}, {:?}, {:?})",
Duration::from_nanos(lower as u64),
Duration::from_nanos(median as u64),
Duration::from_nanos(upper as u64)
);
Ok(())
}
#[test]
fn bench_small_packet() -> io::Result<()> {
env_logger::init();
rpc::init(TokioDefaultSpawner);
tokio::run(
bench()
.map_err(|e| panic!(e.to_string()))
.boxed()
.compat(),
);
println!("done");
Ok(())
}

View File

@@ -1,152 +0,0 @@
// Copyright 2018 Google LLC
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
//! Tests client/server control flow.
#![feature(generators, await_macro, async_await, futures_api,)]
use futures::{
compat::{Future01CompatExt, TokioDefaultSpawner},
prelude::*,
stream,
};
use log::{info, trace};
use rand::distributions::{Distribution, Normal};
use rpc::{
client::{self, Client},
context,
server::{self, Server},
};
use std::{
io,
time::{Duration, Instant, SystemTime},
};
use tokio::timer::Delay;
pub trait AsDuration {
/// Delay of 0 if self is in the past
fn as_duration(&self) -> Duration;
}
impl AsDuration for SystemTime {
fn as_duration(&self) -> Duration {
self.duration_since(SystemTime::now()).unwrap_or_default()
}
}
async fn run() -> io::Result<()> {
let listener = tarpc_bincode_transport::listen(&"0.0.0.0:0".parse().unwrap())?;
let addr = listener.local_addr();
let server = Server::<String, String>::new(server::Config::default())
.incoming(listener)
.take(1)
.for_each(async move |channel| {
let channel = if let Ok(channel) = channel {
channel
} else {
return;
};
let client_addr = *channel.client_addr();
let handler = channel.respond_with(move |ctx, request| {
// Sleep for a time sampled from a normal distribution with:
// - mean: 1/2 the deadline.
// - std dev: 1/2 the deadline.
let deadline: Duration = ctx.deadline.as_duration();
let deadline_millis = deadline.as_secs() * 1000 + deadline.subsec_millis() as u64;
let distribution =
Normal::new(deadline_millis as f64 / 2., deadline_millis as f64 / 2.);
let delay_millis = distribution.sample(&mut rand::thread_rng()).max(0.);
let delay = Duration::from_millis(delay_millis as u64);
trace!(
"[{}/{}] Responding to request in {:?}.",
ctx.trace_id(),
client_addr,
delay,
);
let wait = Delay::new(Instant::now() + delay).compat();
async move {
await!(wait).unwrap();
Ok(request)
}
});
tokio_executor::spawn(handler.unit_error().boxed().compat());
});
tokio_executor::spawn(server.unit_error().boxed().compat());
let conn = await!(tarpc_bincode_transport::connect(&addr))?;
let client = await!(Client::<String, String>::new(
client::Config::default(),
conn
))?;
// Proxy service
let listener = tarpc_bincode_transport::listen(&"0.0.0.0:0".parse().unwrap())?;
let addr = listener.local_addr();
let proxy_server = Server::<String, String>::new(server::Config::default())
.incoming(listener)
.take(1)
.for_each(move |channel| {
let client = client.clone();
async move {
let channel = if let Ok(channel) = channel {
channel
} else {
return;
};
let client_addr = *channel.client_addr();
let handler = channel.respond_with(move |ctx, request| {
trace!("[{}/{}] Proxying request.", ctx.trace_id(), client_addr);
let mut client = client.clone();
async move { await!(client.call(ctx, request)) }
});
tokio_executor::spawn(handler.unit_error().boxed().compat());
}
});
tokio_executor::spawn(proxy_server.unit_error().boxed().compat());
let mut config = client::Config::default();
config.max_in_flight_requests = 10;
config.pending_request_buffer = 10;
let client = await!(Client::<String, String>::new(
config,
await!(tarpc_bincode_transport::connect(&addr))?
))?;
// Make 3 speculative requests, returning only the quickest.
let mut clients: Vec<_> = (1..=3u32).map(|_| client.clone()).collect();
let mut requests = vec![];
for client in &mut clients {
let mut ctx = context::current();
ctx.deadline = SystemTime::now() + Duration::from_millis(200);
let trace_id = *ctx.trace_id();
let response = client.call(ctx, "ping".into());
requests.push(response.map(move |r| (trace_id, r)));
}
let (fastest_response, _) = await!(stream::futures_unordered(requests).into_future());
let (trace_id, resp) = fastest_response.unwrap();
info!("[{}] fastest_response = {:?}", trace_id, resp);
Ok::<_, io::Error>(())
}
#[test]
fn cancel_slower() -> io::Result<()> {
env_logger::init();
rpc::init(TokioDefaultSpawner);
tokio::run(
run()
.boxed()
.map_err(|e| panic!(e))
.compat(),
);
Ok(())
}

View File

@@ -1,120 +0,0 @@
// Copyright 2018 Google LLC
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
//! Tests client/server control flow.
#![feature(generators, await_macro, async_await, futures_api,)]
use futures::{
compat::{Future01CompatExt, TokioDefaultSpawner},
prelude::*,
};
use log::{error, info, trace};
use rand::distributions::{Distribution, Normal};
use rpc::{
client::{self, Client},
context,
server::{self, Server},
};
use std::{
io,
time::{Duration, Instant, SystemTime},
};
use tokio::timer::Delay;
pub trait AsDuration {
/// Delay of 0 if self is in the past
fn as_duration(&self) -> Duration;
}
impl AsDuration for SystemTime {
fn as_duration(&self) -> Duration {
self.duration_since(SystemTime::now()).unwrap_or_default()
}
}
async fn run() -> io::Result<()> {
let listener = tarpc_bincode_transport::listen(&"0.0.0.0:0".parse().unwrap())?;
let addr = listener.local_addr();
let server = Server::<String, String>::new(server::Config::default())
.incoming(listener)
.take(1)
.for_each(async move |channel| {
let channel = if let Ok(channel) = channel {
channel
} else {
return;
};
let client_addr = *channel.client_addr();
let handler = channel.respond_with(move |ctx, request| {
// Sleep for a time sampled from a normal distribution with:
// - mean: 1/2 the deadline.
// - std dev: 1/2 the deadline.
let deadline: Duration = ctx.deadline.as_duration();
let deadline_millis = deadline.as_secs() * 1000 + deadline.subsec_millis() as u64;
let distribution =
Normal::new(deadline_millis as f64 / 2., deadline_millis as f64 / 2.);
let delay_millis = distribution.sample(&mut rand::thread_rng()).max(0.);
let delay = Duration::from_millis(delay_millis as u64);
trace!(
"[{}/{}] Responding to request in {:?}.",
ctx.trace_id(),
client_addr,
delay,
);
let sleep = Delay::new(Instant::now() + delay).compat();
async {
await!(sleep).unwrap();
Ok(request)
}
});
tokio_executor::spawn(handler.unit_error().boxed().compat());
});
tokio_executor::spawn(server.unit_error().boxed().compat());
let mut config = client::Config::default();
config.max_in_flight_requests = 10;
config.pending_request_buffer = 10;
let conn = await!(tarpc_bincode_transport::connect(&addr))?;
let client = await!(Client::<String, String>::new(config, conn))?;
let clients = (1..=100u32).map(|_| client.clone()).collect::<Vec<_>>();
for mut client in clients {
let ctx = context::current();
tokio_executor::spawn(
async move {
let trace_id = *ctx.trace_id();
let response = client.call(ctx, "ping".into());
match await!(response) {
Ok(response) => info!("[{}] response: {}", trace_id, response),
Err(e) => error!("[{}] request error: {:?}: {}", trace_id, e.kind(), e),
}
}.unit_error().boxed().compat()
);
}
Ok(())
}
#[test]
fn ping_pong() -> io::Result<()> {
env_logger::init();
rpc::init(TokioDefaultSpawner);
tokio::run(
run()
.map_ok(|_| println!("done"))
.map_err(|e| panic!(e.to_string()))
.boxed()
.compat(),
);
Ok(())
}

View File

@@ -1,26 +1,32 @@
cargo-features = ["rename-dependency"]
[package]
name = "tarpc-example-service"
version = "0.1.0"
version = "0.15.0"
rust-version = "1.56"
authors = ["Tim Kuehn <tikue@google.com>"]
edition = "2018"
edition = "2021"
license = "MIT"
documentation = "https://docs.rs/tarpc-example-service"
homepage = "https://github.com/google/tarpc"
repository = "https://github.com/google/tarpc"
keywords = ["rpc", "network", "server", "api", "microservices", "example"]
keywords = ["rpc", "network", "server", "microservices", "example"]
categories = ["asynchronous", "network-programming"]
readme = "../README.md"
description = "An example server built on tarpc."
[dependencies]
bincode-transport = { package = "tarpc-bincode-transport", version = "0.1", path = "../bincode-transport" }
futures-preview = { version = "0.3.0-alpha.8", features = ["compat", "tokio-compat"] }
serde = { version = "1.0" }
tarpc = { version = "0.13", path = "../tarpc", features = ["serde1"] }
tokio = "0.1"
tokio-executor = "0.1"
anyhow = "1.0"
clap = { version = "3.0.0-rc.9", features = ["derive"] }
log = "0.4"
futures = "0.3"
opentelemetry = { version = "0.21.0" }
opentelemetry-jaeger = { version = "0.20.0", features = ["rt-tokio"] }
rand = "0.8"
tarpc = { version = "0.34", path = "../tarpc", features = ["full"] }
tokio = { version = "1", features = ["macros", "net", "rt-multi-thread"] }
tracing = { version = "0.1" }
tracing-opentelemetry = "0.22.0"
tracing-subscriber = { version = "0.3.18", features = ["env-filter"] }
opentelemetry_sdk = "0.21.1"
[lib]
name = "service"
@@ -28,4 +34,8 @@ path = "src/lib.rs"
[[bin]]
name = "server"
path = "src/main.rs"
path = "src/server.rs"
[[bin]]
name = "client"
path = "src/client.rs"

15
example-service/README.md Normal file
View File

@@ -0,0 +1,15 @@
# Example
Example service to demonstrate how to set up `tarpc` with [Jaeger](https://www.jaegertracing.io). To see traces Jaeger, run the following with `RUST_LOG=trace`.
## Server
```bash
cargo run --bin server -- --port 50051
```
## Client
```bash
cargo run --bin client -- --server-addr "[::1]:50051" --name "Bob"
```

View File

@@ -0,0 +1,56 @@
// Copyright 2018 Google LLC
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
use clap::Parser;
use service::{init_tracing, WorldClient};
use std::{net::SocketAddr, time::Duration};
use tarpc::{client, context, tokio_serde::formats::Json};
use tokio::time::sleep;
use tracing::Instrument;
#[derive(Parser)]
struct Flags {
/// Sets the server address to connect to.
#[clap(long)]
server_addr: SocketAddr,
/// Sets the name to say hello to.
#[clap(long)]
name: String,
}
#[tokio::main]
async fn main() -> anyhow::Result<()> {
let flags = Flags::parse();
init_tracing("Tarpc Example Client")?;
let mut transport = tarpc::serde_transport::tcp::connect(flags.server_addr, Json::default);
transport.config_mut().max_frame_length(usize::MAX);
// WorldClient is generated by the service attribute. It has a constructor `new` that takes a
// config and any Transport as input.
let client = WorldClient::new(client::Config::default(), transport.await?).spawn();
let hello = async move {
// Send the request twice, just to be safe! ;)
tokio::select! {
hello1 = client.hello(context::current(), format!("{}1", flags.name)) => { hello1 }
hello2 = client.hello(context::current(), format!("{}2", flags.name)) => { hello2 }
}
}
.instrument(tracing::info_span!("Two Hellos"))
.await;
match hello {
Ok(hello) => tracing::info!("{hello:?}"),
Err(e) => tracing::warn!("{:?}", anyhow::Error::from(e)),
}
// Let the background span processor finish.
sleep(Duration::from_micros(1)).await;
opentelemetry::global::shutdown_tracer_provider();
Ok(())
}

View File

@@ -4,18 +4,31 @@
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
#![feature(
futures_api,
pin,
arbitrary_self_types,
await_macro,
async_await,
proc_macro_hygiene,
)]
use std::env;
use tracing_subscriber::{fmt::format::FmtSpan, prelude::*};
// This is the service definition. It looks a lot like a trait definition.
// It defines one RPC, hello, which takes one arg, name, and returns a String.
tarpc::service! {
/// This is the service definition. It looks a lot like a trait definition.
/// It defines one RPC, hello, which takes one arg, name, and returns a String.
#[tarpc::service]
pub trait World {
/// Returns a greeting for name.
rpc hello(name: String) -> String;
async fn hello(name: String) -> String;
}
/// Initializes an OpenTelemetry tracing subscriber with a Jaeger backend.
pub fn init_tracing(service_name: &str) -> anyhow::Result<()> {
env::set_var("OTEL_BSP_MAX_EXPORT_BATCH_SIZE", "12");
let tracer = opentelemetry_jaeger::new_agent_pipeline()
.with_service_name(service_name)
.with_max_packet_size(2usize.pow(13))
.install_batch(opentelemetry_sdk::runtime::Tokio)?;
tracing_subscriber::registry()
.with(tracing_subscriber::EnvFilter::from_default_env())
.with(tracing_subscriber::fmt::layer().with_span_events(FmtSpan::NEW | FmtSpan::CLOSE))
.with(tracing_opentelemetry::layer().with_tracer(tracer))
.try_init()?;
Ok(())
}

View File

@@ -1,85 +0,0 @@
// Copyright 2018 Google LLC
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
#![feature(
futures_api,
pin,
arbitrary_self_types,
await_macro,
async_await,
)]
use futures::{
compat::TokioDefaultSpawner,
future::{self, Ready},
prelude::*,
};
use tarpc::{
client, context,
server::{self, Handler, Server},
};
use std::io;
// This is the type that implements the generated Service trait. It is the business logic
// and is used to start the server.
#[derive(Clone)]
struct HelloServer;
impl service::Service for HelloServer {
// Each defined rpc generates two items in the trait, a fn that serves the RPC, and
// an associated type representing the future output by the fn.
type HelloFut = Ready<String>;
fn hello(&self, _: context::Context, name: String) -> Self::HelloFut {
future::ready(format!("Hello, {}!", name))
}
}
async fn run() -> io::Result<()> {
// bincode_transport is provided by the associated crate bincode-transport. It makes it easy
// to start up a serde-powered bincode serialization strategy over TCP.
let transport = bincode_transport::listen(&"0.0.0.0:0".parse().unwrap())?;
let addr = transport.local_addr();
// The server is configured with the defaults.
let server = Server::new(server::Config::default())
// Server can listen on any type that implements the Transport trait.
.incoming(transport)
// Close the stream after the client connects
.take(1)
// serve is generated by the service! macro. It takes as input any type implementing
// the generated Service trait.
.respond_with(service::serve(HelloServer));
tokio_executor::spawn(server.unit_error().boxed().compat());
let transport = await!(bincode_transport::connect(&addr))?;
// new_stub is generated by the service! macro. Like Server, it takes a config and any
// Transport as input, and returns a Client, also generated by the macro.
// by the service mcro.
let mut client = await!(service::new_stub(client::Config::default(), transport))?;
// The client has an RPC method for each RPC defined in service!. It takes the same args
// as defined, with the addition of a Context, which is always the first arg. The Context
// specifies a deadline and trace information which can be helpful in debugging requests.
let hello = await!(client.hello(context::current(), "Stim".to_string()))?;
println!("{}", hello);
Ok(())
}
fn main() {
tarpc::init(TokioDefaultSpawner);
tokio::run(run()
.map_err(|e| eprintln!("Oh no: {}", e))
.boxed()
.compat()
);
}

View File

@@ -0,0 +1,80 @@
// Copyright 2018 Google LLC
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
use clap::Parser;
use futures::{future, prelude::*};
use rand::{
distributions::{Distribution, Uniform},
thread_rng,
};
use service::{init_tracing, World};
use std::{
net::{IpAddr, Ipv6Addr, SocketAddr},
time::Duration,
};
use tarpc::{
context,
server::{self, incoming::Incoming, Channel},
tokio_serde::formats::Json,
};
use tokio::time;
#[derive(Parser)]
struct Flags {
/// Sets the port number to listen on.
#[clap(long)]
port: u16,
}
// This is the type that implements the generated World trait. It is the business logic
// and is used to start the server.
#[derive(Clone)]
struct HelloServer(SocketAddr);
impl World for HelloServer {
async fn hello(self, _: context::Context, name: String) -> String {
let sleep_time =
Duration::from_millis(Uniform::new_inclusive(1, 10).sample(&mut thread_rng()));
time::sleep(sleep_time).await;
format!("Hello, {name}! You are connected from {}", self.0)
}
}
async fn spawn(fut: impl Future<Output = ()> + Send + 'static) {
tokio::spawn(fut);
}
#[tokio::main]
async fn main() -> anyhow::Result<()> {
let flags = Flags::parse();
init_tracing("Tarpc Example Server")?;
let server_addr = (IpAddr::V6(Ipv6Addr::LOCALHOST), flags.port);
// JSON transport is provided by the json_transport tarpc module. It makes it easy
// to start up a serde-powered json serialization strategy over TCP.
let mut listener = tarpc::serde_transport::tcp::listen(&server_addr, Json::default).await?;
tracing::info!("Listening on port {}", listener.local_addr().port());
listener.config_mut().max_frame_length(usize::MAX);
listener
// Ignore accept errors.
.filter_map(|r| future::ready(r.ok()))
.map(server::BaseChannel::with_defaults)
// Limit channels to 1 per IP.
.max_channels_per_key(1, |t| t.transport().peer_addr().unwrap().ip())
// serve is generated by the service attribute. It takes as input any type implementing
// the generated World trait.
.map(|channel| {
let server = HelloServer(channel.transport().peer_addr().unwrap());
channel.execute(server.serve()).for_each(spawn)
})
// Max 10 channels.
.buffer_unordered(10)
.for_each(|_| async {})
.await;
Ok(())
}

View File

@@ -67,7 +67,7 @@ else
fi
printf "${PREFIX} Checking for rustfmt ... "
command -v cargo fmt &>/dev/null
command -v rustfmt &>/dev/null
if [ $? == 0 ]; then
printf "${SUCCESS}\n"
else
@@ -93,19 +93,19 @@ diff=""
for file in $(git diff --name-only --cached);
do
if [ ${file: -3} == ".rs" ]; then
diff="$diff$(cargo fmt -- --skip-children --write-mode=diff $file)"
diff="$diff$(rustfmt --edition 2018 --check $file)"
if [ $? != 0 ]; then
FMTRESULT=1
fi
fi
done
if grep --quiet "^Diff at line" <<< "$diff"; then
FMTRESULT=1
fi
if [ "${TARPC_SKIP_RUSTFMT}" == 1 ]; then
printf "${SKIPPED}\n"$?
elif [ ${FMTRESULT} != 0 ]; then
FAILED=1
printf "${FAILURE}\n"
echo "$diff" | sed 's/Using rustfmt config file.*$/d/'
echo "$diff"
else
printf "${SUCCESS}\n"
fi

View File

@@ -84,14 +84,19 @@ command -v rustup &>/dev/null
if [ "$?" == 0 ]; then
printf "${SUCCESS}\n"
try_run "Building ... " cargo +stable build --color=always
try_run "Testing ... " cargo +stable test --color=always
try_run "Testing with all features enabled ... " cargo +stable test --all-features --color=always
for EXAMPLE in $(cargo +stable run --example 2>&1 | grep ' ' | awk '{print $1}')
do
try_run "Running example \"$EXAMPLE\" ... " cargo +stable run --example $EXAMPLE
done
check_toolchain nightly
if [ ${TOOLCHAIN_RESULT} == 1 ]; then
exit 1
if [ ${TOOLCHAIN_RESULT} != 1 ]; then
try_run "Running clippy ... " cargo +nightly clippy --color=always -Z unstable-options -- --deny warnings
fi
try_run "Building ... " cargo build --color=always
try_run "Testing ... " cargo test --color=always
try_run "Doc Test ... " cargo clean && cargo build --tests && rustdoc --test README.md --edition 2018 -L target/debug/deps -Z unstable-options
fi

View File

@@ -1,7 +1,9 @@
[package]
name = "tarpc-plugins"
version = "0.5.0"
version = "0.13.0"
rust-version = "1.56"
authors = ["Adam Wright <adam.austin.wright@gmail.com>", "Tim Kuehn <timothy.j.kuehn@gmail.com>"]
edition = "2021"
license = "MIT"
documentation = "https://docs.rs/tarpc-plugins"
homepage = "https://github.com/google/tarpc"
@@ -11,14 +13,22 @@ categories = ["asynchronous", "network-programming"]
readme = "../README.md"
description = "Proc macros for tarpc."
[features]
serde1 = []
[badges]
travis-ci = { repository = "google/tarpc" }
[dependencies]
itertools = "0.7"
syn = { version = "0.15", features = ["full", "extra-traits"] }
quote = "0.6"
proc-macro2 = "0.4"
proc-macro2 = "1.0"
quote = "1.0"
syn = { version = "1.0", features = ["full"] }
[lib]
proc-macro = true
[dev-dependencies]
assert-type-eq = "0.1.0"
futures = "0.3"
serde = { version = "1.0", features = ["derive"] }
tarpc = { path = "../tarpc", features = ["serde1"] }

9
plugins/LICENSE Normal file
View File

@@ -0,0 +1,9 @@
The MIT License (MIT)
Copyright 2016 Google Inc. All Rights Reserved.
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

View File

@@ -1 +1 @@
edition = "Edition2018"
edition = "2018"

View File

@@ -4,88 +4,651 @@
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
#![recursion_limit = "512"]
extern crate proc_macro;
extern crate proc_macro2;
extern crate syn;
extern crate itertools;
extern crate quote;
extern crate syn;
use proc_macro::TokenStream;
use proc_macro2::TokenStream as TokenStream2;
use quote::{format_ident, quote, ToTokens};
use syn::{
braced,
ext::IdentExt,
parenthesized,
parse::{Parse, ParseStream},
parse_macro_input, parse_quote,
spanned::Spanned,
token::Comma,
Attribute, FnArg, Ident, Lit, LitBool, MetaNameValue, Pat, PatType, ReturnType, Token, Type,
Visibility,
};
use itertools::Itertools;
use quote::ToTokens;
use syn::{Ident, TraitItemType, TypePath, parse};
use proc_macro2::Span;
use std::str::FromStr;
/// Accumulates multiple errors into a result.
/// Only use this for recoverable errors, i.e. non-parse errors. Fatal errors should early exit to
/// avoid further complications.
macro_rules! extend_errors {
($errors: ident, $e: expr) => {
match $errors {
Ok(_) => $errors = Err($e),
Err(ref mut errors) => errors.extend($e),
}
};
}
#[proc_macro]
pub fn snake_to_camel(input: TokenStream) -> TokenStream {
let i = input.clone();
let mut assoc_type = parse::<TraitItemType>(input).unwrap_or_else(|_| panic!("Could not parse trait item from:\n{}", i));
struct Service {
attrs: Vec<Attribute>,
vis: Visibility,
ident: Ident,
rpcs: Vec<RpcMethod>,
}
let old_ident = convert(&mut assoc_type.ident);
struct RpcMethod {
attrs: Vec<Attribute>,
ident: Ident,
args: Vec<PatType>,
output: ReturnType,
}
for mut attr in &mut assoc_type.attrs {
if let Some(pair) = attr.path.segments.first() {
if pair.value().ident == "doc" {
attr.tts = proc_macro2::TokenStream::from_str(&attr.tts.to_string().replace("{}", &old_ident)).unwrap();
impl Parse for Service {
fn parse(input: ParseStream) -> syn::Result<Self> {
let attrs = input.call(Attribute::parse_outer)?;
let vis = input.parse()?;
input.parse::<Token![trait]>()?;
let ident: Ident = input.parse()?;
let content;
braced!(content in input);
let mut rpcs = Vec::<RpcMethod>::new();
while !content.is_empty() {
rpcs.push(content.parse()?);
}
let mut ident_errors = Ok(());
for rpc in &rpcs {
if rpc.ident == "new" {
extend_errors!(
ident_errors,
syn::Error::new(
rpc.ident.span(),
format!(
"method name conflicts with generated fn `{}Client::new`",
ident.unraw()
)
)
);
}
if rpc.ident == "serve" {
extend_errors!(
ident_errors,
syn::Error::new(
rpc.ident.span(),
format!("method name conflicts with generated fn `{ident}::serve`")
)
);
}
}
ident_errors?;
Ok(Self {
attrs,
vis,
ident,
rpcs,
})
}
}
impl Parse for RpcMethod {
fn parse(input: ParseStream) -> syn::Result<Self> {
let attrs = input.call(Attribute::parse_outer)?;
input.parse::<Token![async]>()?;
input.parse::<Token![fn]>()?;
let ident = input.parse()?;
let content;
parenthesized!(content in input);
let mut args = Vec::new();
let mut errors = Ok(());
for arg in content.parse_terminated::<FnArg, Comma>(FnArg::parse)? {
match arg {
FnArg::Typed(captured) if matches!(&*captured.pat, Pat::Ident(_)) => {
args.push(captured);
}
FnArg::Typed(captured) => {
extend_errors!(
errors,
syn::Error::new(captured.pat.span(), "patterns aren't allowed in RPC args")
);
}
FnArg::Receiver(_) => {
extend_errors!(
errors,
syn::Error::new(arg.span(), "method args cannot start with self")
);
}
}
}
errors?;
let output = input.parse()?;
input.parse::<Token![;]>()?;
Ok(Self {
attrs,
ident,
args,
output,
})
}
}
// If `derive_serde` meta item is not present, defaults to cfg!(feature = "serde1").
// `derive_serde` can only be true when serde1 is enabled.
struct DeriveSerde(bool);
impl Parse for DeriveSerde {
fn parse(input: ParseStream) -> syn::Result<Self> {
let mut result = Ok(None);
let mut derive_serde = Vec::new();
let meta_items = input.parse_terminated::<MetaNameValue, Comma>(MetaNameValue::parse)?;
for meta in meta_items {
if meta.path.segments.len() != 1 {
extend_errors!(
result,
syn::Error::new(
meta.span(),
"tarpc::service does not support this meta item"
)
);
continue;
}
let segment = meta.path.segments.first().unwrap();
if segment.ident != "derive_serde" {
extend_errors!(
result,
syn::Error::new(
meta.span(),
"tarpc::service does not support this meta item"
)
);
continue;
}
match meta.lit {
Lit::Bool(LitBool { value: true, .. }) if cfg!(feature = "serde1") => {
result = result.and(Ok(Some(true)))
}
Lit::Bool(LitBool { value: true, .. }) => {
extend_errors!(
result,
syn::Error::new(
meta.span(),
"To enable serde, first enable the `serde1` feature of tarpc"
)
);
}
Lit::Bool(LitBool { value: false, .. }) => result = result.and(Ok(Some(false))),
_ => extend_errors!(
result,
syn::Error::new(
meta.lit.span(),
"`derive_serde` expects a value of type `bool`"
)
),
}
derive_serde.push(meta);
}
if derive_serde.len() > 1 {
for (i, derive_serde) in derive_serde.iter().enumerate() {
extend_errors!(
result,
syn::Error::new(
derive_serde.span(),
format!(
"`derive_serde` appears more than once (occurrence #{})",
i + 1
)
)
);
}
}
let derive_serde = result?.unwrap_or(cfg!(feature = "serde1"));
Ok(Self(derive_serde))
}
}
/// A helper attribute to avoid a direct dependency on Serde.
///
/// Adds the following annotations to the annotated item:
///
/// ```rust
/// #[derive(tarpc::serde::Serialize, tarpc::serde::Deserialize)]
/// #[serde(crate = "tarpc::serde")]
/// # struct Foo;
/// ```
#[proc_macro_attribute]
pub fn derive_serde(_attr: TokenStream, item: TokenStream) -> TokenStream {
let mut gen: proc_macro2::TokenStream = quote! {
#[derive(tarpc::serde::Serialize, tarpc::serde::Deserialize)]
#[serde(crate = "tarpc::serde")]
};
gen.extend(proc_macro2::TokenStream::from(item));
proc_macro::TokenStream::from(gen)
}
/// Generates:
/// - service trait
/// - serve fn
/// - client stub struct
/// - new_stub client factory fn
/// - Request and Response enums
/// - ResponseFut Future
#[proc_macro_attribute]
pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream {
let derive_serde = parse_macro_input!(attr as DeriveSerde);
let unit_type: &Type = &parse_quote!(());
let Service {
ref attrs,
ref vis,
ref ident,
ref rpcs,
} = parse_macro_input!(input as Service);
let camel_case_fn_names: &Vec<_> = &rpcs
.iter()
.map(|rpc| snake_to_camel(&rpc.ident.unraw().to_string()))
.collect();
let args: &[&[PatType]] = &rpcs.iter().map(|rpc| &*rpc.args).collect::<Vec<_>>();
let derive_serialize = if derive_serde.0 {
Some(
quote! {#[derive(tarpc::serde::Serialize, tarpc::serde::Deserialize)]
#[serde(crate = "tarpc::serde")]},
)
} else {
None
};
let methods = rpcs.iter().map(|rpc| &rpc.ident).collect::<Vec<_>>();
let request_names = methods
.iter()
.map(|m| format!("{ident}.{m}"))
.collect::<Vec<_>>();
ServiceGenerator {
service_ident: ident,
client_stub_ident: &format_ident!("{}Stub", ident),
server_ident: &format_ident!("Serve{}", ident),
client_ident: &format_ident!("{}Client", ident),
request_ident: &format_ident!("{}Request", ident),
response_ident: &format_ident!("{}Response", ident),
vis,
args,
method_attrs: &rpcs.iter().map(|rpc| &*rpc.attrs).collect::<Vec<_>>(),
method_idents: &methods,
request_names: &request_names,
attrs,
rpcs,
return_types: &rpcs
.iter()
.map(|rpc| match rpc.output {
ReturnType::Type(_, ref ty) => ty,
ReturnType::Default => unit_type,
})
.collect::<Vec<_>>(),
arg_pats: &args
.iter()
.map(|args| args.iter().map(|arg| &*arg.pat).collect())
.collect::<Vec<_>>(),
camel_case_idents: &rpcs
.iter()
.zip(camel_case_fn_names.iter())
.map(|(rpc, name)| Ident::new(name, rpc.ident.span()))
.collect::<Vec<_>>(),
derive_serialize: derive_serialize.as_ref(),
}
.into_token_stream()
.into()
}
// Things needed to generate the service items: trait, serve impl, request/response enums, and
// the client stub.
struct ServiceGenerator<'a> {
service_ident: &'a Ident,
client_stub_ident: &'a Ident,
server_ident: &'a Ident,
client_ident: &'a Ident,
request_ident: &'a Ident,
response_ident: &'a Ident,
vis: &'a Visibility,
attrs: &'a [Attribute],
rpcs: &'a [RpcMethod],
camel_case_idents: &'a [Ident],
method_idents: &'a [&'a Ident],
request_names: &'a [String],
method_attrs: &'a [&'a [Attribute]],
args: &'a [&'a [PatType]],
return_types: &'a [&'a Type],
arg_pats: &'a [Vec<&'a Pat>],
derive_serialize: Option<&'a TokenStream2>,
}
impl<'a> ServiceGenerator<'a> {
fn trait_service(&self) -> TokenStream2 {
let &Self {
attrs,
rpcs,
vis,
return_types,
service_ident,
client_stub_ident,
request_ident,
response_ident,
server_ident,
..
} = self;
let rpc_fns = rpcs
.iter()
.zip(return_types.iter())
.map(
|(
RpcMethod {
attrs, ident, args, ..
},
output,
)| {
quote! {
#( #attrs )*
async fn #ident(self, context: tarpc::context::Context, #( #args ),*) -> #output;
}
},
);
let stub_doc = format!("The stub trait for service [`{service_ident}`].");
quote! {
#( #attrs )*
#vis trait #service_ident: Sized {
#( #rpc_fns )*
/// Returns a serving function to use with
/// [InFlightRequest::execute](tarpc::server::InFlightRequest::execute).
fn serve(self) -> #server_ident<Self> {
#server_ident { service: self }
}
}
#[doc = #stub_doc]
#vis trait #client_stub_ident: tarpc::client::stub::Stub<Req = #request_ident, Resp = #response_ident> {
}
impl<S> #client_stub_ident for S
where S: tarpc::client::stub::Stub<Req = #request_ident, Resp = #response_ident>
{
}
}
}
assoc_type.into_token_stream().into()
}
fn struct_server(&self) -> TokenStream2 {
let &Self {
vis, server_ident, ..
} = self;
#[proc_macro]
pub fn ty_snake_to_camel(input: TokenStream) -> TokenStream {
let mut path = parse::<TypePath>(input).unwrap();
// Only capitalize the final segment
convert(&mut path.path
.segments
.last_mut()
.unwrap()
.into_value()
.ident);
path.into_token_stream().into()
}
/// Converts an ident in-place to CamelCase and returns the previous ident.
fn convert(ident: &mut Ident) -> String {
let ident_str = ident.to_string();
let mut camel_ty = String::new();
{
// Find the first non-underscore and add it capitalized.
let mut chars = ident_str.chars();
// Find the first non-underscore char, uppercase it, and append it.
// Guaranteed to succeed because all idents must have at least one non-underscore char.
camel_ty.extend(chars.find(|&c| c != '_').unwrap().to_uppercase());
// When we find an underscore, we remove it and capitalize the next char. To do this,
// we need to ensure the next char is not another underscore.
let mut chars = chars.coalesce(|c1, c2| {
if c1 == '_' && c2 == '_' {
Ok(c1)
} else {
Err((c1, c2))
quote! {
/// A serving function to use with [tarpc::server::InFlightRequest::execute].
#[derive(Clone)]
#vis struct #server_ident<S> {
service: S,
}
});
}
}
while let Some(c) = chars.next() {
if c != '_' {
camel_ty.push(c);
} else if let Some(c) = chars.next() {
fn impl_serve_for_server(&self) -> TokenStream2 {
let &Self {
request_ident,
server_ident,
service_ident,
response_ident,
camel_case_idents,
arg_pats,
method_idents,
request_names,
..
} = self;
quote! {
impl<S> tarpc::server::Serve for #server_ident<S>
where S: #service_ident
{
type Req = #request_ident;
type Resp = #response_ident;
fn method(&self, req: &#request_ident) -> Option<&'static str> {
Some(match req {
#(
#request_ident::#camel_case_idents{..} => {
#request_names
}
)*
})
}
async fn serve(self, ctx: tarpc::context::Context, req: #request_ident)
-> Result<#response_ident, tarpc::ServerError> {
match req {
#(
#request_ident::#camel_case_idents{ #( #arg_pats ),* } => {
Ok(#response_ident::#camel_case_idents(
#service_ident::#method_idents(
self.service, ctx, #( #arg_pats ),*
).await
))
}
)*
}
}
}
}
}
fn enum_request(&self) -> TokenStream2 {
let &Self {
derive_serialize,
vis,
request_ident,
camel_case_idents,
args,
..
} = self;
quote! {
/// The request sent over the wire from the client to the server.
#[allow(missing_docs)]
#[derive(Debug)]
#derive_serialize
#vis enum #request_ident {
#( #camel_case_idents{ #( #args ),* } ),*
}
}
}
fn enum_response(&self) -> TokenStream2 {
let &Self {
derive_serialize,
vis,
response_ident,
camel_case_idents,
return_types,
..
} = self;
quote! {
/// The response sent over the wire from the server to the client.
#[allow(missing_docs)]
#[derive(Debug)]
#derive_serialize
#vis enum #response_ident {
#( #camel_case_idents(#return_types) ),*
}
}
}
fn struct_client(&self) -> TokenStream2 {
let &Self {
vis,
client_ident,
request_ident,
response_ident,
..
} = self;
quote! {
#[allow(unused)]
#[derive(Clone, Debug)]
/// The client stub that makes RPC calls to the server. All request methods return
/// [Futures](std::future::Future).
#vis struct #client_ident<
Stub = tarpc::client::Channel<#request_ident, #response_ident>
>(Stub);
}
}
fn impl_client_new(&self) -> TokenStream2 {
let &Self {
client_ident,
vis,
request_ident,
response_ident,
..
} = self;
quote! {
impl #client_ident {
/// Returns a new client stub that sends requests over the given transport.
#vis fn new<T>(config: tarpc::client::Config, transport: T)
-> tarpc::client::NewClient<
Self,
tarpc::client::RequestDispatch<#request_ident, #response_ident, T>
>
where
T: tarpc::Transport<tarpc::ClientMessage<#request_ident>, tarpc::Response<#response_ident>>
{
let new_client = tarpc::client::new(config, transport);
tarpc::client::NewClient {
client: #client_ident(new_client.client),
dispatch: new_client.dispatch,
}
}
}
impl<Stub> From<Stub> for #client_ident<Stub>
where Stub: tarpc::client::stub::Stub<
Req = #request_ident,
Resp = #response_ident>
{
/// Returns a new client stub that sends requests over the given transport.
fn from(stub: Stub) -> Self {
#client_ident(stub)
}
}
}
}
fn impl_client_rpc_methods(&self) -> TokenStream2 {
let &Self {
client_ident,
request_ident,
response_ident,
method_attrs,
vis,
method_idents,
request_names,
args,
return_types,
arg_pats,
camel_case_idents,
..
} = self;
quote! {
impl<Stub> #client_ident<Stub>
where Stub: tarpc::client::stub::Stub<
Req = #request_ident,
Resp = #response_ident>
{
#(
#[allow(unused)]
#( #method_attrs )*
#vis fn #method_idents(&self, ctx: tarpc::context::Context, #( #args ),*)
-> impl std::future::Future<Output = Result<#return_types, tarpc::client::RpcError>> + '_ {
let request = #request_ident::#camel_case_idents { #( #arg_pats ),* };
let resp = self.0.call(ctx, #request_names, request);
async move {
match resp.await? {
#response_ident::#camel_case_idents(msg) => std::result::Result::Ok(msg),
_ => unreachable!(),
}
}
}
)*
}
}
}
}
impl<'a> ToTokens for ServiceGenerator<'a> {
fn to_tokens(&self, output: &mut TokenStream2) {
output.extend(vec![
self.trait_service(),
self.struct_server(),
self.impl_serve_for_server(),
self.enum_request(),
self.enum_response(),
self.struct_client(),
self.impl_client_new(),
self.impl_client_rpc_methods(),
])
}
}
fn snake_to_camel(ident_str: &str) -> String {
let mut camel_ty = String::with_capacity(ident_str.len());
let mut last_char_was_underscore = true;
for c in ident_str.chars() {
match c {
'_' => last_char_was_underscore = true,
c if last_char_was_underscore => {
camel_ty.extend(c.to_uppercase());
last_char_was_underscore = false;
}
c => camel_ty.extend(c.to_lowercase()),
}
}
// The Fut suffix is hardcoded right now; this macro isn't really meant to be general-purpose.
camel_ty.push_str("Fut");
*ident = Ident::new(&camel_ty, Span::call_site());
ident_str
camel_ty.shrink_to_fit();
camel_ty
}
#[test]
fn snake_to_camel_basic() {
assert_eq!(snake_to_camel("abc_def"), "AbcDef");
}
#[test]
fn snake_to_camel_underscore_suffix() {
assert_eq!(snake_to_camel("abc_def_"), "AbcDef");
}
#[test]
fn snake_to_camel_underscore_prefix() {
assert_eq!(snake_to_camel("_abc_def"), "AbcDef");
}
#[test]
fn snake_to_camel_underscore_consecutive() {
assert_eq!(snake_to_camel("abc__def"), "AbcDef");
}
#[test]
fn snake_to_camel_capital_in_middle() {
assert_eq!(snake_to_camel("aBc_dEf"), "AbcDef");
}

44
plugins/tests/server.rs Normal file
View File

@@ -0,0 +1,44 @@
// these need to be out here rather than inside the function so that the
// assert_type_eq macro can pick them up.
#[tarpc::service]
trait Foo {
async fn two_part(s: String, i: i32) -> (String, i32);
async fn bar(s: String) -> String;
async fn baz();
}
#[allow(non_camel_case_types)]
#[test]
fn raw_idents_work() {
type r#yield = String;
#[tarpc::service]
trait r#trait {
async fn r#await(r#struct: r#yield, r#enum: i32) -> (r#yield, i32);
async fn r#fn(r#impl: r#yield) -> r#yield;
async fn r#async();
}
}
#[test]
fn syntax() {
#[tarpc::service]
trait Syntax {
#[deny(warnings)]
#[allow(non_snake_case)]
async fn TestCamelCaseDoesntConflict();
async fn hello() -> String;
#[doc = "attr"]
async fn attr(s: String) -> String;
async fn no_args_no_return();
async fn no_args() -> ();
async fn one_arg(one: String) -> i32;
async fn two_args_no_return(one: String, two: u64);
async fn two_args(one: String, two: u64) -> String;
async fn no_args_ret_error() -> i32;
async fn one_arg_ret_error(one: String) -> String;
async fn no_arg_implicit_return_error();
#[doc = "attr"]
async fn one_arg_implicit_return_error(one: String);
}
}

80
plugins/tests/service.rs Normal file
View File

@@ -0,0 +1,80 @@
use tarpc::context;
#[test]
fn att_service_trait() {
#[tarpc::service]
trait Foo {
async fn two_part(s: String, i: i32) -> (String, i32);
async fn bar(s: String) -> String;
async fn baz();
}
impl Foo for () {
async fn two_part(self, _: context::Context, s: String, i: i32) -> (String, i32) {
(s, i)
}
async fn bar(self, _: context::Context, s: String) -> String {
s
}
async fn baz(self, _: context::Context) {
()
}
}
}
#[allow(non_camel_case_types)]
#[test]
fn raw_idents() {
type r#yield = String;
#[tarpc::service]
trait r#trait {
async fn r#await(r#struct: r#yield, r#enum: i32) -> (r#yield, i32);
async fn r#fn(r#impl: r#yield) -> r#yield;
async fn r#async();
}
impl r#trait for () {
async fn r#await(
self,
_: context::Context,
r#struct: r#yield,
r#enum: i32,
) -> (r#yield, i32) {
(r#struct, r#enum)
}
async fn r#fn(self, _: context::Context, r#impl: r#yield) -> r#yield {
r#impl
}
async fn r#async(self, _: context::Context) {
()
}
}
}
#[test]
fn syntax() {
#[tarpc::service]
trait Syntax {
#[deny(warnings)]
#[allow(non_snake_case)]
async fn TestCamelCaseDoesntConflict();
async fn hello() -> String;
#[doc = "attr"]
async fn attr(s: String) -> String;
async fn no_args_no_return();
async fn no_args() -> ();
async fn one_arg(one: String) -> i32;
async fn two_args_no_return(one: String, two: u64);
async fn two_args(one: String, two: u64) -> String;
async fn no_args_ret_error() -> i32;
async fn one_arg_ret_error(one: String) -> String;
async fn no_arg_implicit_return_error();
#[doc = "attr"]
async fn one_arg_implicit_return_error(one: String);
}
}

View File

@@ -1,38 +0,0 @@
cargo-features = ["rename-dependency"]
[package]
name = "tarpc-lib"
version = "0.1.0"
authors = ["Tim Kuehn <tikue@google.com>"]
edition = '2018'
license = "MIT"
documentation = "https://docs.rs/tarpc-lib"
homepage = "https://github.com/google/tarpc"
repository = "https://github.com/google/tarpc"
keywords = ["rpc", "network", "server", "api", "microservices"]
categories = ["asynchronous", "network-programming"]
readme = "../README.md"
description = "An RPC framework for Rust with a focus on ease of use."
[features]
default = []
serde1 = ["trace/serde", "serde", "serde/derive"]
[dependencies]
fnv = "1.0"
humantime = "1.0"
log = "0.4"
pin-utils = "0.1.0-alpha.2"
rand = "0.5"
tokio-timer = "0.2"
trace = { package = "tarpc-trace", version = "0.1", path = "../trace" }
serde = { optional = true, version = "1.0" }
[target.'cfg(not(test))'.dependencies]
futures-preview = { version = "0.3.0-alpha.8", features = ["compat"] }
[dev-dependencies]
futures-preview = { version = "0.3.0-alpha.8", features = ["compat", "tokio-compat"] }
futures-test-preview = { version = "0.3.0-alpha.8" }
env_logger = "0.5"
tokio = "0.1"

View File

@@ -1 +0,0 @@
edition = "Edition2018"

View File

@@ -1,714 +0,0 @@
// Copyright 2018 Google LLC
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
use crate::{
context,
util::{deadline_compat, AsDuration, Compact},
ClientMessage, ClientMessageKind, Request, Response, Transport,
};
use fnv::FnvHashMap;
use futures::{
Poll,
channel::{mpsc, oneshot},
prelude::*,
ready,
stream::Fuse,
task::LocalWaker,
};
use humantime::format_rfc3339;
use log::{debug, error, info, trace};
use pin_utils::unsafe_pinned;
use std::{
io,
net::SocketAddr,
pin::Pin,
sync::{
atomic::{AtomicU64, Ordering},
Arc,
},
time::Instant,
};
use trace::SpanId;
use super::Config;
/// Handles communication from the client to request dispatch.
#[derive(Debug)]
pub(crate) struct Channel<Req, Resp> {
to_dispatch: mpsc::Sender<DispatchRequest<Req, Resp>>,
/// Channel to send a cancel message to the dispatcher.
cancellation: RequestCancellation,
/// The ID to use for the next request to stage.
next_request_id: Arc<AtomicU64>,
server_addr: SocketAddr,
}
impl<Req, Resp> Clone for Channel<Req, Resp> {
fn clone(&self) -> Self {
Self {
to_dispatch: self.to_dispatch.clone(),
cancellation: self.cancellation.clone(),
next_request_id: self.next_request_id.clone(),
server_addr: self.server_addr,
}
}
}
impl<Req, Resp> Channel<Req, Resp> {
/// Sends a request to the dispatch task to forward to the server, returning a [`Future`] that
/// resolves when the request is sent (not when the response is received).
pub(crate) async fn send(
&mut self,
mut ctx: context::Context,
request: Req,
) -> io::Result<DispatchResponse<Resp>> {
// Convert the context to the call context.
ctx.trace_context.parent_id = Some(ctx.trace_context.span_id);
ctx.trace_context.span_id = SpanId::random(&mut rand::thread_rng());
let timeout = ctx.deadline.as_duration();
let deadline = Instant::now() + timeout;
trace!(
"[{}/{}] Queuing request with deadline {} (timeout {:?}).",
ctx.trace_id(),
self.server_addr,
format_rfc3339(ctx.deadline),
timeout,
);
let (response_completion, response) = oneshot::channel();
let cancellation = self.cancellation.clone();
let request_id = self.next_request_id.fetch_add(1, Ordering::Relaxed);
await!(self.to_dispatch.send(DispatchRequest {
ctx,
request_id,
request,
response_completion,
})).map_err(|_| io::Error::from(io::ErrorKind::ConnectionReset))?;
Ok(DispatchResponse {
response: deadline_compat::Deadline::new(response, deadline),
complete: false,
request_id,
cancellation,
ctx,
server_addr: self.server_addr,
})
}
/// Sends a request to the dispatch task to forward to the server, returning a [`Future`] that
/// resolves to the response.
pub(crate) async fn call(
&mut self,
context: context::Context,
request: Req,
) -> io::Result<Resp> {
let response_future = await!(self.send(context, request))?;
await!(response_future)
}
}
/// A server response that is completed by request dispatch when the corresponding response
/// arrives off the wire.
#[derive(Debug)]
pub struct DispatchResponse<Resp> {
response: deadline_compat::Deadline<oneshot::Receiver<Response<Resp>>>,
ctx: context::Context,
complete: bool,
cancellation: RequestCancellation,
request_id: u64,
server_addr: SocketAddr,
}
impl<Resp> DispatchResponse<Resp> {
unsafe_pinned!(server_addr: SocketAddr);
unsafe_pinned!(ctx: context::Context);
}
impl<Resp> Future for DispatchResponse<Resp> {
type Output = io::Result<Resp>;
fn poll(mut self: Pin<&mut Self>, waker: &LocalWaker) -> Poll<io::Result<Resp>> {
let resp = ready!(self.response.poll_unpin(waker));
self.complete = true;
Poll::Ready(match resp {
Ok(resp) => Ok(resp.message?),
Err(e) => Err({
let trace_id = *self.ctx().trace_id();
let server_addr = *self.server_addr();
if e.is_elapsed() {
io::Error::new(
io::ErrorKind::TimedOut,
"Client dropped expired request.".to_string(),
)
} else if e.is_timer() {
let e = e.into_timer().unwrap();
if e.is_at_capacity() {
io::Error::new(
io::ErrorKind::Other,
"Cancelling request because an expiration could not be set \
due to the timer being at capacity."
.to_string(),
)
} else if e.is_shutdown() {
panic!("[{}/{}] Timer was shutdown", trace_id, server_addr)
} else {
panic!(
"[{}/{}] Unrecognized timer error: {}",
trace_id, server_addr, e
)
}
} else if e.is_inner() {
// The oneshot is Canceled when the dispatch task ends.
io::Error::from(io::ErrorKind::ConnectionReset)
} else {
panic!(
"[{}/{}] Unrecognized deadline error: {}",
trace_id, server_addr, e
)
}
}),
})
}
}
// Cancels the request when dropped, if not already complete.
impl<Resp> Drop for DispatchResponse<Resp> {
fn drop(&mut self) {
if !self.complete {
// The receiver needs to be closed to handle the edge case that the request has not
// yet been received by the dispatch task. It is possible for the cancel message to
// arrive before the request itself, in which case the request could get stuck in the
// dispatch map forever if the server never responds (e.g. if the server dies while
// responding). Even if the server does respond, it will have unnecessarily done work
// for a client no longer waiting for a response. To avoid this, the dispatch task
// checks if the receiver is closed before inserting the request in the map. By
// closing the receiver before sending the cancel message, it is guaranteed that if the
// dispatch task misses an early-arriving cancellation message, then it will see the
// receiver as closed.
self.response.get_mut().close();
self.cancellation.cancel(self.request_id);
}
}
}
/// Spawns a dispatch task on the default executor that manages the lifecycle of requests initiated
/// by the returned [`Channel`].
pub async fn spawn<Req, Resp, C>(
config: Config,
transport: C,
server_addr: SocketAddr,
) -> io::Result<Channel<Req, Resp>>
where
Req: Send,
Resp: Send,
C: Transport<Item = Response<Resp>, SinkItem = ClientMessage<Req>> + Send,
{
let (to_dispatch, pending_requests) = mpsc::channel(config.pending_request_buffer);
let (cancellation, canceled_requests) = cancellations();
crate::spawn(
RequestDispatch {
config,
server_addr,
canceled_requests,
transport: transport.fuse(),
in_flight_requests: FnvHashMap::default(),
pending_requests: pending_requests.fuse(),
}.unwrap_or_else(move |e| error!("[{}] Connection broken: {}", server_addr, e))
).map_err(|e| {
io::Error::new(
io::ErrorKind::Other,
format!(
"Could not spawn client dispatch task. Is shutdown: {}",
e.is_shutdown()
),
)
})?;
Ok(Channel {
to_dispatch,
cancellation,
server_addr,
next_request_id: Arc::new(AtomicU64::new(0)),
})
}
/// Handles the lifecycle of requests, writing requests to the wire, managing cancellations,
/// and dispatching responses to the appropriate channel.
struct RequestDispatch<Req, Resp, C> {
/// Writes requests to the wire and reads responses off the wire.
transport: Fuse<C>,
/// Requests waiting to be written to the wire.
pending_requests: Fuse<mpsc::Receiver<DispatchRequest<Req, Resp>>>,
/// Requests that were dropped.
canceled_requests: CanceledRequests,
/// Requests already written to the wire that haven't yet received responses.
in_flight_requests: FnvHashMap<u64, InFlightData<Resp>>,
/// Configures limits to prevent unlimited resource usage.
config: Config,
/// The address of the server connected to.
server_addr: SocketAddr,
}
impl<Req, Resp, C> RequestDispatch<Req, Resp, C>
where
Req: Send,
Resp: Send,
C: Transport<Item = Response<Resp>, SinkItem = ClientMessage<Req>>,
{
unsafe_pinned!(server_addr: SocketAddr);
unsafe_pinned!(in_flight_requests: FnvHashMap<u64, InFlightData<Resp>>);
unsafe_pinned!(canceled_requests: CanceledRequests);
unsafe_pinned!(pending_requests: Fuse<mpsc::Receiver<DispatchRequest<Req, Resp>>>);
unsafe_pinned!(transport: Fuse<C>);
fn pump_read(self: &mut Pin<&mut Self>, waker: &LocalWaker) -> Poll<Option<io::Result<()>>> {
Poll::Ready(match ready!(self.transport().poll_next(waker)?) {
Some(response) => {
self.complete(response);
Some(Ok(()))
}
None => {
trace!("[{}] read half closed", self.server_addr());
None
}
})
}
fn pump_write(self: &mut Pin<&mut Self>, waker: &LocalWaker) -> Poll<Option<io::Result<()>>> {
enum ReceiverStatus {
NotReady,
Closed,
}
let pending_requests_status = match self.poll_next_request(waker)? {
Poll::Ready(Some(dispatch_request)) => {
self.write_request(dispatch_request)?;
return Poll::Ready(Some(Ok(())));
}
Poll::Ready(None) => ReceiverStatus::Closed,
Poll::Pending => ReceiverStatus::NotReady,
};
let canceled_requests_status = match self.poll_next_cancellation(waker)? {
Poll::Ready(Some((context, request_id))) => {
self.write_cancel(context, request_id)?;
return Poll::Ready(Some(Ok(())));
}
Poll::Ready(None) => ReceiverStatus::Closed,
Poll::Pending => ReceiverStatus::NotReady,
};
match (pending_requests_status, canceled_requests_status) {
(ReceiverStatus::Closed, ReceiverStatus::Closed) => {
ready!(self.transport().poll_flush(waker)?);
Poll::Ready(None)
}
(ReceiverStatus::NotReady, _) | (_, ReceiverStatus::NotReady) => {
// No more messages to process, so flush any messages buffered in the transport.
ready!(self.transport().poll_flush(waker)?);
// Even if we fully-flush, we return Pending, because we have no more requests
// or cancellations right now.
Poll::Pending
}
}
}
/// Yields the next pending request, if one is ready to be sent.
fn poll_next_request(
self: &mut Pin<&mut Self>,
waker: &LocalWaker,
) -> Poll<Option<io::Result<DispatchRequest<Req, Resp>>>> {
if self.in_flight_requests().len() >= self.config.max_in_flight_requests {
info!(
"At in-flight request capacity ({}/{}).",
self.in_flight_requests().len(),
self.config.max_in_flight_requests
);
// No need to schedule a wakeup, because timers and responses are responsible
// for clearing out in-flight requests.
return Poll::Pending;
}
while let Poll::Pending = self.transport().poll_ready(waker)? {
// We can't yield a request-to-be-sent before the transport is capable of buffering it.
ready!(self.transport().poll_flush(waker)?);
}
loop {
match ready!(self.pending_requests().poll_next_unpin(waker)) {
Some(request) => {
if request.response_completion.is_canceled() {
trace!(
"[{}] Request canceled before being sent.",
request.ctx.trace_id()
);
continue;
}
return Poll::Ready(Some(Ok(request)));
}
None => {
trace!("[{}] pending_requests closed", self.server_addr());
return Poll::Ready(None);
}
}
}
}
/// Yields the next pending cancellation, and, if one is ready, cancels the associated request.
fn poll_next_cancellation(
self: &mut Pin<&mut Self>,
waker: &LocalWaker,
) -> Poll<Option<io::Result<(context::Context, u64)>>> {
while let Poll::Pending = self.transport().poll_ready(waker)? {
ready!(self.transport().poll_flush(waker)?);
}
loop {
match ready!(self.canceled_requests().poll_next_unpin(waker)) {
Some(request_id) => {
if let Some(in_flight_data) = self.in_flight_requests().remove(&request_id) {
self.in_flight_requests().compact(0.1);
debug!(
"[{}/{}] Removed request.",
in_flight_data.ctx.trace_id(),
self.server_addr()
);
return Poll::Ready(Some(Ok((in_flight_data.ctx, request_id))));
}
}
None => {
trace!("[{}] canceled_requests closed.", self.server_addr());
return Poll::Ready(None);
}
}
}
}
fn write_request(
self: &mut Pin<&mut Self>,
dispatch_request: DispatchRequest<Req, Resp>,
) -> io::Result<()> {
let request_id = dispatch_request.request_id;
let request = ClientMessage {
trace_context: dispatch_request.ctx.trace_context,
message: ClientMessageKind::Request(Request {
id: request_id,
message: dispatch_request.request,
deadline: dispatch_request.ctx.deadline,
}),
};
self.transport().start_send(request)?;
self.in_flight_requests().insert(
request_id,
InFlightData {
ctx: dispatch_request.ctx,
response_completion: dispatch_request.response_completion,
},
);
Ok(())
}
fn write_cancel(
self: &mut Pin<&mut Self>,
context: context::Context,
request_id: u64,
) -> io::Result<()> {
let trace_id = *context.trace_id();
let cancel = ClientMessage {
trace_context: context.trace_context,
message: ClientMessageKind::Cancel { request_id },
};
self.transport().start_send(cancel)?;
trace!("[{}/{}] Cancel message sent.", trace_id, self.server_addr());
return Ok(());
}
/// Sends a server response to the client task that initiated the associated request.
fn complete(self: &mut Pin<&mut Self>, response: Response<Resp>) -> bool {
if let Some(in_flight_data) = self.in_flight_requests().remove(&response.request_id) {
self.in_flight_requests().compact(0.1);
trace!(
"[{}/{}] Received response.",
in_flight_data.ctx.trace_id(),
self.server_addr()
);
let _ = in_flight_data.response_completion.send(response);
return true;
}
debug!(
"[{}] No in-flight request found for request_id = {}.",
self.server_addr(),
response.request_id
);
// If the response completion was absent, then the request was already canceled.
false
}
}
impl<Req, Resp, C> Future for RequestDispatch<Req, Resp, C>
where
Req: Send,
Resp: Send,
C: Transport<Item = Response<Resp>, SinkItem = ClientMessage<Req>>,
{
type Output = io::Result<()>;
fn poll(mut self: Pin<&mut Self>, waker: &LocalWaker) -> Poll<io::Result<()>> {
trace!("[{}] RequestDispatch::poll", self.server_addr());
loop {
match (self.pump_read(waker)?, self.pump_write(waker)?) {
(read, write @ Poll::Ready(None)) => {
if self.in_flight_requests().is_empty() {
info!(
"[{}] Shutdown: write half closed, and no requests in flight.",
self.server_addr()
);
return Poll::Ready(Ok(()));
}
match read {
Poll::Ready(Some(())) => continue,
_ => {
trace!(
"[{}] read: {:?}, write: {:?}, (not ready)",
self.server_addr(),
read,
write,
);
return Poll::Pending;
}
}
}
(read @ Poll::Ready(Some(())), write) | (read, write @ Poll::Ready(Some(()))) => {
trace!(
"[{}] read: {:?}, write: {:?}",
self.server_addr(),
read,
write,
)
}
(read, write) => {
trace!(
"[{}] read: {:?}, write: {:?} (not ready)",
self.server_addr(),
read,
write,
);
return Poll::Pending;
}
}
}
}
}
/// A server-bound request sent from a [`Channel`] to request dispatch, which will then manage
/// the lifecycle of the request.
#[derive(Debug)]
struct DispatchRequest<Req, Resp> {
ctx: context::Context,
request_id: u64,
request: Req,
response_completion: oneshot::Sender<Response<Resp>>,
}
struct InFlightData<Resp> {
ctx: context::Context,
response_completion: oneshot::Sender<Response<Resp>>,
}
/// Sends request cancellation signals.
#[derive(Debug, Clone)]
struct RequestCancellation(mpsc::UnboundedSender<u64>);
/// A stream of IDs of requests that have been canceled.
#[derive(Debug)]
struct CanceledRequests(mpsc::UnboundedReceiver<u64>);
/// Returns a channel to send request cancellation messages.
fn cancellations() -> (RequestCancellation, CanceledRequests) {
// Unbounded because messages are sent in the drop fn. This is fine, because it's still
// bounded by the number of in-flight requests. Additionally, each request has a clone
// of the sender, so the bounded channel would have the same behavior,
// since it guarantees a slot.
let (tx, rx) = mpsc::unbounded();
(RequestCancellation(tx), CanceledRequests(rx))
}
impl RequestCancellation {
/// Cancels the request with ID `request_id`.
fn cancel(&mut self, request_id: u64) {
let _ = self.0.unbounded_send(request_id);
}
}
impl Stream for CanceledRequests {
type Item = u64;
fn poll_next(mut self: Pin<&mut Self>, waker: &LocalWaker) -> Poll<Option<u64>> {
self.0.poll_next_unpin(waker)
}
}
#[cfg(test)]
mod tests {
use super::{CanceledRequests, Channel, RequestCancellation, RequestDispatch};
use crate::{
client::Config,
context,
transport::{self, channel::UnboundedChannel},
ClientMessage, Response,
};
use fnv::FnvHashMap;
use futures::{Poll, channel::mpsc, prelude::*};
use futures_test::task::{noop_local_waker_ref};
use std::{
net::{IpAddr, Ipv4Addr, SocketAddr},
pin::Pin,
sync::atomic::AtomicU64,
sync::Arc,
};
#[test]
fn stage_request() {
let (mut dispatch, mut channel, _server_channel) = set_up();
// Test that a request future dropped before it's processed by dispatch will cause the request
// to not be added to the in-flight request map.
let _resp = tokio::runtime::current_thread::block_on_all(
channel
.send(context::current(), "hi".to_string())
.boxed()
.compat(),
);
let mut dispatch = Pin::new(&mut dispatch);
let waker = &noop_local_waker_ref();
let req = dispatch.poll_next_request(waker).ready();
assert!(req.is_some());
let req = req.unwrap();
assert_eq!(req.request_id, 0);
assert_eq!(req.request, "hi".to_string());
}
#[test]
fn stage_request_response_future_dropped() {
let (mut dispatch, mut channel, _server_channel) = set_up();
// Test that a request future dropped before it's processed by dispatch will cause the request
// to not be added to the in-flight request map.
let resp = tokio::runtime::current_thread::block_on_all(
channel
.send(context::current(), "hi".into())
.boxed()
.compat(),
).unwrap();
drop(resp);
drop(channel);
let mut dispatch = Pin::new(&mut dispatch);
let waker = &noop_local_waker_ref();
dispatch.poll_next_cancellation(waker).unwrap();
assert!(dispatch.poll_next_request(waker).ready().is_none());
}
#[test]
fn stage_request_response_future_closed() {
let (mut dispatch, mut channel, _server_channel) = set_up();
// Test that a request future that's closed its receiver but not yet canceled its request --
// i.e. still in `drop fn` -- will cause the request to not be added to the in-flight request
// map.
let resp = tokio::runtime::current_thread::block_on_all(
channel
.send(context::current(), "hi".into())
.boxed()
.compat(),
).unwrap();
drop(resp);
drop(channel);
let mut dispatch = Pin::new(&mut dispatch);
let waker = &noop_local_waker_ref();
assert!(dispatch.poll_next_request(waker).ready().is_none());
}
fn set_up() -> (
RequestDispatch<String, String, UnboundedChannel<Response<String>, ClientMessage<String>>>,
Channel<String, String>,
UnboundedChannel<ClientMessage<String>, Response<String>>,
) {
let _ = env_logger::try_init();
let (to_dispatch, pending_requests) = mpsc::channel(1);
let (cancel_tx, canceled_requests) = mpsc::unbounded();
let (client_channel, server_channel) = transport::channel::unbounded();
let dispatch = RequestDispatch::<String, String, _> {
transport: client_channel.fuse(),
pending_requests: pending_requests.fuse(),
canceled_requests: CanceledRequests(canceled_requests),
in_flight_requests: FnvHashMap::default(),
config: Config::default(),
server_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
};
let cancellation = RequestCancellation(cancel_tx);
let channel = Channel {
to_dispatch,
cancellation,
next_request_id: Arc::new(AtomicU64::new(0)),
server_addr: SocketAddr::new(Ipv4Addr::UNSPECIFIED.into(), 0),
};
(dispatch, channel, server_channel)
}
trait PollTest {
type T;
fn unwrap(self) -> Poll<Self::T>;
fn ready(self) -> Self::T;
}
impl<T, E> PollTest for Poll<Option<Result<T, E>>>
where
E: ::std::fmt::Display + Send + 'static,
{
type T = Option<T>;
fn unwrap(self) -> Poll<Option<T>> {
match self {
Poll::Ready(Some(Ok(t))) => Poll::Ready(Some(t)),
Poll::Ready(None) => Poll::Ready(None),
Poll::Ready(Some(Err(e))) => panic!(e.to_string()),
Poll::Pending => Poll::Pending,
}
}
fn ready(self) -> Option<T> {
match self {
Poll::Ready(Some(Ok(t))) => Some(t),
Poll::Ready(None) => None,
Poll::Ready(Some(Err(e))) => panic!(e.to_string()),
Poll::Pending => panic!("Pending"),
}
}
}
}

View File

@@ -1,91 +0,0 @@
// Copyright 2018 Google LLC
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
//! Provides a client that connects to a server and sends multiplexed requests.
use crate::{context::Context, ClientMessage, Response, Transport};
use log::warn;
use std::{
io,
net::{Ipv4Addr, SocketAddr},
};
mod dispatch;
/// Sends multiplexed requests to, and receives responses from, a server.
#[derive(Debug)]
pub struct Client<Req, Resp> {
/// Channel to send requests to the dispatch task.
channel: dispatch::Channel<Req, Resp>,
}
impl<Req, Resp> Clone for Client<Req, Resp> {
fn clone(&self) -> Self {
Client {
channel: self.channel.clone(),
}
}
}
/// Settings that control the behavior of the client.
#[non_exhaustive]
#[derive(Clone, Debug)]
pub struct Config {
/// The number of requests that can be in flight at once.
/// `max_in_flight_requests` controls the size of the map used by the client
/// for storing pending requests.
pub max_in_flight_requests: usize,
/// The number of requests that can be buffered client-side before being sent.
/// `pending_requests_buffer` controls the size of the channel clients use
/// to communicate with the request dispatch task.
pub pending_request_buffer: usize,
}
impl Default for Config {
fn default() -> Self {
Config {
max_in_flight_requests: 1_000,
pending_request_buffer: 100,
}
}
}
impl<Req, Resp> Client<Req, Resp>
where
Req: Send,
Resp: Send,
{
/// Creates a new Client by wrapping a [`Transport`] and spawning a dispatch task
/// that manages the lifecycle of requests.
///
/// Must only be called from on an executor.
pub async fn new<T>(config: Config, transport: T) -> io::Result<Self>
where
T: Transport<Item = Response<Resp>, SinkItem = ClientMessage<Req>> + Send,
{
let server_addr = transport.peer_addr().unwrap_or_else(|e| {
warn!(
"Setting peer to unspecified because peer could not be determined: {}",
e
);
SocketAddr::new(Ipv4Addr::UNSPECIFIED.into(), 0)
});
Ok(Client {
channel: await!(dispatch::spawn(config, transport, server_addr))?,
})
}
/// Initiates a request, sending it to the dispatch task.
///
/// Returns a [`Future`] that resolves to this client and the future response
/// once the request is successfully enqueued.
///
/// [`Future`]: futures::Future
pub async fn call(&mut self, ctx: Context, request: Req) -> io::Result<Resp> {
await!(self.channel.call(ctx, request))
}
}

View File

@@ -1,45 +0,0 @@
// Copyright 2018 Google LLC
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
//! Provides a request context that carries a deadline and trace context. This context is sent from
//! client to server and is used by the server to enforce response deadlines.
use std::time::{Duration, SystemTime};
use trace::{self, TraceId};
/// A request context that carries request-scoped information like deadlines and trace information.
/// It is sent from client to server and is used by the server to enforce response deadlines.
///
/// The context should not be stored directly in a server implementation, because the context will
/// be different for each request in scope.
#[derive(Clone, Copy, Debug)]
#[non_exhaustive]
pub struct Context {
/// When the client expects the request to be complete by. The server should cancel the request
/// if it is not complete by this time.
pub deadline: SystemTime,
/// Uniquely identifies requests originating from the same source.
/// When a service handles a request by making requests itself, those requests should
/// include the same `trace_id` as that included on the original request. This way,
/// users can trace related actions across a distributed system.
pub trace_context: trace::Context,
}
/// Returns the context for the current request, or a default Context if no request is active.
// TODO: populate Context with request-scoped data, with default fallbacks.
pub fn current() -> Context {
Context {
deadline: SystemTime::now() + Duration::from_secs(10),
trace_context: trace::Context::new_root(),
}
}
impl Context {
/// Returns the ID of the request-scoped trace.
pub fn trace_id(&self) -> &TraceId {
&self.trace_context.trace_id
}
}

View File

@@ -1,215 +0,0 @@
// Copyright 2018 Google LLC
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
#![feature(
const_fn,
non_exhaustive,
integer_atomics,
try_trait,
nll,
futures_api,
pin,
arbitrary_self_types,
await_macro,
async_await,
generators,
optin_builtin_traits,
generator_trait,
gen_future,
decl_macro,
)]
#![deny(missing_docs, missing_debug_implementations)]
//! An RPC framework providing client and server.
//!
//! Features:
//! * RPC deadlines, both client- and server-side.
//! * Cascading cancellation (works with multiple hops).
//! * Configurable limits
//! * In-flight requests, both client and server-side.
//! * Server-side limit is per-connection.
//! * When the server reaches the in-flight request maximum, it returns a throttled error
//! to the client.
//! * When the client reaches the in-flight request max, messages are buffered up to a
//! configurable maximum, beyond which the requests are back-pressured.
//! * Server connections.
//! * Total and per-IP limits.
//! * When an incoming connection is accepted, if already at maximum, the connection is
//! dropped.
//! * Transport agnostic.
pub mod client;
pub mod context;
pub mod server;
pub mod transport;
pub(crate) mod util;
pub use crate::{client::Client, server::Server, transport::Transport};
use futures::{Future, task::{Spawn, SpawnExt, SpawnError}};
use std::{cell::RefCell, io, sync::Once, time::SystemTime};
/// A message from a client to a server.
#[derive(Debug)]
#[cfg_attr(
feature = "serde1",
derive(serde::Serialize, serde::Deserialize)
)]
#[non_exhaustive]
pub struct ClientMessage<T> {
/// The trace context associates the message with a specific chain of causally-related actions,
/// possibly orchestrated across many distributed systems.
pub trace_context: trace::Context,
/// The message payload.
pub message: ClientMessageKind<T>,
}
/// Different messages that can be sent from a client to a server.
#[derive(Debug)]
#[cfg_attr(
feature = "serde1",
derive(serde::Serialize, serde::Deserialize)
)]
#[non_exhaustive]
pub enum ClientMessageKind<T> {
/// A request initiated by a user. The server responds to a request by invoking a
/// service-provided request handler. The handler completes with a [`response`](Response), which
/// the server sends back to the client.
Request(Request<T>),
/// A command to cancel an in-flight request, automatically sent by the client when a response
/// future is dropped.
///
/// When received, the server will immediately cancel the main task (top-level future) of the
/// request handler for the associated request. Any tasks spawned by the request handler will
/// not be canceled, because the framework layer does not
/// know about them.
Cancel {
/// The ID of the request to cancel.
request_id: u64,
},
}
/// A request from a client to a server.
#[derive(Debug)]
#[cfg_attr(
feature = "serde1",
derive(serde::Serialize, serde::Deserialize)
)]
#[non_exhaustive]
pub struct Request<T> {
/// Uniquely identifies the request across all requests sent over a single channel.
pub id: u64,
/// The request body.
pub message: T,
/// When the client expects the request to be complete by. The server will cancel the request
/// if it is not complete by this time.
#[cfg_attr(
feature = "serde1",
serde(serialize_with = "util::serde::serialize_epoch_secs")
)]
#[cfg_attr(
feature = "serde1",
serde(deserialize_with = "util::serde::deserialize_epoch_secs")
)]
pub deadline: SystemTime,
}
/// A response from a server to a client.
#[derive(Debug, PartialEq, Eq)]
#[cfg_attr(
feature = "serde1",
derive(serde::Serialize, serde::Deserialize)
)]
#[non_exhaustive]
pub struct Response<T> {
/// The ID of the request being responded to.
pub request_id: u64,
/// The response body, or an error if the request failed.
pub message: Result<T, ServerError>,
}
/// An error response from a server to a client.
#[derive(Debug, PartialEq, Eq)]
#[cfg_attr(
feature = "serde1",
derive(serde::Serialize, serde::Deserialize)
)]
#[non_exhaustive]
pub struct ServerError {
#[cfg_attr(
feature = "serde1",
serde(serialize_with = "util::serde::serialize_io_error_kind_as_u32")
)]
#[cfg_attr(
feature = "serde1",
serde(deserialize_with = "util::serde::deserialize_io_error_kind_from_u32")
)]
/// The type of error that occurred to fail the request.
pub kind: io::ErrorKind,
/// A message describing more detail about the error that occurred.
pub detail: Option<String>,
}
impl From<ServerError> for io::Error {
fn from(e: ServerError) -> io::Error {
io::Error::new(e.kind, e.detail.unwrap_or_default())
}
}
impl<T> Request<T> {
/// Returns the deadline for this request.
pub fn deadline(&self) -> &SystemTime {
&self.deadline
}
}
static INIT: Once = Once::new();
static mut SEED_SPAWN: Option<Box<dyn CloneSpawn>> = None;
thread_local! {
static SPAWN: RefCell<Box<dyn CloneSpawn>> = {
unsafe {
// INIT must always be called before accessing SPAWN.
// Otherwise, accessing SPAWN can trigger undefined behavior due to race conditions.
INIT.call_once(|| {});
RefCell::new(SEED_SPAWN.clone().expect("init() must be called."))
}
};
}
/// Initializes the RPC library with a mechanism to spawn futures on the user's runtime.
/// Client stubs and servers both use the initialized spawn.
///
/// Init only has an effect the first time it is called. If called previously, successive calls to
/// init are noops.
pub fn init(spawn: impl Spawn + Clone + 'static) {
unsafe {
INIT.call_once(|| {
SEED_SPAWN = Some(Box::new(spawn));
});
}
}
pub(crate) fn spawn(future: impl Future<Output = ()> + Send + 'static) -> Result<(), SpawnError> {
SPAWN.with(|spawn| {
spawn.borrow_mut().spawn(future)
})
}
trait CloneSpawn: Spawn {
fn box_clone(&self) -> Box<dyn CloneSpawn>;
}
impl Clone for Box<dyn CloneSpawn> {
fn clone(&self) -> Self {
self.box_clone()
}
}
impl<S: Spawn + Clone + 'static> CloneSpawn for S {
fn box_clone(&self) -> Box<dyn CloneSpawn> {
Box::new(self.clone())
}
}

View File

@@ -1,257 +0,0 @@
// Copyright 2018 Google LLC
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
use crate::{
server::{Channel, Config},
util::Compact,
ClientMessage, Response, Transport,
};
use fnv::FnvHashMap;
use futures::{channel::mpsc, prelude::*, ready, stream::Fuse, task::{LocalWaker, Poll}};
use log::{debug, error, info, trace, warn};
use pin_utils::unsafe_pinned;
use std::{
collections::hash_map::Entry,
io,
marker::PhantomData,
net::{IpAddr, SocketAddr},
ops::Try,
option::NoneError,
pin::Pin,
};
/// Drops connections under configurable conditions:
///
/// 1. If the max number of connections is reached.
/// 2. If the max number of connections for a single IP is reached.
#[derive(Debug)]
pub struct ConnectionFilter<S, Req, Resp> {
listener: Fuse<S>,
closed_connections: mpsc::UnboundedSender<SocketAddr>,
closed_connections_rx: mpsc::UnboundedReceiver<SocketAddr>,
config: Config,
connections_per_ip: FnvHashMap<IpAddr, usize>,
open_connections: usize,
ghost: PhantomData<(Req, Resp)>,
}
enum NewConnection<Req, Resp, C> {
Filtered,
Accepted(Channel<Req, Resp, C>),
}
impl<Req, Resp, C> Try for NewConnection<Req, Resp, C> {
type Ok = Channel<Req, Resp, C>;
type Error = NoneError;
fn into_result(self) -> Result<Channel<Req, Resp, C>, NoneError> {
match self {
NewConnection::Filtered => Err(NoneError),
NewConnection::Accepted(channel) => Ok(channel),
}
}
fn from_error(_: NoneError) -> Self {
NewConnection::Filtered
}
fn from_ok(channel: Channel<Req, Resp, C>) -> Self {
NewConnection::Accepted(channel)
}
}
impl<S, Req, Resp> ConnectionFilter<S, Req, Resp> {
unsafe_pinned!(open_connections: usize);
unsafe_pinned!(config: Config);
unsafe_pinned!(connections_per_ip: FnvHashMap<IpAddr, usize>);
unsafe_pinned!(closed_connections_rx: mpsc::UnboundedReceiver<SocketAddr>);
unsafe_pinned!(listener: Fuse<S>);
/// Sheds new connections to stay under configured limits.
pub fn filter<C>(listener: S, config: Config) -> Self
where
S: Stream<Item = Result<C, io::Error>>,
C: Transport<Item = ClientMessage<Req>, SinkItem = Response<Resp>> + Send,
{
let (closed_connections, closed_connections_rx) = mpsc::unbounded();
ConnectionFilter {
listener: listener.fuse(),
closed_connections,
closed_connections_rx,
config,
connections_per_ip: FnvHashMap::default(),
open_connections: 0,
ghost: PhantomData,
}
}
fn handle_new_connection<C>(self: &mut Pin<&mut Self>, stream: C) -> NewConnection<Req, Resp, C>
where
C: Transport<Item = ClientMessage<Req>, SinkItem = Response<Resp>> + Send,
{
let peer = match stream.peer_addr() {
Ok(peer) => peer,
Err(e) => {
warn!("Could not get peer_addr of new connection: {}", e);
return NewConnection::Filtered;
}
};
let open_connections = *self.open_connections();
if open_connections >= self.config().max_connections {
warn!(
"[{}] Shedding connection because the maximum open connections \
limit is reached ({}/{}).",
peer,
open_connections,
self.config().max_connections
);
return NewConnection::Filtered;
}
let config = self.config.clone();
let open_connections_for_ip = self.increment_connections_for_ip(&peer)?;
*self.open_connections() += 1;
debug!(
"[{}] Opening channel ({}/{} connections for IP, {} total).",
peer,
open_connections_for_ip,
config.max_connections_per_ip,
self.open_connections(),
);
NewConnection::Accepted(Channel {
client_addr: peer,
closed_connections: self.closed_connections.clone(),
transport: stream.fuse(),
config,
ghost: PhantomData,
})
}
fn handle_closed_connection(self: &mut Pin<&mut Self>, addr: &SocketAddr) {
*self.open_connections() -= 1;
debug!(
"[{}] Closing channel. {} open connections remaining.",
addr, self.open_connections
);
self.decrement_connections_for_ip(&addr);
self.connections_per_ip().compact(0.1);
}
fn increment_connections_for_ip(self: &mut Pin<&mut Self>, peer: &SocketAddr) -> Option<usize> {
let max_connections_per_ip = self.config().max_connections_per_ip;
let mut occupied;
let mut connections_per_ip = self.connections_per_ip();
let occupied = match connections_per_ip.entry(peer.ip()) {
Entry::Vacant(vacant) => vacant.insert(0),
Entry::Occupied(o) => {
if *o.get() < max_connections_per_ip {
// Store the reference outside the block to extend the lifetime.
occupied = o;
occupied.get_mut()
} else {
info!(
"[{}] Opened max connections from IP ({}/{}).",
peer,
o.get(),
max_connections_per_ip
);
return None;
}
}
};
*occupied += 1;
Some(*occupied)
}
fn decrement_connections_for_ip(self: &mut Pin<&mut Self>, addr: &SocketAddr) {
let should_compact = match self.connections_per_ip().entry(addr.ip()) {
Entry::Vacant(_) => {
error!("[{}] Got vacant entry when closing connection.", addr);
return;
}
Entry::Occupied(mut occupied) => {
*occupied.get_mut() -= 1;
if *occupied.get() == 0 {
occupied.remove();
true
} else {
false
}
}
};
if should_compact {
self.connections_per_ip().compact(0.1);
}
}
fn poll_listener<C>(
self: &mut Pin<&mut Self>,
cx: &LocalWaker,
) -> Poll<Option<io::Result<NewConnection<Req, Resp, C>>>>
where
S: Stream<Item = Result<C, io::Error>>,
C: Transport<Item = ClientMessage<Req>, SinkItem = Response<Resp>> + Send,
{
match ready!(self.listener().poll_next_unpin(cx)?) {
Some(codec) => Poll::Ready(Some(Ok(self.handle_new_connection(codec)))),
None => Poll::Ready(None),
}
}
fn poll_closed_connections(
self: &mut Pin<&mut Self>,
cx: &LocalWaker,
) -> Poll<io::Result<()>> {
match ready!(self.closed_connections_rx().poll_next_unpin(cx)) {
Some(addr) => {
self.handle_closed_connection(&addr);
Poll::Ready(Ok(()))
}
None => unreachable!("Holding a copy of closed_connections and didn't close it."),
}
}
}
impl<S, Req, Resp, T> Stream for ConnectionFilter<S, Req, Resp>
where
S: Stream<Item = Result<T, io::Error>>,
T: Transport<Item = ClientMessage<Req>, SinkItem = Response<Resp>> + Send,
{
type Item = io::Result<Channel<Req, Resp, T>>;
fn poll_next(
mut self: Pin<&mut Self>,
cx: &LocalWaker,
) -> Poll<Option<io::Result<Channel<Req, Resp, T>>>> {
loop {
match (self.poll_listener(cx)?, self.poll_closed_connections(cx)?) {
(Poll::Ready(Some(NewConnection::Accepted(channel))), _) => {
return Poll::Ready(Some(Ok(channel)))
}
(Poll::Ready(Some(NewConnection::Filtered)), _) | (_, Poll::Ready(())) => {
trace!("Filtered a connection; {} open.", self.open_connections());
continue;
}
(Poll::Pending, Poll::Pending) => return Poll::Pending,
(Poll::Ready(None), Poll::Pending) => {
if *self.open_connections() > 0 {
trace!(
"Listener closed; {} open connections.",
self.open_connections()
);
return Poll::Pending;
}
trace!("Shutting down listener: all connections closed, and no more coming.");
return Poll::Ready(None);
}
}
}
}
}

View File

@@ -1,605 +0,0 @@
// Copyright 2018 Google LLC
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
//! Provides a server that concurrently handles many connections sending multiplexed requests.
use crate::{
context::Context, util::deadline_compat, util::AsDuration, util::Compact, ClientMessage,
ClientMessageKind, Request, Response, ServerError, Transport,
};
use fnv::FnvHashMap;
use futures::{
channel::mpsc,
future::{abortable, AbortHandle},
prelude::*,
ready,
stream::Fuse,
task::{LocalWaker, Poll},
try_ready,
};
use humantime::format_rfc3339;
use log::{debug, error, info, trace, warn};
use pin_utils::{unsafe_pinned, unsafe_unpinned};
use std::{
error::Error as StdError,
io,
marker::PhantomData,
net::SocketAddr,
pin::Pin,
time::{Instant, SystemTime},
};
use tokio_timer::timeout;
use trace::{self, TraceId};
mod filter;
/// Manages clients, serving multiplexed requests over each connection.
#[derive(Debug)]
pub struct Server<Req, Resp> {
config: Config,
ghost: PhantomData<(Req, Resp)>,
}
/// Settings that control the behavior of the server.
#[non_exhaustive]
#[derive(Clone, Debug)]
pub struct Config {
/// The maximum number of clients that can be connected to the server at once. When at the
/// limit, existing connections are honored and new connections are rejected.
pub max_connections: usize,
/// The maximum number of clients per IP address that can be connected to the server at once.
/// When an IP is at the limit, existing connections are honored and new connections on that IP
/// address are rejected.
pub max_connections_per_ip: usize,
/// The maximum number of requests that can be in flight for each client. When a client is at
/// the in-flight request limit, existing requests are fulfilled and new requests are rejected.
/// Rejected requests are sent a response error.
pub max_in_flight_requests_per_connection: usize,
/// The number of responses per client that can be buffered server-side before being sent.
/// `pending_response_buffer` controls the buffer size of the channel that a server's
/// response tasks use to send responses to the client handler task.
pub pending_response_buffer: usize,
}
impl Default for Config {
fn default() -> Self {
Config {
max_connections: 1_000_000,
max_connections_per_ip: 1_000,
max_in_flight_requests_per_connection: 1_000,
pending_response_buffer: 100,
}
}
}
impl<Req, Resp> Server<Req, Resp> {
/// Returns a new server with configuration specified `config`.
pub fn new(config: Config) -> Self {
Server {
config,
ghost: PhantomData,
}
}
/// Returns the config for this server.
pub fn config(&self) -> &Config {
&self.config
}
/// Returns a stream of the incoming connections to the server.
pub fn incoming<S, T>(
self,
listener: S,
) -> impl Stream<Item = io::Result<Channel<Req, Resp, T>>>
where
Req: Send,
Resp: Send,
S: Stream<Item = io::Result<T>>,
T: Transport<Item = ClientMessage<Req>, SinkItem = Response<Resp>> + Send,
{
self::filter::ConnectionFilter::filter(listener, self.config.clone())
}
}
/// The future driving the server.
#[derive(Debug)]
pub struct Running<S, F> {
incoming: S,
request_handler: F,
}
impl<S, F> Running<S, F> {
unsafe_pinned!(incoming: S);
unsafe_unpinned!(request_handler: F);
}
impl<S, T, Req, Resp, F, Fut> Future for Running<S, F>
where
S: Sized + Stream<Item = io::Result<Channel<Req, Resp, T>>>,
Req: Send + 'static,
Resp: Send + 'static,
T: Transport<Item = ClientMessage<Req>, SinkItem = Response<Resp>> + Send + 'static,
F: FnMut(Context, Req) -> Fut + Send + 'static + Clone,
Fut: Future<Output = io::Result<Resp>> + Send + 'static,
{
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &LocalWaker) -> Poll<()> {
while let Some(channel) = ready!(self.incoming().poll_next(cx)) {
match channel {
Ok(channel) => {
let peer = channel.client_addr;
if let Err(e) = crate::spawn(channel.respond_with(self.request_handler().clone()))
{
warn!("[{}] Failed to spawn connection handler: {:?}", peer, e);
}
}
Err(e) => {
warn!("Incoming connection error: {}", e);
}
}
}
info!("Server shutting down.");
return Poll::Ready(());
}
}
/// A utility trait enabling a stream to fluently chain a request handler.
pub trait Handler<T, Req, Resp>
where
Self: Sized + Stream<Item = io::Result<Channel<Req, Resp, T>>>,
Req: Send,
Resp: Send,
T: Transport<Item = ClientMessage<Req>, SinkItem = Response<Resp>> + Send,
{
/// Responds to all requests with `request_handler`.
fn respond_with<F, Fut>(self, request_handler: F) -> Running<Self, F>
where
F: FnMut(Context, Req) -> Fut + Send + 'static + Clone,
Fut: Future<Output = io::Result<Resp>> + Send + 'static,
{
Running {
incoming: self,
request_handler,
}
}
}
impl<T, Req, Resp, S> Handler<T, Req, Resp> for S
where
S: Sized + Stream<Item = io::Result<Channel<Req, Resp, T>>>,
Req: Send,
Resp: Send,
T: Transport<Item = ClientMessage<Req>, SinkItem = Response<Resp>> + Send,
{}
/// Responds to all requests with `request_handler`.
/// The server end of an open connection with a client.
#[derive(Debug)]
pub struct Channel<Req, Resp, T> {
/// Writes responses to the wire and reads requests off the wire.
transport: Fuse<T>,
/// Signals the connection is closed when `Channel` is dropped.
closed_connections: mpsc::UnboundedSender<SocketAddr>,
/// Channel limits to prevent unlimited resource usage.
config: Config,
/// The address of the server connected to.
client_addr: SocketAddr,
/// Types the request and response.
ghost: PhantomData<(Req, Resp)>,
}
impl<Req, Resp, T> Drop for Channel<Req, Resp, T> {
fn drop(&mut self) {
trace!("[{}] Closing channel.", self.client_addr);
// Even in a bounded channel, each connection would have a guaranteed slot, so using
// an unbounded sender is actually no different. And, the bound is on the maximum number
// of open connections.
if self
.closed_connections
.unbounded_send(self.client_addr)
.is_err()
{
warn!(
"[{}] Failed to send closed connection message.",
self.client_addr
);
}
}
}
impl<Req, Resp, T> Channel<Req, Resp, T> {
unsafe_pinned!(transport: Fuse<T>);
}
impl<Req, Resp, T> Channel<Req, Resp, T>
where
T: Transport<Item = ClientMessage<Req>, SinkItem = Response<Resp>> + Send,
Req: Send,
Resp: Send,
{
pub(crate) fn start_send(self: &mut Pin<&mut Self>, response: Response<Resp>) -> io::Result<()> {
self.transport().start_send(response)
}
pub(crate) fn poll_ready(
self: &mut Pin<&mut Self>,
cx: &LocalWaker,
) -> Poll<io::Result<()>> {
self.transport().poll_ready(cx)
}
pub(crate) fn poll_flush(
self: &mut Pin<&mut Self>,
cx: &LocalWaker,
) -> Poll<io::Result<()>> {
self.transport().poll_flush(cx)
}
pub(crate) fn poll_next(
self: &mut Pin<&mut Self>,
cx: &LocalWaker,
) -> Poll<Option<io::Result<ClientMessage<Req>>>> {
self.transport().poll_next(cx)
}
/// Returns the address of the client connected to the channel.
pub fn client_addr(&self) -> &SocketAddr {
&self.client_addr
}
/// Respond to requests coming over the channel with `f`. Returns a future that drives the
/// responses and resolves when the connection is closed.
pub fn respond_with<F, Fut>(self, f: F) -> impl Future<Output = ()>
where
F: FnMut(Context, Req) -> Fut + Send + 'static,
Fut: Future<Output = io::Result<Resp>> + Send + 'static,
Req: 'static,
Resp: 'static,
{
let (responses_tx, responses) = mpsc::channel(self.config.pending_response_buffer);
let responses = responses.fuse();
let peer = self.client_addr;
ClientHandler {
channel: self,
f,
pending_responses: responses,
responses_tx,
in_flight_requests: FnvHashMap::default(),
}.unwrap_or_else(move |e| {
info!("[{}] ClientHandler errored out: {}", peer, e);
})
}
}
#[derive(Debug)]
struct ClientHandler<Req, Resp, T, F> {
channel: Channel<Req, Resp, T>,
/// Responses waiting to be written to the wire.
pending_responses: Fuse<mpsc::Receiver<(Context, Response<Resp>)>>,
/// Handed out to request handlers to fan in responses.
responses_tx: mpsc::Sender<(Context, Response<Resp>)>,
/// Number of requests currently being responded to.
in_flight_requests: FnvHashMap<u64, AbortHandle>,
/// Request handler.
f: F,
}
impl<Req, Resp, T, F> ClientHandler<Req, Resp, T, F> {
unsafe_pinned!(channel: Channel<Req, Resp, T>);
unsafe_pinned!(in_flight_requests: FnvHashMap<u64, AbortHandle>);
unsafe_pinned!(pending_responses: Fuse<mpsc::Receiver<(Context, Response<Resp>)>>);
unsafe_pinned!(responses_tx: mpsc::Sender<(Context, Response<Resp>)>);
// For this to be safe, field f must be private, and code in this module must never
// construct PinMut<F>.
unsafe_unpinned!(f: F);
}
impl<Req, Resp, T, F, Fut> ClientHandler<Req, Resp, T, F>
where
Req: Send + 'static,
Resp: Send + 'static,
T: Transport<Item = ClientMessage<Req>, SinkItem = Response<Resp>> + Send,
F: FnMut(Context, Req) -> Fut + Send + 'static,
Fut: Future<Output = io::Result<Resp>> + Send + 'static,
{
/// If at max in-flight requests, check that there's room to immediately write a throttled
/// response.
fn poll_ready_if_throttling(
self: &mut Pin<&mut Self>,
cx: &LocalWaker,
) -> Poll<io::Result<()>> {
if self.in_flight_requests.len()
>= self.channel.config.max_in_flight_requests_per_connection
{
let peer = self.channel().client_addr;
while let Poll::Pending = self.channel().poll_ready(cx)? {
info!(
"[{}] In-flight requests at max ({}), and transport is not ready.",
peer,
self.in_flight_requests().len(),
);
try_ready!(self.channel().poll_flush(cx));
}
}
Poll::Ready(Ok(()))
}
fn pump_read(self: &mut Pin<&mut Self>, cx: &LocalWaker) -> Poll<Option<io::Result<()>>> {
ready!(self.poll_ready_if_throttling(cx)?);
Poll::Ready(match ready!(self.channel().poll_next(cx)?) {
Some(message) => {
match message.message {
ClientMessageKind::Request(request) => {
self.handle_request(message.trace_context, request)?;
}
ClientMessageKind::Cancel { request_id } => {
self.cancel_request(&message.trace_context, request_id);
}
}
Some(Ok(()))
}
None => {
trace!("[{}] Read half closed", self.channel.client_addr);
None
}
})
}
fn pump_write(
self: &mut Pin<&mut Self>,
cx: &LocalWaker,
read_half_closed: bool,
) -> Poll<Option<io::Result<()>>> {
match self.poll_next_response(cx)? {
Poll::Ready(Some((_, response))) => {
self.channel().start_send(response)?;
Poll::Ready(Some(Ok(())))
}
Poll::Ready(None) => {
// Shutdown can't be done before we finish pumping out remaining responses.
ready!(self.channel().poll_flush(cx)?);
Poll::Ready(None)
}
Poll::Pending => {
// No more requests to process, so flush any requests buffered in the transport.
ready!(self.channel().poll_flush(cx)?);
// Being here means there are no staged requests and all written responses are
// fully flushed. So, if the read half is closed and there are no in-flight
// requests, then we can close the write half.
if read_half_closed && self.in_flight_requests().is_empty() {
Poll::Ready(None)
} else {
Poll::Pending
}
}
}
}
fn poll_next_response(
self: &mut Pin<&mut Self>,
cx: &LocalWaker,
) -> Poll<Option<io::Result<(Context, Response<Resp>)>>> {
// Ensure there's room to write a response.
while let Poll::Pending = self.channel().poll_ready(cx)? {
ready!(self.channel().poll_flush(cx)?);
}
let peer = self.channel().client_addr;
match ready!(self.pending_responses().poll_next(cx)) {
Some((ctx, response)) => {
if let Some(_) = self.in_flight_requests().remove(&response.request_id) {
self.in_flight_requests().compact(0.1);
}
trace!(
"[{}/{}] Staging response. In-flight requests = {}.",
ctx.trace_id(),
peer,
self.in_flight_requests().len(),
);
return Poll::Ready(Some(Ok((ctx, response))));
}
None => {
// This branch likely won't happen, since the ClientHandler is holding a Sender.
trace!("[{}] No new responses.", peer);
Poll::Ready(None)
}
}
}
fn handle_request(
self: &mut Pin<&mut Self>,
trace_context: trace::Context,
request: Request<Req>,
) -> io::Result<()> {
let request_id = request.id;
let peer = self.channel().client_addr;
let ctx = Context {
deadline: request.deadline,
trace_context,
};
let request = request.message;
if self.in_flight_requests().len()
>= self.channel().config.max_in_flight_requests_per_connection
{
debug!(
"[{}/{}] Client has reached in-flight request limit ({}/{}).",
ctx.trace_id(),
peer,
self.in_flight_requests().len(),
self.channel().config.max_in_flight_requests_per_connection
);
self.channel().start_send(Response {
request_id,
message: Err(ServerError {
kind: io::ErrorKind::WouldBlock,
detail: Some("Server throttled the request.".into()),
}),
})?;
return Ok(());
}
let deadline = ctx.deadline;
let timeout = deadline.as_duration();
trace!(
"[{}/{}] Received request with deadline {} (timeout {:?}).",
ctx.trace_id(),
peer,
format_rfc3339(deadline),
timeout,
);
let mut response_tx = self.responses_tx().clone();
let trace_id = *ctx.trace_id();
let response = self.f()(ctx.clone(), request);
let response = deadline_compat::Deadline::new(response, Instant::now() + timeout).then(
async move |result| {
let response = Response {
request_id,
message: match result {
Ok(message) => Ok(message),
Err(e) => Err(make_server_error(e, trace_id, peer, deadline)),
},
};
trace!("[{}/{}] Sending response.", trace_id, peer);
await!(response_tx.send((ctx, response)).unwrap_or_else(|_| ()));
},
);
let (abortable_response, abort_handle) = abortable(response);
crate::spawn(abortable_response.map(|_| ()))
.map_err(|e| {
io::Error::new(
io::ErrorKind::Other,
format!(
"Could not spawn response task. Is shutdown: {}",
e.is_shutdown()
),
)
})?;
self.in_flight_requests().insert(request_id, abort_handle);
Ok(())
}
fn cancel_request(self: &mut Pin<&mut Self>, trace_context: &trace::Context, request_id: u64) {
// It's possible the request was already completed, so it's fine
// if this is None.
if let Some(cancel_handle) = self.in_flight_requests().remove(&request_id) {
self.in_flight_requests().compact(0.1);
cancel_handle.abort();
let remaining = self.in_flight_requests().len();
trace!(
"[{}/{}] Request canceled. In-flight requests = {}",
trace_context.trace_id,
self.channel.client_addr,
remaining,
);
} else {
trace!(
"[{}/{}] Received cancellation, but response handler \
is already complete.",
trace_context.trace_id,
self.channel.client_addr
);
}
}
}
impl<Req, Resp, T, F, Fut> Future for ClientHandler<Req, Resp, T, F>
where
Req: Send + 'static,
Resp: Send + 'static,
T: Transport<Item = ClientMessage<Req>, SinkItem = Response<Resp>> + Send,
F: FnMut(Context, Req) -> Fut + Send + 'static,
Fut: Future<Output = io::Result<Resp>> + Send + 'static,
{
type Output = io::Result<()>;
fn poll(mut self: Pin<&mut Self>, cx: &LocalWaker) -> Poll<io::Result<()>> {
trace!("[{}] ClientHandler::poll", self.channel.client_addr);
loop {
let read = self.pump_read(cx)?;
match (read, self.pump_write(cx, read == Poll::Ready(None))?) {
(Poll::Ready(None), Poll::Ready(None)) => {
info!("[{}] Client disconnected.", self.channel.client_addr);
return Poll::Ready(Ok(()));
}
(read @ Poll::Ready(Some(())), write) | (read, write @ Poll::Ready(Some(()))) => {
trace!(
"[{}] read: {:?}, write: {:?}.",
self.channel.client_addr,
read,
write
)
}
(read, write) => {
trace!(
"[{}] read: {:?}, write: {:?} (not ready).",
self.channel.client_addr,
read,
write,
);
return Poll::Pending;
}
}
}
}
}
fn make_server_error(
e: timeout::Error<io::Error>,
trace_id: TraceId,
peer: SocketAddr,
deadline: SystemTime,
) -> ServerError {
if e.is_elapsed() {
debug!(
"[{}/{}] Response did not complete before deadline of {}s.",
trace_id,
peer,
format_rfc3339(deadline)
);
// No point in responding, since the client will have dropped the request.
ServerError {
kind: io::ErrorKind::TimedOut,
detail: Some(format!(
"Response did not complete before deadline of {}s.",
format_rfc3339(deadline)
)),
}
} else if e.is_timer() {
error!(
"[{}/{}] Response failed because of an issue with a timer: {}",
trace_id, peer, e
);
ServerError {
kind: io::ErrorKind::Other,
detail: Some(format!("{}", e)),
}
} else if e.is_inner() {
let e = e.into_inner().unwrap();
ServerError {
kind: e.kind(),
detail: Some(e.description().into()),
}
} else {
error!("[{}/{}] Unexpected response failure: {}", trace_id, peer, e);
ServerError {
kind: io::ErrorKind::Other,
detail: Some(format!("Server unexpectedly failed to respond: {}", e)),
}
}
}

View File

@@ -1,157 +0,0 @@
// Copyright 2018 Google LLC
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
//! Transports backed by in-memory channels.
use crate::Transport;
use futures::{channel::mpsc, task::{LocalWaker}, Poll, Sink, Stream};
use pin_utils::unsafe_pinned;
use std::pin::Pin;
use std::{
io,
net::{IpAddr, Ipv4Addr, SocketAddr},
};
/// Returns two unbounded channel peers. Each [`Stream`] yields items sent through the other's
/// [`Sink`].
pub fn unbounded<SinkItem, Item>() -> (
UnboundedChannel<SinkItem, Item>,
UnboundedChannel<Item, SinkItem>,
) {
let (tx1, rx2) = mpsc::unbounded();
let (tx2, rx1) = mpsc::unbounded();
(
UnboundedChannel { tx: tx1, rx: rx1 },
UnboundedChannel { tx: tx2, rx: rx2 },
)
}
/// A bi-directional channel backed by an [`UnboundedSender`](mpsc::UnboundedSender)
/// and [`UnboundedReceiver`](mpsc::UnboundedReceiver).
#[derive(Debug)]
pub struct UnboundedChannel<Item, SinkItem> {
rx: mpsc::UnboundedReceiver<Item>,
tx: mpsc::UnboundedSender<SinkItem>,
}
impl<Item, SinkItem> UnboundedChannel<Item, SinkItem> {
unsafe_pinned!(rx: mpsc::UnboundedReceiver<Item>);
unsafe_pinned!(tx: mpsc::UnboundedSender<SinkItem>);
}
impl<Item, SinkItem> Stream for UnboundedChannel<Item, SinkItem> {
type Item = Result<Item, io::Error>;
fn poll_next(mut self: Pin<&mut Self>, cx: &LocalWaker) -> Poll<Option<io::Result<Item>>> {
self.rx().poll_next(cx).map(|option| option.map(Ok))
}
}
impl<Item, SinkItem> Sink for UnboundedChannel<Item, SinkItem> {
type SinkItem = SinkItem;
type SinkError = io::Error;
fn poll_ready(mut self: Pin<&mut Self>, cx: &LocalWaker) -> Poll<io::Result<()>> {
self.tx()
.poll_ready(cx)
.map_err(|_| io::Error::from(io::ErrorKind::NotConnected))
}
fn start_send(mut self: Pin<&mut Self>, item: SinkItem) -> io::Result<()> {
self.tx()
.start_send(item)
.map_err(|_| io::Error::from(io::ErrorKind::NotConnected))
}
fn poll_flush(
mut self: Pin<&mut Self>,
cx: &LocalWaker,
) -> Poll<Result<(), Self::SinkError>> {
self.tx()
.poll_flush(cx)
.map_err(|_| io::Error::from(io::ErrorKind::NotConnected))
}
fn poll_close(mut self: Pin<&mut Self>, cx: &LocalWaker) -> Poll<io::Result<()>> {
self.tx()
.poll_close(cx)
.map_err(|_| io::Error::from(io::ErrorKind::NotConnected))
}
}
impl<Item, SinkItem> Transport for UnboundedChannel<Item, SinkItem> {
type Item = Item;
type SinkItem = SinkItem;
fn peer_addr(&self) -> io::Result<SocketAddr> {
Ok(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0))
}
fn local_addr(&self) -> io::Result<SocketAddr> {
Ok(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0))
}
}
#[cfg(test)]
mod tests {
use crate::{client::{self, Client}, context, server::{self, Handler, Server}, transport};
use futures::{prelude::*, stream, compat::TokioDefaultSpawner};
use log::trace;
use std::io;
#[test]
fn integration() {
let _ = env_logger::try_init();
crate::init(TokioDefaultSpawner);
let (client_channel, server_channel) = transport::channel::unbounded();
let server = Server::<String, u64>::new(server::Config::default())
.incoming(stream::once(future::ready(Ok(server_channel))))
.respond_with(|_ctx, request| {
future::ready(request.parse::<u64>().map_err(|_| {
io::Error::new(
io::ErrorKind::InvalidInput,
format!("{:?} is not an int", request),
)
}))
});
let responses = async {
let mut client = await!(Client::new(client::Config::default(), client_channel))?;
let response1 = await!(client.call(context::current(), "123".into()));
let response2 = await!(client.call(context::current(), "abc".into()));
Ok::<_, io::Error>((response1, response2))
};
let (response1, response2) =
run_future(server.join(responses.unwrap_or_else(|e| panic!(e)))).1;
trace!("response1: {:?}, response2: {:?}", response1, response2);
assert!(response1.is_ok());
assert_eq!(response1.ok().unwrap(), 123);
assert!(response2.is_err());
assert_eq!(response2.err().unwrap().kind(), io::ErrorKind::InvalidInput);
}
fn run_future<F>(f: F) -> F::Output
where
F: Future + Send + 'static,
F::Output: Send + 'static,
{
let (tx, rx) = futures::channel::oneshot::channel();
tokio::run(
f.map(|result| tx.send(result).unwrap_or_else(|_| unreachable!()))
.boxed()
.unit_error()
.compat(),
);
futures::executor::block_on(rx).unwrap()
}
}

View File

@@ -1,32 +0,0 @@
// Copyright 2018 Google LLC
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
//! Provides a [`Transport`] trait as well as implementations.
//!
//! The rpc crate is transport- and protocol-agnostic. Any transport that impls [`Transport`]
//! can be plugged in, using whatever protocol it wants.
use futures::prelude::*;
use std::{io, net::SocketAddr};
pub mod channel;
/// A bidirectional stream ([`Sink`] + [`Stream`]) of messages.
pub trait Transport
where
Self: Stream<Item = io::Result<<Self as Transport>::Item>>,
Self: Sink<SinkItem = <Self as Transport>::SinkItem, SinkError = io::Error>,
{
/// The type read off the transport.
type Item;
/// The type written to the transport.
type SinkItem;
/// The address of the remote peer this transport is in communication with.
fn peer_addr(&self) -> io::Result<SocketAddr>;
/// The address of the local half of this transport.
fn local_addr(&self) -> io::Result<SocketAddr>;
}

View File

@@ -1,69 +0,0 @@
// Copyright 2018 Google LLC
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
use futures::{
compat::{Compat01As03, Future01CompatExt},
prelude::*,
ready, task::{Poll, LocalWaker},
};
use pin_utils::unsafe_pinned;
use std::pin::Pin;
use std::time::Instant;
use tokio_timer::{timeout, Delay};
#[must_use = "futures do nothing unless polled"]
#[derive(Debug)]
pub struct Deadline<T> {
future: T,
delay: Compat01As03<Delay>,
}
impl<T> Deadline<T> {
unsafe_pinned!(future: T);
unsafe_pinned!(delay: Compat01As03<Delay>);
/// Create a new `Deadline` that completes when `future` completes or when
/// `deadline` is reached.
pub fn new(future: T, deadline: Instant) -> Deadline<T> {
Deadline::new_with_delay(future, Delay::new(deadline))
}
pub(crate) fn new_with_delay(future: T, delay: Delay) -> Deadline<T> {
Deadline {
future,
delay: delay.compat(),
}
}
/// Gets a mutable reference to the underlying future in this deadline.
pub fn get_mut(&mut self) -> &mut T {
&mut self.future
}
}
impl<T> Future for Deadline<T>
where
T: TryFuture,
{
type Output = Result<T::Ok, timeout::Error<T::Error>>;
fn poll(mut self: Pin<&mut Self>, waker: &LocalWaker) -> Poll<Self::Output> {
// First, try polling the future
match self.future().try_poll(waker) {
Poll::Ready(Ok(v)) => return Poll::Ready(Ok(v)),
Poll::Pending => {}
Poll::Ready(Err(e)) => return Poll::Ready(Err(timeout::Error::inner(e))),
}
let delay = self.delay().poll_unpin(waker);
// Now check the timer
match ready!(delay) {
Ok(_) => Poll::Ready(Err(timeout::Error::elapsed())),
Err(e) => Poll::Ready(Err(timeout::Error::timer(e))),
}
}
}

View File

@@ -1,46 +0,0 @@
// Copyright 2018 Google LLC
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
use std::{
collections::HashMap,
hash::{BuildHasher, Hash},
time::{Duration, SystemTime},
};
pub mod deadline_compat;
#[cfg(feature = "serde")]
pub mod serde;
/// Types that can be represented by a [`Duration`].
pub trait AsDuration {
fn as_duration(&self) -> Duration;
}
impl AsDuration for SystemTime {
/// Duration of 0 if self is earlier than [`SystemTime::now`].
fn as_duration(&self) -> Duration {
self.duration_since(SystemTime::now()).unwrap_or_default()
}
}
/// Collection compaction; configurable `shrink_to_fit`.
pub trait Compact {
/// Compacts space if the ratio of length : capacity is less than `usage_ratio_threshold`.
fn compact(&mut self, usage_ratio_threshold: f64);
}
impl<K, V, H> Compact for HashMap<K, V, H>
where
K: Eq + Hash,
H: BuildHasher,
{
fn compact(&mut self, usage_ratio_threshold: f64) {
let usage_ratio = self.len() as f64 / self.capacity() as f64;
if usage_ratio < usage_ratio_threshold {
self.shrink_to_fit();
}
}
}

View File

@@ -1,38 +1,119 @@
cargo-features = ["rename-dependency"]
[package]
name = "tarpc"
version = "0.13.0"
authors = ["Adam Wright <adam.austin.wright@gmail.com>", "Tim Kuehn <timothy.j.kuehn@gmail.com>"]
edition = "2018"
version = "0.34.0"
rust-version = "1.58.0"
authors = [
"Adam Wright <adam.austin.wright@gmail.com>",
"Tim Kuehn <timothy.j.kuehn@gmail.com>",
]
edition = "2021"
license = "MIT"
documentation = "https://docs.rs/tarpc"
homepage = "https://github.com/google/tarpc"
repository = "https://github.com/google/tarpc"
keywords = ["rpc", "network", "server", "api", "microservices"]
categories = ["asynchronous", "network-programming"]
readme = "../README.md"
readme = "README.md"
description = "An RPC framework for Rust with a focus on ease of use."
[features]
serde1 = ["rpc/serde1", "serde", "serde/derive"]
default = []
serde1 = ["tarpc-plugins/serde1", "serde", "serde/derive", "serde/rc"]
tokio1 = ["tokio/rt"]
serde-transport = ["serde1", "tokio1", "tokio-serde", "tokio-util/codec"]
serde-transport-json = ["tokio-serde/json"]
serde-transport-bincode = ["tokio-serde/bincode"]
tcp = ["tokio/net"]
unix = ["tokio/net"]
full = [
"serde1",
"tokio1",
"serde-transport",
"serde-transport-json",
"serde-transport-bincode",
"tcp",
"unix",
]
[badges]
travis-ci = { repository = "google/tarpc" }
[dependencies]
log = "0.4"
serde = { optional = true, version = "1.0" }
tarpc-plugins = { path = "../plugins", version = "0.5.0" }
rpc = { package = "tarpc-lib", path = "../rpc", version = "0.1" }
anyhow = "1.0"
fnv = "1.0"
futures = "0.3"
humantime = "2.0"
pin-project = "1.0"
rand = "0.8"
serde = { optional = true, version = "1.0", features = ["derive"] }
static_assertions = "1.1.0"
tarpc-plugins = { path = "../plugins", version = "0.13" }
thiserror = "1.0"
tokio = { version = "1", features = ["time"] }
tokio-util = { version = "0.7.3", features = ["time"] }
tokio-serde = { optional = true, version = "0.8" }
tracing = { version = "0.1", default-features = false, features = [
"attributes",
"log",
] }
tracing-opentelemetry = { version = "0.18.0", default-features = false }
opentelemetry = { version = "0.18.0", default-features = false }
[target.'cfg(not(test))'.dependencies]
futures-preview = "0.3.0-alpha.8"
[dev-dependencies]
humantime = "1.0"
futures-preview = { version = "0.3.0-alpha.8", features = ["compat", "tokio-compat"] }
bincode-transport = { package = "tarpc-bincode-transport", version = "0.1", path = "../bincode-transport" }
env_logger = "0.5"
tokio = "0.1"
tokio-executor = "0.1"
assert_matches = "1.4"
bincode = "1.3"
bytes = { version = "1", features = ["serde"] }
flate2 = "1.0"
futures-test = "0.3"
opentelemetry = { version = "0.18.0", default-features = false, features = [
"rt-tokio",
] }
opentelemetry-jaeger = { version = "0.17.0", features = ["rt-tokio"] }
pin-utils = "0.1.0-alpha"
serde_bytes = "0.11"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
tokio = { version = "1", features = ["full", "test-util", "tracing"] }
console-subscriber = "0.1"
tokio-serde = { version = "0.8", features = ["json", "bincode"] }
trybuild = "1.0"
tokio-rustls = "0.23"
rustls-pemfile = "1.0"
[package.metadata.docs.rs]
all-features = true
rustdoc-args = ["--cfg", "docsrs"]
[[example]]
name = "compression"
required-features = ["serde-transport", "tcp"]
[[example]]
name = "tracing"
required-features = ["full"]
[[example]]
name = "readme"
required-features = ["full"]
[[example]]
name = "pubsub"
required-features = ["full"]
[[example]]
name = "custom_transport"
required-features = ["serde1", "tokio1", "serde-transport"]
[[example]]
name = "tls_over_tcp"
required-features = ["full"]
[[test]]
name = "service_functional"
required-features = ["serde-transport"]
[[test]]
name = "dataservice"
required-features = ["serde-transport", "tcp"]

9
tarpc/LICENSE Normal file
View File

@@ -0,0 +1,9 @@
The MIT License (MIT)
Copyright 2016 Google Inc. All Rights Reserved.
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

1
tarpc/README.md Symbolic link
View File

@@ -0,0 +1 @@
../README.md

View File

@@ -0,0 +1,11 @@
-----BEGIN CERTIFICATE-----
MIIBlDCCAUagAwIBAgICAxUwBQYDK2VwMC4xLDAqBgNVBAMMI3Bvbnl0b3duIEVk
RFNBIGxldmVsIDIgaW50ZXJtZWRpYXRlMB4XDTIzMDMxNzEwMTEwNFoXDTI4MDkw
NjEwMTEwNFowGjEYMBYGA1UEAwwPcG9ueXRvd24gY2xpZW50MCowBQYDK2VwAyEA
NTKuLume19IhJfEFd/5OZUuYDKZH6xvy4AGver17OoejgZswgZgwDAYDVR0TAQH/
BAIwADALBgNVHQ8EBAMCBsAwFgYDVR0lAQH/BAwwCgYIKwYBBQUHAwIwHQYDVR0O
BBYEFDjdrlMu4tyw5MHtbg7WnzSGRBpFMEQGA1UdIwQ9MDuAFHIl7fHKWP6/l8FE
fI2YEIM3oHxKoSCkHjAcMRowGAYDVQQDDBFwb255dG93biBFZERTQSBDQYIBezAF
BgMrZXADQQCaahfj/QLxoCOpvl6y0ZQ9CpojPqBnxV3460j5nUOp040Va2MpF137
izCBY7LwgUE/YG6E+kH30G4jMEnqVEYK
-----END CERTIFICATE-----

View File

@@ -0,0 +1,19 @@
-----BEGIN CERTIFICATE-----
MIIBeDCCASqgAwIBAgIBezAFBgMrZXAwHDEaMBgGA1UEAwwRcG9ueXRvd24gRWRE
U0EgQ0EwHhcNMjMwMzE3MTAxMTA0WhcNMzMwMzE0MTAxMTA0WjAuMSwwKgYDVQQD
DCNwb255dG93biBFZERTQSBsZXZlbCAyIGludGVybWVkaWF0ZTAqMAUGAytlcAMh
AEFsAexz4x2R4k4+PnTbvRVn0r3F/qw/zVnNBxfGcoEpo38wfTAdBgNVHQ4EFgQU
ciXt8cpY/r+XwUR8jZgQgzegfEowIAYDVR0lAQH/BBYwFAYIKwYBBQUHAwEGCCsG
AQUFBwMCMAwGA1UdEwQFMAMBAf8wCwYDVR0PBAQDAgH+MB8GA1UdIwQYMBaAFKYU
oLdKeY7mp7QgMZKrkVtSWYBKMAUGAytlcANBAHVpNpCV8nu4fkH3Smikx5A9qtHc
zgLIyp+wrF1a4YSa6sfTvuQmJd5aF23OXgq5grCOPXtdpHO50Mx5Qy74zQg=
-----END CERTIFICATE-----
-----BEGIN CERTIFICATE-----
MIIBTDCB/6ADAgECAhRZLuF0TWjDs/31OO8VeKHkNIJQaDAFBgMrZXAwHDEaMBgG
A1UEAwwRcG9ueXRvd24gRWREU0EgQ0EwHhcNMjMwMzE3MTAxMTA0WhcNMzMwMzE0
MTAxMTA0WjAcMRowGAYDVQQDDBFwb255dG93biBFZERTQSBDQTAqMAUGAytlcAMh
ABRPZ4TiuBE8CqAFByZvqpMo/unjnnryfG2AkkWGXpa3o1MwUTAdBgNVHQ4EFgQU
phSgt0p5juantCAxkquRW1JZgEowHwYDVR0jBBgwFoAUphSgt0p5juantCAxkquR
W1JZgEowDwYDVR0TAQH/BAUwAwEB/zAFBgMrZXADQQB29o8erJA0/a8/xOHilOCC
t/s5wPHHnS5NSKx/m2N2nRn3zPxEnETlrAmGulJoeKOx8OblwmPi9rBT2K+QY2UB
-----END CERTIFICATE-----

View File

@@ -0,0 +1,3 @@
-----BEGIN PRIVATE KEY-----
MC4CAQAwBQYDK2VwBCIEIIJX9ThTHpVS1SNZb6HP4myg4fRInIVGunTRdgnc+weH
-----END PRIVATE KEY-----

View File

@@ -0,0 +1,12 @@
-----BEGIN CERTIFICATE-----
MIIBuDCCAWqgAwIBAgICAcgwBQYDK2VwMC4xLDAqBgNVBAMMI3Bvbnl0b3duIEVk
RFNBIGxldmVsIDIgaW50ZXJtZWRpYXRlMB4XDTIzMDMxNzEwMTEwNFoXDTI4MDkw
NjEwMTEwNFowGTEXMBUGA1UEAwwOdGVzdHNlcnZlci5jb20wKjAFBgMrZXADIQDc
RLl3/N2tPoWnzBV3noVn/oheEl8IUtiY11Vg/QXTUKOBwDCBvTAMBgNVHRMBAf8E
AjAAMAsGA1UdDwQEAwIGwDAdBgNVHQ4EFgQUk7U2mnxedNWBAH84BsNy5si3ZQow
RAYDVR0jBD0wO4AUciXt8cpY/r+XwUR8jZgQgzegfEqhIKQeMBwxGjAYBgNVBAMM
EXBvbnl0b3duIEVkRFNBIENBggF7MDsGA1UdEQQ0MDKCDnRlc3RzZXJ2ZXIuY29t
ghVzZWNvbmQudGVzdHNlcnZlci5jb22CCWxvY2FsaG9zdDAFBgMrZXADQQCFWIcF
9FiztCuUNzgXDNu5kshuflt0RjkjWpGlWzQjGoYM2IvYhNVPeqnCiY92gqwDSBtq
amD2TBup4eNUCsQB
-----END CERTIFICATE-----

View File

@@ -0,0 +1,19 @@
-----BEGIN CERTIFICATE-----
MIIBeDCCASqgAwIBAgIBezAFBgMrZXAwHDEaMBgGA1UEAwwRcG9ueXRvd24gRWRE
U0EgQ0EwHhcNMjMwMzE3MTAxMTA0WhcNMzMwMzE0MTAxMTA0WjAuMSwwKgYDVQQD
DCNwb255dG93biBFZERTQSBsZXZlbCAyIGludGVybWVkaWF0ZTAqMAUGAytlcAMh
AEFsAexz4x2R4k4+PnTbvRVn0r3F/qw/zVnNBxfGcoEpo38wfTAdBgNVHQ4EFgQU
ciXt8cpY/r+XwUR8jZgQgzegfEowIAYDVR0lAQH/BBYwFAYIKwYBBQUHAwEGCCsG
AQUFBwMCMAwGA1UdEwQFMAMBAf8wCwYDVR0PBAQDAgH+MB8GA1UdIwQYMBaAFKYU
oLdKeY7mp7QgMZKrkVtSWYBKMAUGAytlcANBAHVpNpCV8nu4fkH3Smikx5A9qtHc
zgLIyp+wrF1a4YSa6sfTvuQmJd5aF23OXgq5grCOPXtdpHO50Mx5Qy74zQg=
-----END CERTIFICATE-----
-----BEGIN CERTIFICATE-----
MIIBTDCB/6ADAgECAhRZLuF0TWjDs/31OO8VeKHkNIJQaDAFBgMrZXAwHDEaMBgG
A1UEAwwRcG9ueXRvd24gRWREU0EgQ0EwHhcNMjMwMzE3MTAxMTA0WhcNMzMwMzE0
MTAxMTA0WjAcMRowGAYDVQQDDBFwb255dG93biBFZERTQSBDQTAqMAUGAytlcAMh
ABRPZ4TiuBE8CqAFByZvqpMo/unjnnryfG2AkkWGXpa3o1MwUTAdBgNVHQ4EFgQU
phSgt0p5juantCAxkquRW1JZgEowHwYDVR0jBBgwFoAUphSgt0p5juantCAxkquR
W1JZgEowDwYDVR0TAQH/BAUwAwEB/zAFBgMrZXADQQB29o8erJA0/a8/xOHilOCC
t/s5wPHHnS5NSKx/m2N2nRn3zPxEnETlrAmGulJoeKOx8OblwmPi9rBT2K+QY2UB
-----END CERTIFICATE-----

View File

@@ -0,0 +1,3 @@
-----BEGIN PRIVATE KEY-----
MC4CAQAwBQYDK2VwBCIEIMU6xGVe8JTpZ3bN/wajHfw6pEHt0Rd7wPBxds9eEFy2
-----END PRIVATE KEY-----

View File

@@ -0,0 +1,138 @@
// Copyright 2022 Google LLC
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
use flate2::{read::DeflateDecoder, write::DeflateEncoder, Compression};
use futures::{prelude::*, Sink, SinkExt, Stream, StreamExt, TryStreamExt};
use serde::{Deserialize, Serialize};
use serde_bytes::ByteBuf;
use std::{io, io::Read, io::Write};
use tarpc::{
client, context,
serde_transport::tcp,
server::{BaseChannel, Channel},
tokio_serde::formats::Bincode,
};
/// Type of compression that should be enabled on the request. The transport is free to ignore this.
#[derive(Debug, PartialEq, Eq, Clone, Copy, Deserialize, Serialize)]
pub enum CompressionAlgorithm {
Deflate,
}
#[derive(Debug, Deserialize, Serialize)]
pub enum CompressedMessage<T> {
Uncompressed(T),
Compressed {
algorithm: CompressionAlgorithm,
payload: ByteBuf,
},
}
#[derive(Deserialize, Serialize)]
enum CompressionType {
Uncompressed,
Compressed,
}
async fn compress<T>(message: T) -> io::Result<CompressedMessage<T>>
where
T: Serialize,
{
let message = serialize(message)?;
let mut encoder = DeflateEncoder::new(Vec::new(), Compression::default());
encoder.write_all(&message).unwrap();
let compressed = encoder.finish()?;
Ok(CompressedMessage::Compressed {
algorithm: CompressionAlgorithm::Deflate,
payload: ByteBuf::from(compressed),
})
}
async fn decompress<T>(message: CompressedMessage<T>) -> io::Result<T>
where
for<'a> T: Deserialize<'a>,
{
match message {
CompressedMessage::Compressed { algorithm, payload } => {
if algorithm != CompressionAlgorithm::Deflate {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("Compression algorithm {algorithm:?} not supported"),
));
}
let mut deflater = DeflateDecoder::new(payload.as_slice());
let mut payload = ByteBuf::new();
deflater.read_to_end(&mut payload)?;
let message = deserialize(payload)?;
Ok(message)
}
CompressedMessage::Uncompressed(message) => Ok(message),
}
}
fn serialize<T: Serialize>(t: T) -> io::Result<ByteBuf> {
bincode::serialize(&t)
.map(ByteBuf::from)
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))
}
fn deserialize<D>(message: ByteBuf) -> io::Result<D>
where
for<'a> D: Deserialize<'a>,
{
bincode::deserialize(message.as_ref()).map_err(|e| io::Error::new(io::ErrorKind::Other, e))
}
fn add_compression<In, Out>(
transport: impl Stream<Item = io::Result<CompressedMessage<In>>>
+ Sink<CompressedMessage<Out>, Error = io::Error>,
) -> impl Stream<Item = io::Result<In>> + Sink<Out, Error = io::Error>
where
Out: Serialize,
for<'a> In: Deserialize<'a>,
{
transport.with(compress).and_then(decompress)
}
#[tarpc::service]
pub trait World {
async fn hello(name: String) -> String;
}
#[derive(Clone, Debug)]
struct HelloServer;
impl World for HelloServer {
async fn hello(self, _: context::Context, name: String) -> String {
format!("Hey, {name}!")
}
}
async fn spawn(fut: impl Future<Output = ()> + Send + 'static) {
tokio::spawn(fut);
}
#[tokio::main]
async fn main() -> anyhow::Result<()> {
let mut incoming = tcp::listen("localhost:0", Bincode::default).await?;
let addr = incoming.local_addr();
tokio::spawn(async move {
let transport = incoming.next().await.unwrap().unwrap();
BaseChannel::with_defaults(add_compression(transport))
.execute(HelloServer.serve())
.for_each(spawn)
.await;
});
let transport = tcp::connect(addr, Bincode::default).await?;
let client = WorldClient::new(client::Config::default(), add_compression(transport)).spawn();
println!(
"{}",
client.hello(context::current(), "friend".into()).await?
);
Ok(())
}

View File

@@ -0,0 +1,59 @@
// Copyright 2022 Google LLC
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
use futures::prelude::*;
use tarpc::context::Context;
use tarpc::serde_transport as transport;
use tarpc::server::{BaseChannel, Channel};
use tarpc::tokio_serde::formats::Bincode;
use tarpc::tokio_util::codec::length_delimited::LengthDelimitedCodec;
use tokio::net::{UnixListener, UnixStream};
#[tarpc::service]
pub trait PingService {
async fn ping();
}
#[derive(Clone)]
struct Service;
impl PingService for Service {
async fn ping(self, _: Context) {}
}
#[tokio::main]
async fn main() -> anyhow::Result<()> {
let bind_addr = "/tmp/tarpc_on_unix_example.sock";
let _ = std::fs::remove_file(bind_addr);
let listener = UnixListener::bind(bind_addr).unwrap();
let codec_builder = LengthDelimitedCodec::builder();
async fn spawn(fut: impl Future<Output = ()> + Send + 'static) {
tokio::spawn(fut);
}
tokio::spawn(async move {
loop {
let (conn, _addr) = listener.accept().await.unwrap();
let framed = codec_builder.new_framed(conn);
let transport = transport::new(framed, Bincode::default());
let fut = BaseChannel::with_defaults(transport)
.execute(Service.serve())
.for_each(spawn);
tokio::spawn(fut);
}
});
let conn = UnixStream::connect(bind_addr).await?;
let transport = transport::new(codec_builder.new_framed(conn), Bincode::default());
PingServiceClient::new(Default::default(), transport)
.spawn()
.ping(tarpc::context::current())
.await?;
Ok(())
}

View File

@@ -4,188 +4,359 @@
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
#![feature(
arbitrary_self_types,
pin,
futures_api,
await_macro,
async_await,
existential_type,
proc_macro_hygiene,
)]
/// - The PubSub server sets up TCP listeners on 2 ports, the "subscriber" port and the "publisher"
/// port. Because both publishers and subscribers initiate their connections to the PubSub
/// server, the server requires no prior knowledge of either publishers or subscribers.
///
/// - Subscribers connect to the server on the server's "subscriber" port. Once a connection is
/// established, the server acts as the client of the Subscriber service, initially requesting
/// the topics the subscriber is interested in, and subsequently sending topical messages to the
/// subscriber.
///
/// - Publishers connect to the server on the "publisher" port and, once connected, they send
/// topical messages via Publisher service to the server. The server then broadcasts each
/// messages to all clients subscribed to the topic of that message.
///
/// Subscriber Publisher PubSub Server
/// T1 | | |
/// T2 |-----Connect------------------------------------------------------>|
/// T3 | | |
/// T2 |<-------------------------------------------------------Topics-----|
/// T2 |-----(OK) Topics-------------------------------------------------->|
/// T3 | | |
/// T4 | |-----Connect-------------------->|
/// T5 | | |
/// T6 | |-----Publish-------------------->|
/// T7 | | |
/// T8 |<------------------------------------------------------Receive-----|
/// T9 |-----(OK) Receive------------------------------------------------->|
/// T10 | | |
/// T11 | |<--------------(OK) Publish------|
use anyhow::anyhow;
use futures::{
future::{self, Ready},
channel::oneshot,
future::{self, AbortHandle},
prelude::*,
Future,
};
use rpc::{
client, context,
server::{self, Handler, Server},
};
use publisher::Publisher as _;
use std::{
collections::HashMap,
env,
error::Error,
io,
net::SocketAddr,
sync::{Arc, Mutex},
thread,
time::Duration,
sync::{Arc, Mutex, RwLock},
};
use subscriber::Subscriber as _;
use tarpc::{
client, context,
serde_transport::tcp,
server::{self, Channel},
tokio_serde::formats::Json,
};
use tokio::net::ToSocketAddrs;
use tracing::info;
use tracing_subscriber::prelude::*;
pub mod subscriber {
tarpc::service! {
rpc receive(message: String);
#[tarpc::service]
pub trait Subscriber {
async fn topics() -> Vec<String>;
async fn receive(topic: String, message: String);
}
}
pub mod publisher {
use std::net::SocketAddr;
tarpc::service! {
rpc broadcast(message: String);
rpc subscribe(id: u32, address: SocketAddr) -> Result<(), String>;
rpc unsubscribe(id: u32);
#[tarpc::service]
pub trait Publisher {
async fn publish(topic: String, message: String);
}
}
#[derive(Clone, Debug)]
struct Subscriber {
id: u32,
local_addr: SocketAddr,
topics: Vec<String>,
}
impl subscriber::Service for Subscriber {
type ReceiveFut = Ready<()>;
impl subscriber::Subscriber for Subscriber {
async fn topics(self, _: context::Context) -> Vec<String> {
self.topics.clone()
}
fn receive(&self, _: context::Context, message: String) -> Self::ReceiveFut {
println!("{} received message: {}", self.id, message);
future::ready(())
async fn receive(self, _: context::Context, topic: String, message: String) {
info!(local_addr = %self.local_addr, %topic, %message, "ReceivedMessage")
}
}
struct SubscriberHandle(AbortHandle);
impl Drop for SubscriberHandle {
fn drop(&mut self) {
self.0.abort();
}
}
impl Subscriber {
async fn listen(id: u32, config: server::Config) -> io::Result<SocketAddr> {
let incoming = bincode_transport::listen(&"0.0.0.0:0".parse().unwrap())?;
let addr = incoming.local_addr();
tokio_executor::spawn(
Server::new(config)
.incoming(incoming)
.take(1)
.respond_with(subscriber::serve(Subscriber { id }))
.unit_error()
.boxed()
.compat()
);
Ok(addr)
async fn connect(
publisher_addr: impl ToSocketAddrs,
topics: Vec<String>,
) -> anyhow::Result<SubscriberHandle> {
let publisher = tcp::connect(publisher_addr, Json::default).await?;
let local_addr = publisher.local_addr()?;
let mut handler = server::BaseChannel::with_defaults(publisher).requests();
let subscriber = Subscriber { local_addr, topics };
// The first request is for the topics being subscribed to.
match handler.next().await {
Some(init_topics) => init_topics?.execute(subscriber.clone().serve()).await,
None => {
return Err(anyhow!(
"[{}] Server never initialized the subscriber.",
local_addr
))
}
};
let (handler, abort_handle) =
future::abortable(handler.execute(subscriber.serve()).for_each(spawn));
tokio::spawn(async move {
match handler.await {
Ok(()) | Err(future::Aborted) => info!(?local_addr, "subscriber shutdown."),
}
});
Ok(SubscriberHandle(abort_handle))
}
}
#[derive(Debug)]
struct Subscription {
topics: Vec<String>,
}
#[derive(Clone, Debug)]
struct Publisher {
clients: Arc<Mutex<HashMap<u32, subscriber::Client>>>,
clients: Arc<Mutex<HashMap<SocketAddr, Subscription>>>,
subscriptions: Arc<RwLock<HashMap<String, HashMap<SocketAddr, subscriber::SubscriberClient>>>>,
}
struct PublisherAddrs {
publisher: SocketAddr,
subscriptions: SocketAddr,
}
async fn spawn(fut: impl Future<Output = ()> + Send + 'static) {
tokio::spawn(fut);
}
impl Publisher {
fn new() -> Publisher {
Publisher {
clients: Arc::new(Mutex::new(HashMap::new())),
}
async fn start(self) -> io::Result<PublisherAddrs> {
let mut connecting_publishers = tcp::listen("localhost:0", Json::default).await?;
let publisher_addrs = PublisherAddrs {
publisher: connecting_publishers.local_addr(),
subscriptions: self.clone().start_subscription_manager().await?,
};
info!(publisher_addr = %publisher_addrs.publisher, "listening for publishers.",);
tokio::spawn(async move {
// Because this is just an example, we know there will only be one publisher. In more
// realistic code, this would be a loop to continually accept new publisher
// connections.
let publisher = connecting_publishers.next().await.unwrap().unwrap();
info!(publisher.peer_addr = ?publisher.peer_addr(), "publisher connected.");
server::BaseChannel::with_defaults(publisher)
.execute(self.serve())
.for_each(spawn)
.await
});
Ok(publisher_addrs)
}
}
impl publisher::Service for Publisher {
existential type BroadcastFut: Future<Output = ()>;
async fn start_subscription_manager(mut self) -> io::Result<SocketAddr> {
let mut connecting_subscribers = tcp::listen("localhost:0", Json::default)
.await?
.filter_map(|r| future::ready(r.ok()));
let new_subscriber_addr = connecting_subscribers.get_ref().local_addr();
info!(?new_subscriber_addr, "listening for subscribers.");
fn broadcast(&self, _: context::Context, message: String) -> Self::BroadcastFut {
async fn broadcast(clients: Arc<Mutex<HashMap<u32, subscriber::Client>>>, message: String) {
let mut clients = clients.lock().unwrap().clone();
for client in clients.values_mut() {
// Ignore failing subscribers. In a real pubsub,
// you'd want to continually retry until subscribers
// ack.
let _ = await!(client.receive(context::current(), message.clone()));
tokio::spawn(async move {
while let Some(conn) = connecting_subscribers.next().await {
let subscriber_addr = conn.peer_addr().unwrap();
let tarpc::client::NewClient {
client: subscriber,
dispatch,
} = subscriber::SubscriberClient::new(client::Config::default(), conn);
let (ready_tx, ready) = oneshot::channel();
self.clone()
.start_subscriber_gc(subscriber_addr, dispatch, ready);
// Populate the topics
self.initialize_subscription(subscriber_addr, subscriber)
.await;
// Signal that initialization is done.
ready_tx.send(()).unwrap();
}
});
Ok(new_subscriber_addr)
}
async fn initialize_subscription(
&mut self,
subscriber_addr: SocketAddr,
subscriber: subscriber::SubscriberClient,
) {
// Populate the topics
if let Ok(topics) = subscriber.topics(context::current()).await {
self.clients.lock().unwrap().insert(
subscriber_addr,
Subscription {
topics: topics.clone(),
},
);
info!(%subscriber_addr, ?topics, "subscribed to new topics");
let mut subscriptions = self.subscriptions.write().unwrap();
for topic in topics {
subscriptions
.entry(topic)
.or_insert_with(HashMap::new)
.insert(subscriber_addr, subscriber.clone());
}
}
broadcast(self.clients.clone(), message)
}
existential type SubscribeFut: Future<Output = Result<(), String>>;
fn subscribe(&self, _: context::Context, id: u32, addr: SocketAddr) -> Self::SubscribeFut {
async fn subscribe(
clients: Arc<Mutex<HashMap<u32, subscriber::Client>>>,
id: u32,
addr: SocketAddr,
) -> io::Result<()> {
let conn = await!(bincode_transport::connect(&addr))?;
let subscriber = await!(subscriber::new_stub(client::Config::default(), conn))?;
println!("Subscribing {}.", id);
clients.lock().unwrap().insert(id, subscriber);
Ok(())
}
subscribe(Arc::clone(&self.clients), id, addr).map_err(|e| e.to_string())
}
existential type UnsubscribeFut: Future<Output = ()>;
fn unsubscribe(&self, _: context::Context, id: u32) -> Self::UnsubscribeFut {
println!("Unsubscribing {}", id);
let mut clients = self.clients.lock().unwrap();
if let None = clients.remove(&id) {
eprintln!(
"Client {} not found. Existings clients: {:?}",
id, &*clients
);
}
future::ready(())
fn start_subscriber_gc<E: Error>(
self,
subscriber_addr: SocketAddr,
client_dispatch: impl Future<Output = Result<(), E>> + Send + 'static,
subscriber_ready: oneshot::Receiver<()>,
) {
tokio::spawn(async move {
if let Err(e) = client_dispatch.await {
info!(
%subscriber_addr,
error = %e,
"subscriber connection broken");
}
// Don't clean up the subscriber until initialization is done.
let _ = subscriber_ready.await;
if let Some(subscription) = self.clients.lock().unwrap().remove(&subscriber_addr) {
info!(
"[{} unsubscribing from topics: {:?}",
subscriber_addr, subscription.topics
);
let mut subscriptions = self.subscriptions.write().unwrap();
for topic in subscription.topics {
let subscribers = subscriptions.get_mut(&topic).unwrap();
subscribers.remove(&subscriber_addr);
if subscribers.is_empty() {
subscriptions.remove(&topic);
}
}
}
});
}
}
async fn run() -> io::Result<()> {
env_logger::init();
let transport = bincode_transport::listen(&"0.0.0.0:0".parse().unwrap())?;
let publisher_addr = transport.local_addr();
tokio_executor::spawn(
Server::new(server::Config::default())
.incoming(transport)
.take(1)
.respond_with(publisher::serve(Publisher::new()))
.unit_error()
.boxed()
.compat()
);
let subscriber1 = await!(Subscriber::listen(0, server::Config::default()))?;
let subscriber2 = await!(Subscriber::listen(1, server::Config::default()))?;
let publisher_conn = bincode_transport::connect(&publisher_addr);
let publisher_conn = await!(publisher_conn)?;
let mut publisher = await!(publisher::new_stub(
client::Config::default(),
publisher_conn
))?;
if let Err(e) = await!(publisher.subscribe(context::current(), 0, subscriber1))? {
eprintln!("Couldn't subscribe subscriber 0: {}", e);
}
if let Err(e) = await!(publisher.subscribe(context::current(), 1, subscriber2))? {
eprintln!("Couldn't subscribe subscriber 1: {}", e);
impl publisher::Publisher for Publisher {
async fn publish(self, _: context::Context, topic: String, message: String) {
info!("received message to publish.");
let mut subscribers = match self.subscriptions.read().unwrap().get(&topic) {
None => return,
Some(subscriptions) => subscriptions.clone(),
};
let mut publications = Vec::new();
for client in subscribers.values_mut() {
publications.push(client.receive(context::current(), topic.clone(), message.clone()));
}
// Ignore failing subscribers. In a real pubsub, you'd want to continually retry until
// subscribers ack. Of course, a lot would be different in a real pubsub :)
for response in future::join_all(publications).await {
if let Err(e) = response {
info!("failed to broadcast to subscriber: {}", e);
}
}
}
}
/// Initializes an OpenTelemetry tracing subscriber with a Jaeger backend.
fn init_tracing(service_name: &str) -> anyhow::Result<()> {
env::set_var("OTEL_BSP_MAX_EXPORT_BATCH_SIZE", "12");
let tracer = opentelemetry_jaeger::new_agent_pipeline()
.with_service_name(service_name)
.with_max_packet_size(2usize.pow(13))
.install_batch(opentelemetry::runtime::Tokio)?;
tracing_subscriber::registry()
.with(tracing_subscriber::filter::EnvFilter::from_default_env())
.with(tracing_subscriber::fmt::layer())
.with(tracing_opentelemetry::layer().with_tracer(tracer))
.try_init()?;
println!("Broadcasting...");
await!(publisher.broadcast(context::current(), "hello to all".to_string()))?;
await!(publisher.unsubscribe(context::current(), 1))?;
await!(publisher.broadcast(context::current(), "hi again".to_string()))?;
Ok(())
}
fn main() {
tokio::run(
run()
.boxed()
.map_err(|e| panic!(e))
.boxed()
.compat(),
);
thread::sleep(Duration::from_millis(100));
#[tokio::main]
async fn main() -> anyhow::Result<()> {
init_tracing("Pub/Sub")?;
let addrs = Publisher {
clients: Arc::new(Mutex::new(HashMap::new())),
subscriptions: Arc::new(RwLock::new(HashMap::new())),
}
.start()
.await?;
let _subscriber0 = Subscriber::connect(
addrs.subscriptions,
vec!["calculus".into(), "cool shorts".into()],
)
.await?;
let _subscriber1 = Subscriber::connect(
addrs.subscriptions,
vec!["cool shorts".into(), "history".into()],
)
.await?;
let publisher = publisher::PublisherClient::new(
client::Config::default(),
tcp::connect(addrs.publisher, Json::default).await?,
)
.spawn();
publisher
.publish(context::current(), "calculus".into(), "sqrt(2)".into())
.await?;
publisher
.publish(
context::current(),
"cool shorts".into(),
"hello to all".into(),
)
.await?;
publisher
.publish(context::current(), "history".into(), "napoleon".to_string())
.await?;
drop(_subscriber0);
publisher
.publish(
context::current(),
"cool shorts".into(),
"hello to who?".into(),
)
.await?;
opentelemetry::global::shutdown_tracer_provider();
info!("done.");
Ok(())
}

View File

@@ -4,88 +4,51 @@
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
#![feature(
futures_api,
pin,
arbitrary_self_types,
await_macro,
async_await,
proc_macro_hygiene,
)]
use futures::{
future::{self, Ready},
prelude::*,
};
use rpc::{
use futures::prelude::*;
use tarpc::{
client, context,
server::{self, Handler, Server},
server::{self, Channel},
};
use std::io;
// This is the service definition. It looks a lot like a trait definition.
// It defines one RPC, hello, which takes one arg, name, and returns a String.
tarpc::service! {
rpc hello(name: String) -> String;
/// This is the service definition. It looks a lot like a trait definition.
/// It defines one RPC, hello, which takes one arg, name, and returns a String.
#[tarpc::service]
pub trait World {
async fn hello(name: String) -> String;
}
// This is the type that implements the generated Service trait. It is the business logic
// and is used to start the server.
/// This is the type that implements the generated World trait. It is the business logic
/// and is used to start the server.
#[derive(Clone)]
struct HelloServer;
impl Service for HelloServer {
// Each defined rpc generates two items in the trait, a fn that serves the RPC, and
// an associated type representing the future output by the fn.
type HelloFut = Ready<String>;
fn hello(&self, _: context::Context, name: String) -> Self::HelloFut {
future::ready(format!("Hello, {}!", name))
impl World for HelloServer {
async fn hello(self, _: context::Context, name: String) -> String {
format!("Hello, {name}!")
}
}
async fn run() -> io::Result<()> {
// bincode_transport is provided by the associated crate bincode-transport. It makes it easy
// to start up a serde-powered bincode serialization strategy over TCP.
let transport = bincode_transport::listen(&"0.0.0.0:0".parse().unwrap())?;
let addr = transport.local_addr();
async fn spawn(fut: impl Future<Output = ()> + Send + 'static) {
tokio::spawn(fut);
}
// The server is configured with the defaults.
let server = Server::new(server::Config::default())
// Server can listen on any type that implements the Transport trait.
.incoming(transport)
// Close the stream after the client connects
.take(1)
// serve is generated by the tarpc::service! macro. It takes as input any type implementing
// the generated Service trait.
.respond_with(serve(HelloServer));
#[tokio::main]
async fn main() -> anyhow::Result<()> {
let (client_transport, server_transport) = tarpc::transport::channel::unbounded();
tokio_executor::spawn(server.unit_error().boxed().compat());
let server = server::BaseChannel::with_defaults(server_transport);
tokio::spawn(server.execute(HelloServer.serve()).for_each(spawn));
let transport = await!(bincode_transport::connect(&addr))?;
// WorldClient is generated by the #[tarpc::service] attribute. It has a constructor `new`
// that takes a config and any Transport as input.
let client = WorldClient::new(client::Config::default(), client_transport).spawn();
// new_stub is generated by the tarpc::service! macro. Like Server, it takes a config and any
// Transport as input, and returns a Client, also generated by the macro.
// by the service mcro.
let mut client = await!(new_stub(client::Config::default(), transport))?;
// The client has an RPC method for each RPC defined in tarpc::service!. It takes the same args
// as defined, with the addition of a Context, which is always the first arg. The Context
// The client has an RPC method for each RPC defined in the annotated trait. It takes the same
// args as defined, with the addition of a Context, which is always the first arg. The Context
// specifies a deadline and trace information which can be helpful in debugging requests.
let hello = await!(client.hello(context::current(), "Stim".to_string()))?;
let hello = client.hello(context::current(), "Stim".to_string()).await?;
println!("{}", hello);
println!("{hello}");
Ok(())
}
fn main() {
tokio::run(
run()
.map_err(|e| eprintln!("Oh no: {}", e))
.boxed()
.compat(),
);
}

View File

@@ -1,111 +0,0 @@
// Copyright 2018 Google LLC
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
#![feature(
existential_type,
arbitrary_self_types,
pin,
futures_api,
await_macro,
async_await,
proc_macro_hygiene,
)]
use crate::{add::Service as AddService, double::Service as DoubleService};
use futures::{
future::{self, Ready},
prelude::*,
};
use rpc::{
client, context,
server::{self, Handler, Server},
};
use std::io;
pub mod add {
tarpc::service! {
/// Add two ints together.
rpc add(x: i32, y: i32) -> i32;
}
}
pub mod double {
tarpc::service! {
/// 2 * x
rpc double(x: i32) -> Result<i32, String>;
}
}
#[derive(Clone)]
struct AddServer;
impl AddService for AddServer {
type AddFut = Ready<i32>;
fn add(&self, _: context::Context, x: i32, y: i32) -> Self::AddFut {
future::ready(x + y)
}
}
#[derive(Clone)]
struct DoubleServer {
add_client: add::Client,
}
impl DoubleService for DoubleServer {
existential type DoubleFut: Future<Output = Result<i32, String>> + Send;
fn double(&self, _: context::Context, x: i32) -> Self::DoubleFut {
async fn double(mut client: add::Client, x: i32) -> Result<i32, String> {
let result = await!(client.add(context::current(), x, x));
result.map_err(|e| e.to_string())
}
double(self.add_client.clone(), x)
}
}
async fn run() -> io::Result<()> {
let add_listener = bincode_transport::listen(&"0.0.0.0:0".parse().unwrap())?;
let addr = add_listener.local_addr();
let add_server = Server::new(server::Config::default())
.incoming(add_listener)
.take(1)
.respond_with(add::serve(AddServer));
tokio_executor::spawn(add_server.unit_error().boxed().compat());
let to_add_server = await!(bincode_transport::connect(&addr))?;
let add_client = await!(add::new_stub(client::Config::default(), to_add_server))?;
let double_listener = bincode_transport::listen(&"0.0.0.0:0".parse().unwrap())?;
let addr = double_listener.local_addr();
let double_server = rpc::Server::new(server::Config::default())
.incoming(double_listener)
.take(1)
.respond_with(double::serve(DoubleServer { add_client }));
tokio_executor::spawn(double_server.unit_error().boxed().compat());
let to_double_server = await!(bincode_transport::connect(&addr))?;
let mut double_client = await!(double::new_stub(
client::Config::default(),
to_double_server
))?;
for i in 1..=5 {
println!("{:?}", await!(double_client.double(context::current(), i))?);
}
Ok(())
}
fn main() {
env_logger::init();
tokio::run(
run()
.map_err(|e| panic!(e))
.boxed()
.compat(),
);
}

View File

@@ -0,0 +1,150 @@
// Copyright 2023 Google LLC
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
use futures::prelude::*;
use rustls_pemfile::certs;
use std::io::{BufReader, Cursor};
use std::net::{IpAddr, Ipv4Addr};
use tokio_rustls::rustls::server::AllowAnyAuthenticatedClient;
use std::sync::Arc;
use tokio::net::TcpListener;
use tokio::net::TcpStream;
use tokio_rustls::rustls::{self, RootCertStore};
use tokio_rustls::{TlsAcceptor, TlsConnector};
use tarpc::context::Context;
use tarpc::serde_transport as transport;
use tarpc::server::{BaseChannel, Channel};
use tarpc::tokio_serde::formats::Bincode;
use tarpc::tokio_util::codec::length_delimited::LengthDelimitedCodec;
#[tarpc::service]
pub trait PingService {
async fn ping() -> String;
}
#[derive(Clone)]
struct Service;
impl PingService for Service {
async fn ping(self, _: Context) -> String {
"🔒".to_owned()
}
}
// certs were generated with openssl 3 https://github.com/rustls/rustls/tree/main/test-ca
// used on client-side for server tls
const END_CHAIN: &str = include_str!("certs/eddsa/end.chain");
// used on client-side for client-auth
const CLIENT_PRIVATEKEY_CLIENT_AUTH: &str = include_str!("certs/eddsa/client.key");
const CLIENT_CERT_CLIENT_AUTH: &str = include_str!("certs/eddsa/client.cert");
// used on server-side for server tls
const END_CERT: &str = include_str!("certs/eddsa/end.cert");
const END_PRIVATEKEY: &str = include_str!("certs/eddsa/end.key");
// used on server-side for client-auth
const CLIENT_CHAIN_CLIENT_AUTH: &str = include_str!("certs/eddsa/client.chain");
pub fn load_certs(data: &str) -> Vec<rustls::Certificate> {
certs(&mut BufReader::new(Cursor::new(data)))
.unwrap()
.into_iter()
.map(rustls::Certificate)
.collect()
}
pub fn load_private_key(key: &str) -> rustls::PrivateKey {
let mut reader = BufReader::new(Cursor::new(key));
loop {
match rustls_pemfile::read_one(&mut reader).expect("cannot parse private key .pem file") {
Some(rustls_pemfile::Item::RSAKey(key)) => return rustls::PrivateKey(key),
Some(rustls_pemfile::Item::PKCS8Key(key)) => return rustls::PrivateKey(key),
Some(rustls_pemfile::Item::ECKey(key)) => return rustls::PrivateKey(key),
None => break,
_ => {}
}
}
panic!("no keys found in {:?} (encrypted keys not supported)", key);
}
async fn spawn(fut: impl Future<Output = ()> + Send + 'static) {
tokio::spawn(fut);
}
#[tokio::main]
async fn main() -> anyhow::Result<()> {
// -------------------- start here to setup tls tcp tokio stream --------------------------
// ref certs and loading from: https://github.com/tokio-rs/tls/blob/master/tokio-rustls/tests/test.rs
// ref basic tls server setup from: https://github.com/tokio-rs/tls/blob/master/tokio-rustls/examples/server/src/main.rs
let cert = load_certs(END_CERT);
let key = load_private_key(END_PRIVATEKEY);
let server_addr = (IpAddr::V4(Ipv4Addr::LOCALHOST), 5000);
// ------------- server side client_auth cert loading start
let mut client_auth_roots = RootCertStore::empty();
for root in load_certs(CLIENT_CHAIN_CLIENT_AUTH) {
client_auth_roots.add(&root).unwrap();
}
let client_auth = AllowAnyAuthenticatedClient::new(client_auth_roots);
// ------------- server side client_auth cert loading end
let config = rustls::ServerConfig::builder()
.with_safe_defaults()
.with_client_cert_verifier(client_auth) // use .with_no_client_auth() instead if you don't want client-auth
.with_single_cert(cert, key)
.unwrap();
let acceptor = TlsAcceptor::from(Arc::new(config));
let listener = TcpListener::bind(&server_addr).await.unwrap();
let codec_builder = LengthDelimitedCodec::builder();
// ref ./custom_transport.rs server side
tokio::spawn(async move {
loop {
let (stream, _peer_addr) = listener.accept().await.unwrap();
let tls_stream = acceptor.accept(stream).await.unwrap();
let framed = codec_builder.new_framed(tls_stream);
let transport = transport::new(framed, Bincode::default());
let fut = BaseChannel::with_defaults(transport)
.execute(Service.serve())
.for_each(spawn);
tokio::spawn(fut);
}
});
// ---------------------- client connection ---------------------
// tls client connection from https://github.com/tokio-rs/tls/blob/master/tokio-rustls/examples/client/src/main.rs
let mut root_store = rustls::RootCertStore::empty();
for root in load_certs(END_CHAIN) {
root_store.add(&root).unwrap();
}
let client_auth_private_key = load_private_key(CLIENT_PRIVATEKEY_CLIENT_AUTH);
let client_auth_certs = load_certs(CLIENT_CERT_CLIENT_AUTH);
let config = rustls::ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(root_store)
.with_single_cert(client_auth_certs, client_auth_private_key)?; // use .with_no_client_auth() instead if you don't want client-auth
let domain = rustls::ServerName::try_from("localhost")?;
let connector = TlsConnector::from(Arc::new(config));
let stream = TcpStream::connect(server_addr).await?;
let stream = connector.connect(domain, stream).await?;
let transport = transport::new(codec_builder.new_framed(stream), Bincode::default());
let answer = PingServiceClient::new(Default::default(), transport)
.spawn()
.ping(tarpc::context::current())
.await?;
println!("ping answer: {answer}");
Ok(())
}

190
tarpc/examples/tracing.rs Normal file
View File

@@ -0,0 +1,190 @@
// Copyright 2018 Google LLC
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
use crate::{
add::{Add as AddService, AddStub},
double::Double as DoubleService,
};
use futures::{future, prelude::*};
use std::{
io,
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
};
use tarpc::{
client::{
self,
stub::{load_balance, retry},
RpcError,
},
context, serde_transport,
server::{
incoming::{spawn_incoming, Incoming},
request_hook::{self, BeforeRequestList},
BaseChannel,
},
tokio_serde::formats::Json,
ClientMessage, Response, ServerError, Transport,
};
use tokio::net::TcpStream;
use tracing_subscriber::prelude::*;
pub mod add {
#[tarpc::service]
pub trait Add {
/// Add two ints together.
async fn add(x: i32, y: i32) -> i32;
}
}
pub mod double {
#[tarpc::service]
pub trait Double {
/// 2 * x
async fn double(x: i32) -> Result<i32, String>;
}
}
#[derive(Clone)]
struct AddServer;
impl AddService for AddServer {
async fn add(self, _: context::Context, x: i32, y: i32) -> i32 {
x + y
}
}
#[derive(Clone)]
struct DoubleServer<Stub> {
add_client: add::AddClient<Stub>,
}
impl<Stub> DoubleService for DoubleServer<Stub>
where
Stub: AddStub + Clone + Send + Sync + 'static,
{
async fn double(self, _: context::Context, x: i32) -> Result<i32, String> {
self.add_client
.add(context::current(), x, x)
.await
.map_err(|e| e.to_string())
}
}
fn init_tracing(service_name: &str) -> anyhow::Result<()> {
let tracer = opentelemetry_jaeger::new_agent_pipeline()
.with_service_name(service_name)
.with_auto_split_batch(true)
.with_max_packet_size(2usize.pow(13))
.install_batch(opentelemetry::runtime::Tokio)?;
tracing_subscriber::registry()
.with(tracing_subscriber::EnvFilter::from_default_env())
.with(tracing_subscriber::fmt::layer())
.with(tracing_opentelemetry::layer().with_tracer(tracer))
.try_init()?;
Ok(())
}
async fn listen_on_random_port<Item, SinkItem>() -> anyhow::Result<(
impl Stream<Item = serde_transport::Transport<TcpStream, Item, SinkItem, Json<Item, SinkItem>>>,
std::net::SocketAddr,
)>
where
Item: for<'de> serde::Deserialize<'de>,
SinkItem: serde::Serialize,
{
let listener = tarpc::serde_transport::tcp::listen("localhost:0", Json::default)
.await?
.filter_map(|r| future::ready(r.ok()))
.take(1);
let addr = listener.get_ref().get_ref().local_addr();
Ok((listener, addr))
}
fn make_stub<Req, Resp, const N: usize>(
backends: [impl Transport<ClientMessage<Arc<Req>>, Response<Resp>> + Send + Sync + 'static; N],
) -> retry::Retry<
impl Fn(&Result<Resp, RpcError>, u32) -> bool + Clone,
load_balance::RoundRobin<client::Channel<Arc<Req>, Resp>>,
>
where
Req: Send + Sync + 'static,
Resp: Send + Sync + 'static,
{
let stub = load_balance::RoundRobin::new(
backends
.into_iter()
.map(|transport| tarpc::client::new(client::Config::default(), transport).spawn())
.collect(),
);
let stub = retry::Retry::new(stub, |resp, attempts| {
if let Err(e) = resp {
tracing::warn!("Got an error: {e:?}");
attempts < 3
} else {
false
}
});
stub
}
#[tokio::main]
async fn main() -> anyhow::Result<()> {
init_tracing("tarpc_tracing_example")?;
let (add_listener1, addr1) = listen_on_random_port().await?;
let (add_listener2, addr2) = listen_on_random_port().await?;
let something_bad_happened = Arc::new(AtomicBool::new(false));
let server = request_hook::before()
.then_fn(move |_: &mut _, _: &_| {
let something_bad_happened = something_bad_happened.clone();
async move {
if something_bad_happened.fetch_xor(true, Ordering::Relaxed) {
Err(ServerError::new(
io::ErrorKind::NotFound,
"Gamma Ray!".into(),
))
} else {
Ok(())
}
}
})
.serving(AddServer.serve());
let add_server = add_listener1
.chain(add_listener2)
.map(BaseChannel::with_defaults);
tokio::spawn(spawn_incoming(add_server.execute(server)));
let add_client = add::AddClient::from(make_stub([
tarpc::serde_transport::tcp::connect(addr1, Json::default).await?,
tarpc::serde_transport::tcp::connect(addr2, Json::default).await?,
]));
let double_listener = tarpc::serde_transport::tcp::listen("localhost:0", Json::default)
.await?
.filter_map(|r| future::ready(r.ok()));
let addr = double_listener.get_ref().local_addr();
let double_server = double_listener.map(BaseChannel::with_defaults).take(1);
let server = DoubleServer { add_client }.serve();
tokio::spawn(spawn_incoming(double_server.execute(server)));
let to_double_server = tarpc::serde_transport::tcp::connect(addr, Json::default).await?;
let double_client =
double::DoubleClient::new(client::Config::default(), to_double_server).spawn();
let ctx = context::current();
for _ in 1..=5 {
tracing::info!("{:?}", double_client.double(ctx, 1).await?);
}
opentelemetry::global::shutdown_tracer_provider();
Ok(())
}

View File

@@ -1 +1 @@
edition = "Edition2018"
edition = "2018"

View File

@@ -0,0 +1,49 @@
use futures::{prelude::*, task::*};
use std::pin::Pin;
use tokio::sync::mpsc;
/// Sends request cancellation signals.
#[derive(Debug, Clone)]
pub struct RequestCancellation(mpsc::UnboundedSender<u64>);
/// A stream of IDs of requests that have been canceled.
#[derive(Debug)]
pub struct CanceledRequests(mpsc::UnboundedReceiver<u64>);
/// Returns a channel to send request cancellation messages.
pub fn cancellations() -> (RequestCancellation, CanceledRequests) {
// Unbounded because messages are sent in the drop fn. This is fine, because it's still
// bounded by the number of in-flight requests.
let (tx, rx) = mpsc::unbounded_channel();
(RequestCancellation(tx), CanceledRequests(rx))
}
impl RequestCancellation {
/// Cancels the request with ID `request_id`.
///
/// No validation is done of `request_id`. There is no way to know if the request id provided
/// corresponds to a request actually tracked by the backing channel. `RequestCancellation` is
/// a one-way communication channel.
///
/// Once request data is cleaned up, a response will never be received by the client. This is
/// useful primarily when request processing ends prematurely for requests with long deadlines
/// which would otherwise continue to be tracked by the backing channel—a kind of leak.
pub fn cancel(&self, request_id: u64) {
let _ = self.0.send(request_id);
}
}
impl CanceledRequests {
/// Polls for a cancelled request.
pub fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<u64>> {
self.0.poll_recv(cx)
}
}
impl Stream for CanceledRequests {
type Item = u64;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<u64>> {
self.poll_recv(cx)
}
}

1065
tarpc/src/client.rs Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,137 @@
use crate::{
context,
util::{Compact, TimeUntil},
};
use fnv::FnvHashMap;
use std::{
collections::hash_map,
task::{Context, Poll},
};
use tokio::sync::oneshot;
use tokio_util::time::delay_queue::{self, DelayQueue};
use tracing::Span;
/// Requests already written to the wire that haven't yet received responses.
#[derive(Debug)]
pub struct InFlightRequests<Resp> {
request_data: FnvHashMap<u64, RequestData<Resp>>,
deadlines: DelayQueue<u64>,
}
impl<Resp> Default for InFlightRequests<Resp> {
fn default() -> Self {
Self {
request_data: Default::default(),
deadlines: Default::default(),
}
}
}
#[derive(Debug)]
struct RequestData<Res> {
ctx: context::Context,
span: Span,
response_completion: oneshot::Sender<Res>,
/// The key to remove the timer for the request's deadline.
deadline_key: delay_queue::Key,
}
/// An error returned when an attempt is made to insert a request with an ID that is already in
/// use.
#[derive(Debug)]
pub struct AlreadyExistsError;
impl<Res> InFlightRequests<Res> {
/// Returns the number of in-flight requests.
pub fn len(&self) -> usize {
self.request_data.len()
}
/// Returns true iff there are no requests in flight.
pub fn is_empty(&self) -> bool {
self.request_data.is_empty()
}
/// Starts a request, unless a request with the same ID is already in flight.
pub fn insert_request(
&mut self,
request_id: u64,
ctx: context::Context,
span: Span,
response_completion: oneshot::Sender<Res>,
) -> Result<(), AlreadyExistsError> {
match self.request_data.entry(request_id) {
hash_map::Entry::Vacant(vacant) => {
let timeout = ctx.deadline.time_until();
let deadline_key = self.deadlines.insert(request_id, timeout);
vacant.insert(RequestData {
ctx,
span,
response_completion,
deadline_key,
});
Ok(())
}
hash_map::Entry::Occupied(_) => Err(AlreadyExistsError),
}
}
/// Removes a request without aborting. Returns true iff the request was found.
pub fn complete_request(&mut self, request_id: u64, result: Res) -> Option<Span> {
if let Some(request_data) = self.request_data.remove(&request_id) {
self.request_data.compact(0.1);
self.deadlines.remove(&request_data.deadline_key);
let _ = request_data.response_completion.send(result);
return Some(request_data.span);
}
tracing::debug!("No in-flight request found for request_id = {request_id}.");
// If the response completion was absent, then the request was already canceled.
None
}
/// Completes all requests using the provided function.
/// Returns Spans for all completes requests.
pub fn complete_all_requests<'a>(
&'a mut self,
mut result: impl FnMut() -> Res + 'a,
) -> impl Iterator<Item = Span> + 'a {
self.deadlines.clear();
self.request_data.drain().map(move |(_, request_data)| {
let _ = request_data.response_completion.send(result());
request_data.span
})
}
/// Cancels a request without completing (typically used when a request handle was dropped
/// before the request completed).
pub fn cancel_request(&mut self, request_id: u64) -> Option<(context::Context, Span)> {
if let Some(request_data) = self.request_data.remove(&request_id) {
self.request_data.compact(0.1);
self.deadlines.remove(&request_data.deadline_key);
Some((request_data.ctx, request_data.span))
} else {
None
}
}
/// Yields a request that has expired, completing it with a TimedOut error.
/// The caller should send cancellation messages for any yielded request ID.
pub fn poll_expired(
&mut self,
cx: &mut Context,
expired_error: impl Fn() -> Res,
) -> Poll<Option<u64>> {
self.deadlines.poll_expired(cx).map(|expired| {
let request_id = expired?.into_inner();
if let Some(request_data) = self.request_data.remove(&request_id) {
let _entered = request_data.span.enter();
tracing::error!("DeadlineExceeded");
self.request_data.compact(0.1);
let _ = request_data.response_completion.send(expired_error());
}
Some(request_id)
})
}
}

45
tarpc/src/client/stub.rs Normal file
View File

@@ -0,0 +1,45 @@
//! Provides a Stub trait, implemented by types that can call remote services.
use crate::{
client::{Channel, RpcError},
context,
};
pub mod load_balance;
pub mod retry;
#[cfg(test)]
mod mock;
/// A connection to a remote service.
/// Calls the service with requests of type `Req` and receives responses of type `Resp`.
#[allow(async_fn_in_trait)]
pub trait Stub {
/// The service request type.
type Req;
/// The service response type.
type Resp;
/// Calls a remote service.
async fn call(
&self,
ctx: context::Context,
request_name: &'static str,
request: Self::Req,
) -> Result<Self::Resp, RpcError>;
}
impl<Req, Resp> Stub for Channel<Req, Resp> {
type Req = Req;
type Resp = Resp;
async fn call(
&self,
ctx: context::Context,
request_name: &'static str,
request: Req,
) -> Result<Self::Resp, RpcError> {
Self::call(self, ctx, request_name, request).await
}
}

View File

@@ -0,0 +1,279 @@
//! Provides load-balancing [Stubs](crate::client::stub::Stub).
pub use consistent_hash::ConsistentHash;
pub use round_robin::RoundRobin;
/// Provides a stub that load-balances with a simple round-robin strategy.
mod round_robin {
use crate::{
client::{stub, RpcError},
context,
};
use cycle::AtomicCycle;
impl<Stub> stub::Stub for RoundRobin<Stub>
where
Stub: stub::Stub,
{
type Req = Stub::Req;
type Resp = Stub::Resp;
async fn call(
&self,
ctx: context::Context,
request_name: &'static str,
request: Self::Req,
) -> Result<Stub::Resp, RpcError> {
let next = self.stubs.next();
next.call(ctx, request_name, request).await
}
}
/// A Stub that load-balances across backing stubs by round robin.
#[derive(Clone, Debug)]
pub struct RoundRobin<Stub> {
stubs: AtomicCycle<Stub>,
}
impl<Stub> RoundRobin<Stub>
where
Stub: stub::Stub,
{
/// Returns a new RoundRobin stub.
pub fn new(stubs: Vec<Stub>) -> Self {
Self {
stubs: AtomicCycle::new(stubs),
}
}
}
mod cycle {
use std::sync::{
atomic::{AtomicUsize, Ordering},
Arc,
};
/// Cycles endlessly and atomically over a collection of elements of type T.
#[derive(Clone, Debug)]
pub struct AtomicCycle<T>(Arc<State<T>>);
#[derive(Debug)]
struct State<T> {
elements: Vec<T>,
next: AtomicUsize,
}
impl<T> AtomicCycle<T> {
pub fn new(elements: Vec<T>) -> Self {
Self(Arc::new(State {
elements,
next: Default::default(),
}))
}
pub fn next(&self) -> &T {
self.0.next()
}
}
impl<T> State<T> {
pub fn next(&self) -> &T {
let next = self.next.fetch_add(1, Ordering::Relaxed);
&self.elements[next % self.elements.len()]
}
}
#[test]
fn test_cycle() {
let cycle = AtomicCycle::new(vec![1, 2, 3]);
assert_eq!(cycle.next(), &1);
assert_eq!(cycle.next(), &2);
assert_eq!(cycle.next(), &3);
assert_eq!(cycle.next(), &1);
}
}
}
/// Provides a stub that load-balances with a consistent hashing strategy.
///
/// Each request is hashed, then mapped to a stub based on the hash. Equivalent requests will use
/// the same stub.
mod consistent_hash {
use crate::{
client::{stub, RpcError},
context,
};
use std::{
collections::hash_map::RandomState,
hash::{BuildHasher, Hash, Hasher},
num::TryFromIntError,
};
impl<Stub, S> stub::Stub for ConsistentHash<Stub, S>
where
Stub: stub::Stub,
Stub::Req: Hash,
S: BuildHasher,
{
type Req = Stub::Req;
type Resp = Stub::Resp;
async fn call(
&self,
ctx: context::Context,
request_name: &'static str,
request: Self::Req,
) -> Result<Stub::Resp, RpcError> {
let index = usize::try_from(self.hash_request(&request) % self.stubs_len).expect(
"invariant broken: stubs_len is not larger than a usize, \
so the hash modulo stubs_len should always fit in a usize",
);
let next = &self.stubs[index];
next.call(ctx, request_name, request).await
}
}
/// A Stub that load-balances across backing stubs by round robin.
#[derive(Clone, Debug)]
pub struct ConsistentHash<Stub, S = RandomState> {
stubs: Vec<Stub>,
stubs_len: u64,
hasher: S,
}
impl<Stub> ConsistentHash<Stub, RandomState>
where
Stub: stub::Stub,
Stub::Req: Hash,
{
/// Returns a new RoundRobin stub.
/// Returns an err if the length of `stubs` overflows a u64.
pub fn new(stubs: Vec<Stub>) -> Result<Self, TryFromIntError> {
Ok(Self {
stubs_len: stubs.len().try_into()?,
stubs,
hasher: RandomState::new(),
})
}
}
impl<Stub, S> ConsistentHash<Stub, S>
where
Stub: stub::Stub,
Stub::Req: Hash,
S: BuildHasher,
{
/// Returns a new RoundRobin stub.
/// Returns an err if the length of `stubs` overflows a u64.
pub fn with_hasher(stubs: Vec<Stub>, hasher: S) -> Result<Self, TryFromIntError> {
Ok(Self {
stubs_len: stubs.len().try_into()?,
stubs,
hasher,
})
}
fn hash_request(&self, req: &Stub::Req) -> u64 {
let mut hasher = self.hasher.build_hasher();
req.hash(&mut hasher);
hasher.finish()
}
}
#[cfg(test)]
mod tests {
use super::ConsistentHash;
use crate::{
client::stub::{mock::Mock, Stub},
context,
};
use std::{
collections::HashMap,
hash::{BuildHasher, Hash, Hasher},
rc::Rc,
};
#[tokio::test]
async fn test() -> anyhow::Result<()> {
let stub = ConsistentHash::<_, FakeHasherBuilder>::with_hasher(
vec![
// For easier reading of the assertions made in this test, each Mock's response
// value is equal to a hash value that should map to its index: 3 % 3 = 0, 1 %
// 3 = 1, etc.
Mock::new([('a', 3), ('b', 3), ('c', 3)]),
Mock::new([('a', 1), ('b', 1), ('c', 1)]),
Mock::new([('a', 2), ('b', 2), ('c', 2)]),
],
FakeHasherBuilder::new([('a', 1), ('b', 2), ('c', 3)]),
)?;
for _ in 0..2 {
let resp = stub.call(context::current(), "", 'a').await?;
assert_eq!(resp, 1);
let resp = stub.call(context::current(), "", 'b').await?;
assert_eq!(resp, 2);
let resp = stub.call(context::current(), "", 'c').await?;
assert_eq!(resp, 3);
}
Ok(())
}
struct HashRecorder(Vec<u8>);
impl Hasher for HashRecorder {
fn write(&mut self, bytes: &[u8]) {
self.0 = Vec::from(bytes);
}
fn finish(&self) -> u64 {
0
}
}
struct FakeHasherBuilder {
recorded_hashes: Rc<HashMap<Vec<u8>, u64>>,
}
struct FakeHasher {
recorded_hashes: Rc<HashMap<Vec<u8>, u64>>,
output: u64,
}
impl BuildHasher for FakeHasherBuilder {
type Hasher = FakeHasher;
fn build_hasher(&self) -> Self::Hasher {
FakeHasher {
recorded_hashes: self.recorded_hashes.clone(),
output: 0,
}
}
}
impl FakeHasherBuilder {
fn new<T: Hash, const N: usize>(fake_hashes: [(T, u64); N]) -> Self {
let mut recorded_hashes = HashMap::new();
for (to_hash, fake_hash) in fake_hashes {
let mut recorder = HashRecorder(vec![]);
to_hash.hash(&mut recorder);
recorded_hashes.insert(recorder.0, fake_hash);
}
Self {
recorded_hashes: Rc::new(recorded_hashes),
}
}
}
impl Hasher for FakeHasher {
fn write(&mut self, bytes: &[u8]) {
if let Some(hash) = self.recorded_hashes.get(bytes) {
self.output = *hash;
}
}
fn finish(&self) -> u64 {
self.output
}
}
}
}

View File

@@ -0,0 +1,49 @@
use crate::{
client::{stub::Stub, RpcError},
context, ServerError,
};
use std::{collections::HashMap, hash::Hash, io};
/// A mock stub that returns user-specified responses.
pub struct Mock<Req, Resp> {
responses: HashMap<Req, Resp>,
}
impl<Req, Resp> Mock<Req, Resp>
where
Req: Eq + Hash,
{
/// Returns a new mock, mocking the specified (request, response) pairs.
pub fn new<const N: usize>(responses: [(Req, Resp); N]) -> Self {
Self {
responses: HashMap::from(responses),
}
}
}
impl<Req, Resp> Stub for Mock<Req, Resp>
where
Req: Eq + Hash,
Resp: Clone,
{
type Req = Req;
type Resp = Resp;
async fn call(
&self,
_: context::Context,
_: &'static str,
request: Self::Req,
) -> Result<Resp, RpcError> {
self.responses
.get(&request)
.cloned()
.map(Ok)
.unwrap_or_else(|| {
Err(RpcError::Server(ServerError {
kind: io::ErrorKind::NotFound,
detail: "mock (request, response) entry not found".into(),
}))
})
}
}

View File

@@ -0,0 +1,56 @@
//! Provides a stub that retries requests based on response contents..
use crate::{
client::{stub, RpcError},
context,
};
use std::sync::Arc;
impl<Stub, Req, F> stub::Stub for Retry<F, Stub>
where
Stub: stub::Stub<Req = Arc<Req>>,
F: Fn(&Result<Stub::Resp, RpcError>, u32) -> bool,
{
type Req = Req;
type Resp = Stub::Resp;
async fn call(
&self,
ctx: context::Context,
request_name: &'static str,
request: Self::Req,
) -> Result<Stub::Resp, RpcError> {
let request = Arc::new(request);
for i in 1.. {
let result = self
.stub
.call(ctx, request_name, Arc::clone(&request))
.await;
if (self.should_retry)(&result, i) {
tracing::trace!("Retrying on attempt {i}");
continue;
}
return result;
}
unreachable!("Wow, that was a lot of attempts!");
}
}
/// A Stub that retries requests based on response contents.
/// Note: to use this stub with Serde serialization, the "rc" feature of Serde needs to be enabled.
#[derive(Clone, Debug)]
pub struct Retry<F, Stub> {
should_retry: F,
stub: Stub,
}
impl<Stub, Req, F> Retry<F, Stub>
where
Stub: stub::Stub<Req = Arc<Req>>,
F: Fn(&Result<Stub::Resp, RpcError>, u32) -> bool,
{
/// Creates a new Retry stub that delegates calls to the underlying `stub`.
pub fn new(stub: Stub, should_retry: F) -> Self {
Self { stub, should_retry }
}
}

152
tarpc/src/context.rs Normal file
View File

@@ -0,0 +1,152 @@
// Copyright 2018 Google LLC
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
//! Provides a request context that carries a deadline and trace context. This context is sent from
//! client to server and is used by the server to enforce response deadlines.
use crate::trace::{self, TraceId};
use opentelemetry::trace::TraceContextExt;
use static_assertions::assert_impl_all;
use std::{
convert::TryFrom,
time::{Duration, SystemTime},
};
use tracing_opentelemetry::OpenTelemetrySpanExt;
/// A request context that carries request-scoped information like deadlines and trace information.
/// It is sent from client to server and is used by the server to enforce response deadlines.
///
/// The context should not be stored directly in a server implementation, because the context will
/// be different for each request in scope.
#[derive(Clone, Copy, Debug)]
#[non_exhaustive]
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
pub struct Context {
/// When the client expects the request to be complete by. The server should cancel the request
/// if it is not complete by this time.
#[cfg_attr(feature = "serde1", serde(default = "ten_seconds_from_now"))]
// Serialized as a Duration to prevent clock skew issues.
#[cfg_attr(feature = "serde1", serde(with = "absolute_to_relative_time"))]
pub deadline: SystemTime,
/// Uniquely identifies requests originating from the same source.
/// When a service handles a request by making requests itself, those requests should
/// include the same `trace_id` as that included on the original request. This way,
/// users can trace related actions across a distributed system.
pub trace_context: trace::Context,
}
#[cfg(feature = "serde1")]
mod absolute_to_relative_time {
pub use serde::{Deserialize, Deserializer, Serialize, Serializer};
pub use std::time::{Duration, SystemTime};
pub fn serialize<S>(deadline: &SystemTime, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let deadline = deadline
.duration_since(SystemTime::now())
.unwrap_or(Duration::ZERO);
deadline.serialize(serializer)
}
pub fn deserialize<'de, D>(deserializer: D) -> Result<SystemTime, D::Error>
where
D: Deserializer<'de>,
{
let deadline = Duration::deserialize(deserializer)?;
Ok(SystemTime::now() + deadline)
}
#[cfg(test)]
#[derive(serde::Serialize, serde::Deserialize)]
struct AbsoluteToRelative(#[serde(with = "self")] SystemTime);
#[test]
fn test_serialize() {
let now = SystemTime::now();
let deadline = now + Duration::from_secs(10);
let serialized_deadline = bincode::serialize(&AbsoluteToRelative(deadline)).unwrap();
let deserialized_deadline: Duration = bincode::deserialize(&serialized_deadline).unwrap();
// TODO: how to avoid flakiness?
assert!(deserialized_deadline > Duration::from_secs(9));
}
#[test]
fn test_deserialize() {
let deadline = Duration::from_secs(10);
let serialized_deadline = bincode::serialize(&deadline).unwrap();
let AbsoluteToRelative(deserialized_deadline) =
bincode::deserialize(&serialized_deadline).unwrap();
// TODO: how to avoid flakiness?
assert!(deserialized_deadline > SystemTime::now() + Duration::from_secs(9));
}
}
assert_impl_all!(Context: Send, Sync);
fn ten_seconds_from_now() -> SystemTime {
SystemTime::now() + Duration::from_secs(10)
}
/// Returns the context for the current request, or a default Context if no request is active.
pub fn current() -> Context {
Context::current()
}
#[derive(Clone)]
struct Deadline(SystemTime);
impl Default for Deadline {
fn default() -> Self {
Self(ten_seconds_from_now())
}
}
impl Context {
/// Returns the context for the current request, or a default Context if no request is active.
pub fn current() -> Self {
let span = tracing::Span::current();
Self {
trace_context: trace::Context::try_from(&span)
.unwrap_or_else(|_| trace::Context::default()),
deadline: span
.context()
.get::<Deadline>()
.cloned()
.unwrap_or_default()
.0,
}
}
/// Returns the ID of the request-scoped trace.
pub fn trace_id(&self) -> &TraceId {
&self.trace_context.trace_id
}
}
/// An extension trait for [`tracing::Span`] for propagating tarpc Contexts.
pub(crate) trait SpanExt {
/// Sets the given context on this span. Newly-created spans will be children of the given
/// context's trace context.
fn set_context(&self, context: &Context);
}
impl SpanExt for tracing::Span {
fn set_context(&self, context: &Context) {
self.set_parent(
opentelemetry::Context::new()
.with_remote_span_context(opentelemetry::trace::SpanContext::new(
opentelemetry::trace::TraceId::from(context.trace_context.trace_id),
opentelemetry::trace::SpanId::from(context.trace_context.span_id),
opentelemetry::trace::TraceFlags::from(context.trace_context.sampling_decision),
true,
opentelemetry::trace::TraceState::default(),
))
.with_value(Deadline(context.deadline)),
);
}
}

View File

@@ -3,11 +3,14 @@
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
//! *Disclaimer*: This is not an official Google product.
//!
//! tarpc is an RPC framework for rust with a focus on ease of use. Defining a
//! service can be done in just a few lines of code, and most of the boilerplate of
//! writing a server is taken care of for you.
//!
//! [Documentation](https://docs.rs/crate/tarpc/)
//!
//! ## What is an RPC framework?
//! "RPC" stands for "Remote Procedure Call," a function call where the work of
//! producing the return value is being done somewhere else. When an rpc function is
@@ -21,116 +24,374 @@
//!
//! tarpc differentiates itself from other RPC frameworks by defining the schema in code,
//! rather than in a separate language such as .proto. This means there's no separate compilation
//! process, and no cognitive context switching between different languages. Additionally, it
//! works with the community-backed library serde: any serde-serializable type can be used as
//! arguments to tarpc fns.
//! process, and no context switching between different languages.
//!
//! Some other features of tarpc:
//! - Pluggable transport: any type implementing `Stream<Item = Request> + Sink<Response>` can be
//! used as a transport to connect the client and server.
//! - `Send + 'static` optional: if the transport doesn't require it, neither does tarpc!
//! - Cascading cancellation: dropping a request will send a cancellation message to the server.
//! The server will cease any unfinished work on the request, subsequently cancelling any of its
//! own requests, repeating for the entire chain of transitive dependencies.
//! - Configurable deadlines and deadline propagation: request deadlines default to 10s if
//! unspecified. The server will automatically cease work when the deadline has passed. Any
//! requests sent by the server that use the request context will propagate the request deadline.
//! For example, if a server is handling a request with a 10s deadline, does 2s of work, then
//! sends a request to another server, that server will see an 8s deadline.
//! - Distributed tracing: tarpc is instrumented with
//! [tracing](https://github.com/tokio-rs/tracing) primitives extended with
//! [OpenTelemetry](https://opentelemetry.io/) traces. Using a compatible tracing subscriber like
//! [Jaeger](https://github.com/open-telemetry/opentelemetry-rust/tree/main/opentelemetry-jaeger),
//! each RPC can be traced through the client, server, and other dependencies downstream of the
//! server. Even for applications not connected to a distributed tracing collector, the
//! instrumentation can also be ingested by regular loggers like
//! [env_logger](https://github.com/env-logger-rs/env_logger/).
//! - Serde serialization: enabling the `serde1` Cargo feature will make service requests and
//! responses `Serialize + Deserialize`. It's entirely optional, though: in-memory transports can
//! be used, as well, so the price of serialization doesn't have to be paid when it's not needed.
//!
//! ## Usage
//! Add to your `Cargo.toml` dependencies:
//!
//! ```toml
//! tarpc = "0.29"
//! ```
//!
//! The `tarpc::service` attribute expands to a collection of items that form an rpc service.
//! These generated types make it easy and ergonomic to write servers with less boilerplate.
//! Simply implement the generated service trait, and you're off to the races!
//!
//! ## Example
//!
//! Here's a small service.
//! This example uses [tokio](https://tokio.rs), so add the following dependencies to
//! your `Cargo.toml`:
//!
//! ```toml
//! anyhow = "1.0"
//! futures = "0.3"
//! tarpc = { version = "0.29", features = ["tokio1"] }
//! tokio = { version = "1.0", features = ["macros"] }
//! ```
//!
//! In the following example, we use an in-process channel for communication between
//! client and server. In real code, you will likely communicate over the network.
//! For a more real-world example, see [example-service](example-service).
//!
//! First, let's set up the dependencies and service definition.
//!
//! ```rust
//! #![feature(futures_api, pin, arbitrary_self_types, await_macro, async_await, proc_macro_hygiene)]
//!
//! # extern crate futures;
//!
//! use futures::{
//! compat::TokioDefaultSpawner,
//! future::{self, Ready},
//! prelude::*,
//! };
//! use tarpc::{
//! client, context,
//! server::{self, Handler, Server},
//! server::{self, incoming::Incoming, Channel},
//! };
//! use std::io;
//!
//! // This is the service definition. It looks a lot like a trait definition.
//! // It defines one RPC, hello, which takes one arg, name, and returns a String.
//! tarpc::service! {
//! #[tarpc::service]
//! trait World {
//! /// Returns a greeting for name.
//! rpc hello(name: String) -> String;
//! async fn hello(name: String) -> String;
//! }
//! ```
//!
//! // This is the type that implements the generated Service trait. It is the business logic
//! This service definition generates a trait called `World`. Next we need to
//! implement it for our Server struct.
//!
//! ```rust
//! # extern crate futures;
//! # use futures::{
//! # future::{self, Ready},
//! # prelude::*,
//! # };
//! # use tarpc::{
//! # client, context,
//! # server::{self, incoming::Incoming},
//! # };
//! # // This is the service definition. It looks a lot like a trait definition.
//! # // It defines one RPC, hello, which takes one arg, name, and returns a String.
//! # #[tarpc::service]
//! # trait World {
//! # /// Returns a greeting for name.
//! # async fn hello(name: String) -> String;
//! # }
//! // This is the type that implements the generated World trait. It is the business logic
//! // and is used to start the server.
//! #[derive(Clone)]
//! struct HelloServer;
//!
//! impl Service for HelloServer {
//! // Each defined rpc generates two items in the trait, a fn that serves the RPC, and
//! // an associated type representing the future output by the fn.
//!
//! type HelloFut = Ready<String>;
//!
//! fn hello(&self, _: context::Context, name: String) -> Self::HelloFut {
//! future::ready(format!("Hello, {}!", name))
//! impl World for HelloServer {
//! // Each defined rpc generates an async fn that serves the RPC
//! async fn hello(self, _: context::Context, name: String) -> String {
//! format!("Hello, {name}!")
//! }
//! }
//! ```
//!
//! async fn run() -> io::Result<()> {
//! // bincode_transport is provided by the associated crate bincode-transport. It makes it easy
//! // to start up a serde-powered bincode serialization strategy over TCP.
//! let transport = bincode_transport::listen(&"0.0.0.0:0".parse().unwrap())?;
//! let addr = transport.local_addr();
//! Lastly let's write our `main` that will start the server. While this example uses an
//! [in-process channel](transport::channel), tarpc also ships a generic [`serde_transport`]
//! behind the `serde-transport` feature, with additional [TCP](serde_transport::tcp) functionality
//! available behind the `tcp` feature.
//!
//! // The server is configured with the defaults.
//! let server = Server::new(server::Config::default())
//! // Server can listen on any type that implements the Transport trait.
//! .incoming(transport)
//! // Close the stream after the client connects
//! .take(1)
//! // serve is generated by the service! macro. It takes as input any type implementing
//! // the generated Service trait.
//! .respond_with(serve(HelloServer));
//! ```rust
//! # extern crate futures;
//! # use futures::{
//! # future::{self, Ready},
//! # prelude::*,
//! # };
//! # use tarpc::{
//! # client, context,
//! # server::{self, Channel},
//! # };
//! # // This is the service definition. It looks a lot like a trait definition.
//! # // It defines one RPC, hello, which takes one arg, name, and returns a String.
//! # #[tarpc::service]
//! # trait World {
//! # /// Returns a greeting for name.
//! # async fn hello(name: String) -> String;
//! # }
//! # // This is the type that implements the generated World trait. It is the business logic
//! # // and is used to start the server.
//! # #[derive(Clone)]
//! # struct HelloServer;
//! # impl World for HelloServer {
//! // Each defined rpc generates an async fn that serves the RPC
//! # async fn hello(self, _: context::Context, name: String) -> String {
//! # format!("Hello, {name}!")
//! # }
//! # }
//! # #[cfg(not(feature = "tokio1"))]
//! # fn main() {}
//! # #[cfg(feature = "tokio1")]
//! #[tokio::main]
//! async fn main() -> anyhow::Result<()> {
//! let (client_transport, server_transport) = tarpc::transport::channel::unbounded();
//!
//! tokio_executor::spawn(server.unit_error().boxed().compat());
//! let server = server::BaseChannel::with_defaults(server_transport);
//! tokio::spawn(
//! server.execute(HelloServer.serve())
//! // Handle all requests concurrently.
//! .for_each(|response| async move {
//! tokio::spawn(response);
//! }));
//!
//! let transport = await!(bincode_transport::connect(&addr))?;
//! // WorldClient is generated by the #[tarpc::service] attribute. It has a constructor `new`
//! // that takes a config and any Transport as input.
//! let mut client = WorldClient::new(client::Config::default(), client_transport).spawn();
//!
//! // new_stub is generated by the service! macro. Like Server, it takes a config and any
//! // Transport as input, and returns a Client, also generated by the macro.
//! // by the service mcro.
//! let mut client = await!(new_stub(client::Config::default(), transport))?;
//!
//! // The client has an RPC method for each RPC defined in service!. It takes the same args
//! // as defined, with the addition of a Context, which is always the first arg. The Context
//! // The client has an RPC method for each RPC defined in the annotated trait. It takes the same
//! // args as defined, with the addition of a Context, which is always the first arg. The Context
//! // specifies a deadline and trace information which can be helpful in debugging requests.
//! let hello = await!(client.hello(context::current(), "Stim".to_string()))?;
//! let hello = client.hello(context::current(), "Stim".to_string()).await?;
//!
//! println!("{}", hello);
//! println!("{hello}");
//!
//! Ok(())
//! }
//!
//! fn main() {
//! tarpc::init(TokioDefaultSpawner);
//! tokio::run(run()
//! .map_err(|e| eprintln!("Oh no: {}", e))
//! .boxed()
//! .compat(),
//! );
//! }
//! ```
//!
//! ## Service Documentation
//!
//! Use `cargo doc` as you normally would to see the documentation created for all
//! items expanded by a `service!` invocation.
#![deny(missing_docs, missing_debug_implementations)]
#![feature(
futures_api,
pin,
await_macro,
async_await,
decl_macro,
)]
#![cfg_attr(test, feature(proc_macro_hygiene, arbitrary_self_types))]
#![deny(missing_docs)]
#![allow(clippy::type_complexity)]
#![cfg_attr(docsrs, feature(doc_cfg))]
#[doc(hidden)]
pub use futures;
pub use rpc::*;
#[cfg(feature = "serde")]
#[cfg(feature = "serde1")]
#[doc(hidden)]
pub use serde;
#[doc(hidden)]
pub use tarpc_plugins::*;
/// Provides the macro used for constructing rpc services and client stubs.
#[macro_use]
mod macros;
#[cfg(feature = "serde-transport")]
pub use {tokio_serde, tokio_util};
#[cfg(feature = "serde-transport")]
#[cfg_attr(docsrs, doc(cfg(feature = "serde-transport")))]
pub mod serde_transport;
pub mod trace;
#[cfg(feature = "serde1")]
pub use tarpc_plugins::derive_serde;
/// The main macro that creates RPC services.
///
/// Rpc methods are specified, mirroring trait syntax:
///
/// ```
/// #[tarpc::service]
/// trait Service {
/// /// Say hello
/// async fn hello(name: String) -> String;
/// }
/// ```
///
/// Attributes can be attached to each rpc. These attributes
/// will then be attached to the generated service traits'
/// corresponding `fn`s, as well as to the client stubs' RPCs.
///
/// The following items are expanded in the enclosing module:
///
/// * `trait Service` -- defines the RPC service.
/// * `fn serve` -- turns a service impl into a request handler.
/// * `Client` -- a client stub with a fn for each RPC.
/// * `fn new_stub` -- creates a new Client stub.
pub use tarpc_plugins::service;
pub(crate) mod cancellations;
pub mod client;
pub mod context;
pub mod server;
pub mod transport;
pub(crate) mod util;
pub use crate::transport::sealed::Transport;
use anyhow::Context as _;
use futures::task::*;
use std::sync::Arc;
use std::{error::Error, fmt::Display, io, time::SystemTime};
/// A message from a client to a server.
#[derive(Debug)]
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
#[non_exhaustive]
pub enum ClientMessage<T> {
/// A request initiated by a user. The server responds to a request by invoking a
/// service-provided request handler. The handler completes with a [`response`](Response), which
/// the server sends back to the client.
Request(Request<T>),
/// A command to cancel an in-flight request, automatically sent by the client when a response
/// future is dropped.
///
/// When received, the server will immediately cancel the main task (top-level future) of the
/// request handler for the associated request. Any tasks spawned by the request handler will
/// not be canceled, because the framework layer does not
/// know about them.
Cancel {
/// The trace context associates the message with a specific chain of causally-related actions,
/// possibly orchestrated across many distributed systems.
#[cfg_attr(feature = "serde1", serde(default))]
trace_context: trace::Context,
/// The ID of the request to cancel.
request_id: u64,
},
}
/// A request from a client to a server.
#[derive(Clone, Copy, Debug)]
#[non_exhaustive]
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
pub struct Request<T> {
/// Trace context, deadline, and other cross-cutting concerns.
pub context: context::Context,
/// Uniquely identifies the request across all requests sent over a single channel.
pub id: u64,
/// The request body.
pub message: T,
}
/// A response from a server to a client.
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
#[non_exhaustive]
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
pub struct Response<T> {
/// The ID of the request being responded to.
pub request_id: u64,
/// The response body, or an error if the request failed.
pub message: Result<T, ServerError>,
}
/// An error indicating the server aborted the request early, e.g., due to request throttling.
#[derive(thiserror::Error, Clone, Debug, PartialEq, Eq, Hash)]
#[error("{kind:?}: {detail}")]
#[non_exhaustive]
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
pub struct ServerError {
#[cfg_attr(
feature = "serde1",
serde(serialize_with = "util::serde::serialize_io_error_kind_as_u32")
)]
#[cfg_attr(
feature = "serde1",
serde(deserialize_with = "util::serde::deserialize_io_error_kind_from_u32")
)]
/// The type of error that occurred to fail the request.
pub kind: io::ErrorKind,
/// A message describing more detail about the error that occurred.
pub detail: String,
}
/// Critical errors that result in a Channel disconnecting.
#[derive(thiserror::Error, Debug, PartialEq, Eq)]
pub enum ChannelError<E>
where
E: Error + Send + Sync + 'static,
{
/// Could not read from the transport.
#[error("could not read from the transport")]
Read(#[source] Arc<E>),
/// Could not ready the transport for writes.
#[error("could not ready the transport for writes")]
Ready(#[source] E),
/// Could not write to the transport.
#[error("could not write to the transport")]
Write(#[source] E),
/// Could not flush the transport.
#[error("could not flush the transport")]
Flush(#[source] E),
/// Could not close the write end of the transport.
#[error("could not close the write end of the transport")]
Close(#[source] E),
}
impl ServerError {
/// Returns a new server error with `kind` and `detail`.
pub fn new(kind: io::ErrorKind, detail: String) -> ServerError {
Self { kind, detail }
}
}
impl<T> Request<T> {
/// Returns the deadline for this request.
pub fn deadline(&self) -> &SystemTime {
&self.context.deadline
}
}
pub(crate) trait PollContext<T> {
fn context<C>(self, context: C) -> Poll<Option<anyhow::Result<T>>>
where
C: Display + Send + Sync + 'static;
fn with_context<C, F>(self, f: F) -> Poll<Option<anyhow::Result<T>>>
where
C: Display + Send + Sync + 'static,
F: FnOnce() -> C;
}
impl<T, E> PollContext<T> for Poll<Option<Result<T, E>>>
where
E: Error + Send + Sync + 'static,
{
fn context<C>(self, context: C) -> Poll<Option<anyhow::Result<T>>>
where
C: Display + Send + Sync + 'static,
{
self.map(|o| o.map(|r| r.context(context)))
}
fn with_context<C, F>(self, f: F) -> Poll<Option<anyhow::Result<T>>>
where
C: Display + Send + Sync + 'static,
F: FnOnce() -> C,
{
self.map(|o| o.map(|r| r.with_context(f)))
}
}

View File

@@ -1,364 +0,0 @@
// Copyright 2018 Google LLC
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
#[cfg(feature = "serde")]
#[doc(hidden)]
#[macro_export]
macro_rules! add_serde_if_enabled {
($(#[$attr:meta])* -- $i:item) => {
$(#[$attr])*
#[derive($crate::serde::Serialize, $crate::serde::Deserialize)]
$i
}
}
#[cfg(not(feature = "serde"))]
#[doc(hidden)]
#[macro_export]
macro_rules! add_serde_if_enabled {
($(#[$attr:meta])* -- $i:item) => {
$(#[$attr])*
$i
}
}
/// The main macro that creates RPC services.
///
/// Rpc methods are specified, mirroring trait syntax:
///
/// ```
/// # #![feature(await_macro, pin, arbitrary_self_types, async_await, futures_api, proc_macro_hygiene)]
/// # fn main() {}
/// # tarpc::service! {
/// /// Say hello
/// rpc hello(name: String) -> String;
/// # }
/// ```
///
/// Attributes can be attached to each rpc. These attributes
/// will then be attached to the generated service traits'
/// corresponding `fn`s, as well as to the client stubs' RPCs.
///
/// The following items are expanded in the enclosing module:
///
/// * `trait Service` -- defines the RPC service.
/// * `fn serve` -- turns a service impl into a request handler.
/// * `Client` -- a client stub with a fn for each RPC.
/// * `fn new_stub` -- creates a new Client stub.
///
#[macro_export]
macro_rules! service {
// Entry point
(
$(
$(#[$attr:meta])*
rpc $fn_name:ident( $( $arg:ident : $in_:ty ),* ) $(-> $out:ty)*;
)*
) => {
$crate::service! {{
$(
$(#[$attr])*
rpc $fn_name( $( $arg : $in_ ),* ) $(-> $out)*;
)*
}}
};
// Pattern for when the next rpc has an implicit unit return type.
(
{
$(#[$attr:meta])*
rpc $fn_name:ident( $( $arg:ident : $in_:ty ),* );
$( $unexpanded:tt )*
}
$( $expanded:tt )*
) => {
$crate::service! {
{ $( $unexpanded )* }
$( $expanded )*
$(#[$attr])*
rpc $fn_name( $( $arg : $in_ ),* ) -> ();
}
};
// Pattern for when the next rpc has an explicit return type.
(
{
$(#[$attr:meta])*
rpc $fn_name:ident( $( $arg:ident : $in_:ty ),* ) -> $out:ty;
$( $unexpanded:tt )*
}
$( $expanded:tt )*
) => {
$crate::service! {
{ $( $unexpanded )* }
$( $expanded )*
$(#[$attr])*
rpc $fn_name( $( $arg : $in_ ),* ) -> $out;
}
};
// Pattern for when all return types have been expanded
(
{ } // none left to expand
$(
$(#[$attr:meta])*
rpc $fn_name:ident ( $( $arg:ident : $in_:ty ),* ) -> $out:ty;
)*
) => {
$crate::add_serde_if_enabled! {
#[derive(Debug)]
#[doc(hidden)]
#[allow(non_camel_case_types, unused)]
--
pub enum Request__ {
$(
$fn_name{ $($arg: $in_,)* }
),*
}
}
$crate::add_serde_if_enabled! {
#[derive(Debug)]
#[doc(hidden)]
#[allow(non_camel_case_types, unused)]
--
pub enum Response__ {
$(
$fn_name($out)
),*
}
}
// TODO: proc_macro can't currently parse $crate, so this needs to be imported for the
// usage of snake_to_camel! to work.
use $crate::futures::Future as Future__;
/// Defines the RPC service. The additional trait bounds are required so that services can
/// multiplex requests across multiple tasks, potentially on multiple threads.
pub trait Service: Clone + Send + 'static {
$(
$crate::snake_to_camel! {
/// The type of future returned by `{}`.
type $fn_name: Future__<Output = $out> + Send;
}
$(#[$attr])*
fn $fn_name(&self, ctx: $crate::context::Context, $($arg:$in_),*) -> $crate::ty_snake_to_camel!(Self::$fn_name);
)*
}
// TODO: use an existential type instead of this when existential types work.
#[allow(non_camel_case_types)]
pub enum Response<S: Service> {
$(
$fn_name($crate::ty_snake_to_camel!(<S as Service>::$fn_name)),
)*
}
impl<S: Service> ::std::fmt::Debug for Response<S> {
fn fmt(&self, fmt: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
fmt.debug_struct("Response").finish()
}
}
impl<S: Service> ::std::future::Future for Response<S> {
type Output = ::std::io::Result<Response__>;
fn poll(self: ::std::pin::Pin<&mut Self>, waker: &::std::task::LocalWaker)
-> ::std::task::Poll<::std::io::Result<Response__>>
{
unsafe {
match ::std::pin::Pin::get_mut_unchecked(self) {
$(
Response::$fn_name(resp) =>
::std::pin::Pin::new_unchecked(resp)
.poll(waker)
.map(Response__::$fn_name)
.map(Ok),
)*
}
}
}
}
/// Returns a serving function to use with rpc::server::Server.
pub fn serve<S: Service>(service: S)
-> impl FnMut($crate::context::Context, Request__) -> Response<S> + Send + 'static + Clone {
move |ctx, req| {
match req {
$(
Request__::$fn_name{ $($arg,)* } => {
let resp = Service::$fn_name(&mut service.clone(), ctx, $($arg),*);
Response::$fn_name(resp)
}
)*
}
}
}
#[allow(unused)]
#[derive(Clone, Debug)]
/// The client stub that makes RPC calls to the server. Exposes a Future interface.
pub struct Client($crate::client::Client<Request__, Response__>);
/// Returns a new client stub that sends requests over the given transport.
pub async fn new_stub<T>(config: $crate::client::Config, transport: T)
-> ::std::io::Result<Client>
where
T: $crate::Transport<
Item = $crate::Response<Response__>,
SinkItem = $crate::ClientMessage<Request__>> + Send,
{
Ok(Client(await!($crate::client::Client::new(config, transport))?))
}
impl Client {
$(
#[allow(unused)]
$(#[$attr])*
pub fn $fn_name(&mut self, ctx: $crate::context::Context, $($arg: $in_),*)
-> impl ::std::future::Future<Output = ::std::io::Result<$out>> + '_ {
let request__ = Request__::$fn_name { $($arg,)* };
let resp = self.0.call(ctx, request__);
async move {
match await!(resp)? {
Response__::$fn_name(msg__) => ::std::result::Result::Ok(msg__),
_ => unreachable!(),
}
}
}
)*
}
}
}
// allow dead code; we're just testing that the macro expansion compiles
#[allow(dead_code)]
#[cfg(test)]
mod syntax_test {
service! {
#[deny(warnings)]
#[allow(non_snake_case)]
rpc TestCamelCaseDoesntConflict();
rpc hello() -> String;
#[doc="attr"]
rpc attr(s: String) -> String;
rpc no_args_no_return();
rpc no_args() -> ();
rpc one_arg(foo: String) -> i32;
rpc two_args_no_return(bar: String, baz: u64);
rpc two_args(bar: String, baz: u64) -> String;
rpc no_args_ret_error() -> i32;
rpc one_arg_ret_error(foo: String) -> String;
rpc no_arg_implicit_return_error();
#[doc="attr"]
rpc one_arg_implicit_return_error(foo: String);
}
}
#[cfg(test)]
mod functional_test {
use futures::{
compat::TokioDefaultSpawner,
future::{ready, Ready},
prelude::*,
};
use rpc::{
client, context,
server::{self, Handler},
transport::channel,
};
use std::io;
use tokio::runtime::current_thread;
service! {
rpc add(x: i32, y: i32) -> i32;
rpc hey(name: String) -> String;
}
#[derive(Clone)]
struct Server;
impl Service for Server {
type AddFut = Ready<i32>;
fn add(&self, _: context::Context, x: i32, y: i32) -> Self::AddFut {
ready(x + y)
}
type HeyFut = Ready<String>;
fn hey(&self, _: context::Context, name: String) -> Self::HeyFut {
ready(format!("Hey, {}.", name))
}
}
#[test]
fn sequential() {
let _ = env_logger::try_init();
rpc::init(TokioDefaultSpawner);
let test = async {
let (tx, rx) = channel::unbounded();
tokio_executor::spawn(
rpc::Server::new(server::Config::default())
.incoming(stream::once(ready(Ok(rx))))
.respond_with(serve(Server))
.unit_error()
.boxed()
.compat()
);
let mut client = await!(new_stub(client::Config::default(), tx))?;
assert_eq!(3, await!(client.add(context::current(), 1, 2))?);
assert_eq!(
"Hey, Tim.",
await!(client.hey(context::current(), "Tim".to_string()))?
);
Ok::<_, io::Error>(())
}
.map_err(|e| panic!(e.to_string()));
current_thread::block_on_all(test.boxed().compat()).unwrap();
}
#[test]
fn concurrent() {
let _ = env_logger::try_init();
rpc::init(TokioDefaultSpawner);
let test = async {
let (tx, rx) = channel::unbounded();
tokio_executor::spawn(
rpc::Server::new(server::Config::default())
.incoming(stream::once(ready(Ok(rx))))
.respond_with(serve(Server))
.unit_error()
.boxed()
.compat()
);
let client = await!(new_stub(client::Config::default(), tx))?;
let mut c = client.clone();
let req1 = c.add(context::current(), 1, 2);
let mut c = client.clone();
let req2 = c.add(context::current(), 3, 4);
let mut c = client.clone();
let req3 = c.hey(context::current(), "Tim".to_string());
assert_eq!(3, await!(req1)?);
assert_eq!(7, await!(req2)?);
assert_eq!("Hey, Tim.", await!(req3)?);
Ok::<_, io::Error>(())
}
.map_err(|e| panic!("test failed: {}", e));
current_thread::block_on_all(test.boxed().compat()).unwrap();
}
}

View File

@@ -0,0 +1,672 @@
// Copyright 2019 Google LLC
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
//! A generic Serde-based `Transport` that can serialize anything supported by `tokio-serde` via any medium that implements `AsyncRead` and `AsyncWrite`.
#![deny(missing_docs)]
use futures::{prelude::*, task::*};
use pin_project::pin_project;
use serde::{Deserialize, Serialize};
use std::{error::Error, io, pin::Pin};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_serde::{Framed as SerdeFramed, *};
use tokio_util::codec::{length_delimited::LengthDelimitedCodec, Framed};
/// A transport that serializes to, and deserializes from, a byte stream.
#[pin_project]
pub struct Transport<S, Item, SinkItem, Codec> {
#[pin]
inner: SerdeFramed<Framed<S, LengthDelimitedCodec>, Item, SinkItem, Codec>,
}
impl<S, Item, SinkItem, Codec> Transport<S, Item, SinkItem, Codec> {
/// Returns the inner transport over which messages are sent and received.
pub fn get_ref(&self) -> &S {
self.inner.get_ref().get_ref()
}
}
impl<S, Item, SinkItem, Codec, CodecError> Stream for Transport<S, Item, SinkItem, Codec>
where
S: AsyncWrite + AsyncRead,
Item: for<'a> Deserialize<'a>,
Codec: Deserializer<Item>,
CodecError: Into<Box<dyn std::error::Error + Send + Sync>>,
SerdeFramed<Framed<S, LengthDelimitedCodec>, Item, SinkItem, Codec>:
Stream<Item = Result<Item, CodecError>>,
{
type Item = io::Result<Item>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<io::Result<Item>>> {
self.project()
.inner
.poll_next(cx)
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))
}
}
impl<S, Item, SinkItem, Codec, CodecError> Sink<SinkItem> for Transport<S, Item, SinkItem, Codec>
where
S: AsyncWrite,
SinkItem: Serialize,
Codec: Serializer<SinkItem>,
CodecError: Into<Box<dyn Error + Send + Sync>>,
SerdeFramed<Framed<S, LengthDelimitedCodec>, Item, SinkItem, Codec>:
Sink<SinkItem, Error = CodecError>,
{
type Error = io::Error;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.project()
.inner
.poll_ready(cx)
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))
}
fn start_send(self: Pin<&mut Self>, item: SinkItem) -> io::Result<()> {
self.project()
.inner
.start_send(item)
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.project()
.inner
.poll_flush(cx)
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.project()
.inner
.poll_close(cx)
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))
}
}
/// Constructs a new transport from a framed transport and a serialization codec.
pub fn new<S, Item, SinkItem, Codec>(
framed_io: Framed<S, LengthDelimitedCodec>,
codec: Codec,
) -> Transport<S, Item, SinkItem, Codec>
where
S: AsyncWrite + AsyncRead,
Item: for<'de> Deserialize<'de>,
SinkItem: Serialize,
Codec: Serializer<SinkItem> + Deserializer<Item>,
{
Transport {
inner: SerdeFramed::new(framed_io, codec),
}
}
impl<S, Item, SinkItem, Codec> From<(S, Codec)> for Transport<S, Item, SinkItem, Codec>
where
S: AsyncWrite + AsyncRead,
Item: for<'de> Deserialize<'de>,
SinkItem: Serialize,
Codec: Serializer<SinkItem> + Deserializer<Item>,
{
fn from((io, codec): (S, Codec)) -> Self {
new(Framed::new(io, LengthDelimitedCodec::new()), codec)
}
}
#[cfg(feature = "tcp")]
#[cfg_attr(docsrs, doc(cfg(feature = "tcp")))]
/// TCP support for generic transport using Tokio.
pub mod tcp {
use {
super::*,
futures::ready,
std::{marker::PhantomData, net::SocketAddr},
tokio::net::{TcpListener, TcpStream, ToSocketAddrs},
tokio_util::codec::length_delimited,
};
impl<Item, SinkItem, Codec> Transport<TcpStream, Item, SinkItem, Codec> {
/// Returns the peer address of the underlying TcpStream.
pub fn peer_addr(&self) -> io::Result<SocketAddr> {
self.inner.get_ref().get_ref().peer_addr()
}
/// Returns the local address of the underlying TcpStream.
pub fn local_addr(&self) -> io::Result<SocketAddr> {
self.inner.get_ref().get_ref().local_addr()
}
}
/// A connection Future that also exposes the length-delimited framing config.
#[must_use]
#[pin_project]
pub struct Connect<T, Item, SinkItem, CodecFn> {
#[pin]
inner: T,
codec_fn: CodecFn,
config: length_delimited::Builder,
ghost: PhantomData<(fn(SinkItem), fn() -> Item)>,
}
impl<T, Item, SinkItem, Codec, CodecFn> Future for Connect<T, Item, SinkItem, CodecFn>
where
T: Future<Output = io::Result<TcpStream>>,
Item: for<'de> Deserialize<'de>,
SinkItem: Serialize,
Codec: Serializer<SinkItem> + Deserializer<Item>,
CodecFn: Fn() -> Codec,
{
type Output = io::Result<Transport<TcpStream, Item, SinkItem, Codec>>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let io = ready!(self.as_mut().project().inner.poll(cx))?;
Poll::Ready(Ok(new(self.config.new_framed(io), (self.codec_fn)())))
}
}
impl<T, Item, SinkItem, CodecFn> Connect<T, Item, SinkItem, CodecFn> {
/// Returns an immutable reference to the length-delimited codec's config.
pub fn config(&self) -> &length_delimited::Builder {
&self.config
}
/// Returns a mutable reference to the length-delimited codec's config.
pub fn config_mut(&mut self) -> &mut length_delimited::Builder {
&mut self.config
}
}
/// Connects to `addr`, wrapping the connection in a TCP transport.
pub fn connect<A, Item, SinkItem, Codec, CodecFn>(
addr: A,
codec_fn: CodecFn,
) -> Connect<impl Future<Output = io::Result<TcpStream>>, Item, SinkItem, CodecFn>
where
A: ToSocketAddrs,
Item: for<'de> Deserialize<'de>,
SinkItem: Serialize,
Codec: Serializer<SinkItem> + Deserializer<Item>,
CodecFn: Fn() -> Codec,
{
Connect {
inner: TcpStream::connect(addr),
codec_fn,
config: LengthDelimitedCodec::builder(),
ghost: PhantomData,
}
}
/// Listens on `addr`, wrapping accepted connections in TCP transports.
pub async fn listen<A, Item, SinkItem, Codec, CodecFn>(
addr: A,
codec_fn: CodecFn,
) -> io::Result<Incoming<Item, SinkItem, Codec, CodecFn>>
where
A: ToSocketAddrs,
Item: for<'de> Deserialize<'de>,
Codec: Serializer<SinkItem> + Deserializer<Item>,
CodecFn: Fn() -> Codec,
{
let listener = TcpListener::bind(addr).await?;
let local_addr = listener.local_addr()?;
Ok(Incoming {
listener,
codec_fn,
local_addr,
config: LengthDelimitedCodec::builder(),
ghost: PhantomData,
})
}
/// A [`TcpListener`] that wraps connections in [transports](Transport).
#[pin_project]
#[derive(Debug)]
pub struct Incoming<Item, SinkItem, Codec, CodecFn> {
listener: TcpListener,
local_addr: SocketAddr,
codec_fn: CodecFn,
config: length_delimited::Builder,
ghost: PhantomData<(fn() -> Item, fn(SinkItem), Codec)>,
}
impl<Item, SinkItem, Codec, CodecFn> Incoming<Item, SinkItem, Codec, CodecFn> {
/// Returns the address being listened on.
pub fn local_addr(&self) -> SocketAddr {
self.local_addr
}
/// Returns an immutable reference to the length-delimited codec's config.
pub fn config(&self) -> &length_delimited::Builder {
&self.config
}
/// Returns a mutable reference to the length-delimited codec's config.
pub fn config_mut(&mut self) -> &mut length_delimited::Builder {
&mut self.config
}
}
impl<Item, SinkItem, Codec, CodecFn> Stream for Incoming<Item, SinkItem, Codec, CodecFn>
where
Item: for<'de> Deserialize<'de>,
SinkItem: Serialize,
Codec: Serializer<SinkItem> + Deserializer<Item>,
CodecFn: Fn() -> Codec,
{
type Item = io::Result<Transport<TcpStream, Item, SinkItem, Codec>>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let conn: TcpStream =
ready!(Pin::new(&mut self.as_mut().project().listener).poll_accept(cx)?).0;
Poll::Ready(Some(Ok(new(
self.config.new_framed(conn),
(self.codec_fn)(),
))))
}
}
}
#[cfg(all(unix, feature = "unix"))]
#[cfg_attr(docsrs, doc(cfg(all(unix, feature = "unix"))))]
/// Unix Domain Socket support for generic transport using Tokio.
pub mod unix {
use {
super::*,
futures::ready,
std::{marker::PhantomData, path::Path},
tokio::net::{unix::SocketAddr, UnixListener, UnixStream},
tokio_util::codec::length_delimited,
};
impl<Item, SinkItem, Codec> Transport<UnixStream, Item, SinkItem, Codec> {
/// Returns the socket address of the remote half of the underlying [`UnixStream`].
pub fn peer_addr(&self) -> io::Result<SocketAddr> {
self.inner.get_ref().get_ref().peer_addr()
}
/// Returns the socket address of the local half of the underlying [`UnixStream`].
pub fn local_addr(&self) -> io::Result<SocketAddr> {
self.inner.get_ref().get_ref().local_addr()
}
}
/// A connection Future that also exposes the length-delimited framing config.
#[must_use]
#[pin_project]
pub struct Connect<T, Item, SinkItem, CodecFn> {
#[pin]
inner: T,
codec_fn: CodecFn,
config: length_delimited::Builder,
ghost: PhantomData<(fn(SinkItem), fn() -> Item)>,
}
impl<T, Item, SinkItem, Codec, CodecFn> Future for Connect<T, Item, SinkItem, CodecFn>
where
T: Future<Output = io::Result<UnixStream>>,
Item: for<'de> Deserialize<'de>,
SinkItem: Serialize,
Codec: Serializer<SinkItem> + Deserializer<Item>,
CodecFn: Fn() -> Codec,
{
type Output = io::Result<Transport<UnixStream, Item, SinkItem, Codec>>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let io = ready!(self.as_mut().project().inner.poll(cx))?;
Poll::Ready(Ok(new(self.config.new_framed(io), (self.codec_fn)())))
}
}
impl<T, Item, SinkItem, CodecFn> Connect<T, Item, SinkItem, CodecFn> {
/// Returns an immutable reference to the length-delimited codec's config.
pub fn config(&self) -> &length_delimited::Builder {
&self.config
}
/// Returns a mutable reference to the length-delimited codec's config.
pub fn config_mut(&mut self) -> &mut length_delimited::Builder {
&mut self.config
}
}
/// Connects to socket named by `path`, wrapping the connection in a Unix Domain Socket
/// transport.
pub fn connect<P, Item, SinkItem, Codec, CodecFn>(
path: P,
codec_fn: CodecFn,
) -> Connect<impl Future<Output = io::Result<UnixStream>>, Item, SinkItem, CodecFn>
where
P: AsRef<Path>,
Item: for<'de> Deserialize<'de>,
SinkItem: Serialize,
Codec: Serializer<SinkItem> + Deserializer<Item>,
CodecFn: Fn() -> Codec,
{
Connect {
inner: UnixStream::connect(path),
codec_fn,
config: LengthDelimitedCodec::builder(),
ghost: PhantomData,
}
}
/// Listens on the socket named by `path`, wrapping accepted connections in Unix Domain Socket
/// transports.
pub async fn listen<P, Item, SinkItem, Codec, CodecFn>(
path: P,
codec_fn: CodecFn,
) -> io::Result<Incoming<Item, SinkItem, Codec, CodecFn>>
where
P: AsRef<Path>,
Item: for<'de> Deserialize<'de>,
Codec: Serializer<SinkItem> + Deserializer<Item>,
CodecFn: Fn() -> Codec,
{
let listener = UnixListener::bind(path)?;
let local_addr = listener.local_addr()?;
Ok(Incoming {
listener,
codec_fn,
local_addr,
config: LengthDelimitedCodec::builder(),
ghost: PhantomData,
})
}
/// A [`UnixListener`] that wraps connections in [transports](Transport).
#[pin_project]
#[derive(Debug)]
pub struct Incoming<Item, SinkItem, Codec, CodecFn> {
listener: UnixListener,
local_addr: SocketAddr,
codec_fn: CodecFn,
config: length_delimited::Builder,
ghost: PhantomData<(fn() -> Item, fn(SinkItem), Codec)>,
}
impl<Item, SinkItem, Codec, CodecFn> Incoming<Item, SinkItem, Codec, CodecFn> {
/// Returns the the socket address being listened on.
pub fn local_addr(&self) -> &SocketAddr {
&self.local_addr
}
/// Returns an immutable reference to the length-delimited codec's config.
pub fn config(&self) -> &length_delimited::Builder {
&self.config
}
/// Returns a mutable reference to the length-delimited codec's config.
pub fn config_mut(&mut self) -> &mut length_delimited::Builder {
&mut self.config
}
}
impl<Item, SinkItem, Codec, CodecFn> Stream for Incoming<Item, SinkItem, Codec, CodecFn>
where
Item: for<'de> Deserialize<'de>,
SinkItem: Serialize,
Codec: Serializer<SinkItem> + Deserializer<Item>,
CodecFn: Fn() -> Codec,
{
type Item = io::Result<Transport<UnixStream, Item, SinkItem, Codec>>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let conn: UnixStream = ready!(self.as_mut().project().listener.poll_accept(cx)?).0;
Poll::Ready(Some(Ok(new(
self.config.new_framed(conn),
(self.codec_fn)(),
))))
}
}
/// A temporary `PathBuf` that lives in `std::env::temp_dir` and is removed on drop.
pub struct TempPathBuf(std::path::PathBuf);
impl TempPathBuf {
/// A named socket that results in `<tempdir>/<name>`
pub fn new<S: AsRef<str>>(name: S) -> Self {
let mut sock = std::env::temp_dir();
sock.push(name.as_ref());
Self(sock)
}
/// Appends a random hex string to the socket name resulting in
/// `<tempdir>/<name>_<xxxxx>`
pub fn with_random<S: AsRef<str>>(name: S) -> Self {
Self::new(format!("{}_{:016x}", name.as_ref(), rand::random::<u64>()))
}
}
impl AsRef<std::path::Path> for TempPathBuf {
fn as_ref(&self) -> &std::path::Path {
self.0.as_path()
}
}
impl Drop for TempPathBuf {
fn drop(&mut self) {
// This will remove the file pointed to by this PathBuf if it exists, however Err's can
// be returned such as attempting to remove a non-existing file, or one which we don't
// have permission to remove. In these cases the Err is swallowed
let _ = std::fs::remove_file(&self.0);
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokio_serde::formats::SymmetricalJson;
#[test]
fn temp_path_buf_non_random() {
let sock = TempPathBuf::new("test");
let mut good = std::env::temp_dir();
good.push("test");
assert_eq!(sock.as_ref(), good);
assert_eq!(sock.as_ref().file_name().unwrap(), "test");
}
#[test]
fn temp_path_buf_random() {
let sock = TempPathBuf::with_random("test");
let good = std::env::temp_dir();
assert!(sock.as_ref().starts_with(good));
// Since there are 16 random characters we just assert the file_name has the right name
// and starts with the correct string 'test_'
// file name: test_xxxxxxxxxxxxxxxx
// test = 4
// _ = 1
// <hex> = 16
// total = 21
let fname = sock.as_ref().file_name().unwrap().to_string_lossy();
assert!(fname.starts_with("test_"));
assert_eq!(fname.len(), 21);
}
#[test]
fn temp_path_buf_non_existing() {
let sock = TempPathBuf::with_random("test");
let sock_path = std::path::PathBuf::from(sock.as_ref());
// No actual file has been created yet
assert!(!sock_path.exists());
// Should not panic
std::mem::drop(sock);
assert!(!sock_path.exists());
}
#[test]
fn temp_path_buf_existing_file() {
let sock = TempPathBuf::with_random("test");
let sock_path = std::path::PathBuf::from(sock.as_ref());
let _file = std::fs::File::create(&sock).unwrap();
assert!(sock_path.exists());
std::mem::drop(sock);
assert!(!sock_path.exists());
}
#[test]
fn temp_path_buf_preexisting_file() {
let mut pre_existing = std::env::temp_dir();
pre_existing.push("test");
let _file = std::fs::File::create(&pre_existing).unwrap();
let sock = TempPathBuf::new("test");
let sock_path = std::path::PathBuf::from(sock.as_ref());
assert!(sock_path.exists());
std::mem::drop(sock);
assert!(!sock_path.exists());
}
#[tokio::test]
async fn temp_path_buf_for_socket() {
let sock = TempPathBuf::with_random("test");
// Save path for testing after drop
let sock_path = std::path::PathBuf::from(sock.as_ref());
// create the actual socket
let _ = listen(&sock, SymmetricalJson::<String>::default).await;
assert!(sock_path.exists());
std::mem::drop(sock);
assert!(!sock_path.exists());
}
}
}
#[cfg(test)]
mod tests {
use super::Transport;
use assert_matches::assert_matches;
use futures::{task::*, Sink, Stream};
use pin_utils::pin_mut;
use std::{
io::{self, Cursor},
pin::Pin,
};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio_serde::formats::SymmetricalJson;
fn ctx() -> Context<'static> {
Context::from_waker(noop_waker_ref())
}
struct TestIo(Cursor<Vec<u8>>);
impl AsyncRead for TestIo {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
AsyncRead::poll_read(Pin::new(&mut self.0), cx, buf)
}
}
impl AsyncWrite for TestIo {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
AsyncWrite::poll_write(Pin::new(&mut self.0), cx, buf)
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
AsyncWrite::poll_flush(Pin::new(&mut self.0), cx)
}
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
AsyncWrite::poll_shutdown(Pin::new(&mut self.0), cx)
}
}
#[test]
fn close() {
let (tx, _rx) = crate::transport::channel::bounded::<(), ()>(0);
pin_mut!(tx);
assert_matches!(tx.as_mut().poll_close(&mut ctx()), Poll::Ready(Ok(())));
assert_matches!(tx.as_mut().start_send(()), Err(_));
}
#[test]
fn test_stream() {
let data: &[u8] = b"\x00\x00\x00\x18\"Test one, check check.\"";
let transport = Transport::from((
TestIo(Cursor::new(Vec::from(data))),
SymmetricalJson::<String>::default(),
));
pin_mut!(transport);
assert_matches!(
transport.as_mut().poll_next(&mut ctx()),
Poll::Ready(Some(Ok(ref s))) if s == "Test one, check check.");
assert_matches!(transport.as_mut().poll_next(&mut ctx()), Poll::Ready(None));
}
#[test]
fn test_sink() {
let writer = Cursor::new(vec![]);
let mut transport = Box::pin(Transport::from((
TestIo(writer),
SymmetricalJson::<String>::default(),
)));
assert_matches!(
transport.as_mut().poll_ready(&mut ctx()),
Poll::Ready(Ok(()))
);
assert_matches!(
transport
.as_mut()
.start_send("Test one, check check.".into()),
Ok(())
);
assert_matches!(
transport.as_mut().poll_flush(&mut ctx()),
Poll::Ready(Ok(()))
);
assert_eq!(
transport.get_ref().0.get_ref(),
b"\x00\x00\x00\x18\"Test one, check check.\""
);
}
#[cfg(tcp)]
#[tokio::test]
async fn tcp() -> io::Result<()> {
use super::tcp;
let mut listener = tcp::listen("0.0.0.0:0", SymmetricalJson::<String>::default).await?;
let addr = listener.local_addr();
tokio::spawn(async move {
let mut transport = listener.next().await.unwrap().unwrap();
let message = transport.next().await.unwrap().unwrap();
transport.send(message).await.unwrap();
});
let mut transport = tcp::connect(addr, SymmetricalJson::<String>::default).await?;
transport.send(String::from("test")).await?;
assert_matches!(transport.next().await, Some(Ok(s)) if s == "test");
assert_matches!(transport.next().await, None);
Ok(())
}
#[cfg(all(unix, feature = "unix"))]
#[tokio::test]
async fn uds() -> io::Result<()> {
use super::unix;
use super::*;
let sock = unix::TempPathBuf::with_random("uds");
let mut listener = unix::listen(&sock, SymmetricalJson::<String>::default).await?;
tokio::spawn(async move {
let mut transport = listener.next().await.unwrap().unwrap();
let message = transport.next().await.unwrap().unwrap();
transport.send(message).await.unwrap();
});
let mut transport = unix::connect(&sock, SymmetricalJson::<String>::default).await?;
transport.send(String::from("test")).await?;
assert_matches!(transport.next().await, Some(Ok(s)) if s == "test");
assert_matches!(transport.next().await, None);
Ok(())
}
}

1592
tarpc/src/server.rs Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,221 @@
use crate::util::{Compact, TimeUntil};
use fnv::FnvHashMap;
use futures::future::{AbortHandle, AbortRegistration};
use std::{
collections::hash_map,
task::{Context, Poll},
time::SystemTime,
};
use tokio_util::time::delay_queue::{self, DelayQueue};
use tracing::Span;
/// A data structure that tracks in-flight requests. It aborts requests,
/// either on demand or when a request deadline expires.
#[derive(Debug, Default)]
pub struct InFlightRequests {
request_data: FnvHashMap<u64, RequestData>,
deadlines: DelayQueue<u64>,
}
/// Data needed to clean up a single in-flight request.
#[derive(Debug)]
struct RequestData {
/// Aborts the response handler for the associated request.
abort_handle: AbortHandle,
/// The key to remove the timer for the request's deadline.
deadline_key: delay_queue::Key,
/// The client span.
span: Span,
}
/// An error returned when a request attempted to start with the same ID as a request already
/// in flight.
#[derive(Debug)]
pub struct AlreadyExistsError;
impl InFlightRequests {
/// Returns the number of in-flight requests.
pub fn len(&self) -> usize {
self.request_data.len()
}
/// Starts a request, unless a request with the same ID is already in flight.
pub fn start_request(
&mut self,
request_id: u64,
deadline: SystemTime,
span: Span,
) -> Result<AbortRegistration, AlreadyExistsError> {
match self.request_data.entry(request_id) {
hash_map::Entry::Vacant(vacant) => {
let timeout = deadline.time_until();
let (abort_handle, abort_registration) = AbortHandle::new_pair();
let deadline_key = self.deadlines.insert(request_id, timeout);
vacant.insert(RequestData {
abort_handle,
deadline_key,
span,
});
Ok(abort_registration)
}
hash_map::Entry::Occupied(_) => Err(AlreadyExistsError),
}
}
/// Cancels an in-flight request. Returns true iff the request was found.
pub fn cancel_request(&mut self, request_id: u64) -> bool {
if let Some(RequestData {
span,
abort_handle,
deadline_key,
}) = self.request_data.remove(&request_id)
{
let _entered = span.enter();
self.request_data.compact(0.1);
abort_handle.abort();
self.deadlines.remove(&deadline_key);
tracing::info!("ReceiveCancel");
true
} else {
false
}
}
/// Removes a request without aborting. Returns true iff the request was found.
/// This method should be used when a response is being sent.
pub fn remove_request(&mut self, request_id: u64) -> Option<Span> {
if let Some(request_data) = self.request_data.remove(&request_id) {
self.request_data.compact(0.1);
self.deadlines.remove(&request_data.deadline_key);
Some(request_data.span)
} else {
None
}
}
/// Yields a request that has expired, aborting any ongoing processing of that request.
pub fn poll_expired(&mut self, cx: &mut Context) -> Poll<Option<u64>> {
if self.deadlines.is_empty() {
// TODO(https://github.com/tokio-rs/tokio/issues/4161)
// This is a workaround for DelayQueue not always treating this case correctly.
return Poll::Ready(None);
}
self.deadlines.poll_expired(cx).map(|expired| {
let expired = expired?;
if let Some(RequestData {
abort_handle, span, ..
}) = self.request_data.remove(expired.get_ref())
{
let _entered = span.enter();
self.request_data.compact(0.1);
abort_handle.abort();
tracing::error!("DeadlineExceeded");
}
Some(expired.into_inner())
})
}
}
/// When InFlightRequests is dropped, any outstanding requests are aborted.
impl Drop for InFlightRequests {
fn drop(&mut self) {
self.request_data
.values()
.for_each(|request_data| request_data.abort_handle.abort())
}
}
#[cfg(test)]
mod tests {
use super::*;
use assert_matches::assert_matches;
use futures::{
future::{pending, Abortable},
FutureExt,
};
use futures_test::task::noop_context;
#[tokio::test]
async fn start_request_increases_len() {
let mut in_flight_requests = InFlightRequests::default();
assert_eq!(in_flight_requests.len(), 0);
in_flight_requests
.start_request(0, SystemTime::now(), Span::current())
.unwrap();
assert_eq!(in_flight_requests.len(), 1);
}
#[tokio::test]
async fn polling_expired_aborts() {
let mut in_flight_requests = InFlightRequests::default();
let abort_registration = in_flight_requests
.start_request(0, SystemTime::now(), Span::current())
.unwrap();
let mut abortable_future = Box::new(Abortable::new(pending::<()>(), abort_registration));
tokio::time::pause();
tokio::time::advance(std::time::Duration::from_secs(1000)).await;
assert_matches!(
in_flight_requests.poll_expired(&mut noop_context()),
Poll::Ready(Some(_))
);
assert_matches!(
abortable_future.poll_unpin(&mut noop_context()),
Poll::Ready(Err(_))
);
assert_eq!(in_flight_requests.len(), 0);
}
#[tokio::test]
async fn cancel_request_aborts() {
let mut in_flight_requests = InFlightRequests::default();
let abort_registration = in_flight_requests
.start_request(0, SystemTime::now(), Span::current())
.unwrap();
let mut abortable_future = Box::new(Abortable::new(pending::<()>(), abort_registration));
assert!(in_flight_requests.cancel_request(0));
assert_matches!(
abortable_future.poll_unpin(&mut noop_context()),
Poll::Ready(Err(_))
);
assert_eq!(in_flight_requests.len(), 0);
}
#[tokio::test]
async fn remove_request_doesnt_abort() {
let mut in_flight_requests = InFlightRequests::default();
assert!(in_flight_requests.deadlines.is_empty());
let abort_registration = in_flight_requests
.start_request(
0,
SystemTime::now() + std::time::Duration::from_secs(10),
Span::current(),
)
.unwrap();
let mut abortable_future = Box::new(Abortable::new(pending::<()>(), abort_registration));
// Precondition: Pending expiration
assert_matches!(
in_flight_requests.poll_expired(&mut noop_context()),
Poll::Pending
);
assert!(!in_flight_requests.deadlines.is_empty());
assert_matches!(in_flight_requests.remove_request(0), Some(_));
// Postcondition: No pending expirations
assert!(in_flight_requests.deadlines.is_empty());
assert_matches!(
in_flight_requests.poll_expired(&mut noop_context()),
Poll::Ready(None)
);
assert_matches!(
abortable_future.poll_unpin(&mut noop_context()),
Poll::Pending
);
assert_eq!(in_flight_requests.len(), 0);
}
}

View File

@@ -0,0 +1,92 @@
use super::{
limits::{channels_per_key::MaxChannelsPerKey, requests_per_channel::MaxRequestsPerChannel},
Channel, Serve,
};
use futures::prelude::*;
use std::{fmt, hash::Hash};
/// An extension trait for [streams](futures::prelude::Stream) of [`Channels`](Channel).
pub trait Incoming<C>
where
Self: Sized + Stream<Item = C>,
C: Channel,
{
/// Enforces channel per-key limits.
fn max_channels_per_key<K, KF>(self, n: u32, keymaker: KF) -> MaxChannelsPerKey<Self, K, KF>
where
K: fmt::Display + Eq + Hash + Clone + Unpin,
KF: Fn(&C) -> K,
{
MaxChannelsPerKey::new(self, n, keymaker)
}
/// Caps the number of concurrent requests per channel.
fn max_concurrent_requests_per_channel(self, n: usize) -> MaxRequestsPerChannel<Self> {
MaxRequestsPerChannel::new(self, n)
}
/// Returns a stream of channels in execution. Each channel in execution is a stream of
/// futures, where each future is an in-flight request being rsponded to.
fn execute<S>(
self,
serve: S,
) -> impl Stream<Item = impl Stream<Item = impl Future<Output = ()>>>
where
S: Serve<Req = C::Req, Resp = C::Resp> + Clone,
{
self.map(move |channel| channel.execute(serve.clone()))
}
}
#[cfg(feature = "tokio1")]
/// Spawns all channels-in-execution, delegating to the tokio runtime to manage their completion.
/// Each channel is spawned, and each request from each channel is spawned.
/// Note that this function is generic over any stream-of-streams-of-futures, but it is intended
/// for spawning streams of channels.
///
/// # Example
/// ```rust
/// use tarpc::{
/// context,
/// client::{self, NewClient},
/// server::{self, BaseChannel, Channel, incoming::{Incoming, spawn_incoming}, serve},
/// transport,
/// };
/// use futures::prelude::*;
///
/// #[tokio::main]
/// async fn main() {
/// let (tx, rx) = transport::channel::unbounded();
/// let NewClient { client, dispatch } = client::new(client::Config::default(), tx);
/// tokio::spawn(dispatch);
///
/// let incoming = stream::once(async move {
/// BaseChannel::new(server::Config::default(), rx)
/// }).execute(serve(|_, i| async move { Ok(i + 1) }));
/// tokio::spawn(spawn_incoming(incoming));
/// assert_eq!(client.call(context::current(), "AddOne", 1).await.unwrap(), 2);
/// }
/// ```
pub async fn spawn_incoming(
incoming: impl Stream<
Item = impl Stream<Item = impl Future<Output = ()> + Send + 'static> + Send + 'static,
>,
) {
use futures::pin_mut;
pin_mut!(incoming);
while let Some(channel) = incoming.next().await {
tokio::spawn(async move {
pin_mut!(channel);
while let Some(request) = channel.next().await {
tokio::spawn(request);
}
});
}
}
impl<S, C> Incoming<C> for S
where
S: Sized + Stream<Item = C>,
C: Channel,
{
}

View File

@@ -0,0 +1,5 @@
/// Provides functionality to limit the number of active channels.
pub mod channels_per_key;
/// Provides a [channel](crate::server::Channel) that limits the number of in-flight requests.
pub mod requests_per_channel;

View File

@@ -0,0 +1,480 @@
// Copyright 2018 Google LLC
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
use crate::{
server::{self, Channel},
util::Compact,
};
use fnv::FnvHashMap;
use futures::{prelude::*, ready, stream::Fuse, task::*};
use pin_project::pin_project;
use std::sync::{Arc, Weak};
use std::{
collections::hash_map::Entry, convert::TryFrom, fmt, hash::Hash, marker::Unpin, pin::Pin,
};
use tokio::sync::mpsc;
use tracing::{debug, info, trace};
/// An [`Incoming`](crate::server::incoming::Incoming) stream that drops new channels based on
/// per-key limits.
///
/// The decision to drop a Channel is made once at the time the Channel materializes. Once a
/// Channel is yielded, it will not be prematurely dropped.
#[pin_project]
#[derive(Debug)]
pub struct MaxChannelsPerKey<S, K, F>
where
K: Eq + Hash,
{
#[pin]
listener: Fuse<S>,
channels_per_key: u32,
dropped_keys: mpsc::UnboundedReceiver<K>,
dropped_keys_tx: mpsc::UnboundedSender<K>,
key_counts: FnvHashMap<K, Weak<Tracker<K>>>,
keymaker: F,
}
/// A channel that is tracked by [`MaxChannelsPerKey`].
#[pin_project]
#[derive(Debug)]
pub struct TrackedChannel<C, K> {
#[pin]
inner: C,
tracker: Arc<Tracker<K>>,
}
#[derive(Debug)]
struct Tracker<K> {
key: Option<K>,
dropped_keys: mpsc::UnboundedSender<K>,
}
impl<K> Drop for Tracker<K> {
fn drop(&mut self) {
// Don't care if the listener is dropped.
let _ = self.dropped_keys.send(self.key.take().unwrap());
}
}
impl<C, K> Stream for TrackedChannel<C, K>
where
C: Stream,
{
type Item = <C as Stream>::Item;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
self.inner_pin_mut().poll_next(cx)
}
}
impl<C, I, K> Sink<I> for TrackedChannel<C, K>
where
C: Sink<I>,
{
type Error = C::Error;
fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.inner_pin_mut().poll_ready(cx)
}
fn start_send(mut self: Pin<&mut Self>, item: I) -> Result<(), Self::Error> {
self.inner_pin_mut().start_send(item)
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.inner_pin_mut().poll_flush(cx)
}
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.inner_pin_mut().poll_close(cx)
}
}
impl<C, K> AsRef<C> for TrackedChannel<C, K> {
fn as_ref(&self) -> &C {
&self.inner
}
}
impl<C, K> Channel for TrackedChannel<C, K>
where
C: Channel,
{
type Req = C::Req;
type Resp = C::Resp;
type Transport = C::Transport;
fn config(&self) -> &server::Config {
self.inner.config()
}
fn in_flight_requests(&self) -> usize {
self.inner.in_flight_requests()
}
fn transport(&self) -> &Self::Transport {
self.inner.transport()
}
}
impl<C, K> TrackedChannel<C, K> {
/// Returns the inner channel.
pub fn get_ref(&self) -> &C {
&self.inner
}
/// Returns the pinned inner channel.
fn inner_pin_mut<'a>(self: &'a mut Pin<&mut Self>) -> Pin<&'a mut C> {
self.as_mut().project().inner
}
}
impl<S, K, F> MaxChannelsPerKey<S, K, F>
where
K: Eq + Hash,
S: Stream,
F: Fn(&S::Item) -> K,
{
/// Sheds new channels to stay under configured limits.
pub(crate) fn new(listener: S, channels_per_key: u32, keymaker: F) -> Self {
let (dropped_keys_tx, dropped_keys) = mpsc::unbounded_channel();
MaxChannelsPerKey {
listener: listener.fuse(),
channels_per_key,
dropped_keys,
dropped_keys_tx,
key_counts: FnvHashMap::default(),
keymaker,
}
}
}
impl<S, K, F> MaxChannelsPerKey<S, K, F>
where
S: Stream,
K: fmt::Display + Eq + Hash + Clone + Unpin,
F: Fn(&S::Item) -> K,
{
fn listener_pin_mut<'a>(self: &'a mut Pin<&mut Self>) -> Pin<&'a mut Fuse<S>> {
self.as_mut().project().listener
}
fn handle_new_channel(
mut self: Pin<&mut Self>,
stream: S::Item,
) -> Result<TrackedChannel<S::Item, K>, K> {
let key = (self.as_mut().keymaker)(&stream);
let tracker = self.as_mut().increment_channels_for_key(key.clone())?;
trace!(
channel_filter_key = %key,
open_channels = Arc::strong_count(&tracker),
max_open_channels = self.channels_per_key,
"Opening channel");
Ok(TrackedChannel {
tracker,
inner: stream,
})
}
fn increment_channels_for_key(self: Pin<&mut Self>, key: K) -> Result<Arc<Tracker<K>>, K> {
let self_ = self.project();
let dropped_keys = self_.dropped_keys_tx;
match self_.key_counts.entry(key.clone()) {
Entry::Vacant(vacant) => {
let tracker = Arc::new(Tracker {
key: Some(key),
dropped_keys: dropped_keys.clone(),
});
vacant.insert(Arc::downgrade(&tracker));
Ok(tracker)
}
Entry::Occupied(mut o) => {
let count = o.get().strong_count();
if count >= TryFrom::try_from(*self_.channels_per_key).unwrap() {
info!(
channel_filter_key = %key,
open_channels = count,
max_open_channels = *self_.channels_per_key,
"At open channel limit");
Err(key)
} else {
Ok(o.get().upgrade().unwrap_or_else(|| {
let tracker = Arc::new(Tracker {
key: Some(key),
dropped_keys: dropped_keys.clone(),
});
*o.get_mut() = Arc::downgrade(&tracker);
tracker
}))
}
}
}
}
fn poll_listener(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<TrackedChannel<S::Item, K>, K>>> {
match ready!(self.listener_pin_mut().poll_next_unpin(cx)) {
Some(codec) => Poll::Ready(Some(self.handle_new_channel(codec))),
None => Poll::Ready(None),
}
}
fn poll_closed_channels(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
let self_ = self.project();
match ready!(self_.dropped_keys.poll_recv(cx)) {
Some(key) => {
debug!(
channel_filter_key = %key,
"All channels dropped");
self_.key_counts.remove(&key);
self_.key_counts.compact(0.1);
Poll::Ready(())
}
None => unreachable!("Holding a copy of closed_channels and didn't close it."),
}
}
}
impl<S, K, F> Stream for MaxChannelsPerKey<S, K, F>
where
S: Stream,
K: fmt::Display + Eq + Hash + Clone + Unpin,
F: Fn(&S::Item) -> K,
{
type Item = TrackedChannel<S::Item, K>;
fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<TrackedChannel<S::Item, K>>> {
loop {
match (
self.as_mut().poll_listener(cx),
self.as_mut().poll_closed_channels(cx),
) {
(Poll::Ready(Some(Ok(channel))), _) => {
return Poll::Ready(Some(channel));
}
(Poll::Ready(Some(Err(_))), _) => {
continue;
}
(_, Poll::Ready(())) => continue,
(Poll::Pending, Poll::Pending) => return Poll::Pending,
(Poll::Ready(None), Poll::Pending) => {
trace!("Shutting down listener.");
return Poll::Ready(None);
}
}
}
}
}
#[cfg(test)]
fn ctx() -> Context<'static> {
use futures::task::*;
Context::from_waker(noop_waker_ref())
}
#[test]
fn tracker_drop() {
use assert_matches::assert_matches;
let (tx, mut rx) = mpsc::unbounded_channel();
Tracker {
key: Some(1),
dropped_keys: tx,
};
assert_matches!(rx.poll_recv(&mut ctx()), Poll::Ready(Some(1)));
}
#[test]
fn tracked_channel_stream() {
use assert_matches::assert_matches;
use pin_utils::pin_mut;
let (chan_tx, chan) = futures::channel::mpsc::unbounded();
let (dropped_keys, _) = mpsc::unbounded_channel();
let channel = TrackedChannel {
inner: chan,
tracker: Arc::new(Tracker {
key: Some(1),
dropped_keys,
}),
};
chan_tx.unbounded_send("test").unwrap();
pin_mut!(channel);
assert_matches!(channel.poll_next(&mut ctx()), Poll::Ready(Some("test")));
}
#[test]
fn tracked_channel_sink() {
use assert_matches::assert_matches;
use pin_utils::pin_mut;
let (chan, mut chan_rx) = futures::channel::mpsc::unbounded();
let (dropped_keys, _) = mpsc::unbounded_channel();
let channel = TrackedChannel {
inner: chan,
tracker: Arc::new(Tracker {
key: Some(1),
dropped_keys,
}),
};
pin_mut!(channel);
assert_matches!(channel.as_mut().poll_ready(&mut ctx()), Poll::Ready(Ok(())));
assert_matches!(channel.as_mut().start_send("test"), Ok(()));
assert_matches!(channel.as_mut().poll_flush(&mut ctx()), Poll::Ready(Ok(())));
assert_matches!(chan_rx.try_next(), Ok(Some("test")));
}
#[test]
fn channel_filter_increment_channels_for_key() {
use assert_matches::assert_matches;
use pin_utils::pin_mut;
struct TestChannel {
key: &'static str,
}
let (_, listener) = futures::channel::mpsc::unbounded();
let filter = MaxChannelsPerKey::new(listener, 2, |chan: &TestChannel| chan.key);
pin_mut!(filter);
let tracker1 = filter.as_mut().increment_channels_for_key("key").unwrap();
assert_eq!(Arc::strong_count(&tracker1), 1);
let tracker2 = filter.as_mut().increment_channels_for_key("key").unwrap();
assert_eq!(Arc::strong_count(&tracker1), 2);
assert_matches!(filter.increment_channels_for_key("key"), Err("key"));
drop(tracker2);
assert_eq!(Arc::strong_count(&tracker1), 1);
}
#[test]
fn channel_filter_handle_new_channel() {
use assert_matches::assert_matches;
use pin_utils::pin_mut;
#[derive(Debug)]
struct TestChannel {
key: &'static str,
}
let (_, listener) = futures::channel::mpsc::unbounded();
let filter = MaxChannelsPerKey::new(listener, 2, |chan: &TestChannel| chan.key);
pin_mut!(filter);
let channel1 = filter
.as_mut()
.handle_new_channel(TestChannel { key: "key" })
.unwrap();
assert_eq!(Arc::strong_count(&channel1.tracker), 1);
let channel2 = filter
.as_mut()
.handle_new_channel(TestChannel { key: "key" })
.unwrap();
assert_eq!(Arc::strong_count(&channel1.tracker), 2);
assert_matches!(
filter.handle_new_channel(TestChannel { key: "key" }),
Err("key")
);
drop(channel2);
assert_eq!(Arc::strong_count(&channel1.tracker), 1);
}
#[test]
fn channel_filter_poll_listener() {
use assert_matches::assert_matches;
use pin_utils::pin_mut;
#[derive(Debug)]
struct TestChannel {
key: &'static str,
}
let (new_channels, listener) = futures::channel::mpsc::unbounded();
let filter = MaxChannelsPerKey::new(listener, 2, |chan: &TestChannel| chan.key);
pin_mut!(filter);
new_channels
.unbounded_send(TestChannel { key: "key" })
.unwrap();
let channel1 =
assert_matches!(filter.as_mut().poll_listener(&mut ctx()), Poll::Ready(Some(Ok(c))) => c);
assert_eq!(Arc::strong_count(&channel1.tracker), 1);
new_channels
.unbounded_send(TestChannel { key: "key" })
.unwrap();
let _channel2 =
assert_matches!(filter.as_mut().poll_listener(&mut ctx()), Poll::Ready(Some(Ok(c))) => c);
assert_eq!(Arc::strong_count(&channel1.tracker), 2);
new_channels
.unbounded_send(TestChannel { key: "key" })
.unwrap();
let key =
assert_matches!(filter.as_mut().poll_listener(&mut ctx()), Poll::Ready(Some(Err(k))) => k);
assert_eq!(key, "key");
assert_eq!(Arc::strong_count(&channel1.tracker), 2);
}
#[test]
fn channel_filter_poll_closed_channels() {
use assert_matches::assert_matches;
use pin_utils::pin_mut;
#[derive(Debug)]
struct TestChannel {
key: &'static str,
}
let (new_channels, listener) = futures::channel::mpsc::unbounded();
let filter = MaxChannelsPerKey::new(listener, 2, |chan: &TestChannel| chan.key);
pin_mut!(filter);
new_channels
.unbounded_send(TestChannel { key: "key" })
.unwrap();
let channel =
assert_matches!(filter.as_mut().poll_listener(&mut ctx()), Poll::Ready(Some(Ok(c))) => c);
assert_eq!(filter.key_counts.len(), 1);
drop(channel);
assert_matches!(
filter.as_mut().poll_closed_channels(&mut ctx()),
Poll::Ready(())
);
assert!(filter.key_counts.is_empty());
}
#[test]
fn channel_filter_stream() {
use assert_matches::assert_matches;
use pin_utils::pin_mut;
#[derive(Debug)]
struct TestChannel {
key: &'static str,
}
let (new_channels, listener) = futures::channel::mpsc::unbounded();
let filter = MaxChannelsPerKey::new(listener, 2, |chan: &TestChannel| chan.key);
pin_mut!(filter);
new_channels
.unbounded_send(TestChannel { key: "key" })
.unwrap();
let channel = assert_matches!(filter.as_mut().poll_next(&mut ctx()), Poll::Ready(Some(c)) => c);
assert_eq!(filter.key_counts.len(), 1);
drop(channel);
assert_matches!(filter.as_mut().poll_next(&mut ctx()), Poll::Pending);
assert!(filter.key_counts.is_empty());
}

View File

@@ -0,0 +1,349 @@
// Copyright 2020 Google LLC
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
use crate::{
server::{Channel, Config},
Response, ServerError,
};
use futures::{prelude::*, ready, task::*};
use pin_project::pin_project;
use std::{io, pin::Pin};
/// A [`Channel`] that limits the number of concurrent requests by throttling.
///
/// Note that this is a very basic throttling heuristic. It is easy to set a number that is too low
/// for the resources available to the server. For production use cases, a more advanced throttler
/// is likely needed.
#[pin_project]
#[derive(Debug)]
pub struct MaxRequests<C> {
max_in_flight_requests: usize,
#[pin]
inner: C,
}
impl<C> MaxRequests<C> {
/// Returns the inner channel.
pub fn get_ref(&self) -> &C {
&self.inner
}
}
impl<C> MaxRequests<C>
where
C: Channel,
{
/// Returns a new `MaxRequests` that wraps the given channel and limits concurrent requests to
/// `max_in_flight_requests`.
pub fn new(inner: C, max_in_flight_requests: usize) -> Self {
MaxRequests {
max_in_flight_requests,
inner,
}
}
}
impl<C> Stream for MaxRequests<C>
where
C: Channel,
{
type Item = <C as Stream>::Item;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
while self.as_mut().in_flight_requests() >= *self.as_mut().project().max_in_flight_requests
{
ready!(self.as_mut().project().inner.poll_ready(cx)?);
match ready!(self.as_mut().project().inner.poll_next(cx)?) {
Some(r) => {
let _entered = r.span.enter();
tracing::info!(
in_flight_requests = self.as_mut().in_flight_requests(),
"ThrottleRequest",
);
self.as_mut().start_send(Response {
request_id: r.request.id,
message: Err(ServerError {
kind: io::ErrorKind::WouldBlock,
detail: "server throttled the request.".into(),
}),
})?;
}
None => return Poll::Ready(None),
}
}
self.project().inner.poll_next(cx)
}
}
impl<C> Sink<Response<<C as Channel>::Resp>> for MaxRequests<C>
where
C: Channel,
{
type Error = C::Error;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.project().inner.poll_ready(cx)
}
fn start_send(
self: Pin<&mut Self>,
item: Response<<C as Channel>::Resp>,
) -> Result<(), Self::Error> {
self.project().inner.start_send(item)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.project().inner.poll_flush(cx)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.project().inner.poll_close(cx)
}
}
impl<C> AsRef<C> for MaxRequests<C> {
fn as_ref(&self) -> &C {
&self.inner
}
}
impl<C> Channel for MaxRequests<C>
where
C: Channel,
{
type Req = <C as Channel>::Req;
type Resp = <C as Channel>::Resp;
type Transport = <C as Channel>::Transport;
fn in_flight_requests(&self) -> usize {
self.inner.in_flight_requests()
}
fn config(&self) -> &Config {
self.inner.config()
}
fn transport(&self) -> &Self::Transport {
self.inner.transport()
}
}
/// An [`Incoming`](crate::server::incoming::Incoming) stream of channels that enforce limits on
/// the number of in-flight requests.
#[pin_project]
#[derive(Debug)]
pub struct MaxRequestsPerChannel<S> {
#[pin]
inner: S,
max_in_flight_requests: usize,
}
impl<S> MaxRequestsPerChannel<S>
where
S: Stream,
<S as Stream>::Item: Channel,
{
pub(crate) fn new(inner: S, max_in_flight_requests: usize) -> Self {
Self {
inner,
max_in_flight_requests,
}
}
}
impl<S> Stream for MaxRequestsPerChannel<S>
where
S: Stream,
<S as Stream>::Item: Channel,
{
type Item = MaxRequests<<S as Stream>::Item>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
match ready!(self.as_mut().project().inner.poll_next(cx)) {
Some(channel) => Poll::Ready(Some(MaxRequests::new(
channel,
*self.project().max_in_flight_requests,
))),
None => Poll::Ready(None),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::server::{
testing::{self, FakeChannel, PollExt},
TrackedRequest,
};
use pin_utils::pin_mut;
use std::{
marker::PhantomData,
time::{Duration, SystemTime},
};
use tracing::Span;
#[tokio::test]
async fn throttler_in_flight_requests() {
let throttler = MaxRequests {
max_in_flight_requests: 0,
inner: FakeChannel::default::<isize, isize>(),
};
pin_mut!(throttler);
for i in 0..5 {
throttler
.inner
.in_flight_requests
.start_request(
i,
SystemTime::now() + Duration::from_secs(1),
Span::current(),
)
.unwrap();
}
assert_eq!(throttler.as_mut().in_flight_requests(), 5);
}
#[test]
fn throttler_poll_next_done() {
let throttler = MaxRequests {
max_in_flight_requests: 0,
inner: FakeChannel::default::<isize, isize>(),
};
pin_mut!(throttler);
assert!(throttler.as_mut().poll_next(&mut testing::cx()).is_done());
}
#[test]
fn throttler_poll_next_some() -> io::Result<()> {
let throttler = MaxRequests {
max_in_flight_requests: 1,
inner: FakeChannel::default::<isize, isize>(),
};
pin_mut!(throttler);
throttler.inner.push_req(0, 1);
assert!(throttler.as_mut().poll_ready(&mut testing::cx()).is_ready());
assert_eq!(
throttler
.as_mut()
.poll_next(&mut testing::cx())?
.map(|r| r.map(|r| (r.request.id, r.request.message))),
Poll::Ready(Some((0, 1)))
);
Ok(())
}
#[test]
fn throttler_poll_next_throttled() {
let throttler = MaxRequests {
max_in_flight_requests: 0,
inner: FakeChannel::default::<isize, isize>(),
};
pin_mut!(throttler);
throttler.inner.push_req(1, 1);
assert!(throttler.as_mut().poll_next(&mut testing::cx()).is_done());
assert_eq!(throttler.inner.sink.len(), 1);
let resp = throttler.inner.sink.get(0).unwrap();
assert_eq!(resp.request_id, 1);
assert!(resp.message.is_err());
}
#[test]
fn throttler_poll_next_throttled_sink_not_ready() {
let throttler = MaxRequests {
max_in_flight_requests: 0,
inner: PendingSink::default::<isize, isize>(),
};
pin_mut!(throttler);
assert!(throttler.poll_next(&mut testing::cx()).is_pending());
struct PendingSink<In, Out> {
ghost: PhantomData<fn(Out) -> In>,
}
impl PendingSink<(), ()> {
pub fn default<Req, Resp>(
) -> PendingSink<io::Result<TrackedRequest<Req>>, Response<Resp>> {
PendingSink { ghost: PhantomData }
}
}
impl<In, Out> Stream for PendingSink<In, Out> {
type Item = In;
fn poll_next(self: Pin<&mut Self>, _: &mut Context) -> Poll<Option<Self::Item>> {
unimplemented!()
}
}
impl<In, Out> Sink<Out> for PendingSink<In, Out> {
type Error = io::Error;
fn poll_ready(self: Pin<&mut Self>, _: &mut Context) -> Poll<Result<(), Self::Error>> {
Poll::Pending
}
fn start_send(self: Pin<&mut Self>, _: Out) -> Result<(), Self::Error> {
Err(io::Error::from(io::ErrorKind::WouldBlock))
}
fn poll_flush(self: Pin<&mut Self>, _: &mut Context) -> Poll<Result<(), Self::Error>> {
Poll::Pending
}
fn poll_close(self: Pin<&mut Self>, _: &mut Context) -> Poll<Result<(), Self::Error>> {
Poll::Pending
}
}
impl<Req, Resp> Channel for PendingSink<io::Result<TrackedRequest<Req>>, Response<Resp>> {
type Req = Req;
type Resp = Resp;
type Transport = ();
fn config(&self) -> &Config {
unimplemented!()
}
fn in_flight_requests(&self) -> usize {
0
}
fn transport(&self) -> &() {
&()
}
}
}
#[tokio::test]
async fn throttler_start_send() {
let throttler = MaxRequests {
max_in_flight_requests: 0,
inner: FakeChannel::default::<isize, isize>(),
};
pin_mut!(throttler);
throttler
.inner
.in_flight_requests
.start_request(
0,
SystemTime::now() + Duration::from_secs(1),
Span::current(),
)
.unwrap();
throttler
.as_mut()
.start_send(Response {
request_id: 0,
message: Ok(1),
})
.unwrap();
assert_eq!(throttler.inner.in_flight_requests.len(), 0);
assert_eq!(
throttler.inner.sink.get(0),
Some(&Response {
request_id: 0,
message: Ok(1),
})
);
}
}

View File

@@ -0,0 +1,25 @@
// Copyright 2022 Google LLC
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
//! Hooks for horizontal functionality that can run either before or after a request is executed.
/// A request hook that runs before a request is executed.
mod before;
/// A request hook that runs after a request is completed.
mod after;
/// A request hook that runs both before a request is executed and after it is completed.
mod before_and_after;
pub use {
after::{AfterRequest, ServeThenHook},
before::{
before, BeforeRequest, BeforeRequestCons, BeforeRequestList, BeforeRequestNil,
HookThenServe,
},
before_and_after::HookThenServeThenHook,
};

View File

@@ -0,0 +1,72 @@
// Copyright 2022 Google LLC
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
//! Provides a hook that runs after request execution.
use crate::{context, server::Serve, ServerError};
use futures::prelude::*;
/// A hook that runs after request execution.
#[allow(async_fn_in_trait)]
pub trait AfterRequest<Resp> {
/// The function that is called after request execution.
///
/// The hook can modify the request context and the response.
async fn after(&mut self, ctx: &mut context::Context, resp: &mut Result<Resp, ServerError>);
}
impl<F, Fut, Resp> AfterRequest<Resp> for F
where
F: FnMut(&mut context::Context, &mut Result<Resp, ServerError>) -> Fut,
Fut: Future<Output = ()>,
{
async fn after(&mut self, ctx: &mut context::Context, resp: &mut Result<Resp, ServerError>) {
self(ctx, resp).await
}
}
/// A Service function that runs a hook after request execution.
pub struct ServeThenHook<Serv, Hook> {
serve: Serv,
hook: Hook,
}
impl<Serv, Hook> ServeThenHook<Serv, Hook> {
pub(crate) fn new(serve: Serv, hook: Hook) -> Self {
Self { serve, hook }
}
}
impl<Serv: Clone, Hook: Clone> Clone for ServeThenHook<Serv, Hook> {
fn clone(&self) -> Self {
Self {
serve: self.serve.clone(),
hook: self.hook.clone(),
}
}
}
impl<Serv, Hook> Serve for ServeThenHook<Serv, Hook>
where
Serv: Serve,
Hook: AfterRequest<Serv::Resp>,
{
type Req = Serv::Req;
type Resp = Serv::Resp;
async fn serve(
self,
mut ctx: context::Context,
req: Serv::Req,
) -> Result<Serv::Resp, ServerError> {
let ServeThenHook {
serve, mut hook, ..
} = self;
let mut resp = serve.serve(ctx, req).await;
hook.after(&mut ctx, &mut resp).await;
resp
}
}

View File

@@ -0,0 +1,210 @@
// Copyright 2022 Google LLC
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
//! Provides a hook that runs before request execution.
use crate::{context, server::Serve, ServerError};
use futures::prelude::*;
/// A hook that runs before request execution.
#[allow(async_fn_in_trait)]
pub trait BeforeRequest<Req> {
/// The function that is called before request execution.
///
/// If this function returns an error, the request will not be executed and the error will be
/// returned instead.
///
/// This function can also modify the request context. This could be used, for example, to
/// enforce a maximum deadline on all requests.
async fn before(&mut self, ctx: &mut context::Context, req: &Req) -> Result<(), ServerError>;
}
/// A list of hooks that run in order before request execution.
pub trait BeforeRequestList<Req>: BeforeRequest<Req> {
/// The hook returned by `BeforeRequestList::then`.
type Then<Next>: BeforeRequest<Req>
where
Next: BeforeRequest<Req>;
/// Returns a hook that, when run, runs two hooks, first `self` and then `next`.
fn then<Next: BeforeRequest<Req>>(self, next: Next) -> Self::Then<Next>;
/// Same as `then`, but helps the compiler with type inference when Next is a closure.
fn then_fn<
Next: FnMut(&mut context::Context, &Req) -> Fut,
Fut: Future<Output = Result<(), ServerError>>,
>(
self,
next: Next,
) -> Self::Then<Next>
where
Self: Sized,
{
self.then(next)
}
/// The service fn returned by `BeforeRequestList::serving`.
type Serve<S: Serve<Req = Req>>: Serve<Req = Req>;
/// Runs the list of request hooks before execution of the given serve fn.
/// This is equivalent to `serve.before(before_request_chain)` but may be syntactically nicer.
fn serving<S: Serve<Req = Req>>(self, serve: S) -> Self::Serve<S>;
}
impl<F, Fut, Req> BeforeRequest<Req> for F
where
F: FnMut(&mut context::Context, &Req) -> Fut,
Fut: Future<Output = Result<(), ServerError>>,
{
async fn before(&mut self, ctx: &mut context::Context, req: &Req) -> Result<(), ServerError> {
self(ctx, req).await
}
}
/// A Service function that runs a hook before request execution.
#[derive(Clone)]
pub struct HookThenServe<Serv, Hook> {
serve: Serv,
hook: Hook,
}
impl<Serv, Hook> HookThenServe<Serv, Hook> {
pub(crate) fn new(serve: Serv, hook: Hook) -> Self {
Self { serve, hook }
}
}
impl<Serv, Hook> Serve for HookThenServe<Serv, Hook>
where
Serv: Serve,
Hook: BeforeRequest<Serv::Req>,
{
type Req = Serv::Req;
type Resp = Serv::Resp;
async fn serve(
self,
mut ctx: context::Context,
req: Self::Req,
) -> Result<Serv::Resp, ServerError> {
let HookThenServe {
serve, mut hook, ..
} = self;
hook.before(&mut ctx, &req).await?;
serve.serve(ctx, req).await
}
}
/// Returns a request hook builder that runs a series of hooks before request execution.
///
/// Example
///
/// ```rust
/// use futures::{executor::block_on, future};
/// use tarpc::{context, ServerError, server::{Serve, serve, request_hook::{self,
/// BeforeRequest, BeforeRequestList}}};
/// use std::{cell::Cell, io};
///
/// let i = Cell::new(0);
/// let serve = request_hook::before()
/// .then_fn(|_, _| async {
/// assert!(i.get() == 0);
/// i.set(1);
/// Ok(())
/// })
/// .then_fn(|_, _| async {
/// assert!(i.get() == 1);
/// i.set(2);
/// Ok(())
/// })
/// .serving(serve(|_ctx, i| async move { Ok(i + 1) }));
/// let response = serve.clone().serve(context::current(), 1);
/// assert!(block_on(response).is_ok());
/// assert!(i.get() == 2);
/// ```
pub fn before() -> BeforeRequestNil {
BeforeRequestNil
}
/// A list of hooks that run in order before a request is executed.
#[derive(Clone, Copy)]
pub struct BeforeRequestCons<First, Rest>(First, Rest);
/// A noop hook that runs before a request is executed.
#[derive(Clone, Copy)]
pub struct BeforeRequestNil;
impl<Req, First: BeforeRequest<Req>, Rest: BeforeRequest<Req>> BeforeRequest<Req>
for BeforeRequestCons<First, Rest>
{
async fn before(&mut self, ctx: &mut context::Context, req: &Req) -> Result<(), ServerError> {
let BeforeRequestCons(first, rest) = self;
first.before(ctx, req).await?;
rest.before(ctx, req).await?;
Ok(())
}
}
impl<Req> BeforeRequest<Req> for BeforeRequestNil {
async fn before(&mut self, _: &mut context::Context, _: &Req) -> Result<(), ServerError> {
Ok(())
}
}
impl<Req, First: BeforeRequest<Req>, Rest: BeforeRequestList<Req>> BeforeRequestList<Req>
for BeforeRequestCons<First, Rest>
{
type Then<Next> = BeforeRequestCons<First, Rest::Then<Next>> where Next: BeforeRequest<Req>;
fn then<Next: BeforeRequest<Req>>(self, next: Next) -> Self::Then<Next> {
let BeforeRequestCons(first, rest) = self;
BeforeRequestCons(first, rest.then(next))
}
type Serve<S: Serve<Req = Req>> = HookThenServe<S, Self>;
fn serving<S: Serve<Req = Req>>(self, serve: S) -> Self::Serve<S> {
HookThenServe::new(serve, self)
}
}
impl<Req> BeforeRequestList<Req> for BeforeRequestNil {
type Then<Next> = BeforeRequestCons<Next, BeforeRequestNil> where Next: BeforeRequest<Req>;
fn then<Next: BeforeRequest<Req>>(self, next: Next) -> Self::Then<Next> {
BeforeRequestCons(next, BeforeRequestNil)
}
type Serve<S: Serve<Req = Req>> = S;
fn serving<S: Serve<Req = Req>>(self, serve: S) -> S {
serve
}
}
#[test]
fn before_request_list() {
use crate::server::serve;
use futures::executor::block_on;
use std::cell::Cell;
let i = Cell::new(0);
let serve = before()
.then_fn(|_, _| async {
assert!(i.get() == 0);
i.set(1);
Ok(())
})
.then_fn(|_, _| async {
assert!(i.get() == 1);
i.set(2);
Ok(())
})
.serving(serve(|_ctx, i| async move { Ok(i + 1) }));
let response = serve.clone().serve(context::current(), 1);
assert!(block_on(response).is_ok());
assert!(i.get() == 2);
}

View File

@@ -0,0 +1,57 @@
// Copyright 2022 Google LLC
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
//! Provides a hook that runs both before and after request execution.
use super::{after::AfterRequest, before::BeforeRequest};
use crate::{context, server::Serve, ServerError};
use std::marker::PhantomData;
/// A Service function that runs a hook both before and after request execution.
pub struct HookThenServeThenHook<Req, Resp, Serv, Hook> {
serve: Serv,
hook: Hook,
fns: PhantomData<(fn(Req), fn(Resp))>,
}
impl<Req, Resp, Serv, Hook> HookThenServeThenHook<Req, Resp, Serv, Hook> {
pub(crate) fn new(serve: Serv, hook: Hook) -> Self {
Self {
serve,
hook,
fns: PhantomData,
}
}
}
impl<Req, Resp, Serv: Clone, Hook: Clone> Clone for HookThenServeThenHook<Req, Resp, Serv, Hook> {
fn clone(&self) -> Self {
Self {
serve: self.serve.clone(),
hook: self.hook.clone(),
fns: PhantomData,
}
}
}
impl<Req, Resp, Serv, Hook> Serve for HookThenServeThenHook<Req, Resp, Serv, Hook>
where
Serv: Serve<Req = Req, Resp = Resp>,
Hook: BeforeRequest<Req> + AfterRequest<Resp>,
{
type Req = Req;
type Resp = Resp;
async fn serve(self, mut ctx: context::Context, req: Req) -> Result<Serv::Resp, ServerError> {
let HookThenServeThenHook {
serve, mut hook, ..
} = self;
hook.before(&mut ctx, &req).await?;
let mut resp = serve.serve(ctx, req).await;
hook.after(&mut ctx, &mut resp).await;
resp
}
}

139
tarpc/src/server/testing.rs Normal file
View File

@@ -0,0 +1,139 @@
// Copyright 2020 Google LLC
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
use crate::{
cancellations::{cancellations, CanceledRequests, RequestCancellation},
context,
server::{Channel, Config, ResponseGuard, TrackedRequest},
Request, Response,
};
use futures::{task::*, Sink, Stream};
use pin_project::pin_project;
use std::{collections::VecDeque, io, pin::Pin, time::SystemTime};
use tracing::Span;
#[pin_project]
pub(crate) struct FakeChannel<In, Out> {
#[pin]
pub stream: VecDeque<In>,
#[pin]
pub sink: VecDeque<Out>,
pub config: Config,
pub in_flight_requests: super::in_flight_requests::InFlightRequests,
pub request_cancellation: RequestCancellation,
pub canceled_requests: CanceledRequests,
}
impl<In, Out> Stream for FakeChannel<In, Out>
where
In: Unpin,
{
type Item = In;
fn poll_next(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<Option<Self::Item>> {
Poll::Ready(self.project().stream.pop_front())
}
}
impl<In, Resp> Sink<Response<Resp>> for FakeChannel<In, Response<Resp>> {
type Error = io::Error;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.project().sink.poll_ready(cx).map_err(|e| match e {})
}
fn start_send(mut self: Pin<&mut Self>, response: Response<Resp>) -> Result<(), Self::Error> {
self.as_mut()
.project()
.in_flight_requests
.remove_request(response.request_id);
self.project()
.sink
.start_send(response)
.map_err(|e| match e {})
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.project().sink.poll_flush(cx).map_err(|e| match e {})
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.project().sink.poll_close(cx).map_err(|e| match e {})
}
}
impl<Req, Resp> Channel for FakeChannel<io::Result<TrackedRequest<Req>>, Response<Resp>>
where
Req: Unpin,
{
type Req = Req;
type Resp = Resp;
type Transport = ();
fn config(&self) -> &Config {
&self.config
}
fn in_flight_requests(&self) -> usize {
self.in_flight_requests.len()
}
fn transport(&self) -> &() {
&()
}
}
impl<Req, Resp> FakeChannel<io::Result<TrackedRequest<Req>>, Response<Resp>> {
pub fn push_req(&mut self, id: u64, message: Req) {
let (_, abort_registration) = futures::future::AbortHandle::new_pair();
let (request_cancellation, _) = cancellations();
self.stream.push_back(Ok(TrackedRequest {
request: Request {
context: context::Context {
deadline: SystemTime::UNIX_EPOCH,
trace_context: Default::default(),
},
id,
message,
},
abort_registration,
span: Span::none(),
response_guard: ResponseGuard {
request_cancellation,
request_id: id,
cancel: false,
},
}));
}
}
impl FakeChannel<(), ()> {
pub fn default<Req, Resp>() -> FakeChannel<io::Result<TrackedRequest<Req>>, Response<Resp>> {
let (request_cancellation, canceled_requests) = cancellations();
FakeChannel {
stream: Default::default(),
sink: Default::default(),
config: Default::default(),
in_flight_requests: Default::default(),
request_cancellation,
canceled_requests,
}
}
}
pub trait PollExt {
fn is_done(&self) -> bool;
}
impl<T> PollExt for Poll<Option<T>> {
fn is_done(&self) -> bool {
matches!(self, Poll::Ready(None))
}
}
pub fn cx() -> Context<'static> {
Context::from_waker(noop_waker_ref())
}

261
tarpc/src/trace.rs Normal file
View File

@@ -0,0 +1,261 @@
// Copyright 2018 Google LLC
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
#![deny(missing_docs, missing_debug_implementations)]
//! Provides building blocks for tracing distributed programs.
//!
//! A trace is logically a tree of causally-related events called spans. Traces are tracked via a
//! [context](Context) that identifies the current trace, span, and parent of the current span. In
//! distributed systems, a context can be sent from client to server to connect events occurring on
//! either side.
//!
//! This crate's design is based on [opencensus
//! tracing](https://opencensus.io/core-concepts/tracing/).
use opentelemetry::trace::TraceContextExt;
use rand::Rng;
use std::{
convert::TryFrom,
fmt::{self, Formatter},
num::{NonZeroU128, NonZeroU64},
};
use tracing_opentelemetry::OpenTelemetrySpanExt;
/// A context for tracing the execution of processes, distributed or otherwise.
///
/// Consists of a span identifying an event, an optional parent span identifying a causal event
/// that triggered the current span, and a trace with which all related spans are associated.
#[derive(Debug, Default, PartialEq, Eq, Hash, Clone, Copy)]
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
pub struct Context {
/// An identifier of the trace associated with the current context. A trace ID is typically
/// created at a root span and passed along through all causal events.
pub trace_id: TraceId,
/// An identifier of the current span. In typical RPC usage, a span is created by a client
/// before making an RPC, and the span ID is sent to the server. The server is free to create
/// its own spans, for which it sets the client's span as the parent span.
pub span_id: SpanId,
/// Indicates whether a sampler has already decided whether or not to sample the trace
/// associated with the Context. If `sampling_decision` is None, then a decision has not yet
/// been made. Downstream samplers do not need to abide by "no sample" decisions--for example,
/// an upstream client may choose to never sample, which may not make sense for the client's
/// dependencies. On the other hand, if an upstream process has chosen to sample this trace,
/// then the downstream samplers are expected to respect that decision and also sample the
/// trace. Otherwise, the full trace would not be able to be reconstructed.
pub sampling_decision: SamplingDecision,
}
/// A 128-bit UUID identifying a trace. All spans caused by the same originating span share the
/// same trace ID.
#[derive(Default, PartialEq, Eq, Hash, Clone, Copy)]
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
pub struct TraceId(#[cfg_attr(feature = "serde1", serde(with = "u128_serde"))] u128);
/// A 64-bit identifier of a span within a trace. The identifier is unique within the span's trace.
#[derive(Default, PartialEq, Eq, Hash, Clone, Copy)]
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
pub struct SpanId(u64);
/// Indicates whether a sampler has decided whether or not to sample the trace associated with the
/// Context. Downstream samplers do not need to abide by "no sample" decisions--for example, an
/// upstream client may choose to never sample, which may not make sense for the client's
/// dependencies. On the other hand, if an upstream process has chosen to sample this trace, then
/// the downstream samplers are expected to respect that decision and also sample the trace.
/// Otherwise, the full trace would not be able to be reconstructed reliably.
#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)]
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
#[repr(u8)]
pub enum SamplingDecision {
/// The associated span was sampled by its creating process. Child spans must also be sampled.
Sampled,
/// The associated span was not sampled by its creating process.
Unsampled,
}
impl Context {
/// Constructs a new context with the trace ID and sampling decision inherited from the parent.
pub(crate) fn new_child(&self) -> Self {
Self {
trace_id: self.trace_id,
span_id: SpanId::random(&mut rand::thread_rng()),
sampling_decision: self.sampling_decision,
}
}
}
impl TraceId {
/// Returns a random trace ID that can be assumed to be globally unique if `rng` generates
/// actually-random numbers.
pub fn random<R: Rng>(rng: &mut R) -> Self {
TraceId(rng.gen::<NonZeroU128>().get())
}
/// Returns true iff the trace ID is 0.
pub fn is_none(&self) -> bool {
self.0 == 0
}
}
impl SpanId {
/// Returns a random span ID that can be assumed to be unique within a single trace.
pub fn random<R: Rng>(rng: &mut R) -> Self {
SpanId(rng.gen::<NonZeroU64>().get())
}
/// Returns true iff the span ID is 0.
pub fn is_none(&self) -> bool {
self.0 == 0
}
}
impl From<TraceId> for u128 {
fn from(trace_id: TraceId) -> Self {
trace_id.0
}
}
impl From<u128> for TraceId {
fn from(trace_id: u128) -> Self {
Self(trace_id)
}
}
impl From<SpanId> for u64 {
fn from(span_id: SpanId) -> Self {
span_id.0
}
}
impl From<u64> for SpanId {
fn from(span_id: u64) -> Self {
Self(span_id)
}
}
impl From<opentelemetry::trace::TraceId> for TraceId {
fn from(trace_id: opentelemetry::trace::TraceId) -> Self {
Self::from(u128::from_be_bytes(trace_id.to_bytes()))
}
}
impl From<TraceId> for opentelemetry::trace::TraceId {
fn from(trace_id: TraceId) -> Self {
Self::from_bytes(u128::from(trace_id).to_be_bytes())
}
}
impl From<opentelemetry::trace::SpanId> for SpanId {
fn from(span_id: opentelemetry::trace::SpanId) -> Self {
Self::from(u64::from_be_bytes(span_id.to_bytes()))
}
}
impl From<SpanId> for opentelemetry::trace::SpanId {
fn from(span_id: SpanId) -> Self {
Self::from_bytes(u64::from(span_id).to_be_bytes())
}
}
impl TryFrom<&tracing::Span> for Context {
type Error = NoActiveSpan;
fn try_from(span: &tracing::Span) -> Result<Self, NoActiveSpan> {
let context = span.context();
if context.has_active_span() {
Ok(Self::from(context.span()))
} else {
Err(NoActiveSpan)
}
}
}
impl From<opentelemetry::trace::SpanRef<'_>> for Context {
fn from(span: opentelemetry::trace::SpanRef<'_>) -> Self {
let otel_ctx = span.span_context();
Self {
trace_id: TraceId::from(otel_ctx.trace_id()),
span_id: SpanId::from(otel_ctx.span_id()),
sampling_decision: SamplingDecision::from(otel_ctx),
}
}
}
impl From<SamplingDecision> for opentelemetry::trace::TraceFlags {
fn from(decision: SamplingDecision) -> Self {
match decision {
SamplingDecision::Sampled => opentelemetry::trace::TraceFlags::SAMPLED,
SamplingDecision::Unsampled => opentelemetry::trace::TraceFlags::default(),
}
}
}
impl From<&opentelemetry::trace::SpanContext> for SamplingDecision {
fn from(context: &opentelemetry::trace::SpanContext) -> Self {
if context.is_sampled() {
SamplingDecision::Sampled
} else {
SamplingDecision::Unsampled
}
}
}
impl Default for SamplingDecision {
fn default() -> Self {
Self::Unsampled
}
}
/// Returned when a [`Context`] cannot be constructed from a [`Span`](tracing::Span).
#[derive(Debug)]
pub struct NoActiveSpan;
impl fmt::Display for TraceId {
fn fmt(&self, f: &mut Formatter) -> Result<(), fmt::Error> {
write!(f, "{:02x}", self.0)?;
Ok(())
}
}
impl fmt::Debug for TraceId {
fn fmt(&self, f: &mut Formatter) -> Result<(), fmt::Error> {
write!(f, "{:02x}", self.0)?;
Ok(())
}
}
impl fmt::Display for SpanId {
fn fmt(&self, f: &mut Formatter) -> Result<(), fmt::Error> {
write!(f, "{:02x}", self.0)?;
Ok(())
}
}
impl fmt::Debug for SpanId {
fn fmt(&self, f: &mut Formatter) -> Result<(), fmt::Error> {
write!(f, "{:02x}", self.0)?;
Ok(())
}
}
#[cfg(feature = "serde1")]
mod u128_serde {
pub fn serialize<S>(u: &u128, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serde::Serialize::serialize(&u.to_le_bytes(), serializer)
}
pub fn deserialize<'de, D>(deserializer: D) -> Result<u128, D::Error>
where
D: serde::Deserializer<'de>,
{
Ok(u128::from_le_bytes(serde::Deserialize::deserialize(
deserializer,
)?))
}
}

40
tarpc/src/transport.rs Normal file
View File

@@ -0,0 +1,40 @@
// Copyright 2018 Google LLC
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
//! Provides a [`Transport`](sealed::Transport) trait as well as implementations.
//!
//! The rpc crate is transport- and protocol-agnostic. Any transport that impls [`Transport`](sealed::Transport)
//! can be plugged in, using whatever protocol it wants.
pub mod channel;
pub(crate) mod sealed {
use futures::prelude::*;
use std::error::Error;
/// A bidirectional stream ([`Sink`] + [`Stream`]) of messages.
pub trait Transport<SinkItem, Item>
where
Self: Stream<Item = Result<Item, <Self as Sink<SinkItem>>::Error>>,
Self: Sink<SinkItem, Error = <Self as Transport<SinkItem, Item>>::TransportError>,
<Self as Sink<SinkItem>>::Error: Error,
{
/// Associated type where clauses are not elaborated; this associated type allows users
/// bounding types by Transport to avoid having to explicitly add `T::Error: Error` to their
/// bounds.
type TransportError: Error + Send + Sync + 'static;
}
impl<T, SinkItem, Item, E> Transport<SinkItem, Item> for T
where
T: ?Sized,
T: Stream<Item = Result<Item, E>>,
T: Sink<SinkItem, Error = E>,
T::Error: Error + Send + Sync + 'static,
{
type TransportError = E;
}
}

View File

@@ -0,0 +1,219 @@
// Copyright 2018 Google LLC
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
//! Transports backed by in-memory channels.
use futures::{task::*, Sink, Stream};
use pin_project::pin_project;
use std::{error::Error, pin::Pin};
use tokio::sync::mpsc;
/// Errors that occur in the sending or receiving of messages over a channel.
#[derive(thiserror::Error, Debug)]
pub enum ChannelError {
/// An error occurred readying to send into the channel.
#[error("an error occurred readying to send into the channel")]
Ready(#[source] Box<dyn Error + Send + Sync + 'static>),
/// An error occurred sending into the channel.
#[error("an error occurred sending into the channel")]
Send(#[source] Box<dyn Error + Send + Sync + 'static>),
/// An error occurred receiving from the channel.
#[error("an error occurred receiving from the channel")]
Receive(#[source] Box<dyn Error + Send + Sync + 'static>),
}
/// Returns two unbounded channel peers. Each [`Stream`] yields items sent through the other's
/// [`Sink`].
pub fn unbounded<SinkItem, Item>() -> (
UnboundedChannel<SinkItem, Item>,
UnboundedChannel<Item, SinkItem>,
) {
let (tx1, rx2) = mpsc::unbounded_channel();
let (tx2, rx1) = mpsc::unbounded_channel();
(
UnboundedChannel { tx: tx1, rx: rx1 },
UnboundedChannel { tx: tx2, rx: rx2 },
)
}
/// A bi-directional channel backed by an [`UnboundedSender`](mpsc::UnboundedSender)
/// and [`UnboundedReceiver`](mpsc::UnboundedReceiver).
#[derive(Debug)]
pub struct UnboundedChannel<Item, SinkItem> {
rx: mpsc::UnboundedReceiver<Item>,
tx: mpsc::UnboundedSender<SinkItem>,
}
impl<Item, SinkItem> Stream for UnboundedChannel<Item, SinkItem> {
type Item = Result<Item, ChannelError>;
fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Item, ChannelError>>> {
self.rx
.poll_recv(cx)
.map(|option| option.map(Ok))
.map_err(ChannelError::Receive)
}
}
const CLOSED_MESSAGE: &str = "the channel is closed and cannot accept new items for sending";
impl<Item, SinkItem> Sink<SinkItem> for UnboundedChannel<Item, SinkItem> {
type Error = ChannelError;
fn poll_ready(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(if self.tx.is_closed() {
Err(ChannelError::Ready(CLOSED_MESSAGE.into()))
} else {
Ok(())
})
}
fn start_send(self: Pin<&mut Self>, item: SinkItem) -> Result<(), Self::Error> {
self.tx
.send(item)
.map_err(|_| ChannelError::Send(CLOSED_MESSAGE.into()))
}
fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
// UnboundedSender requires no flushing.
Poll::Ready(Ok(()))
}
fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
// UnboundedSender can't initiate closure.
Poll::Ready(Ok(()))
}
}
/// Returns two channel peers with buffer equal to `capacity`. Each [`Stream`] yields items sent
/// through the other's [`Sink`].
pub fn bounded<SinkItem, Item>(
capacity: usize,
) -> (Channel<SinkItem, Item>, Channel<Item, SinkItem>) {
let (tx1, rx2) = futures::channel::mpsc::channel(capacity);
let (tx2, rx1) = futures::channel::mpsc::channel(capacity);
(Channel { tx: tx1, rx: rx1 }, Channel { tx: tx2, rx: rx2 })
}
/// A bi-directional channel backed by a [`Sender`](futures::channel::mpsc::Sender)
/// and [`Receiver`](futures::channel::mpsc::Receiver).
#[pin_project]
#[derive(Debug)]
pub struct Channel<Item, SinkItem> {
#[pin]
rx: futures::channel::mpsc::Receiver<Item>,
#[pin]
tx: futures::channel::mpsc::Sender<SinkItem>,
}
impl<Item, SinkItem> Stream for Channel<Item, SinkItem> {
type Item = Result<Item, ChannelError>;
fn poll_next(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Item, ChannelError>>> {
self.project()
.rx
.poll_next(cx)
.map(|option| option.map(Ok))
.map_err(ChannelError::Receive)
}
}
impl<Item, SinkItem> Sink<SinkItem> for Channel<Item, SinkItem> {
type Error = ChannelError;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.project()
.tx
.poll_ready(cx)
.map_err(|e| ChannelError::Ready(Box::new(e)))
}
fn start_send(self: Pin<&mut Self>, item: SinkItem) -> Result<(), Self::Error> {
self.project()
.tx
.start_send(item)
.map_err(|e| ChannelError::Send(Box::new(e)))
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.project()
.tx
.poll_flush(cx)
.map_err(|e| ChannelError::Send(Box::new(e)))
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.project()
.tx
.poll_close(cx)
.map_err(|e| ChannelError::Send(Box::new(e)))
}
}
#[cfg(all(test, feature = "tokio1"))]
mod tests {
use crate::{
client::{self, RpcError},
context,
server::{incoming::Incoming, serve, BaseChannel},
transport::{
self,
channel::{Channel, UnboundedChannel},
},
ServerError,
};
use assert_matches::assert_matches;
use futures::{prelude::*, stream};
use std::io;
use tracing::trace;
#[test]
fn ensure_is_transport() {
fn is_transport<SinkItem, Item, T: crate::Transport<SinkItem, Item>>() {}
is_transport::<(), (), UnboundedChannel<(), ()>>();
is_transport::<(), (), Channel<(), ()>>();
}
#[tokio::test]
async fn integration() -> anyhow::Result<()> {
let _ = tracing_subscriber::fmt::try_init();
let (client_channel, server_channel) = transport::channel::unbounded();
tokio::spawn(
stream::once(future::ready(server_channel))
.map(BaseChannel::with_defaults)
.execute(serve(|_ctx, request: String| async move {
request.parse::<u64>().map_err(|_| {
ServerError::new(
io::ErrorKind::InvalidInput,
format!("{request:?} is not an int"),
)
})
}))
.for_each(|channel| async move {
tokio::spawn(channel.for_each(|response| response));
}),
);
let client = client::new(client::Config::default(), client_channel).spawn();
let response1 = client.call(context::current(), "", "123".into()).await;
let response2 = client.call(context::current(), "", "abc".into()).await;
trace!("response1: {:?}, response2: {:?}", response1, response2);
assert_matches!(response1, Ok(123));
assert_matches!(response2, Err(RpcError::Server(e)) if e.kind == io::ErrorKind::InvalidInput);
Ok(())
}
}

71
tarpc/src/util.rs Normal file
View File

@@ -0,0 +1,71 @@
// Copyright 2018 Google LLC
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
use std::{
collections::HashMap,
hash::{BuildHasher, Hash},
time::{Duration, SystemTime},
};
#[cfg(feature = "serde1")]
#[cfg_attr(docsrs, doc(cfg(feature = "serde1")))]
pub mod serde;
/// Extension trait for [SystemTimes](SystemTime) in the future, i.e. deadlines.
pub trait TimeUntil {
/// How much time from now until this time is reached.
fn time_until(&self) -> Duration;
}
impl TimeUntil for SystemTime {
fn time_until(&self) -> Duration {
self.duration_since(SystemTime::now()).unwrap_or_default()
}
}
/// Collection compaction; configurable `shrink_to_fit`.
pub trait Compact {
/// Compacts space if the ratio of length : capacity is less than `usage_ratio_threshold`.
fn compact(&mut self, usage_ratio_threshold: f64);
}
impl<K, V, H> Compact for HashMap<K, V, H>
where
K: Eq + Hash,
H: BuildHasher,
{
fn compact(&mut self, usage_ratio_threshold: f64) {
let usage_ratio_threshold = usage_ratio_threshold.clamp(f64::MIN_POSITIVE, 1.);
let cap = f64::max(1000., self.len() as f64 / usage_ratio_threshold);
self.shrink_to(cap as usize);
}
}
#[test]
fn test_compact() {
let mut map = HashMap::with_capacity(2048);
assert_eq!(map.capacity(), 3584);
// Make usage ratio 25%
for i in 0..896 {
map.insert(format!("k{i}"), "v");
}
map.compact(-1.0);
assert_eq!(map.capacity(), 3584);
map.compact(0.25);
assert_eq!(map.capacity(), 3584);
map.compact(0.50);
assert_eq!(map.capacity(), 1792);
map.compact(1.0);
assert_eq!(map.capacity(), 1792);
map.compact(2.0);
assert_eq!(map.capacity(), 1792);
}

View File

@@ -5,32 +5,10 @@
// https://opensource.org/licenses/MIT.
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use std::{
io,
time::{Duration, SystemTime},
};
/// Serializes `system_time` as a `u64` equal to the number of seconds since the epoch.
pub fn serialize_epoch_secs<S>(system_time: &SystemTime, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
system_time
.duration_since(SystemTime::UNIX_EPOCH)
.unwrap_or(Duration::from_secs(0))
.as_secs() // Only care about second precision
.serialize(serializer)
}
/// Deserializes [`SystemTime`] from a `u64` equal to the number of seconds since the epoch.
pub fn deserialize_epoch_secs<'de, D>(deserializer: D) -> Result<SystemTime, D::Error>
where
D: Deserializer<'de>,
{
Ok(SystemTime::UNIX_EPOCH + Duration::from_secs(u64::deserialize(deserializer)?))
}
use std::io;
/// Serializes [`io::ErrorKind`] as a `u32`.
#[allow(clippy::trivially_copy_pass_by_ref)] // Exact fn signature required by serde derive
pub fn serialize_io_error_kind_as_u32<S>(
kind: &io::ErrorKind,
serializer: S,
@@ -59,7 +37,8 @@ where
Other => 16,
UnexpectedEof => 17,
_ => 16,
}.serialize(serializer)
}
.serialize(serializer)
}
/// Deserializes [`io::ErrorKind`] from a `u32`.

View File

@@ -0,0 +1,7 @@
#[test]
fn ui() {
let t = trybuild::TestCases::new();
t.compile_fail("tests/compile_fail/*.rs");
#[cfg(all(feature = "serde-transport", feature = "tcp"))]
t.compile_fail("tests/compile_fail/serde_transport/*.rs");
}

View File

@@ -0,0 +1,15 @@
use tarpc::client;
#[tarpc::service]
trait World {
async fn hello(name: String) -> String;
}
fn main() {
let (client_transport, _) = tarpc::transport::channel::unbounded();
#[deny(unused_must_use)]
{
WorldClient::new(client::Config::default(), client_transport).dispatch;
}
}

View File

@@ -0,0 +1,15 @@
error: unused `RequestDispatch` that must be used
--> tests/compile_fail/must_use_request_dispatch.rs:13:9
|
13 | WorldClient::new(client::Config::default(), client_transport).dispatch;
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
note: the lint level is defined here
--> tests/compile_fail/must_use_request_dispatch.rs:11:12
|
11 | #[deny(unused_must_use)]
| ^^^^^^^^^^^^^^^
help: use `let _ = ...` to ignore the resulting value
|
13 | let _ = WorldClient::new(client::Config::default(), client_transport).dispatch;
| +++++++

View File

@@ -0,0 +1,9 @@
use tarpc::serde_transport;
use tokio_serde::formats::Json;
fn main() {
#[deny(unused_must_use)]
{
serde_transport::tcp::connect::<_, (), (), _, _>("0.0.0.0:0", Json::default);
}
}

View File

@@ -0,0 +1,15 @@
error: unused `tarpc::serde_transport::tcp::Connect` that must be used
--> tests/compile_fail/serde_transport/must_use_tcp_connect.rs:7:9
|
7 | serde_transport::tcp::connect::<_, (), (), _, _>("0.0.0.0:0", Json::default);
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
note: the lint level is defined here
--> tests/compile_fail/serde_transport/must_use_tcp_connect.rs:5:12
|
5 | #[deny(unused_must_use)]
| ^^^^^^^^^^^^^^^
help: use `let _ = ...` to ignore the resulting value
|
7 | let _ = serde_transport::tcp::connect::<_, (), (), _, _>("0.0.0.0:0", Json::default);
| +++++++

View File

@@ -0,0 +1,6 @@
#[tarpc::service]
trait World {
async fn pat((a, b): (u8, u32));
}
fn main() {}

View File

@@ -0,0 +1,5 @@
error: patterns aren't allowed in RPC args
--> $DIR/tarpc_service_arg_pat.rs:3:18
|
3 | async fn pat((a, b): (u8, u32));
| ^^^^^^

View File

@@ -0,0 +1,6 @@
#[tarpc::service]
trait World {
async fn new();
}
fn main() {}

View File

@@ -0,0 +1,5 @@
error: method name conflicts with generated fn `WorldClient::new`
--> $DIR/tarpc_service_fn_new.rs:3:14
|
3 | async fn new();
| ^^^

View File

@@ -0,0 +1,6 @@
#[tarpc::service]
trait World {
async fn serve();
}
fn main() {}

View File

@@ -0,0 +1,5 @@
error: method name conflicts with generated fn `World::serve`
--> $DIR/tarpc_service_fn_serve.rs:3:14
|
3 | async fn serve();
| ^^^^^

View File

@@ -0,0 +1,61 @@
use futures::prelude::*;
use tarpc::serde_transport;
use tarpc::{
client, context,
server::{incoming::Incoming, BaseChannel},
};
use tokio_serde::formats::Json;
#[tarpc::derive_serde]
#[derive(Debug, PartialEq, Eq)]
pub enum TestData {
Black,
White,
}
#[tarpc::service]
pub trait ColorProtocol {
async fn get_opposite_color(color: TestData) -> TestData;
}
#[derive(Clone)]
struct ColorServer;
impl ColorProtocol for ColorServer {
async fn get_opposite_color(self, _: context::Context, color: TestData) -> TestData {
match color {
TestData::White => TestData::Black,
TestData::Black => TestData::White,
}
}
}
#[cfg(test)]
async fn spawn(fut: impl Future<Output = ()> + Send + 'static) {
tokio::spawn(fut);
}
#[tokio::test]
async fn test_call() -> anyhow::Result<()> {
let transport = tarpc::serde_transport::tcp::listen("localhost:56797", Json::default).await?;
let addr = transport.local_addr();
tokio::spawn(
transport
.take(1)
.filter_map(|r| async { r.ok() })
.map(BaseChannel::with_defaults)
.execute(ColorServer.serve())
.map(|channel| channel.for_each(spawn))
.for_each(spawn),
);
let transport = serde_transport::tcp::connect(addr, Json::default).await?;
let client = ColorProtocolClient::new(client::Config::default(), transport).spawn();
let color = client
.get_opposite_color(context::current(), TestData::White)
.await?;
assert_eq!(color, TestData::Black);
Ok(())
}

View File

@@ -1,131 +0,0 @@
// Copyright 2018 Google LLC
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
#![feature(
test,
arbitrary_self_types,
pin,
integer_atomics,
futures_api,
generators,
await_macro,
async_await,
proc_macro_hygiene,
)]
extern crate test;
use self::test::stats::Stats;
use futures::{compat::TokioDefaultSpawner, future, prelude::*};
use rpc::{
client, context,
server::{self, Handler, Server},
};
use std::{
io,
time::{Duration, Instant},
};
mod ack {
tarpc::service! {
rpc ack();
}
}
#[derive(Clone)]
struct Serve;
impl ack::Service for Serve {
type AckFut = future::Ready<()>;
fn ack(&self, _: context::Context) -> Self::AckFut {
future::ready(())
}
}
async fn bench() -> io::Result<()> {
let listener = bincode_transport::listen(&"0.0.0.0:0".parse().unwrap())?;
let addr = listener.local_addr();
tokio_executor::spawn(
Server::new(server::Config::default())
.incoming(listener)
.take(1)
.respond_with(ack::serve(Serve))
.unit_error()
.boxed()
.compat()
);
let conn = await!(bincode_transport::connect(&addr))?;
let mut client = await!(ack::new_stub(client::Config::default(), conn))?;
let total = 10_000usize;
let mut successful = 0u32;
let mut unsuccessful = 0u32;
let mut durations = vec![];
for _ in 1..=total {
let now = Instant::now();
let response = await!(client.ack(context::current()));
let elapsed = now.elapsed();
match response {
Ok(_) => successful += 1,
Err(_) => unsuccessful += 1,
};
durations.push(elapsed);
}
let durations_nanos = durations
.iter()
.map(|duration| duration.as_secs() as f64 * 1E9 + duration.subsec_nanos() as f64)
.collect::<Vec<_>>();
let (lower, median, upper) = durations_nanos.quartiles();
println!("Of {:?} runs:", durations_nanos.len());
println!("\tSuccessful: {:?}", successful);
println!("\tUnsuccessful: {:?}", unsuccessful);
println!(
"\tMean: {:?}",
Duration::from_nanos(durations_nanos.mean() as u64)
);
println!("\tMedian: {:?}", Duration::from_nanos(median as u64));
println!(
"\tStd Dev: {:?}",
Duration::from_nanos(durations_nanos.std_dev() as u64)
);
println!(
"\tMin: {:?}",
Duration::from_nanos(durations_nanos.min() as u64)
);
println!(
"\tMax: {:?}",
Duration::from_nanos(durations_nanos.max() as u64)
);
println!(
"\tQuartiles: ({:?}, {:?}, {:?})",
Duration::from_nanos(lower as u64),
Duration::from_nanos(median as u64),
Duration::from_nanos(upper as u64)
);
println!("done");
Ok(())
}
#[test]
fn bench_small_packet() {
env_logger::init();
tarpc::init(TokioDefaultSpawner);
tokio::run(
bench()
.map_err(|e| panic!(e.to_string()))
.boxed()
.compat(),
)
}

View File

@@ -0,0 +1,275 @@
use assert_matches::assert_matches;
use futures::{
future::{join_all, ready},
prelude::*,
};
use std::time::{Duration, SystemTime};
use tarpc::{
client::{self},
context,
server::{incoming::Incoming, BaseChannel, Channel},
transport::channel,
};
use tokio::join;
#[tarpc_plugins::service]
trait Service {
async fn add(x: i32, y: i32) -> i32;
async fn hey(name: String) -> String;
}
#[derive(Clone)]
struct Server;
impl Service for Server {
async fn add(self, _: context::Context, x: i32, y: i32) -> i32 {
x + y
}
async fn hey(self, _: context::Context, name: String) -> String {
format!("Hey, {name}.")
}
}
#[tokio::test]
async fn sequential() {
let (tx, rx) = tarpc::transport::channel::unbounded();
let client = client::new(client::Config::default(), tx).spawn();
let channel = BaseChannel::with_defaults(rx);
tokio::spawn(
channel
.execute(tarpc::server::serve(|_, i| async move { Ok(i + 1) }))
.for_each(|response| response),
);
assert_eq!(
client.call(context::current(), "AddOne", 1).await.unwrap(),
2
);
}
#[tokio::test]
async fn dropped_channel_aborts_in_flight_requests() -> anyhow::Result<()> {
#[tarpc_plugins::service]
trait Loop {
async fn r#loop();
}
#[derive(Clone)]
struct LoopServer;
#[derive(Debug)]
struct AllHandlersComplete;
impl Loop for LoopServer {
async fn r#loop(self, _: context::Context) {
loop {
futures::pending!();
}
}
}
let _ = tracing_subscriber::fmt::try_init();
let (tx, rx) = channel::unbounded();
// Set up a client that initiates a long-lived request.
// The request will complete in error when the server drops the connection.
tokio::spawn(async move {
let client = LoopClient::new(client::Config::default(), tx).spawn();
let mut ctx = context::current();
ctx.deadline = SystemTime::now() + Duration::from_secs(60 * 60);
let _ = client.r#loop(ctx).await;
});
let mut requests = BaseChannel::with_defaults(rx).requests();
// Reading a request should trigger the request being registered with BaseChannel.
let first_request = requests.next().await.unwrap()?;
// Dropping the channel should trigger cleanup of outstanding requests.
drop(requests);
// In-flight requests should be aborted by channel cleanup.
// The first and only request sent by the client is `loop`, which is an infinite loop
// on the server side, so if cleanup was not triggered, this line should hang indefinitely.
first_request.execute(LoopServer.serve()).await;
Ok(())
}
#[cfg(all(feature = "serde-transport", feature = "tcp"))]
#[tokio::test]
async fn serde_tcp() -> anyhow::Result<()> {
use tarpc::serde_transport;
use tokio_serde::formats::Json;
let _ = tracing_subscriber::fmt::try_init();
let transport = tarpc::serde_transport::tcp::listen("localhost:56789", Json::default).await?;
let addr = transport.local_addr();
tokio::spawn(
transport
.take(1)
.filter_map(|r| async { r.ok() })
.map(BaseChannel::with_defaults)
.execute(Server.serve())
.map(|channel| channel.for_each(spawn))
.for_each(spawn),
);
let transport = serde_transport::tcp::connect(addr, Json::default).await?;
let client = ServiceClient::new(client::Config::default(), transport).spawn();
assert_matches!(client.add(context::current(), 1, 2).await, Ok(3));
assert_matches!(
client.hey(context::current(), "Tim".to_string()).await,
Ok(ref s) if s == "Hey, Tim."
);
Ok(())
}
#[cfg(all(feature = "serde-transport", feature = "unix", unix))]
#[tokio::test]
async fn serde_uds() -> anyhow::Result<()> {
use tarpc::serde_transport;
use tokio_serde::formats::Json;
let _ = tracing_subscriber::fmt::try_init();
let sock = tarpc::serde_transport::unix::TempPathBuf::with_random("uds");
let transport = tarpc::serde_transport::unix::listen(&sock, Json::default).await?;
tokio::spawn(
transport
.take(1)
.filter_map(|r| async { r.ok() })
.map(BaseChannel::with_defaults)
.execute(Server.serve())
.map(|channel| channel.for_each(spawn))
.for_each(spawn),
);
let transport = serde_transport::unix::connect(&sock, Json::default).await?;
let client = ServiceClient::new(client::Config::default(), transport).spawn();
// Save results using socket so we can clean the socket even if our test assertions fail
let res1 = client.add(context::current(), 1, 2).await;
let res2 = client.hey(context::current(), "Tim".to_string()).await;
assert_matches!(res1, Ok(3));
assert_matches!(res2, Ok(ref s) if s == "Hey, Tim.");
Ok(())
}
#[tokio::test]
async fn concurrent() -> anyhow::Result<()> {
let _ = tracing_subscriber::fmt::try_init();
let (tx, rx) = channel::unbounded();
tokio::spawn(
stream::once(ready(rx))
.map(BaseChannel::with_defaults)
.execute(Server.serve())
.map(|channel| channel.for_each(spawn))
.for_each(spawn),
);
let client = ServiceClient::new(client::Config::default(), tx).spawn();
let req1 = client.add(context::current(), 1, 2);
let req2 = client.add(context::current(), 3, 4);
let req3 = client.hey(context::current(), "Tim".to_string());
assert_matches!(req1.await, Ok(3));
assert_matches!(req2.await, Ok(7));
assert_matches!(req3.await, Ok(ref s) if s == "Hey, Tim.");
Ok(())
}
#[tokio::test]
async fn concurrent_join() -> anyhow::Result<()> {
let _ = tracing_subscriber::fmt::try_init();
let (tx, rx) = channel::unbounded();
tokio::spawn(
stream::once(ready(rx))
.map(BaseChannel::with_defaults)
.execute(Server.serve())
.map(|channel| channel.for_each(spawn))
.for_each(spawn),
);
let client = ServiceClient::new(client::Config::default(), tx).spawn();
let req1 = client.add(context::current(), 1, 2);
let req2 = client.add(context::current(), 3, 4);
let req3 = client.hey(context::current(), "Tim".to_string());
let (resp1, resp2, resp3) = join!(req1, req2, req3);
assert_matches!(resp1, Ok(3));
assert_matches!(resp2, Ok(7));
assert_matches!(resp3, Ok(ref s) if s == "Hey, Tim.");
Ok(())
}
#[cfg(test)]
async fn spawn(fut: impl Future<Output = ()> + Send + 'static) {
tokio::spawn(fut);
}
#[tokio::test]
async fn concurrent_join_all() -> anyhow::Result<()> {
let _ = tracing_subscriber::fmt::try_init();
let (tx, rx) = channel::unbounded();
tokio::spawn(
BaseChannel::with_defaults(rx)
.execute(Server.serve())
.for_each(spawn),
);
let client = ServiceClient::new(client::Config::default(), tx).spawn();
let req1 = client.add(context::current(), 1, 2);
let req2 = client.add(context::current(), 3, 4);
let responses = join_all(vec![req1, req2]).await;
assert_matches!(responses[0], Ok(3));
assert_matches!(responses[1], Ok(7));
Ok(())
}
#[tokio::test]
async fn counter() -> anyhow::Result<()> {
#[tarpc::service]
trait Counter {
async fn count() -> u32;
}
struct CountService(u32);
impl Counter for &mut CountService {
async fn count(self, _: context::Context) -> u32 {
self.0 += 1;
self.0
}
}
let (tx, rx) = channel::unbounded();
tokio::spawn(async {
let mut requests = BaseChannel::with_defaults(rx).requests();
let mut counter = CountService(0);
while let Some(Ok(request)) = requests.next().await {
request.execute(counter.serve()).await;
}
});
let client = CounterClient::new(client::Config::default(), tx).spawn();
assert_matches!(client.count(context::current()).await, Ok(1));
assert_matches!(client.count(context::current()).await, Ok(2));
Ok(())
}

View File

@@ -1,21 +0,0 @@
[package]
name = "tarpc-trace"
version = "0.1.0"
authors = ["tikue <tikue@google.com>"]
edition = '2018'
license = "MIT"
documentation = "https://docs.rs/tarpc-trace"
homepage = "https://github.com/google/tarpc"
repository = "https://github.com/google/tarpc"
keywords = ["rpc", "network", "server", "api", "tls"]
categories = ["asynchronous", "network-programming"]
readme = "../README.md"
description = "foundations for tracing in tarpc"
[dependencies]
rand = "0.5"
[dependencies.serde]
version = "1.0"
optional = true
features = ["derive"]

View File

@@ -1 +0,0 @@
edition = "Edition2018"

View File

@@ -1,106 +0,0 @@
// Copyright 2018 Google LLC
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
#![deny(missing_docs, missing_debug_implementations)]
//! Provides building blocks for tracing distributed programs.
//!
//! A trace is logically a tree of causally-related events called spans. Traces are tracked via a
//! [context](Context) that identifies the current trace, span, and parent of the current span. In
//! distributed systems, a context can be sent from client to server to connect events occurring on
//! either side.
//!
//! This crate's design is based on [opencensus
//! tracing](https://opencensus.io/core-concepts/tracing/).
use rand::Rng;
use std::{
fmt::{self, Formatter},
mem,
};
/// A context for tracing the execution of processes, distributed or otherwise.
///
/// Consists of a span identifying an event, an optional parent span identifying a causal event
/// that triggered the current span, and a trace with which all related spans are associated.
#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)]
#[cfg_attr(
feature = "serde",
derive(serde::Serialize, serde::Deserialize)
)]
pub struct Context {
/// An identifier of the trace associated with the current context. A trace ID is typically
/// created at a root span and passed along through all causal events.
pub trace_id: TraceId,
/// An identifier of the current span. In typical RPC usage, a span is created by a client
/// before making an RPC, and the span ID is sent to the server. The server is free to create
/// its own spans, for which it sets the client's span as the parent span.
pub span_id: SpanId,
/// An identifier of the span that originated the current span. For example, if a server sends
/// an RPC in response to a client request that included a span, the server would create a span
/// for the RPC and set its parent to the span_id in the incoming request's context.
///
/// If `parent_id` is `None`, then this is a root context.
pub parent_id: Option<SpanId>,
}
/// A 128-bit UUID identifying a trace. All spans caused by the same originating span share the
/// same trace ID.
#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)]
#[cfg_attr(
feature = "serde",
derive(serde::Serialize, serde::Deserialize)
)]
pub struct TraceId(u128);
/// A 64-bit identifier of a span within a trace. The identifier is unique within the span's trace.
#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)]
#[cfg_attr(
feature = "serde",
derive(serde::Serialize, serde::Deserialize)
)]
pub struct SpanId(u64);
impl Context {
/// Constructs a new root context. A root context is one with no parent span.
pub fn new_root() -> Self {
let rng = &mut rand::thread_rng();
Context {
trace_id: TraceId::random(rng),
span_id: SpanId::random(rng),
parent_id: None,
}
}
}
impl TraceId {
/// Returns a random trace ID that can be assumed to be globally unique if `rng` generates
/// actually-random numbers.
pub fn random<R: Rng>(rng: &mut R) -> Self {
TraceId((rng.next_u64() as u128) << mem::size_of::<u64>() | rng.next_u64() as u128)
}
}
impl SpanId {
/// Returns a random span ID that can be assumed to be unique within a single trace.
pub fn random<R: Rng>(rng: &mut R) -> Self {
SpanId(rng.next_u64())
}
}
impl fmt::Display for TraceId {
fn fmt(&self, f: &mut Formatter) -> Result<(), fmt::Error> {
write!(f, "{:02x}", self.0)?;
Ok(())
}
}
impl fmt::Display for SpanId {
fn fmt(&self, f: &mut Formatter) -> Result<(), fmt::Error> {
write!(f, "{:02x}", self.0)?;
Ok(())
}
}