// Copyright 2018 Google LLC // // Use of this source code is governed by an MIT-style // license that can be found in the LICENSE file or at // https://opensource.org/licenses/MIT. //! Provides a server that concurrently handles many connections sending multiplexed requests. use crate::{ context::Context, util::deadline_compat, util::AsDuration, util::Compact, ClientMessage, ClientMessageKind, Request, Response, ServerError, Transport, }; use fnv::FnvHashMap; use futures::{ channel::mpsc, future::{abortable, AbortHandle}, prelude::*, ready, stream::Fuse, task::{LocalWaker, Poll}, try_ready, }; use humantime::format_rfc3339; use log::{debug, error, info, trace, warn}; use pin_utils::{unsafe_pinned, unsafe_unpinned}; use std::{ error::Error as StdError, io, marker::PhantomData, net::SocketAddr, pin::Pin, time::{Instant, SystemTime}, }; use tokio_timer::timeout; use trace::{self, TraceId}; mod filter; /// Manages clients, serving multiplexed requests over each connection. #[derive(Debug)] pub struct Server { config: Config, ghost: PhantomData<(Req, Resp)>, } /// Settings that control the behavior of the server. #[non_exhaustive] #[derive(Clone, Debug)] pub struct Config { /// The maximum number of clients that can be connected to the server at once. When at the /// limit, existing connections are honored and new connections are rejected. pub max_connections: usize, /// The maximum number of clients per IP address that can be connected to the server at once. /// When an IP is at the limit, existing connections are honored and new connections on that IP /// address are rejected. pub max_connections_per_ip: usize, /// The maximum number of requests that can be in flight for each client. When a client is at /// the in-flight request limit, existing requests are fulfilled and new requests are rejected. /// Rejected requests are sent a response error. pub max_in_flight_requests_per_connection: usize, /// The number of responses per client that can be buffered server-side before being sent. /// `pending_response_buffer` controls the buffer size of the channel that a server's /// response tasks use to send responses to the client handler task. pub pending_response_buffer: usize, } impl Default for Config { fn default() -> Self { Config { max_connections: 1_000_000, max_connections_per_ip: 1_000, max_in_flight_requests_per_connection: 1_000, pending_response_buffer: 100, } } } impl Server { /// Returns a new server with configuration specified `config`. pub fn new(config: Config) -> Self { Server { config, ghost: PhantomData, } } /// Returns the config for this server. pub fn config(&self) -> &Config { &self.config } /// Returns a stream of the incoming connections to the server. pub fn incoming( self, listener: S, ) -> impl Stream>> where Req: Send, Resp: Send, S: Stream>, T: Transport, SinkItem = Response> + Send, { self::filter::ConnectionFilter::filter(listener, self.config.clone()) } } /// The future driving the server. #[derive(Debug)] pub struct Running { incoming: S, request_handler: F, } impl Running { unsafe_pinned!(incoming: S); unsafe_unpinned!(request_handler: F); } impl Future for Running where S: Sized + Stream>>, Req: Send + 'static, Resp: Send + 'static, T: Transport, SinkItem = Response> + Send + 'static, F: FnMut(Context, Req) -> Fut + Send + 'static + Clone, Fut: Future> + Send + 'static, { type Output = (); fn poll(mut self: Pin<&mut Self>, cx: &LocalWaker) -> Poll<()> { while let Some(channel) = ready!(self.incoming().poll_next(cx)) { match channel { Ok(channel) => { let peer = channel.client_addr; if let Err(e) = crate::spawn(channel.respond_with(self.request_handler().clone())) { warn!("[{}] Failed to spawn connection handler: {:?}", peer, e); } } Err(e) => { warn!("Incoming connection error: {}", e); } } } info!("Server shutting down."); return Poll::Ready(()); } } /// A utility trait enabling a stream to fluently chain a request handler. pub trait Handler where Self: Sized + Stream>>, Req: Send, Resp: Send, T: Transport, SinkItem = Response> + Send, { /// Responds to all requests with `request_handler`. fn respond_with(self, request_handler: F) -> Running where F: FnMut(Context, Req) -> Fut + Send + 'static + Clone, Fut: Future> + Send + 'static, { Running { incoming: self, request_handler, } } } impl Handler for S where S: Sized + Stream>>, Req: Send, Resp: Send, T: Transport, SinkItem = Response> + Send, {} /// Responds to all requests with `request_handler`. /// The server end of an open connection with a client. #[derive(Debug)] pub struct Channel { /// Writes responses to the wire and reads requests off the wire. transport: Fuse, /// Signals the connection is closed when `Channel` is dropped. closed_connections: mpsc::UnboundedSender, /// Channel limits to prevent unlimited resource usage. config: Config, /// The address of the server connected to. client_addr: SocketAddr, /// Types the request and response. ghost: PhantomData<(Req, Resp)>, } impl Drop for Channel { fn drop(&mut self) { trace!("[{}] Closing channel.", self.client_addr); // Even in a bounded channel, each connection would have a guaranteed slot, so using // an unbounded sender is actually no different. And, the bound is on the maximum number // of open connections. if self .closed_connections .unbounded_send(self.client_addr) .is_err() { warn!( "[{}] Failed to send closed connection message.", self.client_addr ); } } } impl Channel { unsafe_pinned!(transport: Fuse); } impl Channel where T: Transport, SinkItem = Response> + Send, Req: Send, Resp: Send, { pub(crate) fn start_send(self: &mut Pin<&mut Self>, response: Response) -> io::Result<()> { self.transport().start_send(response) } pub(crate) fn poll_ready( self: &mut Pin<&mut Self>, cx: &LocalWaker, ) -> Poll> { self.transport().poll_ready(cx) } pub(crate) fn poll_flush( self: &mut Pin<&mut Self>, cx: &LocalWaker, ) -> Poll> { self.transport().poll_flush(cx) } pub(crate) fn poll_next( self: &mut Pin<&mut Self>, cx: &LocalWaker, ) -> Poll>>> { self.transport().poll_next(cx) } /// Returns the address of the client connected to the channel. pub fn client_addr(&self) -> &SocketAddr { &self.client_addr } /// Respond to requests coming over the channel with `f`. Returns a future that drives the /// responses and resolves when the connection is closed. pub fn respond_with(self, f: F) -> impl Future where F: FnMut(Context, Req) -> Fut + Send + 'static, Fut: Future> + Send + 'static, Req: 'static, Resp: 'static, { let (responses_tx, responses) = mpsc::channel(self.config.pending_response_buffer); let responses = responses.fuse(); let peer = self.client_addr; ClientHandler { channel: self, f, pending_responses: responses, responses_tx, in_flight_requests: FnvHashMap::default(), }.unwrap_or_else(move |e| { info!("[{}] ClientHandler errored out: {}", peer, e); }) } } #[derive(Debug)] struct ClientHandler { channel: Channel, /// Responses waiting to be written to the wire. pending_responses: Fuse)>>, /// Handed out to request handlers to fan in responses. responses_tx: mpsc::Sender<(Context, Response)>, /// Number of requests currently being responded to. in_flight_requests: FnvHashMap, /// Request handler. f: F, } impl ClientHandler { unsafe_pinned!(channel: Channel); unsafe_pinned!(in_flight_requests: FnvHashMap); unsafe_pinned!(pending_responses: Fuse)>>); unsafe_pinned!(responses_tx: mpsc::Sender<(Context, Response)>); // For this to be safe, field f must be private, and code in this module must never // construct PinMut. unsafe_unpinned!(f: F); } impl ClientHandler where Req: Send + 'static, Resp: Send + 'static, T: Transport, SinkItem = Response> + Send, F: FnMut(Context, Req) -> Fut + Send + 'static, Fut: Future> + Send + 'static, { /// If at max in-flight requests, check that there's room to immediately write a throttled /// response. fn poll_ready_if_throttling( self: &mut Pin<&mut Self>, cx: &LocalWaker, ) -> Poll> { if self.in_flight_requests.len() >= self.channel.config.max_in_flight_requests_per_connection { let peer = self.channel().client_addr; while let Poll::Pending = self.channel().poll_ready(cx)? { info!( "[{}] In-flight requests at max ({}), and transport is not ready.", peer, self.in_flight_requests().len(), ); try_ready!(self.channel().poll_flush(cx)); } } Poll::Ready(Ok(())) } fn pump_read(self: &mut Pin<&mut Self>, cx: &LocalWaker) -> Poll>> { ready!(self.poll_ready_if_throttling(cx)?); Poll::Ready(match ready!(self.channel().poll_next(cx)?) { Some(message) => { match message.message { ClientMessageKind::Request(request) => { self.handle_request(message.trace_context, request)?; } ClientMessageKind::Cancel { request_id } => { self.cancel_request(&message.trace_context, request_id); } } Some(Ok(())) } None => { trace!("[{}] Read half closed", self.channel.client_addr); None } }) } fn pump_write( self: &mut Pin<&mut Self>, cx: &LocalWaker, read_half_closed: bool, ) -> Poll>> { match self.poll_next_response(cx)? { Poll::Ready(Some((_, response))) => { self.channel().start_send(response)?; Poll::Ready(Some(Ok(()))) } Poll::Ready(None) => { // Shutdown can't be done before we finish pumping out remaining responses. ready!(self.channel().poll_flush(cx)?); Poll::Ready(None) } Poll::Pending => { // No more requests to process, so flush any requests buffered in the transport. ready!(self.channel().poll_flush(cx)?); // Being here means there are no staged requests and all written responses are // fully flushed. So, if the read half is closed and there are no in-flight // requests, then we can close the write half. if read_half_closed && self.in_flight_requests().is_empty() { Poll::Ready(None) } else { Poll::Pending } } } } fn poll_next_response( self: &mut Pin<&mut Self>, cx: &LocalWaker, ) -> Poll)>>> { // Ensure there's room to write a response. while let Poll::Pending = self.channel().poll_ready(cx)? { ready!(self.channel().poll_flush(cx)?); } let peer = self.channel().client_addr; match ready!(self.pending_responses().poll_next(cx)) { Some((ctx, response)) => { if let Some(_) = self.in_flight_requests().remove(&response.request_id) { self.in_flight_requests().compact(0.1); } trace!( "[{}/{}] Staging response. In-flight requests = {}.", ctx.trace_id(), peer, self.in_flight_requests().len(), ); return Poll::Ready(Some(Ok((ctx, response)))); } None => { // This branch likely won't happen, since the ClientHandler is holding a Sender. trace!("[{}] No new responses.", peer); Poll::Ready(None) } } } fn handle_request( self: &mut Pin<&mut Self>, trace_context: trace::Context, request: Request, ) -> io::Result<()> { let request_id = request.id; let peer = self.channel().client_addr; let ctx = Context { deadline: request.deadline, trace_context, }; let request = request.message; if self.in_flight_requests().len() >= self.channel().config.max_in_flight_requests_per_connection { debug!( "[{}/{}] Client has reached in-flight request limit ({}/{}).", ctx.trace_id(), peer, self.in_flight_requests().len(), self.channel().config.max_in_flight_requests_per_connection ); self.channel().start_send(Response { request_id, message: Err(ServerError { kind: io::ErrorKind::WouldBlock, detail: Some("Server throttled the request.".into()), }), })?; return Ok(()); } let deadline = ctx.deadline; let timeout = deadline.as_duration(); trace!( "[{}/{}] Received request with deadline {} (timeout {:?}).", ctx.trace_id(), peer, format_rfc3339(deadline), timeout, ); let mut response_tx = self.responses_tx().clone(); let trace_id = *ctx.trace_id(); let response = self.f()(ctx.clone(), request); let response = deadline_compat::Deadline::new(response, Instant::now() + timeout).then( async move |result| { let response = Response { request_id, message: match result { Ok(message) => Ok(message), Err(e) => Err(make_server_error(e, trace_id, peer, deadline)), }, }; trace!("[{}/{}] Sending response.", trace_id, peer); await!(response_tx.send((ctx, response)).unwrap_or_else(|_| ())); }, ); let (abortable_response, abort_handle) = abortable(response); crate::spawn(abortable_response.map(|_| ())) .map_err(|e| { io::Error::new( io::ErrorKind::Other, format!( "Could not spawn response task. Is shutdown: {}", e.is_shutdown() ), ) })?; self.in_flight_requests().insert(request_id, abort_handle); Ok(()) } fn cancel_request(self: &mut Pin<&mut Self>, trace_context: &trace::Context, request_id: u64) { // It's possible the request was already completed, so it's fine // if this is None. if let Some(cancel_handle) = self.in_flight_requests().remove(&request_id) { self.in_flight_requests().compact(0.1); cancel_handle.abort(); let remaining = self.in_flight_requests().len(); trace!( "[{}/{}] Request canceled. In-flight requests = {}", trace_context.trace_id, self.channel.client_addr, remaining, ); } else { trace!( "[{}/{}] Received cancellation, but response handler \ is already complete.", trace_context.trace_id, self.channel.client_addr ); } } } impl Future for ClientHandler where Req: Send + 'static, Resp: Send + 'static, T: Transport, SinkItem = Response> + Send, F: FnMut(Context, Req) -> Fut + Send + 'static, Fut: Future> + Send + 'static, { type Output = io::Result<()>; fn poll(mut self: Pin<&mut Self>, cx: &LocalWaker) -> Poll> { trace!("[{}] ClientHandler::poll", self.channel.client_addr); loop { let read = self.pump_read(cx)?; match (read, self.pump_write(cx, read == Poll::Ready(None))?) { (Poll::Ready(None), Poll::Ready(None)) => { info!("[{}] Client disconnected.", self.channel.client_addr); return Poll::Ready(Ok(())); } (read @ Poll::Ready(Some(())), write) | (read, write @ Poll::Ready(Some(()))) => { trace!( "[{}] read: {:?}, write: {:?}.", self.channel.client_addr, read, write ) } (read, write) => { trace!( "[{}] read: {:?}, write: {:?} (not ready).", self.channel.client_addr, read, write, ); return Poll::Pending; } } } } } fn make_server_error( e: timeout::Error, trace_id: TraceId, peer: SocketAddr, deadline: SystemTime, ) -> ServerError { if e.is_elapsed() { debug!( "[{}/{}] Response did not complete before deadline of {}s.", trace_id, peer, format_rfc3339(deadline) ); // No point in responding, since the client will have dropped the request. ServerError { kind: io::ErrorKind::TimedOut, detail: Some(format!( "Response did not complete before deadline of {}s.", format_rfc3339(deadline) )), } } else if e.is_timer() { error!( "[{}/{}] Response failed because of an issue with a timer: {}", trace_id, peer, e ); ServerError { kind: io::ErrorKind::Other, detail: Some(format!("{}", e)), } } else if e.is_inner() { let e = e.into_inner().unwrap(); ServerError { kind: e.kind(), detail: Some(e.description().into()), } } else { error!("[{}/{}] Unexpected response failure: {}", trace_id, peer, e); ServerError { kind: io::ErrorKind::Other, detail: Some(format!("Server unexpectedly failed to respond: {}", e)), } } }