Compression example.

Follow-up work: some extension points would be useful allow enabling compression on a per-request basis.

Fixes https://github.com/google/tarpc/issues/200
This commit is contained in:
Tim Kuehn
2020-08-01 13:45:05 -07:00
parent ff55080193
commit 4987094483
2 changed files with 134 additions and 1 deletions

View File

@@ -41,14 +41,17 @@ tokio-serde = { optional = true, version = "0.6" }
[dev-dependencies]
assert_matches = "1.0"
bincode = "1.3"
bytes = { version = "0.5", features = ["serde"] }
env_logger = "0.6"
flate2 = "1.0.16"
futures = "0.3"
humantime = "1.0"
log = "0.4"
pin-utils = "0.1.0-alpha"
serde_bytes = "0.11"
tokio = { version = "0.2", features = ["full"] }
tokio-serde = { version = "0.6", features = ["json"] }
tokio-serde = { version = "0.6", features = ["json", "bincode"] }
trybuild = "1.0"
[package.metadata.docs.rs]

View File

@@ -0,0 +1,130 @@
use flate2::{read::DeflateDecoder, write::DeflateEncoder, Compression};
use futures::{Sink, SinkExt, Stream, StreamExt, TryStreamExt};
use serde::{Deserialize, Serialize};
use serde_bytes::ByteBuf;
use std::{io, io::Read, io::Write};
use tarpc::{
client, context,
serde_transport::tcp,
server::{BaseChannel, Channel},
};
use tokio_serde::formats::Bincode;
/// Type of compression that should be enabled on the request. The transport is free to ignore this.
#[derive(Debug, PartialEq, Eq, Clone, Copy, Deserialize, Serialize)]
pub enum CompressionAlgorithm {
Deflate,
}
#[derive(Debug, Deserialize, Serialize)]
pub enum CompressedMessage<T> {
Uncompressed(T),
Compressed {
algorithm: CompressionAlgorithm,
payload: ByteBuf,
},
}
#[derive(Deserialize, Serialize)]
enum CompressionType {
Uncompressed,
Compressed,
}
async fn compress<T>(message: T) -> io::Result<CompressedMessage<T>>
where
T: Serialize,
{
let message = serialize(message)?;
let mut encoder = DeflateEncoder::new(Vec::new(), Compression::default());
encoder.write_all(&message).unwrap();
let compressed = encoder.finish()?;
Ok(CompressedMessage::Compressed {
algorithm: CompressionAlgorithm::Deflate,
payload: ByteBuf::from(compressed),
})
}
async fn decompress<T>(message: CompressedMessage<T>) -> io::Result<T>
where
for<'a> T: Deserialize<'a>,
{
match message {
CompressedMessage::Compressed { algorithm, payload } => {
if algorithm != CompressionAlgorithm::Deflate {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("Compression algorithm {:?} not supported", algorithm),
));
}
let mut deflater = DeflateDecoder::new(payload.as_slice());
let mut payload = ByteBuf::new();
deflater.read_to_end(&mut payload)?;
let message = deserialize(payload)?;
Ok(message)
}
CompressedMessage::Uncompressed(message) => Ok(message),
}
}
fn serialize<T: Serialize>(t: T) -> io::Result<ByteBuf> {
bincode::serialize(&t)
.map(ByteBuf::from)
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))
}
fn deserialize<D>(message: ByteBuf) -> io::Result<D>
where
for<'a> D: Deserialize<'a>,
{
bincode::deserialize(message.as_ref()).map_err(|e| io::Error::new(io::ErrorKind::Other, e))
}
fn add_compression<In, Out>(
transport: impl Stream<Item = io::Result<CompressedMessage<In>>>
+ Sink<CompressedMessage<Out>, Error = io::Error>,
) -> impl Stream<Item = io::Result<In>> + Sink<Out, Error = io::Error>
where
Out: Serialize,
for<'a> In: Deserialize<'a>,
{
transport.with(compress).and_then(decompress)
}
#[tarpc::service]
pub trait World {
async fn hello(name: String) -> String;
}
#[derive(Clone, Debug)]
struct HelloServer;
#[tarpc::server]
impl World for HelloServer {
async fn hello(self, _: context::Context, name: String) -> String {
format!("Hey, {}!", name)
}
}
#[tokio::main]
async fn main() -> anyhow::Result<()> {
let mut incoming = tcp::listen("localhost:0", Bincode::default).await?;
let addr = incoming.local_addr();
tokio::spawn(async move {
let transport = incoming.next().await.unwrap().unwrap();
BaseChannel::with_defaults(add_compression(transport))
.respond_with(HelloServer.serve())
.execute()
.await;
});
let transport = tcp::connect(addr, Bincode::default()).await?;
let mut client =
WorldClient::new(client::Config::default(), add_compression(transport)).spawn()?;
println!(
"{}",
client.hello(context::current(), "friend".into()).await?
);
Ok(())
}