mirror of
https://github.com/OMGeeky/tarpc.git
synced 2026-01-08 04:21:03 +01:00
A number of improvements
- Support non clonable Serve objects by wrapping in an Arc - Support multiple RPCs per connection - Support cleanish shutdown
This commit is contained in:
119
src/lib.rs
119
src/lib.rs
@@ -10,12 +10,14 @@ use std::io::{self, Read};
|
||||
use std::convert;
|
||||
use std::collections::HashMap;
|
||||
use std::net::{
|
||||
self,
|
||||
TcpListener,
|
||||
TcpStream,
|
||||
};
|
||||
use std::sync::{
|
||||
self,
|
||||
Mutex,
|
||||
Arc,
|
||||
};
|
||||
use std::sync::mpsc::{
|
||||
channel,
|
||||
@@ -24,6 +26,7 @@ use std::sync::mpsc::{
|
||||
SyncSender,
|
||||
Receiver,
|
||||
};
|
||||
use std::time;
|
||||
use std::thread;
|
||||
|
||||
#[derive(Debug)]
|
||||
@@ -58,26 +61,30 @@ impl<T> convert::From<sync::mpsc::SendError<T>> for Error {
|
||||
|
||||
pub type Result<T> = std::result::Result<T, Error>;
|
||||
|
||||
pub fn handle_conn<F, Request, Reply>(
|
||||
mut stream: TcpStream,
|
||||
f: F) -> Result<()>
|
||||
pub fn handle_conn<F, Request, Reply>(mut stream: TcpStream, f: Arc<F>) -> Result<()>
|
||||
where Request: fmt::Debug + serde::de::Deserialize,
|
||||
Reply: fmt::Debug + serde::ser::Serialize,
|
||||
F: 'static + Serve<Request, Reply>
|
||||
F: Serve<Request, Reply>
|
||||
{
|
||||
let read_stream = try!(stream.try_clone());
|
||||
let mut de = serde_json::Deserializer::new(read_stream.bytes());
|
||||
let request_packet: Packet<Request> = try!(Packet::deserialize(&mut de));
|
||||
let reply = try!(f.serve(&request_packet.message));
|
||||
let reply_packet = Packet{
|
||||
id: request_packet.id,
|
||||
message: reply,
|
||||
};
|
||||
try!(serde_json::to_writer(&mut stream, &reply_packet));
|
||||
loop {
|
||||
println!("read");
|
||||
let request_packet: Packet<Request> = try!(Packet::deserialize(&mut de));
|
||||
match request_packet {
|
||||
Packet::Shutdown => break,
|
||||
Packet::Message(id, message) => {
|
||||
let reply = try!(f.serve(&message));
|
||||
let reply_packet = Packet::Message(id, reply);
|
||||
println!("write");
|
||||
try!(serde_json::to_writer(&mut stream, &reply_packet));
|
||||
},
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn serve<F, Request, Reply>(listener: TcpListener, f: F) -> Error
|
||||
pub fn serve<F, Request, Reply>(listener: TcpListener, f: Arc<F>) -> Error
|
||||
where Request: fmt::Debug + serde::de::Deserialize,
|
||||
Reply: fmt::Debug + serde::ser::Serialize,
|
||||
F: 'static + Serve<Request, Reply>,
|
||||
@@ -97,14 +104,14 @@ pub fn serve<F, Request, Reply>(listener: TcpListener, f: F) -> Error
|
||||
Error::Impossible
|
||||
}
|
||||
|
||||
pub trait Serve<Request, Reply> : Sync + Send + Clone {
|
||||
pub trait Serve<Request, Reply>: Send + Sync {
|
||||
fn serve(&self, request: &Request) -> io::Result<Reply>;
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
struct Packet<T> {
|
||||
id: u64,
|
||||
message: T,
|
||||
enum Packet<T> {
|
||||
Message(u64, T),
|
||||
Shutdown,
|
||||
}
|
||||
|
||||
struct Handle<T> {
|
||||
@@ -115,6 +122,7 @@ struct Handle<T> {
|
||||
enum ReceiverMessage<Reply> {
|
||||
Handle(Handle<Reply>),
|
||||
Packet(Packet<Reply>),
|
||||
Shutdown,
|
||||
}
|
||||
|
||||
fn receiver<Reply>(messages: Receiver<ReceiverMessage<Reply>>) -> Result<()> {
|
||||
@@ -124,25 +132,32 @@ fn receiver<Reply>(messages: Receiver<ReceiverMessage<Reply>>) -> Result<()> {
|
||||
ReceiverMessage::Handle(handle) => {
|
||||
ready_handles.insert(handle.id, handle);
|
||||
},
|
||||
ReceiverMessage::Packet(packet) => {
|
||||
let handle = ready_handles.remove(&packet.id).unwrap();
|
||||
try!(handle.sender.send(packet.message));
|
||||
ReceiverMessage::Packet(Packet::Shutdown) => break,
|
||||
ReceiverMessage::Packet(Packet::Message(id, message)) => {
|
||||
let handle = ready_handles.remove(&id).unwrap();
|
||||
try!(handle.sender.send(message));
|
||||
}
|
||||
ReceiverMessage::Shutdown => break,
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn reader<T, F>(mut stream: TcpStream, decode: F, tx: SyncSender<T>)
|
||||
where F: Send + 'static + Fn(&mut TcpStream) -> Result<T>,
|
||||
T: Send + 'static
|
||||
fn reader<Reply>(stream: TcpStream, tx: SyncSender<ReceiverMessage<Reply>>)
|
||||
where Reply: serde::Deserialize
|
||||
{
|
||||
use serde_json::Error::SyntaxError;
|
||||
use serde_json::ErrorCode::EOFWhileParsingValue;
|
||||
let mut de = serde_json::Deserializer::new(stream.bytes());
|
||||
loop {
|
||||
match decode(&mut stream) {
|
||||
Ok(t) => tx.send(t).unwrap(),
|
||||
Err(Error::Json(SyntaxError(EOFWhileParsingValue, _, _))) => break,
|
||||
match Packet::deserialize(&mut de) {
|
||||
Ok(packet) =>{
|
||||
println!("send!");
|
||||
tx.send(ReceiverMessage::Packet(packet)).unwrap();
|
||||
},
|
||||
// TODO: This shutdown logic is janky.. What's the right way to do this?
|
||||
Err(SyntaxError(EOFWhileParsingValue, _, _)) => break,
|
||||
Err(SyntaxError(ExpectedValue, _, _)) => break,
|
||||
Err(err) => panic!("unexpected error while parsing!: {:?}", err),
|
||||
}
|
||||
}
|
||||
@@ -170,13 +185,10 @@ impl<Reply> Client<Reply>
|
||||
{
|
||||
pub fn new(stream: TcpStream) -> Result<Self> {
|
||||
let (handles_tx, receiver_rx) = sync_channel(0);
|
||||
let decode = |mut stream: &mut TcpStream| {
|
||||
let packet = try!(serde_json::from_reader(&mut stream));
|
||||
Ok(ReceiverMessage::Packet(packet))
|
||||
};
|
||||
let read_stream = try!(stream.try_clone());
|
||||
try!(read_stream.set_read_timeout(Some(time::Duration::from_millis(50))));
|
||||
let reader_handles_tx = handles_tx.clone();
|
||||
let guard = thread::spawn(move || reader(read_stream, decode, reader_handles_tx));
|
||||
let guard = thread::spawn(move || reader(read_stream, reader_handles_tx));
|
||||
thread::spawn(move || receiver(receiver_rx));
|
||||
Ok(Client{
|
||||
synced_state: Mutex::new(SyncedClientState{
|
||||
@@ -198,24 +210,28 @@ impl<Reply> Client<Reply>
|
||||
id: id,
|
||||
sender: tx,
|
||||
})));
|
||||
try!(serde_json::to_writer(&mut state.stream, &Packet{
|
||||
id: id,
|
||||
message: request.clone(),
|
||||
}));
|
||||
let packet = Packet::Message(id, request.clone());
|
||||
try!(serde_json::to_writer(&mut state.stream, &packet));
|
||||
Ok(rx.recv().unwrap())
|
||||
}
|
||||
|
||||
pub fn join(self) {
|
||||
pub fn join<Request: serde::Serialize>(self) -> Result<()> {
|
||||
let mut state = self.synced_state.lock().unwrap();
|
||||
let packet: Packet<Request> = Packet::Shutdown;
|
||||
try!(serde_json::to_writer(&mut state.stream, &packet));
|
||||
try!(state.stream.shutdown(net::Shutdown::Both));
|
||||
self.reader_guard.join().unwrap();
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
use std::thread;
|
||||
use std::net::{TcpStream, TcpListener};
|
||||
use std::io;
|
||||
use std::net::{TcpStream, TcpListener};
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::thread;
|
||||
|
||||
fn pair() -> (TcpStream, TcpListener) {
|
||||
let addr = "127.0.0.1:9000";
|
||||
@@ -231,24 +247,39 @@ mod test {
|
||||
|
||||
#[derive(Debug, PartialEq, Serialize, Deserialize)]
|
||||
enum Reply {
|
||||
Increment
|
||||
Increment(u64)
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct Server;
|
||||
struct Server {
|
||||
counter: Mutex<u64>,
|
||||
}
|
||||
|
||||
impl Serve<Request, Reply> for Server {
|
||||
fn serve(&self, _: &Request) -> io::Result<Reply> {
|
||||
Ok(Reply::Increment)
|
||||
let mut counter = self.counter.lock().unwrap();
|
||||
let reply = Reply::Increment(*counter);
|
||||
*counter += 1;
|
||||
Ok(reply)
|
||||
}
|
||||
}
|
||||
|
||||
impl Server {
|
||||
fn count(&self) -> u64 {
|
||||
*self.counter.lock().unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test() {
|
||||
let (client_stream, server_streams) = pair();
|
||||
thread::spawn(|| serve(server_streams, Server));
|
||||
let server = Arc::new(Server{counter: Mutex::new(0)});
|
||||
let thread_server = server.clone();
|
||||
thread::spawn(move || serve(server_streams, thread_server));
|
||||
let client = Client::new(client_stream).unwrap();
|
||||
assert_eq!(Reply::Increment, client.rpc(&Request::Increment).unwrap());
|
||||
client.join();
|
||||
assert_eq!(Reply::Increment(0), client.rpc(&Request::Increment).unwrap());
|
||||
assert_eq!(1, server.count());
|
||||
assert_eq!(Reply::Increment(1), client.rpc(&Request::Increment).unwrap());
|
||||
assert_eq!(2, server.count());
|
||||
client.join::<Request>().unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user