mirror of
https://github.com/OMGeeky/tarpc.git
synced 2025-12-26 17:02:32 +01:00
Refactor server module to be easier to understand.
1. Renames
Some of the items in this module were renamed to be less generic:
- Handler => Incoming
- ClientHandler => Requests
- ResponseHandler => InFlightRequest
- Channel::{respond_with => requests}
In the case of Handler: handler of *what*? Now it's a bit clearer that
this is a stream of Channels (aka *incoming* connections).
Similarly, ClientHandler was a stream of requests over a single
connection. Hopefully Requests better reflects that.
ResponseHandler was renamed InFlightRequest because it no longer
contains the serving function. Instead, it is just the request, plus
the response channel and an abort hook. As a result of this,
Channel::respond_with underwent a big change: it used to take the
serving function and return a ClientHandler; now it has been renamed
Channel::requests and does not take any args.
2. Execute methods
All methods thats actually result in responses being generated
have been consolidated into methods named `execute`:
- InFlightRequest::execute returns a future that completes when a
response has been generated and sent to the server Channel.
- Requests::execute automatically spawns response handlers for all
requests over a single channel.
- Channel::execute is a convenience for `channel.requests().execute()`.
- Incoming::execute automatically spawns response handlers for all
requests over all channels.
3. Removal of Server.
server::Server was removed, as it provided no value over the Incoming/Channel
abstractions. Additionally, server::new was removed, since it just
returned a Server.
This commit is contained in:
15
README.md
15
README.md
@@ -91,7 +91,7 @@ use futures::{
|
||||
};
|
||||
use tarpc::{
|
||||
client, context,
|
||||
server::{self, Handler},
|
||||
server::{self, Incoming},
|
||||
};
|
||||
use std::io;
|
||||
|
||||
@@ -135,16 +135,11 @@ available behind the `tcp` feature.
|
||||
async fn main() -> io::Result<()> {
|
||||
let (client_transport, server_transport) = tarpc::transport::channel::unbounded();
|
||||
|
||||
let server = server::new(server::Config::default())
|
||||
// incoming() takes a stream of transports such as would be returned by
|
||||
// TcpListener::incoming (but a stream instead of an iterator).
|
||||
.incoming(stream::once(future::ready(server_transport)))
|
||||
.respond_with(HelloServer.serve());
|
||||
let server = server::BaseChannel::with_defaults(server_transport);
|
||||
tokio::spawn(server.execute(HelloServer.serve()));
|
||||
|
||||
tokio::spawn(server);
|
||||
|
||||
// WorldClient is generated by the macro. It has a constructor `new` that takes a config and
|
||||
// any Transport as input
|
||||
// WorldClient is generated by the #[tarpc::service] attribute. It has a constructor `new`
|
||||
// that takes a config and any Transport as input.
|
||||
let mut client = WorldClient::new(client::Config::default(), client_transport).spawn()?;
|
||||
|
||||
// The client has an RPC method for each RPC defined in the annotated trait. It takes the same
|
||||
|
||||
@@ -13,7 +13,7 @@ use std::{
|
||||
};
|
||||
use tarpc::{
|
||||
context,
|
||||
server::{self, Channel, Handler},
|
||||
server::{self, Channel, Incoming},
|
||||
tokio_serde::formats::Json,
|
||||
};
|
||||
|
||||
@@ -69,7 +69,7 @@ async fn main() -> io::Result<()> {
|
||||
// the generated World trait.
|
||||
.map(|channel| {
|
||||
let server = HelloServer(channel.as_ref().as_ref().peer_addr().unwrap());
|
||||
channel.respond_with(server.serve()).execute()
|
||||
channel.requests().execute(server.serve())
|
||||
})
|
||||
// Max 10 channels.
|
||||
.buffer_unordered(10)
|
||||
|
||||
@@ -93,7 +93,7 @@ diff=""
|
||||
for file in $(git diff --name-only --cached);
|
||||
do
|
||||
if [ ${file: -3} == ".rs" ]; then
|
||||
diff="$diff$(cargo fmt -- --unstable-features --skip-children --check $file)"
|
||||
diff="$diff$(cargo fmt -- --check $file)"
|
||||
fi
|
||||
done
|
||||
if grep --quiet "^[-+]" <<< "$diff"; then
|
||||
|
||||
@@ -485,7 +485,7 @@ impl<'a> ServiceGenerator<'a> {
|
||||
#vis trait #service_ident: Clone {
|
||||
#( #types_and_fns )*
|
||||
|
||||
/// Returns a serving function to use with [tarpc::server::Channel::respond_with].
|
||||
/// Returns a serving function to use with [tarpc::server::InFlightRequest::execute].
|
||||
fn serve(self) -> #server_ident<Self> {
|
||||
#server_ident { service: self }
|
||||
}
|
||||
@@ -499,7 +499,7 @@ impl<'a> ServiceGenerator<'a> {
|
||||
} = self;
|
||||
|
||||
quote! {
|
||||
/// A serving function to use with [tarpc::server::Channel::respond_with].
|
||||
/// A serving function to use with [tarpc::server::InFlightRequest::execute].
|
||||
#[derive(Clone)]
|
||||
#vis struct #server_ident<S> {
|
||||
service: S,
|
||||
@@ -662,7 +662,8 @@ impl<'a> ServiceGenerator<'a> {
|
||||
quote! {
|
||||
#[allow(unused)]
|
||||
#[derive(Clone, Debug)]
|
||||
/// The client stub that makes RPC calls to the server. Exposes a Future interface.
|
||||
/// The client stub that makes RPC calls to the server. ALl request methods return
|
||||
/// [Futures](std::future::Future).
|
||||
#vis struct #client_ident<C = tarpc::client::Channel<#request_ident, #response_ident>>(C);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -113,8 +113,7 @@ async fn main() -> anyhow::Result<()> {
|
||||
tokio::spawn(async move {
|
||||
let transport = incoming.next().await.unwrap().unwrap();
|
||||
BaseChannel::with_defaults(add_compression(transport))
|
||||
.respond_with(HelloServer.serve())
|
||||
.execute()
|
||||
.execute(HelloServer.serve())
|
||||
.await;
|
||||
});
|
||||
|
||||
|
||||
@@ -36,9 +36,7 @@ async fn main() -> std::io::Result<()> {
|
||||
let framed = codec_builder.new_framed(conn);
|
||||
let transport = transport::new(framed, Bincode::default());
|
||||
|
||||
let fut = BaseChannel::with_defaults(transport)
|
||||
.respond_with(Service.serve())
|
||||
.execute();
|
||||
let fut = BaseChannel::with_defaults(transport).execute(Service.serve());
|
||||
tokio::spawn(fut);
|
||||
}
|
||||
});
|
||||
|
||||
@@ -105,11 +105,11 @@ impl Subscriber {
|
||||
) -> anyhow::Result<SubscriberHandle> {
|
||||
let publisher = tcp::connect(publisher_addr, Json::default).await?;
|
||||
let local_addr = publisher.local_addr()?;
|
||||
let mut handler = server::BaseChannel::with_defaults(publisher)
|
||||
.respond_with(Subscriber { local_addr, topics }.serve());
|
||||
// The first request is for the topics being subscriibed to.
|
||||
let mut handler = server::BaseChannel::with_defaults(publisher).requests();
|
||||
let subscriber = Subscriber { local_addr, topics };
|
||||
// The first request is for the topics being subscribed to.
|
||||
match handler.next().await {
|
||||
Some(init_topics) => init_topics?.await,
|
||||
Some(init_topics) => init_topics?.execute(subscriber.clone().serve()).await,
|
||||
None => {
|
||||
return Err(anyhow!(
|
||||
"[{}] Server never initialized the subscriber.",
|
||||
@@ -117,7 +117,7 @@ impl Subscriber {
|
||||
))
|
||||
}
|
||||
};
|
||||
let (handler, abort_handle) = future::abortable(handler.execute());
|
||||
let (handler, abort_handle) = future::abortable(handler.execute(subscriber.serve()));
|
||||
tokio::spawn(async move {
|
||||
match handler.await {
|
||||
Ok(()) | Err(future::Aborted) => info!("[{}] subscriber shutdown.", local_addr),
|
||||
@@ -162,8 +162,7 @@ impl Publisher {
|
||||
info!("[{}] publisher connected.", publisher.peer_addr().unwrap());
|
||||
|
||||
server::BaseChannel::with_defaults(publisher)
|
||||
.respond_with(self.serve())
|
||||
.execute()
|
||||
.execute(self.serve())
|
||||
.await
|
||||
});
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ use futures::{
|
||||
use std::io;
|
||||
use tarpc::{
|
||||
client, context,
|
||||
server::{BaseChannel, Channel},
|
||||
server::{self, Channel},
|
||||
};
|
||||
use tokio_serde::formats::Json;
|
||||
|
||||
@@ -40,40 +40,21 @@ impl World for HelloServer {
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> io::Result<()> {
|
||||
// tarpc_json_transport is provided by the associated crate json_transport. It makes it
|
||||
// easy to start up a serde-powered JSON serialization strategy over TCP.
|
||||
let mut transport = tarpc::serde_transport::tcp::listen("localhost:0", Json::default).await?;
|
||||
let addr = transport.local_addr();
|
||||
let (client_transport, server_transport) = tarpc::transport::channel::unbounded();
|
||||
|
||||
let server = async move {
|
||||
// For this example, we're just going to wait for one connection.
|
||||
let client = transport.next().await.unwrap().unwrap();
|
||||
let server = server::BaseChannel::with_defaults(server_transport);
|
||||
tokio::spawn(server.execute(HelloServer.serve()));
|
||||
|
||||
// `Channel` is a trait representing a server-side connection. It is a trait to allow
|
||||
// for some channels to be instrumented: for example, to track the number of open connections.
|
||||
// BaseChannel is the most basic channel, simply wrapping a transport with no added
|
||||
// functionality.
|
||||
BaseChannel::with_defaults(client)
|
||||
// serve_world is generated by the tarpc::service attribute. It takes as input any type
|
||||
// implementing the generated World trait.
|
||||
.respond_with(HelloServer.serve())
|
||||
.execute()
|
||||
.await;
|
||||
};
|
||||
tokio::spawn(server);
|
||||
|
||||
let transport = tarpc::serde_transport::tcp::connect(addr, Json::default).await?;
|
||||
|
||||
// WorldClient is generated by the tarpc::service attribute. It has a constructor `new` that
|
||||
// takes a config and any Transport as input.
|
||||
let mut client = WorldClient::new(client::Config::default(), transport).spawn()?;
|
||||
// WorldClient is generated by the #[tarpc::service] attribute. It has a constructor `new`
|
||||
// that takes a config and any Transport as input.
|
||||
let mut client = WorldClient::new(client::Config::default(), client_transport).spawn()?;
|
||||
|
||||
// The client has an RPC method for each RPC defined in the annotated trait. It takes the same
|
||||
// args as defined, with the addition of a Context, which is always the first arg. The Context
|
||||
// specifies a deadline and trace information which can be helpful in debugging requests.
|
||||
let hello = client.hello(context::current(), "Stim".to_string()).await?;
|
||||
|
||||
eprintln!("{}", hello);
|
||||
println!("{}", hello);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -9,7 +9,7 @@ use futures::{future, prelude::*};
|
||||
use std::io;
|
||||
use tarpc::{
|
||||
client, context,
|
||||
server::{Handler, Server},
|
||||
server::{BaseChannel, Incoming},
|
||||
};
|
||||
use tokio_serde::formats::Json;
|
||||
|
||||
@@ -62,10 +62,10 @@ async fn main() -> io::Result<()> {
|
||||
.await?
|
||||
.filter_map(|r| future::ready(r.ok()));
|
||||
let addr = add_listener.get_ref().local_addr();
|
||||
let add_server = Server::default()
|
||||
.incoming(add_listener)
|
||||
let add_server = add_listener
|
||||
.map(BaseChannel::with_defaults)
|
||||
.take(1)
|
||||
.respond_with(AddServer.serve());
|
||||
.execute(AddServer.serve());
|
||||
tokio::spawn(add_server);
|
||||
|
||||
let to_add_server = tarpc::serde_transport::tcp::connect(addr, Json::default).await?;
|
||||
@@ -75,10 +75,10 @@ async fn main() -> io::Result<()> {
|
||||
.await?
|
||||
.filter_map(|r| future::ready(r.ok()));
|
||||
let addr = double_listener.get_ref().local_addr();
|
||||
let double_server = tarpc::Server::default()
|
||||
.incoming(double_listener)
|
||||
let double_server = double_listener
|
||||
.map(BaseChannel::with_defaults)
|
||||
.take(1)
|
||||
.respond_with(DoubleServer { add_client }.serve());
|
||||
.execute(DoubleServer { add_client }.serve());
|
||||
tokio::spawn(double_server);
|
||||
|
||||
let to_double_server = tarpc::serde_transport::tcp::connect(addr, Json::default).await?;
|
||||
|
||||
@@ -79,7 +79,7 @@
|
||||
//! };
|
||||
//! use tarpc::{
|
||||
//! client, context,
|
||||
//! server::{self, Handler},
|
||||
//! server::{self, Incoming},
|
||||
//! };
|
||||
//! use std::io;
|
||||
//!
|
||||
@@ -103,7 +103,7 @@
|
||||
//! # };
|
||||
//! # use tarpc::{
|
||||
//! # client, context,
|
||||
//! # server::{self, Handler},
|
||||
//! # server::{self, Incoming},
|
||||
//! # };
|
||||
//! # use std::io;
|
||||
//! # // This is the service definition. It looks a lot like a trait definition.
|
||||
@@ -143,7 +143,7 @@
|
||||
//! # };
|
||||
//! # use tarpc::{
|
||||
//! # client, context,
|
||||
//! # server::{self, Handler},
|
||||
//! # server::{self, Channel},
|
||||
//! # };
|
||||
//! # use std::io;
|
||||
//! # // This is the service definition. It looks a lot like a trait definition.
|
||||
@@ -172,16 +172,11 @@
|
||||
//! async fn main() -> io::Result<()> {
|
||||
//! let (client_transport, server_transport) = tarpc::transport::channel::unbounded();
|
||||
//!
|
||||
//! let server = server::new(server::Config::default())
|
||||
//! // incoming() takes a stream of transports such as would be returned by
|
||||
//! // TcpListener::incoming (but a stream instead of an iterator).
|
||||
//! .incoming(stream::once(future::ready(server_transport)))
|
||||
//! .respond_with(HelloServer.serve());
|
||||
//! let server = server::BaseChannel::with_defaults(server_transport);
|
||||
//! tokio::spawn(server.execute(HelloServer.serve()));
|
||||
//!
|
||||
//! tokio::spawn(server);
|
||||
//!
|
||||
//! // WorldClient is generated by the macro. It has a constructor `new` that takes a config and
|
||||
//! // any Transport as input
|
||||
//! // WorldClient is generated by the #[tarpc::service] attribute. It has a constructor `new`
|
||||
//! // that takes a config and any Transport as input.
|
||||
//! let mut client = WorldClient::new(client::Config::default(), client_transport).spawn()?;
|
||||
//!
|
||||
//! // The client has an RPC method for each RPC defined in the annotated trait. It takes the same
|
||||
@@ -304,7 +299,7 @@ pub mod server;
|
||||
pub mod transport;
|
||||
pub(crate) mod util;
|
||||
|
||||
pub use crate::{client::Client, server::Server, transport::sealed::Transport};
|
||||
pub use crate::{client::Client, transport::sealed::Transport};
|
||||
|
||||
use anyhow::Context as _;
|
||||
use futures::task::*;
|
||||
|
||||
@@ -22,8 +22,7 @@ use futures::{
|
||||
use humantime::format_rfc3339;
|
||||
use log::{debug, trace};
|
||||
use pin_project::{pin_project, pinned_drop};
|
||||
use std::{fmt, hash::Hash, io, marker::PhantomData, pin::Pin, time::SystemTime};
|
||||
use tokio::time::Timeout;
|
||||
use std::{fmt, hash::Hash, io, marker::PhantomData, pin::Pin};
|
||||
|
||||
mod filter;
|
||||
#[cfg(test)]
|
||||
@@ -35,30 +34,12 @@ pub use self::{
|
||||
throttle::{Throttler, ThrottlerStream},
|
||||
};
|
||||
|
||||
/// Manages clients, serving multiplexed requests over each connection.
|
||||
pub struct Server<Req, Resp> {
|
||||
config: Config,
|
||||
ghost: PhantomData<(Req, Resp)>,
|
||||
}
|
||||
|
||||
impl<Req, Resp> Default for Server<Req, Resp> {
|
||||
fn default() -> Self {
|
||||
new(Config::default())
|
||||
}
|
||||
}
|
||||
|
||||
impl<Req, Resp> fmt::Debug for Server<Req, Resp> {
|
||||
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(fmt, "Server")
|
||||
}
|
||||
}
|
||||
|
||||
/// Settings that control the behavior of the server.
|
||||
/// Settings that control the behavior of [channels](Channel).
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Config {
|
||||
/// The number of responses per client that can be buffered server-side before being sent.
|
||||
/// `pending_response_buffer` controls the buffer size of the channel that a server's
|
||||
/// response tasks use to send responses to the client handler task.
|
||||
/// Controls the buffer size of the in-process channel over which a server's handlers send
|
||||
/// responses to the [`Channel`]. In other words, this is the number of responses that can sit
|
||||
/// in the outbound queue before request handlers begin blocking.
|
||||
pub pending_response_buffer: usize,
|
||||
}
|
||||
|
||||
@@ -80,32 +61,8 @@ impl Config {
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a new server with configuration specified `config`.
|
||||
pub fn new<Req, Resp>(config: Config) -> Server<Req, Resp> {
|
||||
Server {
|
||||
config,
|
||||
ghost: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
impl<Req, Resp> Server<Req, Resp> {
|
||||
/// Returns the config for this server.
|
||||
pub fn config(&self) -> &Config {
|
||||
&self.config
|
||||
}
|
||||
|
||||
/// Returns a stream of server channels.
|
||||
pub fn incoming<S, T>(self, listener: S) -> impl Stream<Item = BaseChannel<Req, Resp, T>>
|
||||
where
|
||||
S: Stream<Item = T>,
|
||||
T: Transport<Response<Resp>, ClientMessage<Req>>,
|
||||
{
|
||||
listener.map(move |t| BaseChannel::new(self.config.clone(), t))
|
||||
}
|
||||
}
|
||||
|
||||
/// Basically a Fn(Req) -> impl Future<Output = Resp>;
|
||||
pub trait Serve<Req>: Sized + Clone {
|
||||
/// Equivalent to a `FnOnce(Req) -> impl Future<Output = Resp>`.
|
||||
pub trait Serve<Req> {
|
||||
/// Type of response.
|
||||
type Resp;
|
||||
|
||||
@@ -129,8 +86,8 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
/// A utility trait enabling a stream to fluently chain a request handler.
|
||||
pub trait Handler<C>
|
||||
/// An extension trait for [streams](Stream) of [`Channels`](Channel).
|
||||
pub trait Incoming<C>
|
||||
where
|
||||
Self: Sized + Stream<Item = C>,
|
||||
C: Channel,
|
||||
@@ -149,28 +106,34 @@ where
|
||||
ThrottlerStream::new(self, n)
|
||||
}
|
||||
|
||||
/// Responds to all requests with [`server::serve`](Serve).
|
||||
/// [Executes](Channel::execute) each incoming channel. Each channel will be handled
|
||||
/// concurrently by spawning on tokio's default executor, and each request will be also
|
||||
/// be spawned on tokio's default executor.
|
||||
#[cfg(feature = "tokio1")]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))]
|
||||
fn respond_with<S>(self, server: S) -> Running<Self, S>
|
||||
fn execute<S>(self, serve: S) -> TokioServerExecutor<Self, S>
|
||||
where
|
||||
S: Serve<C::Req, Resp = C::Resp>,
|
||||
{
|
||||
Running {
|
||||
incoming: self,
|
||||
server,
|
||||
}
|
||||
TokioServerExecutor { inner: self, serve }
|
||||
}
|
||||
}
|
||||
|
||||
impl<S, C> Handler<C> for S
|
||||
impl<S, C> Incoming<C> for S
|
||||
where
|
||||
S: Sized + Stream<Item = C>,
|
||||
C: Channel,
|
||||
{
|
||||
}
|
||||
|
||||
/// BaseChannel lifts a Transport to a Channel by tracking in-flight requests.
|
||||
/// BaseChannel is a [Transport] that keeps track of in-flight requests. It converts a
|
||||
/// [`Transport`](Transport) of [`ClientMessages`](ClientMessage) into a stream of
|
||||
/// [requests](ClientMessage::Request).
|
||||
///
|
||||
/// Besides requests, the other type of client message is [cancellation
|
||||
/// messages](ClientMessage::Cancel). `BaseChannel` does not allow direct access to cancellation
|
||||
/// messages. Instead, it internally handles them by cancelling corresponding requests (removing
|
||||
/// the corresponding in-flight requests and aborting their handlers).
|
||||
#[pin_project(PinnedDrop)]
|
||||
pub struct BaseChannel<Req, Resp, T> {
|
||||
config: Config,
|
||||
@@ -251,10 +214,25 @@ impl<Req, Resp, T> fmt::Debug for BaseChannel<Req, Resp, T> {
|
||||
/// The server end of an open connection with a client, streaming in requests from, and sinking
|
||||
/// responses to, the client.
|
||||
///
|
||||
/// Channels are free to somewhat rely on the assumption that all in-flight requests are eventually
|
||||
/// either [cancelled](BaseChannel::cancel_request) or [responded to](Sink::start_send). Safety cannot
|
||||
/// rely on this assumption, but it is best for `Channel` users to always account for all outstanding
|
||||
/// requests.
|
||||
///
|
||||
/// The ways to use a Channel, in order of simplest to most complex, is:
|
||||
/// 1. [Channel::execute] - Requires the `tokio1` feature. This method is best for those who
|
||||
/// do not have specific scheduling needs and whose services are `Send + 'static`.
|
||||
/// 2. [Channel::requests] - This method is best for those who need direct access to individual
|
||||
/// requests, or are not using `tokio`, or want control over [futures](Future) scheduling.
|
||||
/// 3. [Raw stream](<Channel as Stream>) - A user is free to manually handle requests produced by
|
||||
/// Channel. If they do so, they should uphold the service contract:
|
||||
/// 1. All work being done as part of processing request `request_id` is aborted when
|
||||
/// either of the following occurs:
|
||||
/// - The channel receives a [cancellation message](ClientMessage::Cancel) for request
|
||||
/// `request_id`.
|
||||
/// - The [deadline](crate::context::Context::deadline) of request `request_id` is reached.
|
||||
/// 2. When a server completes a response for request `request_id`, it is
|
||||
/// [sent](Sink::start_send) into the Channel. Because there is no guarantee that a
|
||||
/// cancellation message will ever be received for a request, services should strive to clean
|
||||
/// up Channel resources by sending a response for every request. For example, [`BaseChannel`]
|
||||
/// has a map of requests to [abort handles][AbortHandle] whose entries are only removed
|
||||
/// upon either request cancellation or response completion.
|
||||
pub trait Channel
|
||||
where
|
||||
Self: Transport<Response<<Self as Channel>::Resp>, Request<<Self as Channel>::Req>>,
|
||||
@@ -269,14 +247,14 @@ where
|
||||
fn config(&self) -> &Config;
|
||||
|
||||
/// Returns the number of in-flight requests over this channel.
|
||||
fn in_flight_requests(self: Pin<&mut Self>) -> usize;
|
||||
fn in_flight_requests(&self) -> usize;
|
||||
|
||||
/// Caps the number of concurrent requests.
|
||||
fn max_concurrent_requests(self, n: usize) -> Throttler<Self>
|
||||
/// Caps the number of concurrent requests to `limit`.
|
||||
fn max_concurrent_requests(self, limit: usize) -> Throttler<Self>
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
Throttler::new(self, n)
|
||||
Throttler::new(self, limit)
|
||||
}
|
||||
|
||||
/// Tells the Channel that request with ID `request_id` is being handled.
|
||||
@@ -284,23 +262,37 @@ where
|
||||
/// to the Channel.
|
||||
fn start_request(self: Pin<&mut Self>, request_id: u64) -> AbortRegistration;
|
||||
|
||||
/// Respond to requests coming over the channel with `f`. Returns a future that drives the
|
||||
/// responses and resolves when the connection is closed.
|
||||
fn respond_with<S>(self, server: S) -> ClientHandler<Self, S>
|
||||
/// Returns a stream of requests that automatically handle request cancellation and response
|
||||
/// routing.
|
||||
fn requests(self) -> Requests<Self>
|
||||
where
|
||||
S: Serve<Self::Req, Resp = Self::Resp>,
|
||||
Self: Sized,
|
||||
{
|
||||
let (responses_tx, responses) = mpsc::channel(self.config().pending_response_buffer);
|
||||
let responses = responses.fuse();
|
||||
|
||||
ClientHandler {
|
||||
Requests {
|
||||
channel: self,
|
||||
server,
|
||||
pending_responses: responses,
|
||||
responses_tx,
|
||||
}
|
||||
}
|
||||
|
||||
/// Runs the channel until completion by executing all requests using the given service
|
||||
/// function. Request handlers are run concurrently by [spawning](tokio::spawn) on tokio's
|
||||
/// default executor.
|
||||
#[cfg(feature = "tokio1")]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))]
|
||||
fn execute<S>(self, serve: S) -> TokioChannelExecutor<Requests<Self>, S>
|
||||
where
|
||||
Self: Sized,
|
||||
S: Serve<Self::Req, Resp = Self::Resp> + Send + Sync + 'static,
|
||||
S::Fut: Send,
|
||||
Self::Req: Send + 'static,
|
||||
Self::Resp: Send + 'static,
|
||||
{
|
||||
self.requests().execute(serve)
|
||||
}
|
||||
}
|
||||
|
||||
impl<Req, Resp, T> Stream for BaseChannel<Req, Resp, T>
|
||||
@@ -390,8 +382,8 @@ where
|
||||
&self.config
|
||||
}
|
||||
|
||||
fn in_flight_requests(mut self: Pin<&mut Self>) -> usize {
|
||||
self.as_mut().project().in_flight_requests.len()
|
||||
fn in_flight_requests(&self) -> usize {
|
||||
self.in_flight_requests.len()
|
||||
}
|
||||
|
||||
fn start_request(self: Pin<&mut Self>, request_id: u64) -> AbortRegistration {
|
||||
@@ -405,9 +397,9 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
/// A running handler serving all requests coming over a channel.
|
||||
/// A stream of requests coming over a channel.
|
||||
#[pin_project]
|
||||
pub struct ClientHandler<C, S>
|
||||
pub struct Requests<C>
|
||||
where
|
||||
C: Channel,
|
||||
{
|
||||
@@ -419,26 +411,30 @@ where
|
||||
/// Handed out to request handlers to fan in responses.
|
||||
#[pin]
|
||||
responses_tx: mpsc::Sender<(context::Context, Response<C::Resp>)>,
|
||||
/// Server
|
||||
server: S,
|
||||
}
|
||||
|
||||
impl<C, S> ClientHandler<C, S>
|
||||
impl<C> Requests<C>
|
||||
where
|
||||
C: Channel,
|
||||
S: Serve<C::Req, Resp = C::Resp>,
|
||||
{
|
||||
/// Returns the inner channel over which messages are sent and received.
|
||||
pub fn get_pin_channel(self: Pin<&mut Self>) -> Pin<&mut C> {
|
||||
self.project().channel
|
||||
pub fn channel_pin_mut<'a>(self: &'a mut Pin<&mut Self>) -> Pin<&'a mut C> {
|
||||
self.as_mut().project().channel
|
||||
}
|
||||
|
||||
fn pump_read(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> PollIo<RequestHandler<S::Fut, C::Resp>> {
|
||||
) -> PollIo<InFlightRequest<C::Req, C::Resp>> {
|
||||
match ready!(self.as_mut().project().channel.poll_next(cx)?) {
|
||||
Some(request) => Poll::Ready(Some(Ok(self.handle_request(request)))),
|
||||
Some(request) => {
|
||||
let abort_registration = self.as_mut().project().channel.start_request(request.id);
|
||||
Poll::Ready(Some(Ok(InFlightRequest {
|
||||
request,
|
||||
response_tx: self.responses_tx.clone(),
|
||||
abort_registration,
|
||||
})))
|
||||
}
|
||||
None => Poll::Ready(None),
|
||||
}
|
||||
}
|
||||
@@ -449,28 +445,28 @@ where
|
||||
read_half_closed: bool,
|
||||
) -> PollIo<()> {
|
||||
match self.as_mut().poll_next_response(cx)? {
|
||||
Poll::Ready(Some((ctx, response))) => {
|
||||
Poll::Ready(Some((context, response))) => {
|
||||
trace!(
|
||||
"[{}] Staging response. In-flight requests = {}.",
|
||||
ctx.trace_id(),
|
||||
self.as_mut().project().channel.in_flight_requests(),
|
||||
context.trace_id(),
|
||||
self.channel.in_flight_requests(),
|
||||
);
|
||||
self.as_mut().project().channel.start_send(response)?;
|
||||
self.channel_pin_mut().start_send(response)?;
|
||||
Poll::Ready(Some(Ok(())))
|
||||
}
|
||||
Poll::Ready(None) => {
|
||||
// Shutdown can't be done before we finish pumping out remaining responses.
|
||||
ready!(self.as_mut().project().channel.poll_flush(cx)?);
|
||||
ready!(self.channel_pin_mut().poll_flush(cx)?);
|
||||
Poll::Ready(None)
|
||||
}
|
||||
Poll::Pending => {
|
||||
// No more requests to process, so flush any requests buffered in the transport.
|
||||
ready!(self.as_mut().project().channel.poll_flush(cx)?);
|
||||
ready!(self.channel_pin_mut().poll_flush(cx)?);
|
||||
|
||||
// Being here means there are no staged requests and all written responses are
|
||||
// fully flushed. So, if the read half is closed and there are no in-flight
|
||||
// requests, then we can close the write half.
|
||||
if read_half_closed && self.as_mut().project().channel.in_flight_requests() == 0 {
|
||||
if read_half_closed && self.channel.in_flight_requests() == 0 {
|
||||
Poll::Ready(None)
|
||||
} else {
|
||||
Poll::Pending
|
||||
@@ -484,183 +480,116 @@ where
|
||||
cx: &mut Context<'_>,
|
||||
) -> PollIo<(context::Context, Response<C::Resp>)> {
|
||||
// Ensure there's room to write a response.
|
||||
while self.as_mut().project().channel.poll_ready(cx)?.is_pending() {
|
||||
while self.channel_pin_mut().poll_ready(cx)?.is_pending() {
|
||||
ready!(self.as_mut().project().channel.poll_flush(cx)?);
|
||||
}
|
||||
|
||||
match ready!(self.as_mut().project().pending_responses.poll_next(cx)) {
|
||||
Some((ctx, response)) => Poll::Ready(Some(Ok((ctx, response)))),
|
||||
Some(response) => Poll::Ready(Some(Ok(response))),
|
||||
None => {
|
||||
// This branch likely won't happen, since the ClientHandler is holding a Sender.
|
||||
// This branch likely won't happen, since the Requests stream is holding a Sender.
|
||||
Poll::Ready(None)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_request(
|
||||
mut self: Pin<&mut Self>,
|
||||
request: Request<C::Req>,
|
||||
) -> RequestHandler<S::Fut, C::Resp> {
|
||||
let request_id = request.id;
|
||||
let deadline = request.context.deadline;
|
||||
let timeout = deadline.time_until();
|
||||
trace!(
|
||||
"[{}] Received request with deadline {} (timeout {:?}).",
|
||||
request.context.trace_id(),
|
||||
format_rfc3339(deadline),
|
||||
timeout,
|
||||
);
|
||||
let ctx = request.context;
|
||||
let request = request.message;
|
||||
|
||||
let response = self.as_mut().project().server.clone().serve(ctx, request);
|
||||
let response = Resp {
|
||||
state: RespState::PollResp,
|
||||
request_id,
|
||||
ctx,
|
||||
deadline,
|
||||
f: tokio::time::timeout(timeout, response),
|
||||
response: None,
|
||||
response_tx: self.as_mut().project().responses_tx.clone(),
|
||||
};
|
||||
let abort_registration = self.as_mut().project().channel.start_request(request_id);
|
||||
RequestHandler {
|
||||
resp: Abortable::new(response, abort_registration),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<C, S> fmt::Debug for ClientHandler<C, S>
|
||||
impl<C> fmt::Debug for Requests<C>
|
||||
where
|
||||
C: Channel,
|
||||
{
|
||||
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(fmt, "ClientHandler")
|
||||
write!(fmt, "Requests")
|
||||
}
|
||||
}
|
||||
|
||||
/// A future fulfilling a single client request.
|
||||
#[pin_project]
|
||||
pub struct RequestHandler<F, R> {
|
||||
#[pin]
|
||||
resp: Abortable<Resp<F, R>>,
|
||||
}
|
||||
|
||||
impl<F, R> Future for RequestHandler<F, R>
|
||||
where
|
||||
F: Future<Output = R>,
|
||||
{
|
||||
type Output = ();
|
||||
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
|
||||
let _ = ready!(self.project().resp.poll(cx));
|
||||
Poll::Ready(())
|
||||
}
|
||||
}
|
||||
|
||||
impl<F, R> fmt::Debug for RequestHandler<F, R> {
|
||||
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(fmt, "RequestHandler")
|
||||
}
|
||||
}
|
||||
|
||||
#[pin_project]
|
||||
struct Resp<F, R> {
|
||||
state: RespState,
|
||||
request_id: u64,
|
||||
ctx: context::Context,
|
||||
deadline: SystemTime,
|
||||
#[pin]
|
||||
f: Timeout<F>,
|
||||
response: Option<Response<R>>,
|
||||
#[pin]
|
||||
response_tx: mpsc::Sender<(context::Context, Response<R>)>,
|
||||
}
|
||||
|
||||
/// A request produced by [Channel::requests].
|
||||
#[derive(Debug)]
|
||||
#[allow(clippy::enum_variant_names)]
|
||||
enum RespState {
|
||||
PollResp,
|
||||
PollReady,
|
||||
PollFlush,
|
||||
pub struct InFlightRequest<Req, Res> {
|
||||
request: Request<Req>,
|
||||
response_tx: mpsc::Sender<(context::Context, Response<Res>)>,
|
||||
abort_registration: AbortRegistration,
|
||||
}
|
||||
|
||||
impl<F, R> Future for Resp<F, R>
|
||||
where
|
||||
F: Future<Output = R>,
|
||||
{
|
||||
type Output = ();
|
||||
impl<Req, Res> InFlightRequest<Req, Res> {
|
||||
/// Returns a reference to the request.
|
||||
pub fn get(&self) -> &Request<Req> {
|
||||
&self.request
|
||||
}
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
|
||||
loop {
|
||||
match self.as_mut().project().state {
|
||||
RespState::PollResp => {
|
||||
let result = ready!(self.as_mut().project().f.poll(cx));
|
||||
*self.as_mut().project().response = Some(Response {
|
||||
request_id: self.request_id,
|
||||
message: match result {
|
||||
Ok(message) => Ok(message),
|
||||
Err(tokio::time::error::Elapsed { .. }) => {
|
||||
debug!(
|
||||
"[{}] Response did not complete before deadline of {}s.",
|
||||
self.ctx.trace_id(),
|
||||
format_rfc3339(self.deadline)
|
||||
);
|
||||
// No point in responding, since the client will have dropped the
|
||||
// request.
|
||||
Err(ServerError {
|
||||
kind: io::ErrorKind::TimedOut,
|
||||
detail: Some(format!(
|
||||
"Response did not complete before deadline of {}s.",
|
||||
format_rfc3339(self.deadline)
|
||||
)),
|
||||
})
|
||||
}
|
||||
},
|
||||
});
|
||||
*self.as_mut().project().state = RespState::PollReady;
|
||||
}
|
||||
RespState::PollReady => {
|
||||
let ready = ready!(self.as_mut().project().response_tx.poll_ready(cx));
|
||||
if ready.is_err() {
|
||||
return Poll::Ready(());
|
||||
}
|
||||
let resp = (self.ctx, self.as_mut().project().response.take().unwrap());
|
||||
if self
|
||||
.as_mut()
|
||||
.project()
|
||||
.response_tx
|
||||
.start_send(resp)
|
||||
.is_err()
|
||||
{
|
||||
return Poll::Ready(());
|
||||
}
|
||||
*self.as_mut().project().state = RespState::PollFlush;
|
||||
}
|
||||
RespState::PollFlush => {
|
||||
let ready = ready!(self.as_mut().project().response_tx.poll_flush(cx));
|
||||
if ready.is_err() {
|
||||
return Poll::Ready(());
|
||||
}
|
||||
return Poll::Ready(());
|
||||
}
|
||||
}
|
||||
}
|
||||
/// Returns a [future](Future) that executes the request using the given [service
|
||||
/// function](Serve). The service function's output is automatically sent back to the
|
||||
/// [Channel] that yielded this request.
|
||||
///
|
||||
/// The returned future will stop executing when the first of the following conditions is met:
|
||||
///
|
||||
/// 1. The channel that yielded this request receives a [cancellation
|
||||
/// message](ClientMessage::Cancel) for this request.
|
||||
/// 2. The request [deadline](crate::context::Context::deadline) is reached.
|
||||
/// 3. The service function completes.
|
||||
pub fn execute<S>(self, serve: S) -> impl Future<Output = ()>
|
||||
where
|
||||
S: Serve<Req, Resp = Res>,
|
||||
{
|
||||
let Self {
|
||||
abort_registration,
|
||||
request,
|
||||
mut response_tx,
|
||||
} = self;
|
||||
Abortable::new(
|
||||
async move {
|
||||
let Request {
|
||||
context,
|
||||
message,
|
||||
id: request_id,
|
||||
} = request;
|
||||
let trace_id = *request.context.trace_id();
|
||||
let deadline = request.context.deadline;
|
||||
let timeout = deadline.time_until();
|
||||
trace!(
|
||||
"[{}] Handling request with deadline {} (timeout {:?}).",
|
||||
trace_id,
|
||||
format_rfc3339(deadline),
|
||||
timeout,
|
||||
);
|
||||
let result =
|
||||
tokio::time::timeout(timeout, async { serve.serve(context, message).await })
|
||||
.await;
|
||||
let response = Response {
|
||||
request_id,
|
||||
message: match result {
|
||||
Ok(message) => Ok(message),
|
||||
Err(tokio::time::error::Elapsed { .. }) => {
|
||||
debug!(
|
||||
"[{}] Response did not complete before deadline of {}s.",
|
||||
trace_id,
|
||||
format_rfc3339(deadline)
|
||||
);
|
||||
// No point in responding, since the client will have dropped the
|
||||
// request.
|
||||
Err(ServerError {
|
||||
kind: io::ErrorKind::TimedOut,
|
||||
detail: Some(format!(
|
||||
"Response did not complete before deadline of {}s.",
|
||||
format_rfc3339(deadline)
|
||||
)),
|
||||
})
|
||||
}
|
||||
},
|
||||
};
|
||||
let _ = response_tx.send((context, response)).await;
|
||||
},
|
||||
abort_registration,
|
||||
)
|
||||
.unwrap_or_else(|_| {})
|
||||
}
|
||||
}
|
||||
|
||||
impl<F, R> fmt::Debug for Resp<F, R> {
|
||||
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(fmt, "Resp")
|
||||
}
|
||||
}
|
||||
|
||||
impl<C, S> Stream for ClientHandler<C, S>
|
||||
impl<C> Stream for Requests<C>
|
||||
where
|
||||
C: Channel,
|
||||
S: Serve<C::Req, Resp = C::Resp>,
|
||||
{
|
||||
type Item = io::Result<RequestHandler<S::Fut, C::Resp>>;
|
||||
type Item = io::Result<InFlightRequest<C::Req, C::Resp>>;
|
||||
|
||||
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
loop {
|
||||
@@ -684,77 +613,111 @@ where
|
||||
|
||||
// Send + 'static execution helper methods.
|
||||
|
||||
impl<C, S> ClientHandler<C, S>
|
||||
#[cfg(feature = "tokio1")]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))]
|
||||
impl<C> Requests<C>
|
||||
where
|
||||
C: Channel + 'static,
|
||||
C: Channel,
|
||||
C::Req: Send + 'static,
|
||||
C::Resp: Send + 'static,
|
||||
S: Serve<C::Req, Resp = C::Resp> + Send + 'static,
|
||||
S::Fut: Send + 'static,
|
||||
{
|
||||
/// Runs the client handler until completion by [spawning](tokio::spawn) each
|
||||
/// request handler onto the default executor.
|
||||
#[cfg(feature = "tokio1")]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))]
|
||||
pub fn execute(self) -> impl Future<Output = ()> {
|
||||
self.try_for_each(|request_handler| async {
|
||||
tokio::spawn(request_handler);
|
||||
Ok(())
|
||||
})
|
||||
.map_ok(|()| log::info!("ClientHandler finished."))
|
||||
.unwrap_or_else(|e| log::info!("ClientHandler errored out: {}", e))
|
||||
/// Executes all requests using the given service function. Requests are handled concurrently
|
||||
/// by [spawning](tokio::spawn) each handler on tokio's default executor.
|
||||
pub fn execute<S>(self, serve: S) -> TokioChannelExecutor<Self, S>
|
||||
where
|
||||
S: Serve<C::Req, Resp = C::Resp> + Send + Sync + 'static,
|
||||
{
|
||||
TokioChannelExecutor { inner: self, serve }
|
||||
}
|
||||
}
|
||||
|
||||
/// A future that drives the server by [spawning](tokio::spawn) channels and request handlers on the default
|
||||
/// executor.
|
||||
/// A future that drives the server by [spawning](tokio::spawn) a [`TokioChannelExecutor`](TokioChannelExecutor)
|
||||
/// for each new channel.
|
||||
#[pin_project]
|
||||
#[derive(Debug)]
|
||||
#[cfg(feature = "tokio1")]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))]
|
||||
pub struct Running<St, Se> {
|
||||
pub struct TokioServerExecutor<T, S> {
|
||||
#[pin]
|
||||
incoming: St,
|
||||
server: Se,
|
||||
inner: T,
|
||||
serve: S,
|
||||
}
|
||||
|
||||
/// A future that drives the server by [spawning](tokio::spawn) each [response handler](ResponseHandler)
|
||||
/// on tokio's default executor.
|
||||
#[pin_project]
|
||||
#[derive(Debug)]
|
||||
#[cfg(feature = "tokio1")]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))]
|
||||
pub struct TokioChannelExecutor<T, S> {
|
||||
#[pin]
|
||||
inner: T,
|
||||
serve: S,
|
||||
}
|
||||
|
||||
#[cfg(feature = "tokio1")]
|
||||
impl<St, C, Se> Future for Running<St, Se>
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))]
|
||||
impl<T, S> TokioServerExecutor<T, S> {
|
||||
fn inner_pin_mut<'a>(self: &'a mut Pin<&mut Self>) -> Pin<&'a mut T> {
|
||||
self.as_mut().project().inner
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "tokio1")]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))]
|
||||
impl<T, S> TokioChannelExecutor<T, S> {
|
||||
fn inner_pin_mut<'a>(self: &'a mut Pin<&mut Self>) -> Pin<&'a mut T> {
|
||||
self.as_mut().project().inner
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "tokio1")]
|
||||
impl<St, C, Se> Future for TokioServerExecutor<St, Se>
|
||||
where
|
||||
St: Sized + Stream<Item = C>,
|
||||
C: Channel + Send + 'static,
|
||||
C::Req: Send + 'static,
|
||||
C::Resp: Send + 'static,
|
||||
Se: Serve<C::Req, Resp = C::Resp> + Send + 'static + Clone,
|
||||
Se::Fut: Send + 'static,
|
||||
Se: Serve<C::Req, Resp = C::Resp> + Send + Sync + 'static + Clone,
|
||||
Se::Fut: Send,
|
||||
{
|
||||
type Output = ();
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
|
||||
while let Some(channel) = ready!(self.as_mut().project().incoming.poll_next(cx)) {
|
||||
tokio::spawn(
|
||||
channel
|
||||
.respond_with(self.as_mut().project().server.clone())
|
||||
.execute(),
|
||||
);
|
||||
while let Some(channel) = ready!(self.inner_pin_mut().poll_next(cx)) {
|
||||
tokio::spawn(channel.execute(self.serve.clone()));
|
||||
}
|
||||
log::info!("Server shutting down.");
|
||||
Poll::Ready(())
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn abort_in_flight_requests_on_channel_drop() {
|
||||
use assert_matches::assert_matches;
|
||||
use futures::future::Aborted;
|
||||
#[cfg(feature = "tokio1")]
|
||||
impl<C, S> Future for TokioChannelExecutor<Requests<C>, S>
|
||||
where
|
||||
C: Channel + 'static,
|
||||
C::Req: Send + 'static,
|
||||
C::Resp: Send + 'static,
|
||||
S: Serve<C::Req, Resp = C::Resp> + Send + Sync + 'static + Clone,
|
||||
S::Fut: Send,
|
||||
{
|
||||
type Output = ();
|
||||
|
||||
let (_, server_transport) =
|
||||
super::transport::channel::unbounded::<Response<()>, ClientMessage<()>>();
|
||||
let channel = BaseChannel::with_defaults(server_transport);
|
||||
let mut channel = Box::pin(channel);
|
||||
|
||||
let abort_registration = channel.as_mut().start_request(1);
|
||||
let future = Abortable::new(async { () }, abort_registration);
|
||||
drop(channel);
|
||||
assert_matches!(future.await, Err(Aborted));
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
while let Some(response_handler) = ready!(self.inner_pin_mut().poll_next(cx)) {
|
||||
match response_handler {
|
||||
Ok(resp) => {
|
||||
let server = self.serve.clone();
|
||||
tokio::spawn(async move {
|
||||
resp.execute(server).await;
|
||||
});
|
||||
}
|
||||
Err(e) => {
|
||||
log::info!("Requests stream errored out: {}", e);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
Poll::Ready(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -108,8 +108,8 @@ where
|
||||
self.inner.config()
|
||||
}
|
||||
|
||||
fn in_flight_requests(self: Pin<&mut Self>) -> usize {
|
||||
self.project().inner.in_flight_requests()
|
||||
fn in_flight_requests(&self) -> usize {
|
||||
self.inner.in_flight_requests()
|
||||
}
|
||||
|
||||
fn start_request(self: Pin<&mut Self>, request_id: u64) -> AbortRegistration {
|
||||
|
||||
@@ -77,7 +77,7 @@ where
|
||||
&self.config
|
||||
}
|
||||
|
||||
fn in_flight_requests(self: Pin<&mut Self>) -> usize {
|
||||
fn in_flight_requests(&self) -> usize {
|
||||
self.in_flight_requests.len()
|
||||
}
|
||||
|
||||
|
||||
@@ -113,8 +113,8 @@ where
|
||||
type Req = <C as Channel>::Req;
|
||||
type Resp = <C as Channel>::Resp;
|
||||
|
||||
fn in_flight_requests(self: Pin<&mut Self>) -> usize {
|
||||
self.project().inner.in_flight_requests()
|
||||
fn in_flight_requests(&self) -> usize {
|
||||
self.inner.in_flight_requests()
|
||||
}
|
||||
|
||||
fn config(&self) -> &Config {
|
||||
@@ -292,7 +292,7 @@ fn throttler_poll_next_throttled_sink_not_ready() {
|
||||
fn config(&self) -> &Config {
|
||||
unimplemented!()
|
||||
}
|
||||
fn in_flight_requests(self: Pin<&mut Self>) -> usize {
|
||||
fn in_flight_requests(&self) -> usize {
|
||||
0
|
||||
}
|
||||
fn start_request(self: Pin<&mut Self>, _: u64) -> AbortRegistration {
|
||||
|
||||
@@ -82,7 +82,7 @@ impl<Item, SinkItem> Sink<SinkItem> for UnboundedChannel<Item, SinkItem> {
|
||||
mod tests {
|
||||
use crate::{
|
||||
client, context,
|
||||
server::{Handler, Server},
|
||||
server::{BaseChannel, Incoming},
|
||||
transport,
|
||||
};
|
||||
use assert_matches::assert_matches;
|
||||
@@ -96,9 +96,9 @@ mod tests {
|
||||
|
||||
let (client_channel, server_channel) = transport::channel::unbounded();
|
||||
tokio::spawn(
|
||||
Server::default()
|
||||
.incoming(stream::once(future::ready(server_channel)))
|
||||
.respond_with(|_ctx, request: String| {
|
||||
stream::once(future::ready(server_channel))
|
||||
.map(BaseChannel::with_defaults)
|
||||
.execute(|_ctx, request: String| {
|
||||
future::ready(request.parse::<u64>().map_err(|_| {
|
||||
io::Error::new(
|
||||
io::ErrorKind::InvalidInput,
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
use futures::prelude::*;
|
||||
use std::io;
|
||||
use tarpc::serde_transport;
|
||||
use tarpc::{client, context, server::Handler};
|
||||
use tarpc::{
|
||||
client, context,
|
||||
server::{BaseChannel, Incoming},
|
||||
};
|
||||
use tokio_serde::formats::Json;
|
||||
|
||||
#[tarpc::derive_serde]
|
||||
@@ -34,9 +37,11 @@ async fn test_call() -> io::Result<()> {
|
||||
let transport = tarpc::serde_transport::tcp::listen("localhost:56797", Json::default).await?;
|
||||
let addr = transport.local_addr();
|
||||
tokio::spawn(
|
||||
tarpc::Server::default()
|
||||
.incoming(transport.take(1).filter_map(|r| async { r.ok() }))
|
||||
.respond_with(ColorServer.serve()),
|
||||
transport
|
||||
.take(1)
|
||||
.filter_map(|r| async { r.ok() })
|
||||
.map(BaseChannel::with_defaults)
|
||||
.execute(ColorServer.serve()),
|
||||
);
|
||||
|
||||
let transport = serde_transport::tcp::connect(addr, Json::default).await?;
|
||||
|
||||
@@ -5,13 +5,12 @@ use futures::{
|
||||
};
|
||||
use std::{
|
||||
io,
|
||||
sync::Arc,
|
||||
time::{Duration, SystemTime},
|
||||
};
|
||||
use tarpc::{
|
||||
client::{self},
|
||||
context,
|
||||
server::{self, BaseChannel, Channel, Handler},
|
||||
server::{self, BaseChannel, Channel, Incoming},
|
||||
transport::channel,
|
||||
};
|
||||
use tokio::join;
|
||||
@@ -47,8 +46,8 @@ async fn sequential() -> io::Result<()> {
|
||||
|
||||
tokio::spawn(
|
||||
BaseChannel::new(server::Config::default(), rx)
|
||||
.respond_with(Server.serve())
|
||||
.execute(),
|
||||
.requests()
|
||||
.execute(Server.serve()),
|
||||
);
|
||||
|
||||
let mut client = ServiceClient::new(client::Config::default(), tx).spawn()?;
|
||||
@@ -68,19 +67,14 @@ async fn dropped_channel_aborts_in_flight_requests() -> io::Result<()> {
|
||||
async fn r#loop();
|
||||
}
|
||||
|
||||
struct LoopServer(tokio::sync::mpsc::UnboundedSender<AllHandlersComplete>);
|
||||
#[derive(Clone)]
|
||||
struct LoopServer;
|
||||
|
||||
#[derive(Debug)]
|
||||
struct AllHandlersComplete;
|
||||
|
||||
impl Drop for LoopServer {
|
||||
fn drop(&mut self) {
|
||||
let _ = self.0.send(AllHandlersComplete);
|
||||
}
|
||||
}
|
||||
|
||||
#[tarpc::server]
|
||||
impl Loop for Arc<LoopServer> {
|
||||
impl Loop for LoopServer {
|
||||
async fn r#loop(self, _: context::Context) {
|
||||
loop {
|
||||
futures::pending!();
|
||||
@@ -91,7 +85,6 @@ async fn dropped_channel_aborts_in_flight_requests() -> io::Result<()> {
|
||||
let _ = env_logger::try_init();
|
||||
|
||||
let (tx, rx) = channel::unbounded();
|
||||
let (rpc_finished_tx, mut rpc_finished) = tokio::sync::mpsc::unbounded_channel();
|
||||
|
||||
// Set up a client that initiates a long-lived request.
|
||||
// The request will complete in error when the server drops the connection.
|
||||
@@ -105,18 +98,16 @@ async fn dropped_channel_aborts_in_flight_requests() -> io::Result<()> {
|
||||
let _ = client.r#loop(ctx).await;
|
||||
});
|
||||
|
||||
let mut server =
|
||||
BaseChannel::with_defaults(rx).respond_with(Arc::new(LoopServer(rpc_finished_tx)).serve());
|
||||
let first_handler = server.next().await.unwrap()?;
|
||||
let mut requests = BaseChannel::with_defaults(rx).requests();
|
||||
// Reading a request should trigger the request being registered with BaseChannel.
|
||||
let first_request = requests.next().await.unwrap()?;
|
||||
// Dropping the channel should trigger cleanup of outstanding requests.
|
||||
drop(requests);
|
||||
// In-flight requests should be aborted by channel cleanup.
|
||||
// The first and only request sent by the client is `loop`, which is an infinite loop
|
||||
// on the server side, so if cleanup was not triggered, this line should hang indefinitely.
|
||||
first_request.execute(LoopServer.serve()).await;
|
||||
|
||||
drop(server);
|
||||
first_handler.await;
|
||||
|
||||
// At this point, a single RPC has been sent and a single response initiated.
|
||||
// The request handler will loop for a long time unless aborted.
|
||||
// Now, we assert that the act of disconnecting a client is sufficient to abort all
|
||||
// handlers initiated by the connection's RPCs.
|
||||
assert_matches!(rpc_finished.recv().await, Some(AllHandlersComplete));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -131,9 +122,11 @@ async fn serde() -> io::Result<()> {
|
||||
let transport = tarpc::serde_transport::tcp::listen("localhost:56789", Json::default).await?;
|
||||
let addr = transport.local_addr();
|
||||
tokio::spawn(
|
||||
tarpc::Server::default()
|
||||
.incoming(transport.take(1).filter_map(|r| async { r.ok() }))
|
||||
.respond_with(Server.serve()),
|
||||
transport
|
||||
.take(1)
|
||||
.filter_map(|r| async { r.ok() })
|
||||
.map(BaseChannel::with_defaults)
|
||||
.execute(Server.serve()),
|
||||
);
|
||||
|
||||
let transport = serde_transport::tcp::connect(addr, Json::default).await?;
|
||||
@@ -154,9 +147,9 @@ async fn concurrent() -> io::Result<()> {
|
||||
|
||||
let (tx, rx) = channel::unbounded();
|
||||
tokio::spawn(
|
||||
tarpc::Server::default()
|
||||
.incoming(stream::once(ready(rx)))
|
||||
.respond_with(Server.serve()),
|
||||
stream::once(ready(rx))
|
||||
.map(BaseChannel::with_defaults)
|
||||
.execute(Server.serve()),
|
||||
);
|
||||
|
||||
let client = ServiceClient::new(client::Config::default(), tx).spawn()?;
|
||||
@@ -183,9 +176,9 @@ async fn concurrent_join() -> io::Result<()> {
|
||||
|
||||
let (tx, rx) = channel::unbounded();
|
||||
tokio::spawn(
|
||||
tarpc::Server::default()
|
||||
.incoming(stream::once(ready(rx)))
|
||||
.respond_with(Server.serve()),
|
||||
stream::once(ready(rx))
|
||||
.map(BaseChannel::with_defaults)
|
||||
.execute(Server.serve()),
|
||||
);
|
||||
|
||||
let client = ServiceClient::new(client::Config::default(), tx).spawn()?;
|
||||
@@ -213,9 +206,9 @@ async fn concurrent_join_all() -> io::Result<()> {
|
||||
|
||||
let (tx, rx) = channel::unbounded();
|
||||
tokio::spawn(
|
||||
tarpc::Server::default()
|
||||
.incoming(stream::once(ready(rx)))
|
||||
.respond_with(Server.serve()),
|
||||
stream::once(ready(rx))
|
||||
.map(BaseChannel::with_defaults)
|
||||
.execute(Server.serve()),
|
||||
);
|
||||
|
||||
let client = ServiceClient::new(client::Config::default(), tx).spawn()?;
|
||||
|
||||
Reference in New Issue
Block a user