From 15080b28893ec5f38c362c59551d9b9ab0a78d4a Mon Sep 17 00:00:00 2001 From: Tim Date: Thu, 23 Mar 2017 00:08:08 -0700 Subject: [PATCH] Set a default max packet size. (#128) The default max packet size of 2 << 20 --- src/future/client.rs | 36 ++++++++++++++++++++++++---- src/future/server.rs | 39 ++++++++++++++++++++++++++---- src/protocol.rs | 57 +++++++++++++++++++++++++++++++++++++------- src/sync/client.rs | 26 +++++++++++++++++++- src/sync/server.rs | 6 +++++ 5 files changed, 147 insertions(+), 17 deletions(-) diff --git a/src/future/client.rs b/src/future/client.rs index 1f5e791..0914e2c 100644 --- a/src/future/client.rs +++ b/src/future/client.rs @@ -27,14 +27,40 @@ cfg_if! { } /// Additional options to configure how the client connects and operates. -#[derive(Default)] pub struct Options { + /// Max packet size in bytes. + max_payload_size: u64, reactor: Option, #[cfg(feature = "tls")] tls_ctx: Option, } +impl Default for Options { + #[cfg(feature = "tls")] + fn default() -> Self { + Options { + max_payload_size: 2 << 20, + reactor: None, + tls_ctx: None, + } + } + + #[cfg(not(feature = "tls"))] + fn default() -> Self { + Options { + max_payload_size: 2 << 20, + reactor: None, + } + } +} + impl Options { + /// Set the max payload size in bytes. The default is 2,000,000 (2 MB). + pub fn max_payload_size(mut self, bytes: u64) -> Self { + self.max_payload_size = bytes; + self + } + /// Drive using the given reactor handle. Only used by `FutureClient`s. pub fn handle(mut self, handle: reactor::Handle) -> Self { self.reactor = Some(Reactor::Handle(handle)); @@ -106,12 +132,12 @@ impl Client Resp: Deserialize + 'static, E: Deserialize + 'static { - fn bind(handle: &reactor::Handle, tcp: StreamType) -> Self + fn bind(handle: &reactor::Handle, tcp: StreamType, max_payload_size: u64) -> Self where Req: Serialize + Sync + Send + 'static, Resp: Deserialize + Sync + Send + 'static, E: Deserialize + Sync + Send + 'static { - let inner = Proto::new().bind_client(&handle, tcp); + let inner = Proto::new(max_payload_size).bind_client(&handle, tcp); Client { inner } } @@ -161,6 +187,8 @@ impl ClientExt for Client #[cfg(feature = "tls")] let tls_ctx = options.tls_ctx.take(); + let max_payload_size = options.max_payload_size; + let connect = move |handle: &reactor::Handle| { let handle2 = handle.clone(); TcpStream::connect(&addr, handle) @@ -180,7 +208,7 @@ impl ClientExt for Client #[cfg(not(feature = "tls"))] future::ok(StreamType::Tcp(socket)) }) - .map(move |tcp| Client::bind(&handle2, tcp)) + .map(move |tcp| Client::bind(&handle2, tcp, max_payload_size)) }; let (tx, rx) = futures::oneshot(); let setup = move |handle: &reactor::Handle| { diff --git a/src/future/server.rs b/src/future/server.rs index c07269e..476e415 100644 --- a/src/future/server.rs +++ b/src/future/server.rs @@ -51,7 +51,10 @@ impl Handle { E: Serialize + 'static { let (addr, shutdown, server) = - listen_with(new_service, addr, handle, Acceptor::from(options))?; + listen_with(new_service, + addr, handle, + options.max_payload_size, + Acceptor::from(options))?; Ok((Handle { addr: addr, shutdown: shutdown, @@ -145,14 +148,38 @@ impl Fn<((TcpStream, SocketAddr),)> for Acceptor { } /// Additional options to configure how the server operates. -#[derive(Default)] pub struct Options { + /// Max packet size in bytes. + max_payload_size: u64, #[cfg(feature = "tls")] tls_acceptor: Option, } +impl Default for Options { + #[cfg(not(feature = "tls"))] + fn default() -> Self { + Options { + max_payload_size: 2 << 20, + } + } + + #[cfg(feature = "tls")] + fn default() -> Self { + Options { + max_payload_size: 2 << 20, + tls_acceptor: None, + } + } +} + impl Options { - /// Set the `TlsAcceptor` + /// Set the max payload size in bytes. The default is 2,000,000 (2 MB). + pub fn max_payload_size(mut self, bytes: u64) -> Self { + self.max_payload_size = bytes; + self + } + + /// Sets the `TlsAcceptor` #[cfg(feature = "tls")] pub fn tls(mut self, tls_acceptor: TlsAcceptor) -> Self { self.tls_acceptor = Some(tls_acceptor); @@ -497,6 +524,7 @@ impl Future for Listen fn listen_with(new_service: S, addr: SocketAddr, handle: &reactor::Handle, + max_payload_size: u64, acceptor: Acceptor) -> io::Result<(SocketAddr, Shutdown, Listen)> where S: NewService, @@ -516,6 +544,7 @@ fn listen_with(new_service: S, let server = listener.incoming() .and_then(acceptor) .for_each(Bind { + max_payload_size: max_payload_size, handle: handle, new_service: ConnectionTrackingNewService { connection_tracker: connection_tracker, @@ -533,6 +562,7 @@ fn log_err(e: io::Error) { } struct Bind { + max_payload_size: u64, handle: reactor::Handle, new_service: S, } @@ -548,7 +578,8 @@ impl Bind fn bind(&self, socket: I) -> io::Result<()> where I: Io + 'static { - Proto::new().bind_server(&self.handle, socket, self.new_service.new_service()?); + Proto::new(self.max_payload_size) + .bind_server(&self.handle, socket, self.new_service.new_service()?); Ok(()) } } diff --git a/src/protocol.rs b/src/protocol.rs index 3af2abf..b002e31 100644 --- a/src/protocol.rs +++ b/src/protocol.rs @@ -15,6 +15,7 @@ use tokio_proto::streaming::multiplex::RequestId; // `Encode` is the type that `Codec` encodes. `Decode` is the type it decodes. pub struct Codec { + max_payload_size: u64, state: CodecState, _phantom_data: PhantomData<(Encode, Decode)>, } @@ -26,14 +27,23 @@ enum CodecState { } impl Codec { - fn new() -> Self { + fn new(max_payload_size: u64) -> Self { Codec { + max_payload_size: max_payload_size, state: CodecState::Id, _phantom_data: PhantomData, } } } +fn too_big(payload_size: u64, max_payload_size: u64) -> io::Error { + warn!("Not sending too-big packet of size {} (max is {})", + payload_size, max_payload_size); + io::Error::new(io::ErrorKind::InvalidData, + format!("Maximum payload size is {} bytes but got a payload of {}", + max_payload_size, payload_size)) +} + impl tokio_core::io::Codec for Codec where Encode: serde::Serialize, Decode: serde::Deserialize @@ -44,7 +54,11 @@ impl tokio_core::io::Codec for Codec fn encode(&mut self, (id, message): Self::Out, buf: &mut Vec) -> io::Result<()> { buf.write_u64::(id).unwrap(); trace!("Encoded request id = {} as {:?}", id, buf); - buf.write_u64::(bincode::serialized_size(&message)).unwrap(); + let payload_size = bincode::serialized_size(&message); + if payload_size > self.max_payload_size { + return Err(too_big(payload_size, self.max_payload_size)); + } + buf.write_u64::(payload_size).unwrap(); bincode::serialize_into(buf, &message, Infinite) @@ -80,6 +94,9 @@ impl tokio_core::io::Codec for Codec trace!("--> Parsed payload length = {}, remaining buffer length = {}", len, buf.len()); + if len > self.max_payload_size { + return Err(too_big(len, self.max_payload_size)); + } self.state = Payload { id: id, len: len }; } Payload { len, .. } if buf.len() < len as usize => { @@ -104,12 +121,18 @@ impl tokio_core::io::Codec for Codec } /// Implements the `multiplex::ServerProto` trait. -pub struct Proto(PhantomData<(Encode, Decode)>); +pub struct Proto { + max_payload_size: u64, + _phantom_data: PhantomData<(Encode, Decode)>, +} impl Proto { /// Returns a new `Proto`. - pub fn new() -> Self { - Proto(PhantomData) + pub fn new(max_payload_size: u64) -> Self { + Proto { + max_payload_size: max_payload_size, + _phantom_data: PhantomData + } } } @@ -124,7 +147,7 @@ impl ServerProto for Proto type BindTransport = Result; fn bind_transport(&self, io: T) -> Self::BindTransport { - Ok(io.framed(Codec::new())) + Ok(io.framed(Codec::new(self.max_payload_size))) } } @@ -139,7 +162,7 @@ impl ClientProto for Proto type BindTransport = Result; fn bind_transport(&self, io: T) -> Self::BindTransport { - Ok(io.framed(Codec::new())) + Ok(io.framed(Codec::new(self.max_payload_size))) } } @@ -153,7 +176,7 @@ fn serialize() { // Serialize twice to check for idempotence. for _ in 0..2 { - let mut codec: Codec<(char, char, char), (char, char, char)> = Codec::new(); + let mut codec: Codec<(char, char, char), (char, char, char)> = Codec::new(2_000_000); codec.encode(MSG, &mut vec).unwrap(); buf.get_mut().append(&mut vec); let actual: Result)>, io::Error> = @@ -169,3 +192,21 @@ fn serialize() { *buf.get_mut()); } } + +#[test] +fn deserialize_big() { + use tokio_core::io::Codec as TokioCodec; + let mut codec: Codec, Vec> = Codec::new(24); + + let mut vec = Vec::new(); + assert_eq!(codec.encode((0, vec![0; 24]), &mut vec).err().unwrap().kind(), + io::ErrorKind::InvalidData); + + let mut buf = EasyBuf::new(); + // Header + buf.get_mut().append(&mut vec![0; 8]); + // Len + buf.get_mut().append(&mut vec![0, 0, 0, 0, 0, 0, 0, 25]); + assert_eq!(codec.decode(&mut buf).err().unwrap().kind(), + io::ErrorKind::InvalidData); +} diff --git a/src/sync/client.rs b/src/sync/client.rs index 6aa142c..7031c97 100644 --- a/src/sync/client.rs +++ b/src/sync/client.rs @@ -46,13 +46,37 @@ impl Client } /// Additional options to configure how the client connects and operates. -#[derive(Default)] pub struct Options { + /// Max packet size in bytes. + max_payload_size: u64, #[cfg(feature = "tls")] tls_ctx: Option, } +impl Default for Options { + #[cfg(not(feature = "tls"))] + fn default() -> Self { + Options { + max_payload_size: 2_000_000, + } + } + + #[cfg(feature = "tls")] + fn default() -> Self { + Options { + max_payload_size: 2_000_000, + tls_ctx: None, + } + } +} + impl Options { + /// Set the max payload size in bytes. The default is 2,000,000 (2 MB). + pub fn max_payload_size(mut self, bytes: u64) -> Self { + self.max_payload_size = bytes; + self + } + /// Connect using the given `Context` #[cfg(feature = "tls")] pub fn tls(mut self, ctx: Context) -> Self { diff --git a/src/sync/server.rs b/src/sync/server.rs index 270c54f..8a726c9 100644 --- a/src/sync/server.rs +++ b/src/sync/server.rs @@ -16,6 +16,12 @@ pub struct Options { } impl Options { + /// Set the max payload size in bytes. The default is 2,000,000 (2 MB). + pub fn max_payload_size(mut self, bytes: u64) -> Self { + self.opts = self.opts.max_payload_size(bytes); + self + } + /// Set the `TlsAcceptor` #[cfg(feature = "tls")] pub fn tls(mut self, tls_acceptor: TlsAcceptor) -> Self {