Remove deprecated tokio-proto and replace with homegrown rpc framework (#199)

# New Crates

- crate rpc contains the core client/server request-response framework, as well as a transport trait.
- crate bincode-transport implements a transport that works almost exactly as tarpc works today (not to say it's wire-compatible).
- crate trace has some foundational types for tracing. This isn't really fleshed out yet, but it's useful for in-process log tracing, at least.

All crates are now at the top level. e.g. tarpc-plugins is now tarpc/plugins rather than tarpc/src/plugins. tarpc itself is now a *very* small code surface, as most functionality has been moved into the other more granular crates.

# New Features
- deadlines: all requests specify a deadline, and a server will stop processing a response when past its deadline.
- client cancellation propagation: when a client drops a request, the client sends a message to the server informing it to cancel its response. This means cancellations can propagate across multiple server hops.
- trace context stuff as mentioned above
- more server configuration for total connection limits, per-connection request limits, etc.

# Removals
- no more shutdown handle.  I left it out for now because of time and not being sure what the right solution is.
- all async now, no blocking stub or server interface. This helps with maintainability, and async/await makes async code much more usable. The service trait is thusly renamed Service, and the client is renamed Client.
- no built-in transport. Tarpc is now transport agnostic (see bincode-transport for transitioning existing uses).
- going along with the previous bullet, no preferred transport means no TLS support at this time. We could make a tls transport or make bincode-transport compatible with TLS.
- a lot of examples were removed because I couldn't keep up with maintaining all of them. Hopefully the ones I kept are still illustrative.
- no more plugins!

# Open Questions

1. Should client.send() return `Future<Response>` or `Future<Future<Response>>`? The former appears more ergonomic but it doesn’t allow concurrent requests with a single client handle. The latter is less ergonomic but yields back control of the client once it’s successfully sent out the request. Should we offer fns for both?
2. Should rpc service! Fns take &mut self or &self or self? The service needs to impl Clone anyway, technically we only need to clone it once per connection, and then leave it up to the user to decide if they want to clone it per RPC. In practice, everyone doing nontrivial stuff will need to clone it per RPC, I think.
3. Do the request/response structs look ok?
4. Is supporting server shutdown/lameduck important?

Fixes #178 #155 #124 #104 #83 #38
This commit is contained in:
Tim
2018-10-16 11:26:27 -07:00
committed by GitHub
parent 5e4b97e589
commit 905e5be8bb
73 changed files with 4690 additions and 5143 deletions

708
rpc/src/client/dispatch.rs Normal file
View File

@@ -0,0 +1,708 @@
use crate::{
context,
util::{deadline_compat, AsDuration, Compact},
ClientMessage, ClientMessageKind, Request, Response, Transport,
};
use fnv::FnvHashMap;
use futures::{
Poll,
channel::{mpsc, oneshot},
prelude::*,
ready,
stream::Fuse,
task::LocalWaker,
};
use humantime::format_rfc3339;
use log::{debug, error, info, trace};
use pin_utils::unsafe_pinned;
use std::{
io,
net::SocketAddr,
pin::Pin,
sync::{
atomic::{AtomicU64, Ordering},
Arc,
},
time::Instant,
};
use trace::SpanId;
use super::Config;
/// Handles communication from the client to request dispatch.
#[derive(Debug)]
pub(crate) struct Channel<Req, Resp> {
to_dispatch: mpsc::Sender<DispatchRequest<Req, Resp>>,
/// Channel to send a cancel message to the dispatcher.
cancellation: RequestCancellation,
/// The ID to use for the next request to stage.
next_request_id: Arc<AtomicU64>,
server_addr: SocketAddr,
}
impl<Req, Resp> Clone for Channel<Req, Resp> {
fn clone(&self) -> Self {
Self {
to_dispatch: self.to_dispatch.clone(),
cancellation: self.cancellation.clone(),
next_request_id: self.next_request_id.clone(),
server_addr: self.server_addr,
}
}
}
impl<Req, Resp> Channel<Req, Resp> {
/// Sends a request to the dispatch task to forward to the server, returning a [`Future`] that
/// resolves when the request is sent (not when the response is received).
pub(crate) async fn send(
&mut self,
mut ctx: context::Context,
request: Req,
) -> io::Result<DispatchResponse<Resp>> {
// Convert the context to the call context.
ctx.trace_context.parent_id = Some(ctx.trace_context.span_id);
ctx.trace_context.span_id = SpanId::random(&mut rand::thread_rng());
let timeout = ctx.deadline.as_duration();
let deadline = Instant::now() + timeout;
trace!(
"[{}/{}] Queuing request with deadline {} (timeout {:?}).",
ctx.trace_id(),
self.server_addr,
format_rfc3339(ctx.deadline),
timeout,
);
let (response_completion, response) = oneshot::channel();
let cancellation = self.cancellation.clone();
let request_id = self.next_request_id.fetch_add(1, Ordering::Relaxed);
await!(self.to_dispatch.send(DispatchRequest {
ctx,
request_id,
request,
response_completion,
})).map_err(|_| io::Error::from(io::ErrorKind::ConnectionReset))?;
Ok(DispatchResponse {
response: deadline_compat::Deadline::new(response, deadline),
complete: false,
request_id,
cancellation,
ctx,
server_addr: self.server_addr,
})
}
/// Sends a request to the dispatch task to forward to the server, returning a [`Future`] that
/// resolves to the response.
pub(crate) async fn call(
&mut self,
context: context::Context,
request: Req,
) -> io::Result<Resp> {
let response_future = await!(self.send(context, request))?;
await!(response_future)
}
}
/// A server response that is completed by request dispatch when the corresponding response
/// arrives off the wire.
#[derive(Debug)]
pub struct DispatchResponse<Resp> {
response: deadline_compat::Deadline<oneshot::Receiver<Response<Resp>>>,
ctx: context::Context,
complete: bool,
cancellation: RequestCancellation,
request_id: u64,
server_addr: SocketAddr,
}
impl<Resp> DispatchResponse<Resp> {
unsafe_pinned!(server_addr: SocketAddr);
unsafe_pinned!(ctx: context::Context);
}
impl<Resp> Future for DispatchResponse<Resp> {
type Output = io::Result<Resp>;
fn poll(mut self: Pin<&mut Self>, waker: &LocalWaker) -> Poll<io::Result<Resp>> {
let resp = ready!(self.response.poll_unpin(waker));
self.complete = true;
Poll::Ready(match resp {
Ok(resp) => Ok(resp.message?),
Err(e) => Err({
let trace_id = *self.ctx().trace_id();
let server_addr = *self.server_addr();
if e.is_elapsed() {
io::Error::new(
io::ErrorKind::TimedOut,
"Client dropped expired request.".to_string(),
)
} else if e.is_timer() {
let e = e.into_timer().unwrap();
if e.is_at_capacity() {
io::Error::new(
io::ErrorKind::Other,
"Cancelling request because an expiration could not be set \
due to the timer being at capacity."
.to_string(),
)
} else if e.is_shutdown() {
panic!("[{}/{}] Timer was shutdown", trace_id, server_addr)
} else {
panic!(
"[{}/{}] Unrecognized timer error: {}",
trace_id, server_addr, e
)
}
} else if e.is_inner() {
// The oneshot is Canceled when the dispatch task ends.
io::Error::from(io::ErrorKind::ConnectionReset)
} else {
panic!(
"[{}/{}] Unrecognized deadline error: {}",
trace_id, server_addr, e
)
}
}),
})
}
}
// Cancels the request when dropped, if not already complete.
impl<Resp> Drop for DispatchResponse<Resp> {
fn drop(&mut self) {
if !self.complete {
// The receiver needs to be closed to handle the edge case that the request has not
// yet been received by the dispatch task. It is possible for the cancel message to
// arrive before the request itself, in which case the request could get stuck in the
// dispatch map forever if the server never responds (e.g. if the server dies while
// responding). Even if the server does respond, it will have unnecessarily done work
// for a client no longer waiting for a response. To avoid this, the dispatch task
// checks if the receiver is closed before inserting the request in the map. By
// closing the receiver before sending the cancel message, it is guaranteed that if the
// dispatch task misses an early-arriving cancellation message, then it will see the
// receiver as closed.
self.response.get_mut().close();
self.cancellation.cancel(self.request_id);
}
}
}
/// Spawns a dispatch task on the default executor that manages the lifecycle of requests initiated
/// by the returned [`Channel`].
pub async fn spawn<Req, Resp, C>(
config: Config,
transport: C,
server_addr: SocketAddr,
) -> io::Result<Channel<Req, Resp>>
where
Req: Send,
Resp: Send,
C: Transport<Item = Response<Resp>, SinkItem = ClientMessage<Req>> + Send,
{
let (to_dispatch, pending_requests) = mpsc::channel(config.pending_request_buffer);
let (cancellation, canceled_requests) = cancellations();
crate::spawn(
RequestDispatch {
config,
server_addr,
canceled_requests,
transport: transport.fuse(),
in_flight_requests: FnvHashMap::default(),
pending_requests: pending_requests.fuse(),
}.unwrap_or_else(move |e| error!("[{}] Connection broken: {}", server_addr, e))
).map_err(|e| {
io::Error::new(
io::ErrorKind::Other,
format!(
"Could not spawn client dispatch task. Is shutdown: {}",
e.is_shutdown()
),
)
})?;
Ok(Channel {
to_dispatch,
cancellation,
server_addr,
next_request_id: Arc::new(AtomicU64::new(0)),
})
}
/// Handles the lifecycle of requests, writing requests to the wire, managing cancellations,
/// and dispatching responses to the appropriate channel.
struct RequestDispatch<Req, Resp, C> {
/// Writes requests to the wire and reads responses off the wire.
transport: Fuse<C>,
/// Requests waiting to be written to the wire.
pending_requests: Fuse<mpsc::Receiver<DispatchRequest<Req, Resp>>>,
/// Requests that were dropped.
canceled_requests: CanceledRequests,
/// Requests already written to the wire that haven't yet received responses.
in_flight_requests: FnvHashMap<u64, InFlightData<Resp>>,
/// Configures limits to prevent unlimited resource usage.
config: Config,
/// The address of the server connected to.
server_addr: SocketAddr,
}
impl<Req, Resp, C> RequestDispatch<Req, Resp, C>
where
Req: Send,
Resp: Send,
C: Transport<Item = Response<Resp>, SinkItem = ClientMessage<Req>>,
{
unsafe_pinned!(server_addr: SocketAddr);
unsafe_pinned!(in_flight_requests: FnvHashMap<u64, InFlightData<Resp>>);
unsafe_pinned!(canceled_requests: CanceledRequests);
unsafe_pinned!(pending_requests: Fuse<mpsc::Receiver<DispatchRequest<Req, Resp>>>);
unsafe_pinned!(transport: Fuse<C>);
fn pump_read(self: &mut Pin<&mut Self>, waker: &LocalWaker) -> Poll<Option<io::Result<()>>> {
Poll::Ready(match ready!(self.transport().poll_next(waker)?) {
Some(response) => {
self.complete(response);
Some(Ok(()))
}
None => {
trace!("[{}] read half closed", self.server_addr());
None
}
})
}
fn pump_write(self: &mut Pin<&mut Self>, waker: &LocalWaker) -> Poll<Option<io::Result<()>>> {
enum ReceiverStatus {
NotReady,
Closed,
}
let pending_requests_status = match self.poll_next_request(waker)? {
Poll::Ready(Some(dispatch_request)) => {
self.write_request(dispatch_request)?;
return Poll::Ready(Some(Ok(())));
}
Poll::Ready(None) => ReceiverStatus::Closed,
Poll::Pending => ReceiverStatus::NotReady,
};
let canceled_requests_status = match self.poll_next_cancellation(waker)? {
Poll::Ready(Some((context, request_id))) => {
self.write_cancel(context, request_id)?;
return Poll::Ready(Some(Ok(())));
}
Poll::Ready(None) => ReceiverStatus::Closed,
Poll::Pending => ReceiverStatus::NotReady,
};
match (pending_requests_status, canceled_requests_status) {
(ReceiverStatus::Closed, ReceiverStatus::Closed) => {
ready!(self.transport().poll_flush(waker)?);
Poll::Ready(None)
}
(ReceiverStatus::NotReady, _) | (_, ReceiverStatus::NotReady) => {
// No more messages to process, so flush any messages buffered in the transport.
ready!(self.transport().poll_flush(waker)?);
// Even if we fully-flush, we return Pending, because we have no more requests
// or cancellations right now.
Poll::Pending
}
}
}
/// Yields the next pending request, if one is ready to be sent.
fn poll_next_request(
self: &mut Pin<&mut Self>,
waker: &LocalWaker,
) -> Poll<Option<io::Result<DispatchRequest<Req, Resp>>>> {
if self.in_flight_requests().len() >= self.config.max_in_flight_requests {
info!(
"At in-flight request capacity ({}/{}).",
self.in_flight_requests().len(),
self.config.max_in_flight_requests
);
// No need to schedule a wakeup, because timers and responses are responsible
// for clearing out in-flight requests.
return Poll::Pending;
}
while let Poll::Pending = self.transport().poll_ready(waker)? {
// We can't yield a request-to-be-sent before the transport is capable of buffering it.
ready!(self.transport().poll_flush(waker)?);
}
loop {
match ready!(self.pending_requests().poll_next_unpin(waker)) {
Some(request) => {
if request.response_completion.is_canceled() {
trace!(
"[{}] Request canceled before being sent.",
request.ctx.trace_id()
);
continue;
}
return Poll::Ready(Some(Ok(request)));
}
None => {
trace!("[{}] pending_requests closed", self.server_addr());
return Poll::Ready(None);
}
}
}
}
/// Yields the next pending cancellation, and, if one is ready, cancels the associated request.
fn poll_next_cancellation(
self: &mut Pin<&mut Self>,
waker: &LocalWaker,
) -> Poll<Option<io::Result<(context::Context, u64)>>> {
while let Poll::Pending = self.transport().poll_ready(waker)? {
ready!(self.transport().poll_flush(waker)?);
}
loop {
match ready!(self.canceled_requests().poll_next_unpin(waker)) {
Some(request_id) => {
if let Some(in_flight_data) = self.in_flight_requests().remove(&request_id) {
self.in_flight_requests().compact(0.1);
debug!(
"[{}/{}] Removed request.",
in_flight_data.ctx.trace_id(),
self.server_addr()
);
return Poll::Ready(Some(Ok((in_flight_data.ctx, request_id))));
}
}
None => {
trace!("[{}] canceled_requests closed.", self.server_addr());
return Poll::Ready(None);
}
}
}
}
fn write_request(
self: &mut Pin<&mut Self>,
dispatch_request: DispatchRequest<Req, Resp>,
) -> io::Result<()> {
let request_id = dispatch_request.request_id;
let request = ClientMessage {
trace_context: dispatch_request.ctx.trace_context,
message: ClientMessageKind::Request(Request {
id: request_id,
message: dispatch_request.request,
deadline: dispatch_request.ctx.deadline,
}),
};
self.transport().start_send(request)?;
self.in_flight_requests().insert(
request_id,
InFlightData {
ctx: dispatch_request.ctx,
response_completion: dispatch_request.response_completion,
},
);
Ok(())
}
fn write_cancel(
self: &mut Pin<&mut Self>,
context: context::Context,
request_id: u64,
) -> io::Result<()> {
let trace_id = *context.trace_id();
let cancel = ClientMessage {
trace_context: context.trace_context,
message: ClientMessageKind::Cancel { request_id },
};
self.transport().start_send(cancel)?;
trace!("[{}/{}] Cancel message sent.", trace_id, self.server_addr());
return Ok(());
}
/// Sends a server response to the client task that initiated the associated request.
fn complete(self: &mut Pin<&mut Self>, response: Response<Resp>) -> bool {
if let Some(in_flight_data) = self.in_flight_requests().remove(&response.request_id) {
self.in_flight_requests().compact(0.1);
trace!(
"[{}/{}] Received response.",
in_flight_data.ctx.trace_id(),
self.server_addr()
);
let _ = in_flight_data.response_completion.send(response);
return true;
}
debug!(
"[{}] No in-flight request found for request_id = {}.",
self.server_addr(),
response.request_id
);
// If the response completion was absent, then the request was already canceled.
false
}
}
impl<Req, Resp, C> Future for RequestDispatch<Req, Resp, C>
where
Req: Send,
Resp: Send,
C: Transport<Item = Response<Resp>, SinkItem = ClientMessage<Req>>,
{
type Output = io::Result<()>;
fn poll(mut self: Pin<&mut Self>, waker: &LocalWaker) -> Poll<io::Result<()>> {
trace!("[{}] RequestDispatch::poll", self.server_addr());
loop {
match (self.pump_read(waker)?, self.pump_write(waker)?) {
(read, write @ Poll::Ready(None)) => {
if self.in_flight_requests().is_empty() {
info!(
"[{}] Shutdown: write half closed, and no requests in flight.",
self.server_addr()
);
return Poll::Ready(Ok(()));
}
match read {
Poll::Ready(Some(())) => continue,
_ => {
trace!(
"[{}] read: {:?}, write: {:?}, (not ready)",
self.server_addr(),
read,
write,
);
return Poll::Pending;
}
}
}
(read @ Poll::Ready(Some(())), write) | (read, write @ Poll::Ready(Some(()))) => {
trace!(
"[{}] read: {:?}, write: {:?}",
self.server_addr(),
read,
write,
)
}
(read, write) => {
trace!(
"[{}] read: {:?}, write: {:?} (not ready)",
self.server_addr(),
read,
write,
);
return Poll::Pending;
}
}
}
}
}
/// A server-bound request sent from a [`Channel`] to request dispatch, which will then manage
/// the lifecycle of the request.
#[derive(Debug)]
struct DispatchRequest<Req, Resp> {
ctx: context::Context,
request_id: u64,
request: Req,
response_completion: oneshot::Sender<Response<Resp>>,
}
struct InFlightData<Resp> {
ctx: context::Context,
response_completion: oneshot::Sender<Response<Resp>>,
}
/// Sends request cancellation signals.
#[derive(Debug, Clone)]
struct RequestCancellation(mpsc::UnboundedSender<u64>);
/// A stream of IDs of requests that have been canceled.
#[derive(Debug)]
struct CanceledRequests(mpsc::UnboundedReceiver<u64>);
/// Returns a channel to send request cancellation messages.
fn cancellations() -> (RequestCancellation, CanceledRequests) {
// Unbounded because messages are sent in the drop fn. This is fine, because it's still
// bounded by the number of in-flight requests. Additionally, each request has a clone
// of the sender, so the bounded channel would have the same behavior,
// since it guarantees a slot.
let (tx, rx) = mpsc::unbounded();
(RequestCancellation(tx), CanceledRequests(rx))
}
impl RequestCancellation {
/// Cancels the request with ID `request_id`.
fn cancel(&mut self, request_id: u64) {
let _ = self.0.unbounded_send(request_id);
}
}
impl Stream for CanceledRequests {
type Item = u64;
fn poll_next(mut self: Pin<&mut Self>, waker: &LocalWaker) -> Poll<Option<u64>> {
self.0.poll_next_unpin(waker)
}
}
#[cfg(test)]
mod tests {
use super::{CanceledRequests, Channel, RequestCancellation, RequestDispatch};
use crate::{
client::Config,
context,
transport::{self, channel::UnboundedChannel},
ClientMessage, Response,
};
use fnv::FnvHashMap;
use futures::{Poll, channel::mpsc, prelude::*};
use futures_test::task::{noop_local_waker_ref};
use std::{
net::{IpAddr, Ipv4Addr, SocketAddr},
pin::Pin,
sync::atomic::AtomicU64,
sync::Arc,
};
#[test]
fn stage_request() {
let (mut dispatch, mut channel, _server_channel) = set_up();
// Test that a request future dropped before it's processed by dispatch will cause the request
// to not be added to the in-flight request map.
let _resp = tokio::runtime::current_thread::block_on_all(
channel
.send(context::current(), "hi".to_string())
.boxed()
.compat(),
);
let mut dispatch = Pin::new(&mut dispatch);
let waker = &noop_local_waker_ref();
let req = dispatch.poll_next_request(waker).ready();
assert!(req.is_some());
let req = req.unwrap();
assert_eq!(req.request_id, 0);
assert_eq!(req.request, "hi".to_string());
}
#[test]
fn stage_request_response_future_dropped() {
let (mut dispatch, mut channel, _server_channel) = set_up();
// Test that a request future dropped before it's processed by dispatch will cause the request
// to not be added to the in-flight request map.
let resp = tokio::runtime::current_thread::block_on_all(
channel
.send(context::current(), "hi".into())
.boxed()
.compat(),
).unwrap();
drop(resp);
drop(channel);
let mut dispatch = Pin::new(&mut dispatch);
let waker = &noop_local_waker_ref();
dispatch.poll_next_cancellation(waker).unwrap();
assert!(dispatch.poll_next_request(waker).ready().is_none());
}
#[test]
fn stage_request_response_future_closed() {
let (mut dispatch, mut channel, _server_channel) = set_up();
// Test that a request future that's closed its receiver but not yet canceled its request --
// i.e. still in `drop fn` -- will cause the request to not be added to the in-flight request
// map.
let resp = tokio::runtime::current_thread::block_on_all(
channel
.send(context::current(), "hi".into())
.boxed()
.compat(),
).unwrap();
drop(resp);
drop(channel);
let mut dispatch = Pin::new(&mut dispatch);
let waker = &noop_local_waker_ref();
assert!(dispatch.poll_next_request(waker).ready().is_none());
}
fn set_up() -> (
RequestDispatch<String, String, UnboundedChannel<Response<String>, ClientMessage<String>>>,
Channel<String, String>,
UnboundedChannel<ClientMessage<String>, Response<String>>,
) {
let _ = env_logger::try_init();
let (to_dispatch, pending_requests) = mpsc::channel(1);
let (cancel_tx, canceled_requests) = mpsc::unbounded();
let (client_channel, server_channel) = transport::channel::unbounded();
let dispatch = RequestDispatch::<String, String, _> {
transport: client_channel.fuse(),
pending_requests: pending_requests.fuse(),
canceled_requests: CanceledRequests(canceled_requests),
in_flight_requests: FnvHashMap::default(),
config: Config::default(),
server_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
};
let cancellation = RequestCancellation(cancel_tx);
let channel = Channel {
to_dispatch,
cancellation,
next_request_id: Arc::new(AtomicU64::new(0)),
server_addr: SocketAddr::new(Ipv4Addr::UNSPECIFIED.into(), 0),
};
(dispatch, channel, server_channel)
}
trait PollTest {
type T;
fn unwrap(self) -> Poll<Self::T>;
fn ready(self) -> Self::T;
}
impl<T, E> PollTest for Poll<Option<Result<T, E>>>
where
E: ::std::fmt::Display + Send + 'static,
{
type T = Option<T>;
fn unwrap(self) -> Poll<Option<T>> {
match self {
Poll::Ready(Some(Ok(t))) => Poll::Ready(Some(t)),
Poll::Ready(None) => Poll::Ready(None),
Poll::Ready(Some(Err(e))) => panic!(e.to_string()),
Poll::Pending => Poll::Pending,
}
}
fn ready(self) -> Option<T> {
match self {
Poll::Ready(Some(Ok(t))) => Some(t),
Poll::Ready(None) => None,
Poll::Ready(Some(Err(e))) => panic!(e.to_string()),
Poll::Pending => panic!("Pending"),
}
}
}
}

85
rpc/src/client/mod.rs Normal file
View File

@@ -0,0 +1,85 @@
//! Provides a client that connects to a server and sends multiplexed requests.
use crate::{context::Context, ClientMessage, Response, Transport};
use log::warn;
use std::{
io,
net::{Ipv4Addr, SocketAddr},
};
mod dispatch;
/// Sends multiplexed requests to, and receives responses from, a server.
#[derive(Debug)]
pub struct Client<Req, Resp> {
/// Channel to send requests to the dispatch task.
channel: dispatch::Channel<Req, Resp>,
}
impl<Req, Resp> Clone for Client<Req, Resp> {
fn clone(&self) -> Self {
Client {
channel: self.channel.clone(),
}
}
}
/// Settings that control the behavior of the client.
#[non_exhaustive]
#[derive(Clone, Debug)]
pub struct Config {
/// The number of requests that can be in flight at once.
/// `max_in_flight_requests` controls the size of the map used by the client
/// for storing pending requests.
pub max_in_flight_requests: usize,
/// The number of requests that can be buffered client-side before being sent.
/// `pending_requests_buffer` controls the size of the channel clients use
/// to communicate with the request dispatch task.
pub pending_request_buffer: usize,
}
impl Default for Config {
fn default() -> Self {
Config {
max_in_flight_requests: 1_000,
pending_request_buffer: 100,
}
}
}
impl<Req, Resp> Client<Req, Resp>
where
Req: Send,
Resp: Send,
{
/// Creates a new Client by wrapping a [`Transport`] and spawning a dispatch task
/// that manages the lifecycle of requests.
///
/// Must only be called from on an executor.
pub async fn new<T>(config: Config, transport: T) -> io::Result<Self>
where
T: Transport<Item = Response<Resp>, SinkItem = ClientMessage<Req>> + Send,
{
let server_addr = transport.peer_addr().unwrap_or_else(|e| {
warn!(
"Setting peer to unspecified because peer could not be determined: {}",
e
);
SocketAddr::new(Ipv4Addr::UNSPECIFIED.into(), 0)
});
Ok(Client {
channel: await!(dispatch::spawn(config, transport, server_addr))?,
})
}
/// Initiates a request, sending it to the dispatch task.
///
/// Returns a [`Future`] that resolves to this client and the future response
/// once the request is successfully enqueued.
///
/// [`Future`]: futures::Future
pub async fn call(&mut self, ctx: Context, request: Req) -> io::Result<Resp> {
await!(self.channel.call(ctx, request))
}
}

44
rpc/src/context.rs Normal file
View File

@@ -0,0 +1,44 @@
// Copyright 2018 Google Inc. All Rights Reserved.
//
// Licensed under the MIT License, <LICENSE or http://opensource.org/licenses/MIT>.
// This file may not be copied, modified, or distributed except according to those terms.
//! Provides a request context that carries a deadline and trace context. This context is sent from
//! client to server and is used by the server to enforce response deadlines.
use std::time::{Duration, SystemTime};
use trace::{self, TraceId};
/// A request context that carries request-scoped information like deadlines and trace information.
/// It is sent from client to server and is used by the server to enforce response deadlines.
///
/// The context should not be stored directly in a server implementation, because the context will
/// be different for each request in scope.
#[derive(Clone, Copy, Debug)]
#[non_exhaustive]
pub struct Context {
/// When the client expects the request to be complete by. The server should cancel the request
/// if it is not complete by this time.
pub deadline: SystemTime,
/// Uniquely identifies requests originating from the same source.
/// When a service handles a request by making requests itself, those requests should
/// include the same `trace_id` as that included on the original request. This way,
/// users can trace related actions across a distributed system.
pub trace_context: trace::Context,
}
/// Returns the context for the current request, or a default Context if no request is active.
// TODO: populate Context with request-scoped data, with default fallbacks.
pub fn current() -> Context {
Context {
deadline: SystemTime::now() + Duration::from_secs(10),
trace_context: trace::Context::new_root(),
}
}
impl Context {
/// Returns the ID of the request-scoped trace.
pub fn trace_id(&self) -> &TraceId {
&self.trace_context.trace_id
}
}

214
rpc/src/lib.rs Normal file
View File

@@ -0,0 +1,214 @@
// Copyright 2018 Google Inc. All Rights Reserved.
//
// Licensed under the MIT License, <LICENSE or http://opensource.org/licenses/MIT>.
// This file may not be copied, modified, or distributed except according to those terms.
#![feature(
const_fn,
non_exhaustive,
integer_atomics,
try_trait,
nll,
futures_api,
pin,
arbitrary_self_types,
await_macro,
async_await,
generators,
optin_builtin_traits,
generator_trait,
gen_future,
decl_macro,
)]
#![deny(missing_docs, missing_debug_implementations)]
//! An RPC framework providing client and server.
//!
//! Features:
//! * RPC deadlines, both client- and server-side.
//! * Cascading cancellation (works with multiple hops).
//! * Configurable limits
//! * In-flight requests, both client and server-side.
//! * Server-side limit is per-connection.
//! * When the server reaches the in-flight request maximum, it returns a throttled error
//! to the client.
//! * When the client reaches the in-flight request max, messages are buffered up to a
//! configurable maximum, beyond which the requests are back-pressured.
//! * Server connections.
//! * Total and per-IP limits.
//! * When an incoming connection is accepted, if already at maximum, the connection is
//! dropped.
//! * Transport agnostic.
pub mod client;
pub mod context;
pub mod server;
pub mod transport;
pub(crate) mod util;
pub use crate::{client::Client, server::Server, transport::Transport};
use futures::{Future, task::{Spawn, SpawnExt, SpawnError}};
use std::{cell::RefCell, io, sync::Once, time::SystemTime};
/// A message from a client to a server.
#[derive(Debug)]
#[cfg_attr(
feature = "serde",
derive(serde::Serialize, serde::Deserialize)
)]
#[non_exhaustive]
pub struct ClientMessage<T> {
/// The trace context associates the message with a specific chain of causally-related actions,
/// possibly orchestrated across many distributed systems.
pub trace_context: trace::Context,
/// The message payload.
pub message: ClientMessageKind<T>,
}
/// Different messages that can be sent from a client to a server.
#[derive(Debug)]
#[cfg_attr(
feature = "serde",
derive(serde::Serialize, serde::Deserialize)
)]
#[non_exhaustive]
pub enum ClientMessageKind<T> {
/// A request initiated by a user. The server responds to a request by invoking a
/// service-provided request handler. The handler completes with a [`response`](Response), which
/// the server sends back to the client.
Request(Request<T>),
/// A command to cancel an in-flight request, automatically sent by the client when a response
/// future is dropped.
///
/// When received, the server will immediately cancel the main task (top-level future) of the
/// request handler for the associated request. Any tasks spawned by the request handler will
/// not be canceled, because the framework layer does not
/// know about them.
Cancel {
/// The ID of the request to cancel.
request_id: u64,
},
}
/// A request from a client to a server.
#[derive(Debug)]
#[cfg_attr(
feature = "serde",
derive(serde::Serialize, serde::Deserialize)
)]
#[non_exhaustive]
pub struct Request<T> {
/// Uniquely identifies the request across all requests sent over a single channel.
pub id: u64,
/// The request body.
pub message: T,
/// When the client expects the request to be complete by. The server will cancel the request
/// if it is not complete by this time.
#[cfg_attr(
feature = "serde",
serde(serialize_with = "util::serde::serialize_epoch_secs")
)]
#[cfg_attr(
feature = "serde",
serde(deserialize_with = "util::serde::deserialize_epoch_secs")
)]
pub deadline: SystemTime,
}
/// A response from a server to a client.
#[derive(Debug, PartialEq, Eq)]
#[cfg_attr(
feature = "serde",
derive(serde::Serialize, serde::Deserialize)
)]
#[non_exhaustive]
pub struct Response<T> {
/// The ID of the request being responded to.
pub request_id: u64,
/// The response body, or an error if the request failed.
pub message: Result<T, ServerError>,
}
/// An error response from a server to a client.
#[derive(Debug, PartialEq, Eq)]
#[cfg_attr(
feature = "serde",
derive(serde::Serialize, serde::Deserialize)
)]
#[non_exhaustive]
pub struct ServerError {
#[cfg_attr(
feature = "serde",
serde(serialize_with = "util::serde::serialize_io_error_kind_as_u32")
)]
#[cfg_attr(
feature = "serde",
serde(deserialize_with = "util::serde::deserialize_io_error_kind_from_u32")
)]
/// The type of error that occurred to fail the request.
pub kind: io::ErrorKind,
/// A message describing more detail about the error that occurred.
pub detail: Option<String>,
}
impl From<ServerError> for io::Error {
fn from(e: ServerError) -> io::Error {
io::Error::new(e.kind, e.detail.unwrap_or_default())
}
}
impl<T> Request<T> {
/// Returns the deadline for this request.
pub fn deadline(&self) -> &SystemTime {
&self.deadline
}
}
static INIT: Once = Once::new();
static mut SEED_SPAWN: Option<Box<dyn CloneSpawn>> = None;
thread_local! {
static SPAWN: RefCell<Box<dyn CloneSpawn>> = {
unsafe {
// INIT must always be called before accessing SPAWN.
// Otherwise, accessing SPAWN can trigger undefined behavior due to race conditions.
INIT.call_once(|| {});
RefCell::new(SEED_SPAWN.clone().expect("init() must be called."))
}
};
}
/// Initializes the RPC library with a mechanism to spawn futures on the user's runtime.
/// Client stubs and servers both use the initialized spawn.
///
/// Init only has an effect the first time it is called. If called previously, successive calls to
/// init are noops.
pub fn init(spawn: impl Spawn + Clone + 'static) {
unsafe {
INIT.call_once(|| {
SEED_SPAWN = Some(Box::new(spawn));
});
}
}
pub(crate) fn spawn(future: impl Future<Output = ()> + Send + 'static) -> Result<(), SpawnError> {
SPAWN.with(|spawn| {
spawn.borrow_mut().spawn(future)
})
}
trait CloneSpawn: Spawn {
fn box_clone(&self) -> Box<dyn CloneSpawn>;
}
impl Clone for Box<dyn CloneSpawn> {
fn clone(&self) -> Self {
self.box_clone()
}
}
impl<S: Spawn + Clone + 'static> CloneSpawn for S {
fn box_clone(&self) -> Box<dyn CloneSpawn> {
Box::new(self.clone())
}
}

