From b771854e788a8a5030a4a8dd7d1abcfcc227930e Mon Sep 17 00:00:00 2001 From: Tim Kuehn Date: Sat, 6 Feb 2016 14:53:40 -0800 Subject: [PATCH] Factor out serialization code into a Serialize and Deserialize trait --- tarpc/Cargo.toml | 8 +++--- tarpc/src/lib.rs | 6 ++--- tarpc/src/protocol/client.rs | 52 +++++++++++++++--------------------- tarpc/src/protocol/mod.rs | 25 +++++++++++++++-- tarpc/src/protocol/server.rs | 22 +++++---------- 5 files changed, 57 insertions(+), 56 deletions(-) diff --git a/tarpc/Cargo.toml b/tarpc/Cargo.toml index 37ed583..3f1561a 100644 --- a/tarpc/Cargo.toml +++ b/tarpc/Cargo.toml @@ -6,10 +6,10 @@ license = "MIT/Apache-2.0" description = "tarpc is an RPC framework for rust with a focus on ease of use." [dependencies] -serde = "*" bincode = "*" -serde_macros = "*" -log = "*" env_logger = "*" -scoped-pool = "*" lazy_static = "*" +log = "*" +scoped-pool = "*" +serde = "*" +serde_macros = "*" diff --git a/tarpc/src/lib.rs b/tarpc/src/lib.rs index a8e6486..bcde9f3 100644 --- a/tarpc/src/lib.rs +++ b/tarpc/src/lib.rs @@ -56,13 +56,13 @@ extern crate serde; extern crate bincode; +#[cfg(test)] +#[macro_use] +extern crate lazy_static; #[macro_use] extern crate log; extern crate scoped_pool; extern crate test; -#[cfg(test)] -#[macro_use] -extern crate lazy_static; /// Provides the tarpc client and server, which implements the tarpc protocol. /// The protocol is defined by the implementation. diff --git a/tarpc/src/protocol/client.rs b/tarpc/src/protocol/client.rs index ee895e9..a397a59 100644 --- a/tarpc/src/protocol/client.rs +++ b/tarpc/src/protocol/client.rs @@ -1,15 +1,15 @@ -use bincode; use serde; use std::fmt; -use std::io::{self, BufReader, BufWriter, Read, Write}; +use std::io::{self, BufReader, BufWriter, Read}; use std::collections::HashMap; use std::mem; use std::net::{TcpStream, ToSocketAddrs}; use std::sync::{Arc, Mutex}; use std::sync::mpsc::{Receiver, Sender, channel}; -use std::time::Duration; use std::thread; -use super::{Error, Packet, Result}; +use std::time::Duration; + +use super::{Serialize, Deserialize, Error, Packet, Result}; /// A client stub that connects to a server to run rpcs. pub struct Client @@ -150,18 +150,18 @@ impl RpcFutures { } } - fn complete_reply(&mut self, id: u64, reply: Reply) { - if let Some(tx) = self.0.as_mut().unwrap().remove(&id) { - if let Err(e) = tx.send(Ok(reply)) { + fn complete_reply(&mut self, packet: Packet) { + if let Some(tx) = self.0.as_mut().unwrap().remove(&packet.rpc_id) { + if let Err(e) = tx.send(Ok(packet.message)) { info!("Reader: could not complete reply: {:?}", e); } } else { - warn!("RpcFutures: expected sender for id {} but got None!", id); + warn!("RpcFutures: expected sender for id {} but got None!", packet.rpc_id); } } - fn set_error(&mut self, err: bincode::serde::DeserializeError) { - let _ = mem::replace(&mut self.0, Err(err.into())); + fn set_error(&mut self, err: Error) { + let _ = mem::replace(&mut self.0, Err(err)); } fn get_error(&self) -> Error { @@ -198,9 +198,7 @@ fn write(outbound: Receiver<(Request, Sender>)>, message: request, }; debug!("Writer: calling rpc({:?})", id); - if let Err(e) = bincode::serde::serialize_into(&mut stream, - &packet, - bincode::SizeLimit::Infinite) { + if let Err(e) = stream.serialize(&packet) { report_error(&tx, e.into()); // Typically we'd want to notify the client of any Err returned by remove_tx, but in // this case the client already hit an Err, and doesn't need to know about this one, as @@ -208,9 +206,6 @@ fn write(outbound: Receiver<(Request, Sender>)>, let _ = requests.lock().unwrap().remove_tx(id); continue; } - if let Err(e) = stream.flush() { - report_error(&tx, e.into()); - } } fn report_error(tx: &Sender>, e: Error) @@ -232,22 +227,17 @@ fn read(requests: Arc>>, stream: TcpStream) { let mut stream = BufReader::new(stream); loop { - let packet: bincode::serde::DeserializeResult> = - bincode::serde::deserialize_from(&mut stream, bincode::SizeLimit::Infinite); - match packet { - Ok(Packet { - rpc_id: id, - message: reply - }) => { - debug!("Client: received message, id={}", id); - requests.lock().unwrap().complete_reply(id, reply); + match stream.deserialize::>() { + Ok(packet) => { + debug!("Client: received message, id={}", packet.rpc_id); + requests.lock().unwrap().complete_reply(packet); } - Err(err) => { - warn!("Client: reader thread encountered an unexpected error while parsing; \ - returning now. Error: {:?}", - err); - requests.lock().unwrap().set_error(err); - break; + Err(e) => { + warn!("Client: reader thread encountered an unexpected error; returning now. \ + Error: {:?}", + e); + requests.lock().unwrap().set_error(e); + return; } } } diff --git a/tarpc/src/protocol/mod.rs b/tarpc/src/protocol/mod.rs index d2137c6..f663402 100644 --- a/tarpc/src/protocol/mod.rs +++ b/tarpc/src/protocol/mod.rs @@ -6,8 +6,10 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. -use bincode; -use std::io; +use bincode::{self, SizeLimit}; +use bincode::serde::{deserialize_from, serialize_into}; +use serde; +use std::io::{self, Read, Write}; use std::convert; use std::sync::Arc; @@ -60,6 +62,25 @@ struct Packet { message: T, } +trait Deserialize: Read + Sized { + fn deserialize(&mut self) -> Result { + deserialize_from(self, SizeLimit::Infinite) + .map_err(Error::from) + } +} + +impl Deserialize for R {} + +trait Serialize: Write + Sized { + fn serialize(&mut self, value: &T) -> Result<()> { + try!(serialize_into(self, value, SizeLimit::Infinite)); + try!(self.flush()); + Ok(()) + } +} + +impl Serialize for W {} + #[cfg(test)] mod test { extern crate env_logger; diff --git a/tarpc/src/protocol/server.rs b/tarpc/src/protocol/server.rs index 1b65faa..ebe95dd 100644 --- a/tarpc/src/protocol/server.rs +++ b/tarpc/src/protocol/server.rs @@ -1,15 +1,14 @@ -use bincode; use serde; use scoped_pool::Pool; use std::fmt; -use std::io::{self, BufReader, BufWriter, Write}; +use std::io::{self, BufReader, BufWriter}; use std::net::{SocketAddr, TcpListener, TcpStream, ToSocketAddrs}; use std::sync::{Condvar, Mutex}; use std::sync::mpsc::{Receiver, Sender, TryRecvError, channel}; use std::sync::atomic::{AtomicBool, Ordering}; use std::time::Duration; use std::thread::{self, JoinHandle}; -use super::{Packet, Result}; +use super::{Deserialize, Error, Packet, Result, Serialize}; struct ConnectionHandler<'a, S> where S: Serve @@ -44,7 +43,7 @@ impl<'a, S> ConnectionHandler<'a, S> where S: Serve { let (tx, rx) = channel(); scope.execute(|| Self::write(rx, write_stream, inflight_rpcs)); loop { - match bincode::serde::deserialize_from(read_stream, bincode::SizeLimit::Infinite) { + match read_stream.deserialize() { Ok(Packet { rpc_id, message, }) => { inflight_rpcs.increment(); let tx = tx.clone(); @@ -61,8 +60,7 @@ impl<'a, S> ConnectionHandler<'a, S> where S: Serve { break; } } - Err(bincode::serde::DeserializeError::IoError(ref err)) - if Self::timed_out(err.kind()) => { + Err(Error::Io(ref err)) if Self::timed_out(err.kind()) => { if !shutdown.load(Ordering::SeqCst) { info!("ConnectionHandler: read timed out ({:?}). Server not \ shutdown, so retrying read.", @@ -103,16 +101,8 @@ impl<'a, S> ConnectionHandler<'a, S> where S: Serve { return; } Ok(reply_packet) => { - if let Err(e) = - bincode::serde::serialize_into(stream, - &reply_packet, - bincode::SizeLimit::Infinite) { - warn!("Writer: failed to write reply to Client: {:?}", - e); - } - if let Err(e) = stream.flush() { - warn!("Writer: failed to flush reply to Client: {:?}", - e); + if let Err(e) = stream.serialize(&reply_packet) { + warn!("Writer: failed to write reply to Client: {:?}", e); } inflight_rpcs.decrement(); }