Remove remaining feature flags

This commit is contained in:
Artem Vorotnikov
2019-10-01 16:33:14 +03:00
committed by Tim
parent 46bcc0f559
commit e91005855c
12 changed files with 114 additions and 56 deletions

View File

@@ -1,6 +1,6 @@
language: rust
rust:
- nightly
- beta
sudo: false
cache: cargo

View File

@@ -23,6 +23,7 @@ futures-preview = { version = "0.3.0-alpha.18" }
humantime = "1.0"
log = "0.4"
pin-utils = "0.1.0-alpha.4"
raii-counter = "0.2"
rand = "0.7"
tokio-timer = "0.3.0-alpha.4"
trace = { package = "tarpc-trace", version = "0.2", path = "../trace" }

View File

@@ -396,7 +396,9 @@ where
context: context::Context {
deadline: dispatch_request.ctx.deadline,
trace_context: dispatch_request.ctx.trace_context,
_non_exhaustive: (),
},
_non_exhaustive: (),
});
self.as_mut().transport().start_send(request)?;
self.as_mut().in_flight_requests().insert(
@@ -798,6 +800,7 @@ mod tests {
Response {
request_id: 0,
message: Ok("hello".into()),
_non_exhaustive: (),
},
);
block_on(dispatch).unwrap();

View File

@@ -103,7 +103,6 @@ where
}
/// Settings that control the behavior of the client.
#[non_exhaustive]
#[derive(Clone, Debug)]
pub struct Config {
/// The number of requests that can be in flight at once.
@@ -114,6 +113,8 @@ pub struct Config {
/// `pending_requests_buffer` controls the size of the channel clients use
/// to communicate with the request dispatch task.
pub pending_request_buffer: usize,
#[doc(hidden)]
_non_exhaustive: (),
}
impl Default for Config {
@@ -121,6 +122,7 @@ impl Default for Config {
Config {
max_in_flight_requests: 1_000,
pending_request_buffer: 100,
_non_exhaustive: (),
}
}
}

View File

@@ -17,7 +17,6 @@ use trace::{self, TraceId};
/// be different for each request in scope.
#[derive(Clone, Copy, Debug)]
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
#[non_exhaustive]
pub struct Context {
/// When the client expects the request to be complete by. The server should cancel the request
/// if it is not complete by this time.
@@ -36,6 +35,8 @@ pub struct Context {
/// include the same `trace_id` as that included on the original request. This way,
/// users can trace related actions across a distributed system.
pub trace_context: trace::Context,
#[doc(hidden)]
pub(crate) _non_exhaustive: (),
}
#[cfg(feature = "serde1")]
@@ -49,6 +50,7 @@ pub fn current() -> Context {
Context {
deadline: SystemTime::now() + Duration::from_secs(10),
trace_context: trace::Context::new_root(),
_non_exhaustive: (),
}
}

View File

@@ -4,7 +4,6 @@
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
#![feature(weak_counts, non_exhaustive, trait_alias)]
#![deny(missing_docs, missing_debug_implementations)]
//! An RPC framework providing client and server.
@@ -31,7 +30,7 @@ pub mod server;
pub mod transport;
pub(crate) mod util;
pub use crate::{client::Client, server::Server, transport::Transport};
pub use crate::{client::Client, server::Server, transport::sealed::Transport};
use futures::task::Poll;
use std::{io, time::SystemTime};
@@ -39,7 +38,6 @@ use std::{io, time::SystemTime};
/// A message from a client to a server.
#[derive(Debug)]
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
#[non_exhaustive]
pub enum ClientMessage<T> {
/// A request initiated by a user. The server responds to a request by invoking a
/// service-provided request handler. The handler completes with a [`response`](Response), which
@@ -60,12 +58,13 @@ pub enum ClientMessage<T> {
/// The ID of the request to cancel.
request_id: u64,
},
#[doc(hidden)]
_NonExhaustive,
}
/// A request from a client to a server.
#[derive(Clone, Copy, Debug)]
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
#[non_exhaustive]
pub struct Request<T> {
/// Trace context, deadline, and other cross-cutting concerns.
pub context: context::Context,
@@ -73,23 +72,25 @@ pub struct Request<T> {
pub id: u64,
/// The request body.
pub message: T,
#[doc(hidden)]
_non_exhaustive: (),
}
/// A response from a server to a client.
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
#[non_exhaustive]
pub struct Response<T> {
/// The ID of the request being responded to.
pub request_id: u64,
/// The response body, or an error if the request failed.
pub message: Result<T, ServerError>,
#[doc(hidden)]
_non_exhaustive: (),
}
/// An error response from a server to a client.
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
#[non_exhaustive]
pub struct ServerError {
#[cfg_attr(
feature = "serde1",
@@ -103,6 +104,8 @@ pub struct ServerError {
pub kind: io::ErrorKind,
/// A message describing more detail about the error that occurred.
pub detail: Option<String>,
#[doc(hidden)]
_non_exhaustive: (),
}
impl From<ServerError> for io::Error {

View File

@@ -19,6 +19,7 @@ use futures::{
};
use log::{debug, info, trace};
use pin_utils::{unsafe_pinned, unsafe_unpinned};
use raii_counter::{Counter, WeakCounter};
use std::sync::{Arc, Weak};
use std::{
collections::hash_map::Entry, convert::TryInto, fmt, hash::Hash, marker::Unpin, pin::Pin,
@@ -34,7 +35,7 @@ where
channels_per_key: u32,
dropped_keys: mpsc::UnboundedReceiver<K>,
dropped_keys_tx: mpsc::UnboundedSender<K>,
key_counts: FnvHashMap<K, Weak<Tracker<K>>>,
key_counts: FnvHashMap<K, TrackerPrototype<K>>,
keymaker: F,
}
@@ -42,26 +43,41 @@ where
#[derive(Debug)]
pub struct TrackedChannel<C, K> {
inner: C,
tracker: Arc<Tracker<K>>,
tracker: Tracker<K>,
}
impl<C, K> TrackedChannel<C, K> {
unsafe_pinned!(inner: C);
}
#[derive(Debug)]
#[derive(Clone, Debug)]
struct Tracker<K> {
key: Option<K>,
key: Option<Arc<K>>,
counter: Counter,
dropped_keys: mpsc::UnboundedSender<K>,
}
impl<K> Drop for Tracker<K> {
fn drop(&mut self) {
// Don't care if the listener is dropped.
let _ = self.dropped_keys.unbounded_send(self.key.take().unwrap());
if self.counter.count() <= 1 {
// Don't care if the listener is dropped.
match Arc::try_unwrap(self.key.take().unwrap()) {
Ok(key) => {
let _ = self.dropped_keys.unbounded_send(key);
}
_ => unreachable!(),
}
}
}
}
#[derive(Clone, Debug)]
struct TrackerPrototype<K> {
key: Weak<K>,
counter: WeakCounter,
dropped_keys: mpsc::UnboundedSender<K>,
}
impl<C, K> Stream for TrackedChannel<C, K>
where
C: Stream,
@@ -141,7 +157,7 @@ where
unsafe_pinned!(listener: Fuse<S>);
unsafe_pinned!(dropped_keys: mpsc::UnboundedReceiver<K>);
unsafe_pinned!(dropped_keys_tx: mpsc::UnboundedSender<K>);
unsafe_unpinned!(key_counts: FnvHashMap<K, Weak<Tracker<K>>>);
unsafe_unpinned!(key_counts: FnvHashMap<K, TrackerPrototype<K>>);
unsafe_unpinned!(channels_per_key: u32);
unsafe_unpinned!(keymaker: F);
}
@@ -182,7 +198,7 @@ where
trace!(
"[{}] Opening channel ({}/{}) channels for key.",
key,
Arc::strong_count(&tracker),
tracker.counter.count(),
self.as_mut().channels_per_key()
);
@@ -192,22 +208,28 @@ where
})
}
fn increment_channels_for_key(mut self: Pin<&mut Self>, key: K) -> Result<Arc<Tracker<K>>, K> {
fn increment_channels_for_key(mut self: Pin<&mut Self>, key: K) -> Result<Tracker<K>, K> {
let channels_per_key = self.channels_per_key;
let dropped_keys = self.dropped_keys_tx.clone();
let key_counts = &mut self.as_mut().key_counts();
match key_counts.entry(key.clone()) {
Entry::Vacant(vacant) => {
let tracker = Arc::new(Tracker {
key: Some(key),
dropped_keys,
});
let key = Arc::new(key);
let counter = WeakCounter::new();
vacant.insert(Arc::downgrade(&tracker));
Ok(tracker)
vacant.insert(TrackerPrototype {
key: Arc::downgrade(&key),
counter: counter.clone(),
dropped_keys: dropped_keys.clone(),
});
Ok(Tracker {
key: Some(key),
counter: counter.upgrade(),
dropped_keys,
})
}
Entry::Occupied(mut o) => {
let count = o.get().strong_count();
Entry::Occupied(o) => {
let count = o.get().counter.count();
if count >= channels_per_key.try_into().unwrap() {
info!(
"[{}] Opened max channels from key ({}/{}).",
@@ -215,15 +237,16 @@ where
);
Err(key)
} else {
Ok(o.get().upgrade().unwrap_or_else(|| {
let tracker = Arc::new(Tracker {
key: Some(key),
dropped_keys,
});
*o.get_mut() = Arc::downgrade(&tracker);
tracker
}))
let TrackerPrototype {
key,
counter,
dropped_keys,
} = o.get().clone();
Ok(Tracker {
counter: counter.upgrade(),
key: Some(key.upgrade().unwrap()),
dropped_keys,
})
}
}
}
@@ -296,10 +319,12 @@ fn ctx() -> Context<'static> {
#[test]
fn tracker_drop() {
use assert_matches::assert_matches;
use raii_counter::Counter;
let (tx, mut rx) = mpsc::unbounded();
Tracker {
key: Some(1),
key: Some(Arc::new(1)),
counter: Counter::new(),
dropped_keys: tx,
};
assert_matches!(rx.try_next(), Ok(Some(1)));
@@ -309,15 +334,17 @@ fn tracker_drop() {
fn tracked_channel_stream() {
use assert_matches::assert_matches;
use pin_utils::pin_mut;
use raii_counter::Counter;
let (chan_tx, chan) = mpsc::unbounded();
let (dropped_keys, _) = mpsc::unbounded();
let channel = TrackedChannel {
inner: chan,
tracker: Arc::new(Tracker {
key: Some(1),
tracker: Tracker {
key: Some(Arc::new(1)),
counter: Counter::new(),
dropped_keys,
}),
},
};
chan_tx.unbounded_send("test").unwrap();
@@ -329,15 +356,17 @@ fn tracked_channel_stream() {
fn tracked_channel_sink() {
use assert_matches::assert_matches;
use pin_utils::pin_mut;
use raii_counter::Counter;
let (chan, mut chan_rx) = mpsc::unbounded();
let (dropped_keys, _) = mpsc::unbounded();
let channel = TrackedChannel {
inner: chan,
tracker: Arc::new(Tracker {
key: Some(1),
tracker: Tracker {
key: Some(Arc::new(1)),
counter: Counter::new(),
dropped_keys,
}),
},
};
pin_mut!(channel);
@@ -359,12 +388,12 @@ fn channel_filter_increment_channels_for_key() {
let filter = ChannelFilter::new(listener, 2, |chan: &TestChannel| chan.key);
pin_mut!(filter);
let tracker1 = filter.as_mut().increment_channels_for_key("key").unwrap();
assert_eq!(Arc::strong_count(&tracker1), 1);
assert_eq!(tracker1.counter.count(), 1);
let tracker2 = filter.as_mut().increment_channels_for_key("key").unwrap();
assert_eq!(Arc::strong_count(&tracker1), 2);
assert_eq!(tracker1.counter.count(), 2);
assert_matches!(filter.increment_channels_for_key("key"), Err("key"));
drop(tracker2);
assert_eq!(Arc::strong_count(&tracker1), 1);
assert_eq!(tracker1.counter.count(), 1);
}
#[test]
@@ -383,20 +412,20 @@ fn channel_filter_handle_new_channel() {
.as_mut()
.handle_new_channel(TestChannel { key: "key" })
.unwrap();
assert_eq!(Arc::strong_count(&channel1.tracker), 1);
assert_eq!(channel1.tracker.counter.count(), 1);
let channel2 = filter
.as_mut()
.handle_new_channel(TestChannel { key: "key" })
.unwrap();
assert_eq!(Arc::strong_count(&channel1.tracker), 2);
assert_eq!(channel1.tracker.counter.count(), 2);
assert_matches!(
filter.handle_new_channel(TestChannel { key: "key" }),
Err("key")
);
drop(channel2);
assert_eq!(Arc::strong_count(&channel1.tracker), 1);
assert_eq!(channel1.tracker.counter.count(), 1);
}
#[test]
@@ -417,14 +446,14 @@ fn channel_filter_poll_listener() {
.unwrap();
let channel1 =
assert_matches!(filter.as_mut().poll_listener(&mut ctx()), Poll::Ready(Some(Ok(c))) => c);
assert_eq!(Arc::strong_count(&channel1.tracker), 1);
assert_eq!(channel1.tracker.counter.count(), 1);
new_channels
.unbounded_send(TestChannel { key: "key" })
.unwrap();
let _channel2 =
assert_matches!(filter.as_mut().poll_listener(&mut ctx()), Poll::Ready(Some(Ok(c))) => c);
assert_eq!(Arc::strong_count(&channel1.tracker), 2);
assert_eq!(channel1.tracker.counter.count(), 2);
new_channels
.unbounded_send(TestChannel { key: "key" })
@@ -432,7 +461,7 @@ fn channel_filter_poll_listener() {
let key =
assert_matches!(filter.as_mut().poll_listener(&mut ctx()), Poll::Ready(Some(Err(k))) => k);
assert_eq!(key, "key");
assert_eq!(Arc::strong_count(&channel1.tracker), 2);
assert_eq!(channel1.tracker.counter.count(), 2);
}
#[test]

View File

@@ -49,7 +49,6 @@ impl<Req, Resp> Default for Server<Req, Resp> {
}
/// Settings that control the behavior of the server.
#[non_exhaustive]
#[derive(Clone, Debug)]
pub struct Config {
/// The number of responses per client that can be buffered server-side before being sent.
@@ -307,6 +306,7 @@ where
} => {
self.as_mut().cancel_request(&trace_context, request_id);
}
ClientMessage::_NonExhaustive => unreachable!(),
},
None => return Poll::Ready(None),
}
@@ -582,9 +582,11 @@ where
"Response did not complete before deadline of {}s.",
format_rfc3339(self.deadline)
)),
_non_exhaustive: (),
})
}
},
_non_exhaustive: (),
});
*self.as_mut().state() = RespState::PollReady;
}