251
rpc/src/server/filter.rs Normal file
View File

@@ -0,0 +1,251 @@
use crate::{
server::{Channel, Config},
util::Compact,
ClientMessage, Response, Transport,
};
use fnv::FnvHashMap;
use futures::{channel::mpsc, prelude::*, ready, stream::Fuse, task::{LocalWaker, Poll}};
use log::{debug, error, info, trace, warn};
use pin_utils::unsafe_pinned;
use std::{
collections::hash_map::Entry,
io,
marker::PhantomData,
net::{IpAddr, SocketAddr},
ops::Try,
option::NoneError,
pin::Pin,
};
/// Drops connections under configurable conditions:
///
/// 1. If the max number of connections is reached.
/// 2. If the max number of connections for a single IP is reached.
#[derive(Debug)]
pub struct ConnectionFilter<S, Req, Resp> {
listener: Fuse<S>,
closed_connections: mpsc::UnboundedSender<SocketAddr>,
closed_connections_rx: mpsc::UnboundedReceiver<SocketAddr>,
config: Config,
connections_per_ip: FnvHashMap<IpAddr, usize>,
open_connections: usize,
ghost: PhantomData<(Req, Resp)>,
}
enum NewConnection<Req, Resp, C> {
Filtered,
Accepted(Channel<Req, Resp, C>),
}
impl<Req, Resp, C> Try for NewConnection<Req, Resp, C> {
type Ok = Channel<Req, Resp, C>;
type Error = NoneError;
fn into_result(self) -> Result<Channel<Req, Resp, C>, NoneError> {
match self {
NewConnection::Filtered => Err(NoneError),
NewConnection::Accepted(channel) => Ok(channel),
}
}
fn from_error(_: NoneError) -> Self {
NewConnection::Filtered
}
fn from_ok(channel: Channel<Req, Resp, C>) -> Self {
NewConnection::Accepted(channel)
}
}
impl<S, Req, Resp> ConnectionFilter<S, Req, Resp> {
unsafe_pinned!(open_connections: usize);
unsafe_pinned!(config: Config);
unsafe_pinned!(connections_per_ip: FnvHashMap<IpAddr, usize>);
unsafe_pinned!(closed_connections_rx: mpsc::UnboundedReceiver<SocketAddr>);
unsafe_pinned!(listener: Fuse<S>);
/// Sheds new connections to stay under configured limits.
pub fn filter<C>(listener: S, config: Config) -> Self
where
S: Stream<Item = Result<C, io::Error>>,
C: Transport<Item = ClientMessage<Req>, SinkItem = Response<Resp>> + Send,
{
let (closed_connections, closed_connections_rx) = mpsc::unbounded();
ConnectionFilter {
listener: listener.fuse(),
closed_connections,
closed_connections_rx,
config,
connections_per_ip: FnvHashMap::default(),
open_connections: 0,
ghost: PhantomData,
}
}
fn handle_new_connection<C>(self: &mut Pin<&mut Self>, stream: C) -> NewConnection<Req, Resp, C>
where
C: Transport<Item = ClientMessage<Req>, SinkItem = Response<Resp>> + Send,
{
let peer = match stream.peer_addr() {
Ok(peer) => peer,
Err(e) => {
warn!("Could not get peer_addr of new connection: {}", e);
return NewConnection::Filtered;
}
};
let open_connections = *self.open_connections();
if open_connections >= self.config().max_connections {
warn!(
"[{}] Shedding connection because the maximum open connections \
limit is reached ({}/{}).",
peer,
open_connections,
self.config().max_connections
);
return NewConnection::Filtered;
}
let config = self.config.clone();
let open_connections_for_ip = self.increment_connections_for_ip(&peer)?;
*self.open_connections() += 1;
debug!(
"[{}] Opening channel ({}/{} connections for IP, {} total).",
peer,
open_connections_for_ip,
config.max_connections_per_ip,
self.open_connections(),
);
NewConnection::Accepted(Channel {
client_addr: peer,
closed_connections: self.closed_connections.clone(),
transport: stream.fuse(),
config,
ghost: PhantomData,
})
}
fn handle_closed_connection(self: &mut Pin<&mut Self>, addr: &SocketAddr) {
*self.open_connections() -= 1;
debug!(
"[{}] Closing channel. {} open connections remaining.",
addr, self.open_connections
);
self.decrement_connections_for_ip(&addr);
self.connections_per_ip().compact(0.1);
}
fn increment_connections_for_ip(self: &mut Pin<&mut Self>, peer: &SocketAddr) -> Option<usize> {
let max_connections_per_ip = self.config().max_connections_per_ip;
let mut occupied;
let mut connections_per_ip = self.connections_per_ip();
let occupied = match connections_per_ip.entry(peer.ip()) {
Entry::Vacant(vacant) => vacant.insert(0),
Entry::Occupied(o) => {
if *o.get() < max_connections_per_ip {
// Store the reference outside the block to extend the lifetime.
occupied = o;
occupied.get_mut()
} else {
info!(
"[{}] Opened max connections from IP ({}/{}).",
peer,
o.get(),
max_connections_per_ip
);
return None;
}
}
};
*occupied += 1;
Some(*occupied)
}
fn decrement_connections_for_ip(self: &mut Pin<&mut Self>, addr: &SocketAddr) {
let should_compact = match self.connections_per_ip().entry(addr.ip()) {
Entry::Vacant(_) => {
error!("[{}] Got vacant entry when closing connection.", addr);
return;
}
Entry::Occupied(mut occupied) => {
*occupied.get_mut() -= 1;
if *occupied.get() == 0 {
occupied.remove();
true
} else {
false
}
}
};
if should_compact {
self.connections_per_ip().compact(0.1);
}
}
fn poll_listener<C>(
self: &mut Pin<&mut Self>,
cx: &LocalWaker,
) -> Poll<Option<io::Result<NewConnection<Req, Resp, C>>>>
where
S: Stream<Item = Result<C, io::Error>>,
C: Transport<Item = ClientMessage<Req>, SinkItem = Response<Resp>> + Send,
{
match ready!(self.listener().poll_next_unpin(cx)?) {
Some(codec) => Poll::Ready(Some(Ok(self.handle_new_connection(codec)))),
None => Poll::Ready(None),
}
}
fn poll_closed_connections(
self: &mut Pin<&mut Self>,
cx: &LocalWaker,
) -> Poll<io::Result<()>> {
match ready!(self.closed_connections_rx().poll_next_unpin(cx)) {
Some(addr) => {
self.handle_closed_connection(&addr);
Poll::Ready(Ok(()))
}
None => unreachable!("Holding a copy of closed_connections and didn't close it."),
}
}
}
impl<S, Req, Resp, T> Stream for ConnectionFilter<S, Req, Resp>
where
S: Stream<Item = Result<T, io::Error>>,
T: Transport<Item = ClientMessage<Req>, SinkItem = Response<Resp>> + Send,
{
type Item = io::Result<Channel<Req, Resp, T>>;
fn poll_next(
mut self: Pin<&mut Self>,
cx: &LocalWaker,
) -> Poll<Option<io::Result<Channel<Req, Resp, T>>>> {
loop {
match (self.poll_listener(cx)?, self.poll_closed_connections(cx)?) {
(Poll::Ready(Some(NewConnection::Accepted(channel))), _) => {
return Poll::Ready(Some(Ok(channel)))
}
(Poll::Ready(Some(NewConnection::Filtered)), _) | (_, Poll::Ready(())) => {
trace!("Filtered a connection; {} open.", self.open_connections());
continue;
}
(Poll::Pending, Poll::Pending) => return Poll::Pending,
(Poll::Ready(None), Poll::Pending) => {
if *self.open_connections() > 0 {
trace!(
"Listener closed; {} open connections.",
self.open_connections()
);
return Poll::Pending;
}
trace!("Shutting down listener: all connections closed, and no more coming.");
return Poll::Ready(None);
}
}
}
}
}

