From 4987094483568fe2fcee31ca1d9a65bd2595adc6 Mon Sep 17 00:00:00 2001 From: Tim Kuehn Date: Sat, 1 Aug 2020 13:45:05 -0700 Subject: [PATCH] Compression example. Follow-up work: some extension points would be useful allow enabling compression on a per-request basis. Fixes https://github.com/google/tarpc/issues/200 --- tarpc/Cargo.toml | 5 +- tarpc/examples/compression.rs | 130 ++++++++++++++++++++++++++++++++++ 2 files changed, 134 insertions(+), 1 deletion(-) create mode 100644 tarpc/examples/compression.rs diff --git a/tarpc/Cargo.toml b/tarpc/Cargo.toml index 3d65976..c3621c4 100644 --- a/tarpc/Cargo.toml +++ b/tarpc/Cargo.toml @@ -41,14 +41,17 @@ tokio-serde = { optional = true, version = "0.6" } [dev-dependencies] assert_matches = "1.0" +bincode = "1.3" bytes = { version = "0.5", features = ["serde"] } env_logger = "0.6" +flate2 = "1.0.16" futures = "0.3" humantime = "1.0" log = "0.4" pin-utils = "0.1.0-alpha" +serde_bytes = "0.11" tokio = { version = "0.2", features = ["full"] } -tokio-serde = { version = "0.6", features = ["json"] } +tokio-serde = { version = "0.6", features = ["json", "bincode"] } trybuild = "1.0" [package.metadata.docs.rs] diff --git a/tarpc/examples/compression.rs b/tarpc/examples/compression.rs new file mode 100644 index 0000000..bc958d3 --- /dev/null +++ b/tarpc/examples/compression.rs @@ -0,0 +1,130 @@ +use flate2::{read::DeflateDecoder, write::DeflateEncoder, Compression}; +use futures::{Sink, SinkExt, Stream, StreamExt, TryStreamExt}; +use serde::{Deserialize, Serialize}; +use serde_bytes::ByteBuf; +use std::{io, io::Read, io::Write}; +use tarpc::{ + client, context, + serde_transport::tcp, + server::{BaseChannel, Channel}, +}; +use tokio_serde::formats::Bincode; + +/// Type of compression that should be enabled on the request. The transport is free to ignore this. +#[derive(Debug, PartialEq, Eq, Clone, Copy, Deserialize, Serialize)] +pub enum CompressionAlgorithm { + Deflate, +} + +#[derive(Debug, Deserialize, Serialize)] +pub enum CompressedMessage { + Uncompressed(T), + Compressed { + algorithm: CompressionAlgorithm, + payload: ByteBuf, + }, +} + +#[derive(Deserialize, Serialize)] +enum CompressionType { + Uncompressed, + Compressed, +} + +async fn compress(message: T) -> io::Result> +where + T: Serialize, +{ + let message = serialize(message)?; + let mut encoder = DeflateEncoder::new(Vec::new(), Compression::default()); + encoder.write_all(&message).unwrap(); + let compressed = encoder.finish()?; + Ok(CompressedMessage::Compressed { + algorithm: CompressionAlgorithm::Deflate, + payload: ByteBuf::from(compressed), + }) +} + +async fn decompress(message: CompressedMessage) -> io::Result +where + for<'a> T: Deserialize<'a>, +{ + match message { + CompressedMessage::Compressed { algorithm, payload } => { + if algorithm != CompressionAlgorithm::Deflate { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + format!("Compression algorithm {:?} not supported", algorithm), + )); + } + let mut deflater = DeflateDecoder::new(payload.as_slice()); + let mut payload = ByteBuf::new(); + deflater.read_to_end(&mut payload)?; + let message = deserialize(payload)?; + Ok(message) + } + CompressedMessage::Uncompressed(message) => Ok(message), + } +} + +fn serialize(t: T) -> io::Result { + bincode::serialize(&t) + .map(ByteBuf::from) + .map_err(|e| io::Error::new(io::ErrorKind::Other, e)) +} + +fn deserialize(message: ByteBuf) -> io::Result +where + for<'a> D: Deserialize<'a>, +{ + bincode::deserialize(message.as_ref()).map_err(|e| io::Error::new(io::ErrorKind::Other, e)) +} + +fn add_compression( + transport: impl Stream>> + + Sink, Error = io::Error>, +) -> impl Stream> + Sink +where + Out: Serialize, + for<'a> In: Deserialize<'a>, +{ + transport.with(compress).and_then(decompress) +} + +#[tarpc::service] +pub trait World { + async fn hello(name: String) -> String; +} + +#[derive(Clone, Debug)] +struct HelloServer; + +#[tarpc::server] +impl World for HelloServer { + async fn hello(self, _: context::Context, name: String) -> String { + format!("Hey, {}!", name) + } +} + +#[tokio::main] +async fn main() -> anyhow::Result<()> { + let mut incoming = tcp::listen("localhost:0", Bincode::default).await?; + let addr = incoming.local_addr(); + tokio::spawn(async move { + let transport = incoming.next().await.unwrap().unwrap(); + BaseChannel::with_defaults(add_compression(transport)) + .respond_with(HelloServer.serve()) + .execute() + .await; + }); + + let transport = tcp::connect(addr, Bincode::default()).await?; + let mut client = + WorldClient::new(client::Config::default(), add_compression(transport)).spawn()?; + + println!( + "{}", + client.hello(context::current(), "friend".into()).await? + ); + Ok(()) +}