Don't bake in Send + 'static.

Send + 'static was baked in to make it possible to spawn futures onto
the default executor. We can accomplish the same thing by offering
helper fns that do the spawning while not requiring it for the rest of
the functionality.

Fixes https://github.com/google/tarpc/issues/212
This commit is contained in:
Tim Kuehn
2019-07-22 13:13:08 -07:00
committed by Tim
parent 13cb14a119
commit 50879d2acb
15 changed files with 428 additions and 210 deletions

View File

@@ -182,7 +182,7 @@ async fn main() -> io::Result<()> {
// WorldClient is generated by the macro. It has a constructor `new` that takes a config and
// any Transport as input
let mut client = WorldClient::new(client::Config::default(), client_transport).await?;
let mut client = WorldClient::new(client::Config::default(), client_transport).spawn()?;
// The client has an RPC method for each RPC defined in the annotated trait. It takes the same
// args as defined, with the addition of a Context, which is always the first arg. The Context

View File

@@ -46,7 +46,7 @@ async fn main() -> io::Result<()> {
// WorldClient is generated by the service attribute. It has a constructor `new` that takes a
// config and any Transport as input.
let mut client = service::WorldClient::new(client::Config::default(), transport).await?;
let mut client = service::WorldClient::new(client::Config::default(), transport).spawn()?;
// The client has an RPC method for each RPC defined in the annotated trait. It takes the same
// args as defined, with the addition of a Context, which is always the first arg. The Context

View File

@@ -75,11 +75,11 @@ async fn main() -> io::Result<()> {
// the generated World trait.
.map(|channel| {
let server = HelloServer(channel.as_ref().as_ref().peer_addr().unwrap());
channel.respond_with(server.serve())
channel.respond_with(server.serve()).execute()
})
// Max 10 channels.
.buffer_unordered(10)
.for_each(|_| futures::future::ready(()))
.for_each(|_| async {})
.await;
Ok(())

View File

@@ -21,7 +21,7 @@ use syn::{
punctuated::Punctuated,
spanned::Spanned,
token::Comma,
ArgCaptured, Attribute, FnArg, Ident, Pat, ReturnType, Token, Visibility,
ArgCaptured, Attribute, FnArg, Ident, Lit, LitBool, MetaNameValue, Pat, ReturnType, Token, Visibility,
};
struct Service {
@@ -126,6 +126,40 @@ impl Parse for RpcMethod {
}
}
// If `derive_serde` meta item is not present, defaults to cfg!(feature = "serde1").
// `derive_serde` can only be true when serde1 is enabled.
struct DeriveSerde(bool);
impl Parse for DeriveSerde {
fn parse(input: ParseStream) -> syn::Result<Self> {
if input.is_empty() {
return Ok(DeriveSerde(cfg!(feature = "serde1")))
}
match input.parse::<MetaNameValue>()? {
MetaNameValue { ref ident, ref lit, .. } if ident == "derive_serde" => {
match lit {
Lit::Bool(LitBool{value: true, ..}) if cfg!(feature = "serde1") => Ok(DeriveSerde(true)),
Lit::Bool(LitBool{value: true, ..}) => Err(syn::Error::new(
lit.span(),
"To enable serde, first enable the `serde1` feature of tarpc",
)),
Lit::Bool(LitBool{value: false, ..}) => Ok(DeriveSerde(false)),
lit => Err(syn::Error::new(
lit.span(),
"`derive_serde` expects a value of type `bool`",
)),
}
}
MetaNameValue { ident, .. } => {
Err(syn::Error::new(
ident.span(),
"tarpc::service only supports one meta item, `derive_serde = {bool}`",
))
}
}
}
}
/// Generates:
/// - service trait
/// - serve fn
@@ -135,13 +169,7 @@ impl Parse for RpcMethod {
/// - ResponseFut Future
#[proc_macro_attribute]
pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream {
struct EmptyArgs;
impl Parse for EmptyArgs {
fn parse(_: ParseStream) -> syn::Result<Self> {
Ok(EmptyArgs)
}
}
parse_macro_input!(attr as EmptyArgs);
let derive_serde = parse_macro_input!(attr as DeriveSerde);
let Service {
attrs,
@@ -223,14 +251,15 @@ pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream {
let response_fut_ident_repeated2 = response_fut_ident_repeated.clone();
let server_ident = Ident::new(&format!("Serve{}", ident), ident.span());
#[cfg(feature = "serde1")]
let derive_serialize = quote!(#[derive(serde::Serialize, serde::Deserialize)]);
#[cfg(not(feature = "serde1"))]
let derive_serialize = quote!();
let derive_serialize = if derive_serde.0 {
quote!(#[derive(serde::Serialize, serde::Deserialize)])
} else {
quote!()
};
let tokens = quote! {
#( #attrs )*
#vis trait #ident: Clone + Send + 'static {
#vis trait #ident: Clone {
#( #types_and_fns )*
/// Returns a serving function to use with tarpc::server::Server.
@@ -322,12 +351,18 @@ pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream {
impl #client_ident {
/// Returns a new client stub that sends requests over the given transport.
#vis async fn new<T>(config: tarpc::client::Config, transport: T)
-> std::io::Result<Self>
#vis fn new<T>(config: tarpc::client::Config, transport: T)
-> tarpc::client::NewClient<
Self,
tarpc::client::channel::RequestDispatch<#request_ident, #response_ident, T>>
where
T: tarpc::Transport<tarpc::ClientMessage<#request_ident>, tarpc::Response<#response_ident>> + Send + 'static
T: tarpc::Transport<tarpc::ClientMessage<#request_ident>, tarpc::Response<#response_ident>>
{
Ok(#client_ident(tarpc::client::new(config, transport).await?))
let new_client = tarpc::client::new(config, transport);
tarpc::client::NewClient {
client: #client_ident(new_client.client),
dispatch: new_client.dispatch,
}
}
}

View File

@@ -19,11 +19,11 @@ use futures::{
Poll,
};
use humantime::format_rfc3339;
use log::{debug, error, info, trace};
use log::{debug, info, trace};
use pin_utils::{unsafe_pinned, unsafe_unpinned};
use std::{
io,
marker::{self, Unpin},
marker::Unpin,
pin::Pin,
sync::{
atomic::{AtomicU64, Ordering},
@@ -33,7 +33,7 @@ use std::{
};
use trace::SpanId;
use super::Config;
use super::{Config, NewClient};
/// Handles communication from the client to request dispatch.
#[derive(Debug)]
@@ -246,48 +246,39 @@ impl<Resp> Drop for DispatchResponse<Resp> {
}
}
/// Spawns a dispatch task on the default executor that manages the lifecycle of requests initiated
/// by the returned [`Channel`].
pub async fn spawn<Req, Resp, C>(config: Config, transport: C) -> io::Result<Channel<Req, Resp>>
/// Returns a channel and dispatcher that manages the lifecycle of requests initiated by the
/// channel.
pub fn new<Req, Resp, C>(
config: Config,
transport: C,
) -> NewClient<Channel<Req, Resp>, RequestDispatch<Req, Resp, C>>
where
Req: marker::Send + 'static,
Resp: marker::Send + 'static,
C: Transport<ClientMessage<Req>, Response<Resp>> + marker::Send + 'static,
C: Transport<ClientMessage<Req>, Response<Resp>>,
{
let (to_dispatch, pending_requests) = mpsc::channel(config.pending_request_buffer);
let (cancellation, canceled_requests) = cancellations();
let canceled_requests = canceled_requests.fuse();
crate::spawn(
RequestDispatch {
NewClient {
client: Channel {
to_dispatch,
cancellation,
next_request_id: Arc::new(AtomicU64::new(0)),
},
dispatch: RequestDispatch {
config,
canceled_requests,
transport: transport.fuse(),
in_flight_requests: FnvHashMap::default(),
pending_requests: pending_requests.fuse(),
}
.unwrap_or_else(move |e| error!("Connection broken: {}", e)),
)
.map_err(|e| {
io::Error::new(
io::ErrorKind::Other,
format!(
"Could not spawn client dispatch task. Is shutdown: {}",
e.is_shutdown()
),
)
})?;
Ok(Channel {
to_dispatch,
cancellation,
next_request_id: Arc::new(AtomicU64::new(0)),
})
},
}
}
/// Handles the lifecycle of requests, writing requests to the wire, managing cancellations,
/// and dispatching responses to the appropriate channel.
struct RequestDispatch<Req, Resp, C> {
#[derive(Debug)]
pub struct RequestDispatch<Req, Resp, C> {
/// Writes requests to the wire and reads responses off the wire.
transport: Fuse<C>,
/// Requests waiting to be written to the wire.
@@ -302,8 +293,6 @@ struct RequestDispatch<Req, Resp, C> {
impl<Req, Resp, C> RequestDispatch<Req, Resp, C>
where
Req: marker::Send,
Resp: marker::Send,
C: Transport<ClientMessage<Req>, Response<Resp>>,
{
unsafe_pinned!(in_flight_requests: FnvHashMap<u64, InFlightData<Resp>>);
@@ -492,8 +481,6 @@ where
impl<Req, Resp, C> Future for RequestDispatch<Req, Resp, C>
where
Req: marker::Send,
Resp: marker::Send,
C: Transport<ClientMessage<Req>, Response<Resp>>,
{
type Output = io::Result<()>;
@@ -532,6 +519,7 @@ struct DispatchRequest<Req, Resp> {
response_completion: oneshot::Sender<Response<Resp>>,
}
#[derive(Debug)]
struct InFlightData<Resp> {
ctx: context::Context,
response_completion: oneshot::Sender<Response<Resp>>,
@@ -776,7 +764,7 @@ mod tests {
};
use futures_test::task::noop_waker_ref;
use std::time::Duration;
use std::{marker, pin::Pin, sync::atomic::AtomicU64, sync::Arc, time::Instant};
use std::{pin::Pin, sync::atomic::AtomicU64, sync::Arc, time::Instant};
#[test]
fn dispatch_response_cancels_on_timeout() {
@@ -955,7 +943,7 @@ mod tests {
impl<T, E> PollTest for Poll<Option<Result<T, E>>>
where
E: ::std::fmt::Display + marker::Send + 'static,
E: ::std::fmt::Display,
{
type T = Option<T>;

View File

@@ -6,13 +6,14 @@
//! Provides a client that connects to a server and sends multiplexed requests.
use crate::{context, ClientMessage, Response, Transport};
use crate::context;
use futures::prelude::*;
use log::error;
use std::io;
/// Provides a [`Client`] backed by a transport.
pub mod channel;
pub use self::channel::Channel;
pub use channel::{new, Channel};
/// Sends multiplexed requests to, and receives responses from, a server.
pub trait Client<'a, Req> {
@@ -125,15 +126,34 @@ impl Default for Config {
}
}
/// Creates a new Client by wrapping a [`Transport`] and spawning a dispatch task
/// that manages the lifecycle of requests.
///
/// Must only be called from on an executor.
pub async fn new<Req, Resp, T>(config: Config, transport: T) -> io::Result<Channel<Req, Resp>>
where
Req: Send + 'static,
Resp: Send + 'static,
T: Transport<ClientMessage<Req>, Response<Resp>> + Send + 'static,
{
Ok(channel::spawn(config, transport).await?)
/// A channel and dispatch pair. The dispatch drives the sending and receiving of requests
/// and must be polled continuously or spawned.
#[derive(Debug)]
pub struct NewClient<C, D> {
/// The new client.
pub client: C,
/// The client's dispatch.
pub dispatch: D,
}
impl<C, D> NewClient<C, D>
where
D: Future<Output = io::Result<()>> + Send + 'static,
{
/// Helper method to spawn the dispatch on the default executor.
pub fn spawn(self) -> io::Result<C> {
let dispatch = self
.dispatch
.unwrap_or_else(move |e| error!("Connection broken: {}", e));
crate::spawn(dispatch).map_err(|e| {
io::Error::new(
io::ErrorKind::Other,
format!(
"Could not spawn client dispatch task. Is shutdown: {}",
e.is_shutdown()
),
)
})?;
Ok(self.client)
}
}

View File

@@ -78,7 +78,7 @@ impl Config {
/// Returns a channel backed by `transport` and configured with `self`.
pub fn channel<Req, Resp, T>(self, transport: T) -> BaseChannel<Req, Resp, T>
where
T: Transport<Response<Resp>, ClientMessage<Req>> + Send,
T: Transport<Response<Resp>, ClientMessage<Req>>,
{
BaseChannel::new(self, transport)
}
@@ -101,49 +101,13 @@ impl<Req, Resp> Server<Req, Resp> {
/// Returns a stream of server channels.
pub fn incoming<S, T>(self, listener: S) -> impl Stream<Item = BaseChannel<Req, Resp, T>>
where
Req: Send,
Resp: Send,
S: Stream<Item = T>,
T: Transport<Response<Resp>, ClientMessage<Req>> + Send,
T: Transport<Response<Resp>, ClientMessage<Req>>,
{
listener.map(move |t| BaseChannel::new(self.config.clone(), t))
}
}
/// The future driving the server.
#[derive(Debug)]
pub struct Running<St, Se> {
incoming: St,
server: Se,
}
impl<St, Se> Running<St, Se> {
unsafe_pinned!(incoming: St);
unsafe_unpinned!(server: Se);
}
impl<St, C, Se> Future for Running<St, Se>
where
St: Sized + Stream<Item = C>,
C: Channel + Send + 'static,
Se: Serve<C::Req, Resp = C::Resp> + Send + 'static,
Se::Fut: Send + 'static
{
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
while let Some(channel) = ready!(self.as_mut().incoming().poll_next(cx)) {
if let Err(e) =
crate::spawn(channel.respond_with(self.as_mut().server().clone()))
{
warn!("Failed to spawn channel handler: {:?}", e);
}
}
info!("Server shutting down.");
Poll::Ready(())
}
}
/// Basically a Fn(Req) -> impl Future<Output = Resp>;
pub trait Serve<Req>: Sized + Clone {
/// Type of response.
@@ -191,8 +155,7 @@ where
/// Responds to all requests with `server`.
fn respond_with<S>(self, server: S) -> Running<Self, S>
where
S: Serve<C::Req, Resp = C::Resp> + Send + 'static,
S::Fut: Send + 'static,
S: Serve<C::Req, Resp = C::Resp>,
{
Running {
incoming: self,
@@ -226,7 +189,7 @@ impl<Req, Resp, T> BaseChannel<Req, Resp, T> {
impl<Req, Resp, T> BaseChannel<Req, Resp, T>
where
T: Transport<Response<Resp>, ClientMessage<Req>> + Send,
T: Transport<Response<Resp>, ClientMessage<Req>>,
{
/// Creates a new channel backed by `transport` and configured with `config`.
pub fn new(config: Config, transport: T) -> Self {
@@ -288,10 +251,10 @@ where
Self: Transport<Response<<Self as Channel>::Resp>, Request<<Self as Channel>::Req>>,
{
/// Type of request item.
type Req: Send + 'static;
type Req;
/// Type of response sink item.
type Resp: Send + 'static;
type Resp;
/// Configuration of the channel.
fn config(&self) -> &Config;
@@ -314,16 +277,15 @@ where
/// Respond to requests coming over the channel with `f`. Returns a future that drives the
/// responses and resolves when the connection is closed.
fn respond_with<S>(self, server: S) -> ResponseHandler<Self, S>
fn respond_with<S>(self, server: S) -> ClientHandler<Self, S>
where
S: Serve<Self::Req, Resp = Self::Resp> + Send + 'static,
S::Fut: Send + 'static,
S: Serve<Self::Req, Resp = Self::Resp>,
Self: Sized,
{
let (responses_tx, responses) = mpsc::channel(self.config().pending_response_buffer);
let responses = responses.fuse();
ResponseHandler {
ClientHandler {
channel: self,
server,
pending_responses: responses,
@@ -334,9 +296,7 @@ where
impl<Req, Resp, T> Stream for BaseChannel<Req, Resp, T>
where
T: Transport<Response<Resp>, ClientMessage<Req>> + Send + 'static,
Req: Send + 'static,
Resp: Send + 'static,
T: Transport<Response<Resp>, ClientMessage<Req>>,
{
type Item = io::Result<Request<Req>>;
@@ -362,9 +322,7 @@ where
impl<Req, Resp, T> Sink<Response<Resp>> for BaseChannel<Req, Resp, T>
where
T: Transport<Response<Resp>, ClientMessage<Req>> + Send + 'static,
Req: Send + 'static,
Resp: Send + 'static,
T: Transport<Response<Resp>, ClientMessage<Req>>,
{
type Error = io::Error;
@@ -402,9 +360,7 @@ impl<Req, Resp, T> AsRef<T> for BaseChannel<Req, Resp, T> {
impl<Req, Resp, T> Channel for BaseChannel<Req, Resp, T>
where
T: Transport<Response<Resp>, ClientMessage<Req>> + Send + 'static,
Req: Send + 'static,
Resp: Send + 'static,
T: Transport<Response<Resp>, ClientMessage<Req>>,
{
type Req = Req;
type Resp = Resp;
@@ -429,7 +385,7 @@ where
/// A running handler serving all requests coming over a channel.
#[derive(Debug)]
pub struct ResponseHandler<C, S>
pub struct ClientHandler<C, S>
where
C: Channel,
{
@@ -438,11 +394,11 @@ where
pending_responses: Fuse<mpsc::Receiver<(context::Context, Response<C::Resp>)>>,
/// Handed out to request handlers to fan in responses.
responses_tx: mpsc::Sender<(context::Context, Response<C::Resp>)>,
/// Request handler.
/// Server
server: S,
}
impl<C, S> ResponseHandler<C, S>
impl<C, S> ClientHandler<C, S>
where
C: Channel,
{
@@ -450,22 +406,21 @@ where
unsafe_pinned!(pending_responses: Fuse<mpsc::Receiver<(context::Context, Response<C::Resp>)>>);
unsafe_pinned!(responses_tx: mpsc::Sender<(context::Context, Response<C::Resp>)>);
// For this to be safe, field f must be private, and code in this module must never
// construct PinMut<F>.
// construct PinMut<S>.
unsafe_unpinned!(server: S);
}
impl<C, S> ResponseHandler<C, S>
impl<C, S> ClientHandler<C, S>
where
C: Channel,
S: Serve<C::Req, Resp = C::Resp> + Send + 'static,
S::Fut: Send + 'static,
S: Serve<C::Req, Resp = C::Resp>,
{
fn pump_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> PollIo<()> {
fn pump_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> PollIo<RequestHandler<S::Fut, C::Resp>> {
match ready!(self.as_mut().channel().poll_next(cx)?) {
Some(request) => {
self.handle_request(request)?;
Poll::Ready(Some(Ok(())))
}
Some(request) => Poll::Ready(Some(Ok(self.handle_request(request)))),
None => Poll::Ready(None),
}
}
@@ -518,13 +473,16 @@ where
match ready!(self.as_mut().pending_responses().poll_next(cx)) {
Some((ctx, response)) => Poll::Ready(Some(Ok((ctx, response)))),
None => {
// This branch likely won't happen, since the ResponseHandler is holding a Sender.
// This branch likely won't happen, since the ClientHandler is holding a Sender.
Poll::Ready(None)
}
}
}
fn handle_request(mut self: Pin<&mut Self>, request: Request<C::Req>) -> io::Result<()> {
fn handle_request(
mut self: Pin<&mut Self>,
request: Request<C::Req>,
) -> RequestHandler<S::Fut, C::Resp> {
let request_id = request.id;
let deadline = request.context.deadline;
let timeout = deadline.as_duration();
@@ -536,70 +494,144 @@ where
);
let ctx = request.context;
let request = request.message;
let mut response_tx = self.as_mut().responses_tx().clone();
let trace_id = *ctx.trace_id();
let response = self.as_mut().server().clone().serve(ctx, request);
let response = deadline_compat::Deadline::new(response, Instant::now() + timeout).then(
move |result| {
async move {
let response = Response {
request_id,
message: match result {
Ok(message) => Ok(message),
Err(e) => Err(make_server_error(e, trace_id, deadline)),
},
};
trace!("[{}] Sending response.", trace_id);
response_tx
.send((ctx, response))
.unwrap_or_else(|_| ())
.await;
}
},
);
let response = Resp {
state: RespState::PollResp,
request_id,
ctx,
deadline,
f: deadline_compat::Deadline::new(response, Instant::now() + timeout),
response: None,
response_tx: self.as_mut().responses_tx().clone(),
};
let abort_registration = self.as_mut().channel().start_request(request_id);
let response = Abortable::new(response, abort_registration);
crate::spawn(response.map(|_| ())).map_err(|e| {
io::Error::new(
io::ErrorKind::Other,
format!(
"Could not spawn response task. Is shutdown: {}",
e.is_shutdown()
),
)
})?;
Ok(())
RequestHandler {
resp: Abortable::new(response, abort_registration),
}
}
}
impl<C, S> Future for ResponseHandler<C, S>
/// A future fulfilling a single client request.
#[derive(Debug)]
pub struct RequestHandler<F, R> {
resp: Abortable<Resp<F, R>>,
}
impl<F, R> RequestHandler<F, R> {
unsafe_pinned!(resp: Abortable<Resp<F, R>>);
}
impl<F, R> Future for RequestHandler<F, R>
where
C: Channel,
S: Serve<C::Req, Resp = C::Resp> + Send + 'static,
S::Fut: Send + 'static,
F: Future<Output = R>,
{
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
let _ = ready!(self.resp().poll(cx));
Poll::Ready(())
}
}
#[derive(Debug)]
struct Resp<F, R> {
state: RespState,
request_id: u64,
ctx: context::Context,
deadline: SystemTime,
f: deadline_compat::Deadline<F>,
response: Option<Response<R>>,
response_tx: mpsc::Sender<(context::Context, Response<R>)>,
}
#[derive(Debug)]
enum RespState {
PollResp,
PollReady,
PollFlush,
}
impl<F, R> Resp<F, R> {
unsafe_pinned!(f: deadline_compat::Deadline<F>);
unsafe_pinned!(response_tx: mpsc::Sender<(context::Context, Response<R>)>);
unsafe_unpinned!(response: Option<Response<R>>);
unsafe_unpinned!(state: RespState);
}
impl<F, R> Future for Resp<F, R>
where
F: Future<Output = R>,
{
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
move || -> Poll<io::Result<()>> {
loop {
let read = self.as_mut().pump_read(cx)?;
match (
read,
self.as_mut().pump_write(cx, read == Poll::Ready(None))?,
) {
(Poll::Ready(None), Poll::Ready(None)) => {
return Poll::Ready(Ok(()));
loop {
match self.as_mut().state() {
RespState::PollResp => {
let result = ready!(self.as_mut().f().poll(cx));
*self.as_mut().response() = Some(Response {
request_id: self.request_id,
message: match result {
Ok(message) => Ok(message),
Err(e) => {
Err(make_server_error(e, *self.ctx.trace_id(), self.deadline))
}
},
});
*self.as_mut().state() = RespState::PollReady;
}
RespState::PollReady => {
let ready = ready!(self.as_mut().response_tx().poll_ready(cx));
if ready.is_err() {
return Poll::Ready(());
}
(Poll::Ready(Some(())), _) | (_, Poll::Ready(Some(()))) => {}
_ => {
return Poll::Pending;
let resp = (self.ctx, self.as_mut().response().take().unwrap());
if self.as_mut().response_tx().start_send(resp).is_err() {
return Poll::Ready(());
}
*self.as_mut().state() = RespState::PollFlush;
}
RespState::PollFlush => {
let ready = ready!(self.as_mut().response_tx().poll_flush(cx));
if ready.is_err() {
return Poll::Ready(());
}
return Poll::Ready(());
}
}
}()
.map(|r| r.unwrap_or_else(|e| info!("ResponseHandler errored out: {}", e)))
}
}
}
impl<C, S> Stream for ClientHandler<C, S>
where
C: Channel,
S: Serve<C::Req, Resp = C::Resp>,
{
type Item = io::Result<RequestHandler<S::Fut, C::Resp>>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
loop {
let read = self.as_mut().pump_read(cx)?;
let read_closed = if let Poll::Ready(None) = read {
true
} else {
false
};
match (read, self.as_mut().pump_write(cx, read_closed)?) {
(Poll::Ready(None), Poll::Ready(None)) => {
return Poll::Ready(None);
}
(Poll::Ready(Some(request_handler)), _) => {
return Poll::Ready(Some(Ok(request_handler)));
}
(_, Poll::Ready(Some(()))) => {}
_ => {
return Poll::Pending;
}
}
}
}
}
@@ -641,3 +673,72 @@ fn make_server_error(
}
}
}
// Send + 'static execution helper methods.
impl<C, S> ClientHandler<C, S>
where
C: Channel + 'static,
C::Req: Send + 'static,
C::Resp: Send + 'static,
S: Serve<C::Req, Resp = C::Resp> + Send + 'static,
S::Fut: Send + 'static,
{
/// Runs the client handler until completion by spawning each
/// request handler onto the default executor.
pub fn execute(self) -> impl Future<Output = ()> {
self.try_for_each(|request_handler| {
async {
crate::spawn(request_handler).map_err(|e| {
io::Error::new(
io::ErrorKind::Other,
format!(
"Could not spawn response task. Is shutdown: {}",
e.is_shutdown()
),
)
})
}
})
.unwrap_or_else(|e| info!("ClientHandler errored out: {}", e))
}
}
/// A future that drives the server by spawning channels and request handlers on the default
/// executor.
#[derive(Debug)]
pub struct Running<St, Se> {
incoming: St,
server: Se,
}
impl<St, Se> Running<St, Se> {
unsafe_pinned!(incoming: St);
unsafe_unpinned!(server: Se);
}
impl<St, C, Se> Future for Running<St, Se>
where
St: Sized + Stream<Item = C>,
C: Channel + Send + 'static,
C::Req: Send + 'static,
C::Resp: Send + 'static,
Se: Serve<C::Req, Resp = C::Resp> + Send + 'static + Clone,
Se::Fut: Send + 'static,
{
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
while let Some(channel) = ready!(self.as_mut().incoming().poll_next(cx)) {
if let Err(e) = crate::spawn(
channel
.respond_with(self.as_mut().server().clone())
.execute(),
) {
warn!("Failed to spawn channel handler: {:?}", e);
}
}
info!("Server shutting down.");
Poll::Ready(())
}
}

View File

@@ -60,8 +60,7 @@ impl<In, Resp> Sink<Response<Resp>> for FakeChannel<In, Response<Resp>> {
impl<Req, Resp> Channel for FakeChannel<io::Result<Request<Req>>, Response<Resp>>
where
Req: Unpin + Send + 'static,
Resp: Send + 'static,
Req: Unpin,
{
type Req = Req;
type Resp = Resp;

View File

@@ -286,11 +286,7 @@ fn throttler_poll_next_throttled_sink_not_ready() {
Poll::Pending
}
}
impl<Req, Resp> Channel for PendingSink<io::Result<Request<Req>>, Response<Resp>>
where
Req: Send + 'static,
Resp: Send + 'static,
{
impl<Req, Resp> Channel for PendingSink<io::Result<Request<Req>>, Response<Resp>> {
type Req = Req;
type Resp = Resp;
fn config(&self) -> &Config {

View File

@@ -106,7 +106,7 @@ mod tests {
)
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
let mut client = client::new(client::Config::default(), client_channel).await?;
let mut client = client::new(client::Config::default(), client_channel).spawn()?;
let response1 = client.call(context::current(), "123".into()).await?;
let response2 = client.call(context::current(), "abc".into()).await?;

View File

@@ -31,10 +31,12 @@ bytes = { version = "0.4", features = ["serde"] }
env_logger = "0.6"
futures-preview = { version = "0.3.0-alpha.17", features = ["compat"] }
humantime = "1.0"
log = "0.4"
runtime = "0.3.0-alpha.6"
runtime-tokio = "0.3.0-alpha.5"
tokio-tcp = "0.1"
pin-utils = "0.1.0-alpha.4"
tokio = "0.1"
[[example]]
name = "server_calling_server"

View File

@@ -116,7 +116,7 @@ impl publisher::Publisher for Publisher {
) -> io::Result<()> {
let conn = bincode_transport::connect(&addr).await?;
let subscriber =
subscriber::SubscriberClient::new(client::Config::default(), conn).await?;
subscriber::SubscriberClient::new(client::Config::default(), conn).spawn()?;
eprintln!("Subscribing {}.", id);
clients.lock().unwrap().insert(id, subscriber);
Ok(())
@@ -160,7 +160,7 @@ async fn main() -> io::Result<()> {
let publisher_conn = bincode_transport::connect(&publisher_addr);
let publisher_conn = publisher_conn.await?;
let mut publisher =
publisher::PublisherClient::new(client::Config::default(), publisher_conn).await?;
publisher::PublisherClient::new(client::Config::default(), publisher_conn).spawn()?;
if let Err(e) = publisher
.subscribe(context::current(), 0, subscriber1)

View File

@@ -58,6 +58,7 @@ async fn main() -> io::Result<()> {
// serve_world is generated by the tarpc::service attribute. It takes as input any type
// implementing the generated World trait.
.respond_with(HelloServer.serve())
.execute()
.await;
};
let _ = runtime::spawn(server);
@@ -66,7 +67,7 @@ async fn main() -> io::Result<()> {
// WorldClient is generated by the tarpc::service attribute. It has a constructor `new` that
// takes a config and any Transport as input.
let mut client = WorldClient::new(client::Config::default(), transport).await?;
let mut client = WorldClient::new(client::Config::default(), transport).spawn()?;
// The client has an RPC method for each RPC defined in the annotated trait. It takes the same
// args as defined, with the addition of a Context, which is always the first arg. The Context

View File

@@ -78,7 +78,7 @@ async fn main() -> io::Result<()> {
let _ = runtime::spawn(add_server);
let to_add_server = bincode_transport::connect(&addr).await?;
let add_client = add::AddClient::new(client::Config::default(), to_add_server).await?;
let add_client = add::AddClient::new(client::Config::default(), to_add_server).spawn()?;
let double_listener = bincode_transport::listen(&"0.0.0.0:0".parse().unwrap())?
.filter_map(|r| future::ready(r.ok()));
@@ -91,7 +91,7 @@ async fn main() -> io::Result<()> {
let to_double_server = bincode_transport::connect(&addr).await?;
let mut double_client =
double::DoubleClient::new(client::Config::default(), to_double_server).await?;
double::DoubleClient::new(client::Config::default(), to_double_server).spawn()?;
for i in 1..=5 {
eprintln!("{:?}", double_client.double(context::current(), i).await?);

View File

@@ -1,21 +1,38 @@
#![feature(async_await)]
#[cfg(not(feature = "serde1"))]
use std::rc::Rc;
use assert_matches::assert_matches;
use futures::{
future::{ready, Ready},
prelude::*,
};
use std::io;
use std::{rc::Rc, io};
use tarpc::{
client, context,
client::{self, NewClient}, context,
server::{self, BaseChannel, Channel, Handler},
transport::channel,
};
trait RuntimeExt {
fn exec_bg(&mut self, future: impl Future<Output = ()> + 'static);
fn exec<F, T, E>(&mut self, future: F) -> Result<T, E>
where
F: Future<Output = Result<T, E>>;
}
impl RuntimeExt for tokio::runtime::current_thread::Runtime {
fn exec_bg(&mut self, future: impl Future<Output = ()> + 'static) {
self.spawn(Box::pin(future.unit_error()).compat());
}
fn exec<F, T, E>(&mut self, future: F) -> Result<T, E>
where
F: Future<Output = Result<T, E>>,
{
self.block_on(futures::compat::Compat::new(Box::pin(future)))
}
}
#[tarpc_plugins::service]
trait Service {
async fn add(x: i32, y: i32) -> i32;
@@ -48,9 +65,10 @@ async fn sequential() -> io::Result<()> {
let _ = runtime::spawn(
BaseChannel::new(server::Config::default(), rx)
.respond_with(Server.serve())
.execute()
);
let mut client = ServiceClient::new(client::Config::default(), tx).await?;
let mut client = ServiceClient::new(client::Config::default(), tx).spawn()?;
assert_matches!(client.add(context::current(), 1, 2).await, Ok(3));
assert_matches!(
@@ -74,7 +92,7 @@ async fn serde() -> io::Result<()> {
);
let transport = bincode_transport::connect(&addr).await?;
let mut client = ServiceClient::new(client::Config::default(), transport).await?;
let mut client = ServiceClient::new(client::Config::default(), transport).spawn()?;
assert_matches!(client.add(context::current(), 1, 2).await, Ok(3));
assert_matches!(
@@ -96,7 +114,7 @@ async fn concurrent() -> io::Result<()> {
.respond_with(Server.serve()),
);
let client = ServiceClient::new(client::Config::default(), tx).await?;
let client = ServiceClient::new(client::Config::default(), tx).spawn()?;
let mut c = client.clone();
let req1 = c.add(context::current(), 1, 2);
@@ -113,3 +131,61 @@ async fn concurrent() -> io::Result<()> {
Ok(())
}
#[tarpc::service(derive_serde = false)]
trait InMemory {
async fn strong_count(rc: Rc<()>) -> usize;
async fn weak_count(rc: Rc<()>) -> usize;
}
impl InMemory for () {
type StrongCountFut = Ready<usize>;
fn strong_count(self, _: context::Context, rc: Rc<()>) -> Self::StrongCountFut {
ready(Rc::strong_count(&rc))
}
type WeakCountFut = Ready<usize>;
fn weak_count(self, _: context::Context, rc: Rc<()>) -> Self::WeakCountFut {
ready(Rc::weak_count(&rc))
}
}
#[test]
fn in_memory_single_threaded() -> io::Result<()> {
use log::warn;
let _ = env_logger::try_init();
let mut runtime = tokio::runtime::current_thread::Runtime::new()?;
let (tx, rx) = channel::unbounded();
let server = BaseChannel::new(server::Config::default(), rx)
.respond_with(().serve())
.try_for_each(|r| async move { Ok(r.await) });
runtime.exec_bg(async {
if let Err(e) = server.await {
warn!("Error while running server: {}", e);
}
});
let NewClient{mut client, dispatch} = InMemoryClient::new(client::Config::default(), tx);
runtime.exec_bg(async move {
if let Err(e) = dispatch.await {
warn!("Error while running client dispatch: {}", e)
}
});
let rc = Rc::new(());
assert_matches!(
runtime.exec(client.strong_count(context::current(), rc.clone())),
Ok(2)
);
let _weak = Rc::downgrade(&rc);
assert_matches!(
runtime.exec(client.weak_count(context::current(), rc)),
Ok(1)
);
Ok(())
}