599
rpc/src/server/mod.rs Normal file
View File

@@ -0,0 +1,599 @@
//! Provides a server that concurrently handles many connections sending multiplexed requests.
use crate::{
context::Context, util::deadline_compat, util::AsDuration, util::Compact, ClientMessage,
ClientMessageKind, Request, Response, ServerError, Transport,
};
use fnv::FnvHashMap;
use futures::{
channel::mpsc,
future::{abortable, AbortHandle},
prelude::*,
ready,
stream::Fuse,
task::{LocalWaker, Poll},
try_ready,
};
use humantime::format_rfc3339;
use log::{debug, error, info, trace, warn};
use pin_utils::{unsafe_pinned, unsafe_unpinned};
use std::{
error::Error as StdError,
io,
marker::PhantomData,
net::SocketAddr,
pin::Pin,
time::{Instant, SystemTime},
};
use tokio_timer::timeout;
use trace::{self, TraceId};
mod filter;
/// Manages clients, serving multiplexed requests over each connection.
#[derive(Debug)]
pub struct Server<Req, Resp> {
config: Config,
ghost: PhantomData<(Req, Resp)>,
}
/// Settings that control the behavior of the server.
#[non_exhaustive]
#[derive(Clone, Debug)]
pub struct Config {
/// The maximum number of clients that can be connected to the server at once. When at the
/// limit, existing connections are honored and new connections are rejected.
pub max_connections: usize,
/// The maximum number of clients per IP address that can be connected to the server at once.
/// When an IP is at the limit, existing connections are honored and new connections on that IP
/// address are rejected.
pub max_connections_per_ip: usize,
/// The maximum number of requests that can be in flight for each client. When a client is at
/// the in-flight request limit, existing requests are fulfilled and new requests are rejected.
/// Rejected requests are sent a response error.
pub max_in_flight_requests_per_connection: usize,
/// The number of responses per client that can be buffered server-side before being sent.
/// `pending_response_buffer` controls the buffer size of the channel that a server's
/// response tasks use to send responses to the client handler task.
pub pending_response_buffer: usize,
}
impl Default for Config {
fn default() -> Self {
Config {
max_connections: 1_000_000,
max_connections_per_ip: 1_000,
max_in_flight_requests_per_connection: 1_000,
pending_response_buffer: 100,
}
}
}
impl<Req, Resp> Server<Req, Resp> {
/// Returns a new server with configuration specified `config`.
pub fn new(config: Config) -> Self {
Server {
config,
ghost: PhantomData,
}
}
/// Returns the config for this server.
pub fn config(&self) -> &Config {
&self.config
}
/// Returns a stream of the incoming connections to the server.
pub fn incoming<S, T>(
self,
listener: S,
) -> impl Stream<Item = io::Result<Channel<Req, Resp, T>>>
where
Req: Send,
Resp: Send,
S: Stream<Item = io::Result<T>>,
T: Transport<Item = ClientMessage<Req>, SinkItem = Response<Resp>> + Send,
{
self::filter::ConnectionFilter::filter(listener, self.config.clone())
}
}
/// The future driving the server.
#[derive(Debug)]
pub struct Running<S, F> {
incoming: S,
request_handler: F,
}
impl<S, F> Running<S, F> {
unsafe_pinned!(incoming: S);
unsafe_unpinned!(request_handler: F);
}
impl<S, T, Req, Resp, F, Fut> Future for Running<S, F>
where
S: Sized + Stream<Item = io::Result<Channel<Req, Resp, T>>>,
Req: Send + 'static,
Resp: Send + 'static,
T: Transport<Item = ClientMessage<Req>, SinkItem = Response<Resp>> + Send + 'static,
F: FnMut(Context, Req) -> Fut + Send + 'static + Clone,
Fut: Future<Output = io::Result<Resp>> + Send + 'static,
{
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &LocalWaker) -> Poll<()> {
while let Some(channel) = ready!(self.incoming().poll_next(cx)) {
match channel {
Ok(channel) => {
let peer = channel.client_addr;
if let Err(e) = crate::spawn(channel.respond_with(self.request_handler().clone()))
{
warn!("[{}] Failed to spawn connection handler: {:?}", peer, e);
}
}
Err(e) => {
warn!("Incoming connection error: {}", e);
}
}
}
info!("Server shutting down.");
return Poll::Ready(());
}
}
/// A utility trait enabling a stream to fluently chain a request handler.
pub trait Handler<T, Req, Resp>
where
Self: Sized + Stream<Item = io::Result<Channel<Req, Resp, T>>>,
Req: Send,
Resp: Send,
T: Transport<Item = ClientMessage<Req>, SinkItem = Response<Resp>> + Send,
{
/// Responds to all requests with `request_handler`.
fn respond_with<F, Fut>(self, request_handler: F) -> Running<Self, F>
where
F: FnMut(Context, Req) -> Fut + Send + 'static + Clone,
Fut: Future<Output = io::Result<Resp>> + Send + 'static,
{
Running {
incoming: self,
request_handler,
}
}
}
impl<T, Req, Resp, S> Handler<T, Req, Resp> for S
where
S: Sized + Stream<Item = io::Result<Channel<Req, Resp, T>>>,
Req: Send,
Resp: Send,
T: Transport<Item = ClientMessage<Req>, SinkItem = Response<Resp>> + Send,
{}
/// Responds to all requests with `request_handler`.
/// The server end of an open connection with a client.
#[derive(Debug)]
pub struct Channel<Req, Resp, T> {
/// Writes responses to the wire and reads requests off the wire.
transport: Fuse<T>,
/// Signals the connection is closed when `Channel` is dropped.
closed_connections: mpsc::UnboundedSender<SocketAddr>,
/// Channel limits to prevent unlimited resource usage.
config: Config,
/// The address of the server connected to.
client_addr: SocketAddr,
/// Types the request and response.
ghost: PhantomData<(Req, Resp)>,
}
impl<Req, Resp, T> Drop for Channel<Req, Resp, T> {
fn drop(&mut self) {
trace!("[{}] Closing channel.", self.client_addr);
// Even in a bounded channel, each connection would have a guaranteed slot, so using
// an unbounded sender is actually no different. And, the bound is on the maximum number
// of open connections.
if self
.closed_connections
.unbounded_send(self.client_addr)
.is_err()
{
warn!(
"[{}] Failed to send closed connection message.",
self.client_addr
);
}
}
}
impl<Req, Resp, T> Channel<Req, Resp, T> {
unsafe_pinned!(transport: Fuse<T>);
}
impl<Req, Resp, T> Channel<Req, Resp, T>
where
T: Transport<Item = ClientMessage<Req>, SinkItem = Response<Resp>> + Send,
Req: Send,
Resp: Send,
{
pub(crate) fn start_send(self: &mut Pin<&mut Self>, response: Response<Resp>) -> io::Result<()> {
self.transport().start_send(response)
}
pub(crate) fn poll_ready(
self: &mut Pin<&mut Self>,
cx: &LocalWaker,
) -> Poll<io::Result<()>> {
self.transport().poll_ready(cx)
}
pub(crate) fn poll_flush(
self: &mut Pin<&mut Self>,
cx: &LocalWaker,
) -> Poll<io::Result<()>> {
self.transport().poll_flush(cx)
}
pub(crate) fn poll_next(
self: &mut Pin<&mut Self>,
cx: &LocalWaker,
) -> Poll<Option<io::Result<ClientMessage<Req>>>> {
self.transport().poll_next(cx)
}
/// Returns the address of the client connected to the channel.
pub fn client_addr(&self) -> &SocketAddr {
&self.client_addr
}
/// Respond to requests coming over the channel with `f`. Returns a future that drives the
/// responses and resolves when the connection is closed.
pub fn respond_with<F, Fut>(self, f: F) -> impl Future<Output = ()>
where
F: FnMut(Context, Req) -> Fut + Send + 'static,
Fut: Future<Output = io::Result<Resp>> + Send + 'static,
Req: 'static,
Resp: 'static,
{
let (responses_tx, responses) = mpsc::channel(self.config.pending_response_buffer);
let responses = responses.fuse();
let peer = self.client_addr;
ClientHandler {
channel: self,
f,
pending_responses: responses,
responses_tx,
in_flight_requests: FnvHashMap::default(),
}.unwrap_or_else(move |e| {
info!("[{}] ClientHandler errored out: {}", peer, e);
})
}
}
#[derive(Debug)]
struct ClientHandler<Req, Resp, T, F> {
channel: Channel<Req, Resp, T>,
/// Responses waiting to be written to the wire.
pending_responses: Fuse<mpsc::Receiver<(Context, Response<Resp>)>>,
/// Handed out to request handlers to fan in responses.
responses_tx: mpsc::Sender<(Context, Response<Resp>)>,
/// Number of requests currently being responded to.
in_flight_requests: FnvHashMap<u64, AbortHandle>,
/// Request handler.
f: F,
}
impl<Req, Resp, T, F> ClientHandler<Req, Resp, T, F> {
unsafe_pinned!(channel: Channel<Req, Resp, T>);
unsafe_pinned!(in_flight_requests: FnvHashMap<u64, AbortHandle>);
unsafe_pinned!(pending_responses: Fuse<mpsc::Receiver<(Context, Response<Resp>)>>);
unsafe_pinned!(responses_tx: mpsc::Sender<(Context, Response<Resp>)>);
// For this to be safe, field f must be private, and code in this module must never
// construct PinMut<F>.
unsafe_unpinned!(f: F);
}
impl<Req, Resp, T, F, Fut> ClientHandler<Req, Resp, T, F>
where
Req: Send + 'static,
Resp: Send + 'static,
T: Transport<Item = ClientMessage<Req>, SinkItem = Response<Resp>> + Send,
F: FnMut(Context, Req) -> Fut + Send + 'static,
Fut: Future<Output = io::Result<Resp>> + Send + 'static,
{
/// If at max in-flight requests, check that there's room to immediately write a throttled
/// response.
fn poll_ready_if_throttling(
self: &mut Pin<&mut Self>,
cx: &LocalWaker,
) -> Poll<io::Result<()>> {
if self.in_flight_requests.len()
>= self.channel.config.max_in_flight_requests_per_connection
{
let peer = self.channel().client_addr;
while let Poll::Pending = self.channel().poll_ready(cx)? {
info!(
"[{}] In-flight requests at max ({}), and transport is not ready.",
peer,
self.in_flight_requests().len(),
);
try_ready!(self.channel().poll_flush(cx));
}
}
Poll::Ready(Ok(()))
}
fn pump_read(self: &mut Pin<&mut Self>, cx: &LocalWaker) -> Poll<Option<io::Result<()>>> {
ready!(self.poll_ready_if_throttling(cx)?);
Poll::Ready(match ready!(self.channel().poll_next(cx)?) {
Some(message) => {
match message.message {
ClientMessageKind::Request(request) => {
self.handle_request(message.trace_context, request)?;
}
ClientMessageKind::Cancel { request_id } => {
self.cancel_request(&message.trace_context, request_id);
}
}
Some(Ok(()))
}
None => {
trace!("[{}] Read half closed", self.channel.client_addr);
None
}
})
}
fn pump_write(
self: &mut Pin<&mut Self>,
cx: &LocalWaker,
read_half_closed: bool,
) -> Poll<Option<io::Result<()>>> {
match self.poll_next_response(cx)? {
Poll::Ready(Some((_, response))) => {
self.channel().start_send(response)?;
Poll::Ready(Some(Ok(())))
}
Poll::Ready(None) => {
// Shutdown can't be done before we finish pumping out remaining responses.
ready!(self.channel().poll_flush(cx)?);
Poll::Ready(None)
}
Poll::Pending => {
// No more requests to process, so flush any requests buffered in the transport.
ready!(self.channel().poll_flush(cx)?);
// Being here means there are no staged requests and all written responses are
// fully flushed. So, if the read half is closed and there are no in-flight
// requests, then we can close the write half.
if read_half_closed && self.in_flight_requests().is_empty() {
Poll::Ready(None)
} else {
Poll::Pending
}
}
}
}
fn poll_next_response(
self: &mut Pin<&mut Self>,
cx: &LocalWaker,
) -> Poll<Option<io::Result<(Context, Response<Resp>)>>> {
// Ensure there's room to write a response.
while let Poll::Pending = self.channel().poll_ready(cx)? {
ready!(self.channel().poll_flush(cx)?);
}
let peer = self.channel().client_addr;
match ready!(self.pending_responses().poll_next(cx)) {
Some((ctx, response)) => {
if let Some(_) = self.in_flight_requests().remove(&response.request_id) {
self.in_flight_requests().compact(0.1);
}
trace!(
"[{}/{}] Staging response. In-flight requests = {}.",
ctx.trace_id(),
peer,
self.in_flight_requests().len(),
);
return Poll::Ready(Some(Ok((ctx, response))));
}
None => {
// This branch likely won't happen, since the ClientHandler is holding a Sender.
trace!("[{}] No new responses.", peer);
Poll::Ready(None)
}
}
}
fn handle_request(
self: &mut Pin<&mut Self>,
trace_context: trace::Context,
request: Request<Req>,
) -> io::Result<()> {
let request_id = request.id;
let peer = self.channel().client_addr;
let ctx = Context {
deadline: request.deadline,
trace_context,
};
let request = request.message;
if self.in_flight_requests().len()
>= self.channel().config.max_in_flight_requests_per_connection
{
debug!(
"[{}/{}] Client has reached in-flight request limit ({}/{}).",
ctx.trace_id(),
peer,
self.in_flight_requests().len(),
self.channel().config.max_in_flight_requests_per_connection
);
self.channel().start_send(Response {
request_id,
message: Err(ServerError {
kind: io::ErrorKind::WouldBlock,
detail: Some("Server throttled the request.".into()),
}),
})?;
return Ok(());
}
let deadline = ctx.deadline;
let timeout = deadline.as_duration();
trace!(
"[{}/{}] Received request with deadline {} (timeout {:?}).",
ctx.trace_id(),
peer,
format_rfc3339(deadline),
timeout,
);
let mut response_tx = self.responses_tx().clone();
let trace_id = *ctx.trace_id();
let response = self.f()(ctx.clone(), request);
let response = deadline_compat::Deadline::new(response, Instant::now() + timeout).then(
async move |result| {
let response = Response {
request_id,
message: match result {
Ok(message) => Ok(message),
Err(e) => Err(make_server_error(e, trace_id, peer, deadline)),
},
};
trace!("[{}/{}] Sending response.", trace_id, peer);
await!(response_tx.send((ctx, response)).unwrap_or_else(|_| ()));
},
);
let (abortable_response, abort_handle) = abortable(response);
crate::spawn(abortable_response.map(|_| ()))
.map_err(|e| {
io::Error::new(
io::ErrorKind::Other,
format!(
"Could not spawn response task. Is shutdown: {}",
e.is_shutdown()
),
)
})?;
self.in_flight_requests().insert(request_id, abort_handle);
Ok(())
}
fn cancel_request(self: &mut Pin<&mut Self>, trace_context: &trace::Context, request_id: u64) {
// It's possible the request was already completed, so it's fine
// if this is None.
if let Some(cancel_handle) = self.in_flight_requests().remove(&request_id) {
self.in_flight_requests().compact(0.1);
cancel_handle.abort();
let remaining = self.in_flight_requests().len();
trace!(
"[{}/{}] Request canceled. In-flight requests = {}",
trace_context.trace_id,
self.channel.client_addr,
remaining,
);
} else {
trace!(
"[{}/{}] Received cancellation, but response handler \
is already complete.",
trace_context.trace_id,
self.channel.client_addr
);
}
}
}
impl<Req, Resp, T, F, Fut> Future for ClientHandler<Req, Resp, T, F>
where
Req: Send + 'static,
Resp: Send + 'static,
T: Transport<Item = ClientMessage<Req>, SinkItem = Response<Resp>> + Send,
F: FnMut(Context, Req) -> Fut + Send + 'static,
Fut: Future<Output = io::Result<Resp>> + Send + 'static,
{
type Output = io::Result<()>;
fn poll(mut self: Pin<&mut Self>, cx: &LocalWaker) -> Poll<io::Result<()>> {
trace!("[{}] ClientHandler::poll", self.channel.client_addr);
loop {
let read = self.pump_read(cx)?;
match (read, self.pump_write(cx, read == Poll::Ready(None))?) {
(Poll::Ready(None), Poll::Ready(None)) => {
info!("[{}] Client disconnected.", self.channel.client_addr);
return Poll::Ready(Ok(()));
}
(read @ Poll::Ready(Some(())), write) | (read, write @ Poll::Ready(Some(()))) => {
trace!(
"[{}] read: {:?}, write: {:?}.",
self.channel.client_addr,
read,
write
)
}
(read, write) => {
trace!(
"[{}] read: {:?}, write: {:?} (not ready).",
self.channel.client_addr,
read,
write,
);
return Poll::Pending;
}
}
}
}
}
fn make_server_error(
e: timeout::Error<io::Error>,
trace_id: TraceId,
peer: SocketAddr,
deadline: SystemTime,
) -> ServerError {
if e.is_elapsed() {
debug!(
"[{}/{}] Response did not complete before deadline of {}s.",
trace_id,
peer,
format_rfc3339(deadline)
);
// No point in responding, since the client will have dropped the request.
ServerError {
kind: io::ErrorKind::TimedOut,
detail: Some(format!(
"Response did not complete before deadline of {}s.",
format_rfc3339(deadline)
)),
}
} else if e.is_timer() {
error!(
"[{}/{}] Response failed because of an issue with a timer: {}",
trace_id, peer, e
);
ServerError {
kind: io::ErrorKind::Other,
detail: Some(format!("{}", e)),
}
} else if e.is_inner() {
let e = e.into_inner().unwrap();
ServerError {
kind: e.kind(),
detail: Some(e.description().into()),
}
} else {
error!("[{}/{}] Unexpected response failure: {}", trace_id, peer, e);
ServerError {
kind: io::ErrorKind::Other,
detail: Some(format!("Server unexpectedly failed to respond: {}", e)),
}
}
}