View File

@@ -85,9 +85,11 @@ impl<Req, Resp> FakeChannel<io::Result<Request<Req>>, Response<Resp>> {
context: context::Context {
deadline: SystemTime::UNIX_EPOCH,
trace_context: Default::default(),
_non_exhaustive: (),
},
id,
message,
_non_exhaustive: (),
}));
}
}

View File

@@ -66,7 +66,9 @@ where
message: Err(ServerError {
kind: io::ErrorKind::WouldBlock,
detail: Some("Server throttled the request.".into()),
_non_exhaustive: (),
}),
_non_exhaustive: (),
})?;
}
None => return Poll::Ready(None),
@@ -315,6 +317,7 @@ fn throttler_start_send() {
.start_send(Response {
request_id: 0,
message: Ok(1),
_non_exhaustive: (),
})
.unwrap();
assert!(throttler.inner.in_flight_requests.is_empty());
@@ -323,6 +326,7 @@ fn throttler_start_send() {
Some(&Response {
request_id: 0,
message: Ok(1),
_non_exhaustive: ()
})
);
}

View File

@@ -14,6 +14,17 @@ use std::io;
pub mod channel;
/// A bidirectional stream ([`Sink`] + [`Stream`]) of messages.
pub trait Transport<SinkItem, Item> =
Stream<Item = io::Result<Item>> + Sink<SinkItem, Error = io::Error>;
pub(crate) mod sealed {
use super::*;
/// A bidirectional stream ([`Sink`] + [`Stream`]) of messages.
pub trait Transport<SinkItem, Item>:
Stream<Item = io::Result<Item>> + Sink<SinkItem, Error = io::Error>
{
}
impl<T, SinkItem, Item> Transport<SinkItem, Item> for T where
T: Stream<Item = io::Result<Item>> + Sink<SinkItem, Error = io::Error> + ?Sized
{
}
}

View File

@@ -203,7 +203,6 @@
//! items expanded by a `service!` invocation.
#![deny(missing_docs, missing_debug_implementations)]
#![feature(external_doc)]
pub use rpc::*;
/// The main macro that creates RPC services.