mirror of
https://github.com/OMGeeky/tarpc.git
synced 2026-02-23 15:49:54 +01:00
Set a default max packet size. (#128)
The default max packet size of 2 << 20
This commit is contained in:
@@ -15,6 +15,7 @@ use tokio_proto::streaming::multiplex::RequestId;
|
||||
|
||||
// `Encode` is the type that `Codec` encodes. `Decode` is the type it decodes.
|
||||
pub struct Codec<Encode, Decode> {
|
||||
max_payload_size: u64,
|
||||
state: CodecState,
|
||||
_phantom_data: PhantomData<(Encode, Decode)>,
|
||||
}
|
||||
@@ -26,14 +27,23 @@ enum CodecState {
|
||||
}
|
||||
|
||||
impl<Encode, Decode> Codec<Encode, Decode> {
|
||||
fn new() -> Self {
|
||||
fn new(max_payload_size: u64) -> Self {
|
||||
Codec {
|
||||
max_payload_size: max_payload_size,
|
||||
state: CodecState::Id,
|
||||
_phantom_data: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn too_big(payload_size: u64, max_payload_size: u64) -> io::Error {
|
||||
warn!("Not sending too-big packet of size {} (max is {})",
|
||||
payload_size, max_payload_size);
|
||||
io::Error::new(io::ErrorKind::InvalidData,
|
||||
format!("Maximum payload size is {} bytes but got a payload of {}",
|
||||
max_payload_size, payload_size))
|
||||
}
|
||||
|
||||
impl<Encode, Decode> tokio_core::io::Codec for Codec<Encode, Decode>
|
||||
where Encode: serde::Serialize,
|
||||
Decode: serde::Deserialize
|
||||
@@ -44,7 +54,11 @@ impl<Encode, Decode> tokio_core::io::Codec for Codec<Encode, Decode>
|
||||
fn encode(&mut self, (id, message): Self::Out, buf: &mut Vec<u8>) -> io::Result<()> {
|
||||
buf.write_u64::<BigEndian>(id).unwrap();
|
||||
trace!("Encoded request id = {} as {:?}", id, buf);
|
||||
buf.write_u64::<BigEndian>(bincode::serialized_size(&message)).unwrap();
|
||||
let payload_size = bincode::serialized_size(&message);
|
||||
if payload_size > self.max_payload_size {
|
||||
return Err(too_big(payload_size, self.max_payload_size));
|
||||
}
|
||||
buf.write_u64::<BigEndian>(payload_size).unwrap();
|
||||
bincode::serialize_into(buf,
|
||||
&message,
|
||||
Infinite)
|
||||
@@ -80,6 +94,9 @@ impl<Encode, Decode> tokio_core::io::Codec for Codec<Encode, Decode>
|
||||
trace!("--> Parsed payload length = {}, remaining buffer length = {}",
|
||||
len,
|
||||
buf.len());
|
||||
if len > self.max_payload_size {
|
||||
return Err(too_big(len, self.max_payload_size));
|
||||
}
|
||||
self.state = Payload { id: id, len: len };
|
||||
}
|
||||
Payload { len, .. } if buf.len() < len as usize => {
|
||||
@@ -104,12 +121,18 @@ impl<Encode, Decode> tokio_core::io::Codec for Codec<Encode, Decode>
|
||||
}
|
||||
|
||||
/// Implements the `multiplex::ServerProto` trait.
|
||||
pub struct Proto<Encode, Decode>(PhantomData<(Encode, Decode)>);
|
||||
pub struct Proto<Encode, Decode> {
|
||||
max_payload_size: u64,
|
||||
_phantom_data: PhantomData<(Encode, Decode)>,
|
||||
}
|
||||
|
||||
impl<Encode, Decode> Proto<Encode, Decode> {
|
||||
/// Returns a new `Proto`.
|
||||
pub fn new() -> Self {
|
||||
Proto(PhantomData)
|
||||
pub fn new(max_payload_size: u64) -> Self {
|
||||
Proto {
|
||||
max_payload_size: max_payload_size,
|
||||
_phantom_data: PhantomData
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -124,7 +147,7 @@ impl<T, Encode, Decode> ServerProto<T> for Proto<Encode, Decode>
|
||||
type BindTransport = Result<Self::Transport, io::Error>;
|
||||
|
||||
fn bind_transport(&self, io: T) -> Self::BindTransport {
|
||||
Ok(io.framed(Codec::new()))
|
||||
Ok(io.framed(Codec::new(self.max_payload_size)))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -139,7 +162,7 @@ impl<T, Encode, Decode> ClientProto<T> for Proto<Encode, Decode>
|
||||
type BindTransport = Result<Self::Transport, io::Error>;
|
||||
|
||||
fn bind_transport(&self, io: T) -> Self::BindTransport {
|
||||
Ok(io.framed(Codec::new()))
|
||||
Ok(io.framed(Codec::new(self.max_payload_size)))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -153,7 +176,7 @@ fn serialize() {
|
||||
|
||||
// Serialize twice to check for idempotence.
|
||||
for _ in 0..2 {
|
||||
let mut codec: Codec<(char, char, char), (char, char, char)> = Codec::new();
|
||||
let mut codec: Codec<(char, char, char), (char, char, char)> = Codec::new(2_000_000);
|
||||
codec.encode(MSG, &mut vec).unwrap();
|
||||
buf.get_mut().append(&mut vec);
|
||||
let actual: Result<Option<(u64, Result<(char, char, char), bincode::Error>)>, io::Error> =
|
||||
@@ -169,3 +192,21 @@ fn serialize() {
|
||||
*buf.get_mut());
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn deserialize_big() {
|
||||
use tokio_core::io::Codec as TokioCodec;
|
||||
let mut codec: Codec<Vec<u8>, Vec<u8>> = Codec::new(24);
|
||||
|
||||
let mut vec = Vec::new();
|
||||
assert_eq!(codec.encode((0, vec![0; 24]), &mut vec).err().unwrap().kind(),
|
||||
io::ErrorKind::InvalidData);
|
||||
|
||||
let mut buf = EasyBuf::new();
|
||||
// Header
|
||||
buf.get_mut().append(&mut vec![0; 8]);
|
||||
// Len
|
||||
buf.get_mut().append(&mut vec![0, 0, 0, 0, 0, 0, 0, 25]);
|
||||
assert_eq!(codec.decode(&mut buf).err().unwrap().kind(),
|
||||
io::ErrorKind::InvalidData);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user