View File

@@ -0,0 +1,151 @@
//! Transports backed by in-memory channels.
use crate::Transport;
use futures::{channel::mpsc, task::{LocalWaker}, Poll, Sink, Stream};
use pin_utils::unsafe_pinned;
use std::pin::Pin;
use std::{
io,
net::{IpAddr, Ipv4Addr, SocketAddr},
};
/// Returns two unbounded channel peers. Each [`Stream`] yields items sent through the other's
/// [`Sink`].
pub fn unbounded<SinkItem, Item>() -> (
UnboundedChannel<SinkItem, Item>,
UnboundedChannel<Item, SinkItem>,
) {
let (tx1, rx2) = mpsc::unbounded();
let (tx2, rx1) = mpsc::unbounded();
(
UnboundedChannel { tx: tx1, rx: rx1 },
UnboundedChannel { tx: tx2, rx: rx2 },
)
}
/// A bi-directional channel backed by an [`UnboundedSender`](mpsc::UnboundedSender)
/// and [`UnboundedReceiver`](mpsc::UnboundedReceiver).
#[derive(Debug)]
pub struct UnboundedChannel<Item, SinkItem> {
rx: mpsc::UnboundedReceiver<Item>,
tx: mpsc::UnboundedSender<SinkItem>,
}
impl<Item, SinkItem> UnboundedChannel<Item, SinkItem> {
unsafe_pinned!(rx: mpsc::UnboundedReceiver<Item>);
unsafe_pinned!(tx: mpsc::UnboundedSender<SinkItem>);
}
impl<Item, SinkItem> Stream for UnboundedChannel<Item, SinkItem> {
type Item = Result<Item, io::Error>;
fn poll_next(mut self: Pin<&mut Self>, cx: &LocalWaker) -> Poll<Option<io::Result<Item>>> {
self.rx().poll_next(cx).map(|option| option.map(Ok))
}
}
impl<Item, SinkItem> Sink for UnboundedChannel<Item, SinkItem> {
type SinkItem = SinkItem;
type SinkError = io::Error;
fn poll_ready(mut self: Pin<&mut Self>, cx: &LocalWaker) -> Poll<io::Result<()>> {
self.tx()
.poll_ready(cx)
.map_err(|_| io::Error::from(io::ErrorKind::NotConnected))
}
fn start_send(mut self: Pin<&mut Self>, item: SinkItem) -> io::Result<()> {
self.tx()
.start_send(item)
.map_err(|_| io::Error::from(io::ErrorKind::NotConnected))
}
fn poll_flush(
mut self: Pin<&mut Self>,
cx: &LocalWaker,
) -> Poll<Result<(), Self::SinkError>> {
self.tx()
.poll_flush(cx)
.map_err(|_| io::Error::from(io::ErrorKind::NotConnected))
}
fn poll_close(mut self: Pin<&mut Self>, cx: &LocalWaker) -> Poll<io::Result<()>> {
self.tx()
.poll_close(cx)
.map_err(|_| io::Error::from(io::ErrorKind::NotConnected))
}
}
impl<Item, SinkItem> Transport for UnboundedChannel<Item, SinkItem> {
type Item = Item;
type SinkItem = SinkItem;
fn peer_addr(&self) -> io::Result<SocketAddr> {
Ok(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0))
}
fn local_addr(&self) -> io::Result<SocketAddr> {
Ok(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0))
}
}
#[cfg(test)]
mod tests {
use crate::{client::{self, Client}, context, server::{self, Handler, Server}, transport};
use futures::{prelude::*, stream, compat::TokioDefaultSpawner};
use log::trace;
use std::io;
#[test]
fn integration() {
let _ = env_logger::try_init();
crate::init(TokioDefaultSpawner);
let (client_channel, server_channel) = transport::channel::unbounded();
let server = Server::<String, u64>::new(server::Config::default())
.incoming(stream::once(future::ready(Ok(server_channel))))
.respond_with(|_ctx, request| {
future::ready(request.parse::<u64>().map_err(|_| {
io::Error::new(
io::ErrorKind::InvalidInput,
format!("{:?} is not an int", request),
)
}))
});
let responses = async {
let mut client = await!(Client::new(client::Config::default(), client_channel))?;
let response1 = await!(client.call(context::current(), "123".into()));
let response2 = await!(client.call(context::current(), "abc".into()));
Ok::<_, io::Error>((response1, response2))
};
let (response1, response2) =
run_future(server.join(responses.unwrap_or_else(|e| panic!(e)))).1;
trace!("response1: {:?}, response2: {:?}", response1, response2);
assert!(response1.is_ok());
assert_eq!(response1.ok().unwrap(), 123);
assert!(response2.is_err());
assert_eq!(response2.err().unwrap().kind(), io::ErrorKind::InvalidInput);
}
fn run_future<F>(f: F) -> F::Output
where
F: Future + Send + 'static,
F::Output: Send + 'static,
{
let (tx, rx) = futures::channel::oneshot::channel();
tokio::run(
f.map(|result| tx.send(result).unwrap_or_else(|_| unreachable!()))
.boxed()
.unit_error()
.compat(),
);
futures::executor::block_on(rx).unwrap()
}
}

26
rpc/src/transport/mod.rs Normal file
View File

@@ -0,0 +1,26 @@
//! Provides a [`Transport`] trait as well as implementations.
//!
//! The rpc crate is transport- and protocol-agnostic. Any transport that impls [`Transport`]
//! can be plugged in, using whatever protocol it wants.
use futures::prelude::*;
use std::{io, net::SocketAddr};
pub mod channel;
/// A bidirectional stream ([`Sink`] + [`Stream`]) of messages.
pub trait Transport
where
Self: Stream<Item = io::Result<<Self as Transport>::Item>>,
Self: Sink<SinkItem = <Self as Transport>::SinkItem, SinkError = io::Error>,
{
/// The type read off the transport.
type Item;
/// The type written to the transport.
type SinkItem;
/// The address of the remote peer this transport is in communication with.
fn peer_addr(&self) -> io::Result<SocketAddr>;
/// The address of the local half of this transport.
fn local_addr(&self) -> io::Result<SocketAddr>;
}

View File

@@ -0,0 +1,63 @@
use futures::{
compat::{Compat01As03, Future01CompatExt},
prelude::*,
ready, task::{Poll, LocalWaker},
};
use pin_utils::unsafe_pinned;
use std::pin::Pin;
use std::time::Instant;
use tokio_timer::{timeout, Delay};
#[must_use = "futures do nothing unless polled"]
#[derive(Debug)]
pub struct Deadline<T> {
future: T,
delay: Compat01As03<Delay>,
}
impl<T> Deadline<T> {
unsafe_pinned!(future: T);
unsafe_pinned!(delay: Compat01As03<Delay>);
/// Create a new `Deadline` that completes when `future` completes or when
/// `deadline` is reached.
pub fn new(future: T, deadline: Instant) -> Deadline<T> {
Deadline::new_with_delay(future, Delay::new(deadline))
}
pub(crate) fn new_with_delay(future: T, delay: Delay) -> Deadline<T> {
Deadline {
future,
delay: delay.compat(),
}
}
/// Gets a mutable reference to the underlying future in this deadline.
pub fn get_mut(&mut self) -> &mut T {
&mut self.future
}
}
impl<T> Future for Deadline<T>
where
T: TryFuture,
{
type Output = Result<T::Ok, timeout::Error<T::Error>>;
fn poll(mut self: Pin<&mut Self>, waker: &LocalWaker) -> Poll<Self::Output> {
// First, try polling the future
match self.future().try_poll(waker) {
Poll::Ready(Ok(v)) => return Poll::Ready(Ok(v)),
Poll::Pending => {}
Poll::Ready(Err(e)) => return Poll::Ready(Err(timeout::Error::inner(e))),
}
let delay = self.delay().poll_unpin(waker);
// Now check the timer
match ready!(delay) {
Ok(_) => Poll::Ready(Err(timeout::Error::elapsed())),
Err(e) => Poll::Ready(Err(timeout::Error::timer(e))),
}
}
}

40
rpc/src/util/mod.rs Normal file
View File

@@ -0,0 +1,40 @@
use std::{
collections::HashMap,
hash::{BuildHasher, Hash},
time::{Duration, SystemTime},
};
pub mod deadline_compat;
#[cfg(feature = "serde")]
pub mod serde;
/// Types that can be represented by a [`Duration`].
pub trait AsDuration {
fn as_duration(&self) -> Duration;
}
impl AsDuration for SystemTime {
/// Duration of 0 if self is earlier than [`SystemTime::now`].
fn as_duration(&self) -> Duration {
self.duration_since(SystemTime::now()).unwrap_or_default()
}
}
/// Collection compaction; configurable `shrink_to_fit`.
pub trait Compact {
/// Compacts space if the ratio of length : capacity is less than `usage_ratio_threshold`.
fn compact(&mut self, usage_ratio_threshold: f64);
}
impl<K, V, H> Compact for HashMap<K, V, H>
where
K: Eq + Hash,
H: BuildHasher,
{
fn compact(&mut self, usage_ratio_threshold: f64) {
let usage_ratio = self.len() as f64 / self.capacity() as f64;
if usage_ratio < usage_ratio_threshold {
self.shrink_to_fit();
}
}
}

88
rpc/src/util/serde.rs Normal file
View File

@@ -0,0 +1,88 @@
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use std::{
io,
time::{Duration, SystemTime},
};
/// Serializes `system_time` as a `u64` equal to the number of seconds since the epoch.
pub fn serialize_epoch_secs<S>(system_time: &SystemTime, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
system_time
.duration_since(SystemTime::UNIX_EPOCH)
.unwrap_or(Duration::from_secs(0))
.as_secs() // Only care about second precision
.serialize(serializer)
}
/// Deserializes [`SystemTime`] from a `u64` equal to the number of seconds since the epoch.
pub fn deserialize_epoch_secs<'de, D>(deserializer: D) -> Result<SystemTime, D::Error>
where
D: Deserializer<'de>,
{
Ok(SystemTime::UNIX_EPOCH + Duration::from_secs(u64::deserialize(deserializer)?))
}
/// Serializes [`io::ErrorKind`] as a `u32`.
pub fn serialize_io_error_kind_as_u32<S>(
kind: &io::ErrorKind,
serializer: S,
) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
use std::io::ErrorKind::*;
match *kind {
NotFound => 0,
PermissionDenied => 1,
ConnectionRefused => 2,
ConnectionReset => 3,
ConnectionAborted => 4,
NotConnected => 5,
AddrInUse => 6,
AddrNotAvailable => 7,
BrokenPipe => 8,
AlreadyExists => 9,
WouldBlock => 10,
InvalidInput => 11,
InvalidData => 12,
TimedOut => 13,
WriteZero => 14,
Interrupted => 15,
Other => 16,
UnexpectedEof => 17,
_ => 16,
}.serialize(serializer)
}
/// Deserializes [`io::ErrorKind`] from a `u32`.
pub fn deserialize_io_error_kind_from_u32<'de, D>(
deserializer: D,
) -> Result<io::ErrorKind, D::Error>
where
D: Deserializer<'de>,
{
use std::io::ErrorKind::*;
Ok(match u32::deserialize(deserializer)? {
0 => NotFound,
1 => PermissionDenied,
2 => ConnectionRefused,
3 => ConnectionReset,
4 => ConnectionAborted,
5 => NotConnected,
6 => AddrInUse,
7 => AddrNotAvailable,
8 => BrokenPipe,
9 => AlreadyExists,
10 => WouldBlock,
11 => InvalidInput,
12 => InvalidData,
13 => TimedOut,
14 => WriteZero,
15 => Interrupted,
16 => Other,
17 => UnexpectedEof,
_ => Other,
})
}