Merge branch 'master' of ssh://git.adam-wright.net:10022/shaladdle/adamrpc-rs

This commit is contained in:
Tim Kuehn
2016-01-08 20:25:00 -08:00

View File

@@ -1,26 +1,33 @@
#![feature(const_fn)]
#![feature(custom_derive, plugin)]
#![plugin(serde_macros)]
extern crate serde;
extern crate serde_json;
use std::io;
use serde::Deserialize;
use std::fmt;
use std::io::{self, Read};
use std::convert;
use std::collections::HashMap;
use std::error::Error as StdError;
use std::net::{
self,
TcpListener,
TcpStream,
};
use std::sync::{
self,
Mutex,
Arc,
};
use std::sync::mpsc::{
channel,
sync_channel,
Sender,
SyncSender,
Receiver,
};
use std::time;
use std::thread;
#[derive(Debug)]
@@ -34,7 +41,10 @@ pub enum Error {
impl convert::From<serde_json::Error> for Error {
fn from(err: serde_json::Error) -> Error {
Error::Json(err)
match err {
serde_json::Error::IoError(err) => Error::Io(err),
err => Error::Json(err),
}
}
}
@@ -45,32 +55,49 @@ impl convert::From<io::Error> for Error {
}
impl<T> convert::From<sync::mpsc::SendError<T>> for Error {
fn from(err: sync::mpsc::SendError<T>) -> Error {
fn from(_: sync::mpsc::SendError<T>) -> Error {
Error::Sender
}
}
pub type Result<T> = std::result::Result<T, Error>;
pub fn handle_conn<F, Request, Response>(mut conn: TcpStream, f: F) -> Result<()>
where Request: serde::de::Deserialize,
Response: serde::ser::Serialize,
F: Fn(&Request) -> Result<Response>
pub fn handle_conn<F, Request, Reply>(mut stream: TcpStream, f: Arc<F>) -> Result<()>
where Request: fmt::Debug + serde::de::Deserialize,
Reply: fmt::Debug + serde::ser::Serialize,
F: Serve<Request, Reply>
{
let request: Request = try!(serde_json::from_reader(&mut conn));
let response = try!(f(&request));
try!(serde_json::to_writer(&mut conn, &response));
let read_stream = try!(stream.try_clone());
let mut de = serde_json::Deserializer::new(read_stream.bytes());
loop {
println!("read");
let request_packet: Packet<Request> = try!(Packet::deserialize(&mut de));
match request_packet {
Packet::Shutdown => break,
Packet::Message(id, message) => {
let reply = try!(f.serve(&message));
let reply_packet = Packet::Message(id, reply);
println!("write");
try!(serde_json::to_writer(&mut stream, &reply_packet));
},
}
}
Ok(())
}
pub fn serve(listener: TcpListener) -> Error {
pub fn serve<F, Request, Reply>(listener: TcpListener, f: Arc<F>) -> Error
where Request: fmt::Debug + serde::de::Deserialize,
Reply: fmt::Debug + serde::ser::Serialize,
F: 'static + Serve<Request, Reply>,
{
for conn in listener.incoming() {
let conn = match conn {
Err(err) => return convert::From::from(err),
Ok(c) => c,
};
let f = f.clone();
thread::spawn(move || {
if let Err(err) = handle_conn(conn, |a| handle_impl(a)) {
if let Err(err) = handle_conn(conn, f) {
println!("error handling connection: {:?}", err);
}
});
@@ -78,165 +105,238 @@ pub fn serve(listener: TcpListener) -> Error {
Error::Impossible
}
#[derive(Serialize, Deserialize)]
struct Packet<T> {
seq: u64,
message: T,
pub trait Serve<Request, Reply>: Send + Sync {
fn serve(&self, request: &Request) -> io::Result<Reply>;
}
// Generated code
#[derive(Serialize, Deserialize)]
struct A;
#[derive(Serialize, Deserialize)]
struct B;
fn handle_impl(a: &A) -> Result<B> {
Ok(B)
#[derive(Debug, Clone, Serialize, Deserialize)]
enum Packet<T> {
Message(u64, T),
Shutdown,
}
struct InnerClient {
stream: TcpStream,
seq: u64,
outstanding_messages: HashMap<u64, Sender<()>>,
}
struct RPC<Request, Reply> {
struct Handle<T> {
id: u64,
request: Request,
reply: Sender<Reply>,
}
struct RequestHandle<Request> {
id: u64,
request: Request,
}
struct ReplyHandle<Reply> {
id: u64,
reply: Sender<Reply>,
}
struct ReplyPacket<Reply> {
id: u64,
message: Reply,
}
fn message_reader<Reply>(
mut stream: TcpStream,
replies: Sender<ReceiverMessage<Reply>>) -> Result<()>
where Reply: serde::de::Deserialize
{
loop {
let id = try!(serde_json::from_reader(&mut stream));
let reply_message = try!(serde_json::from_reader(&mut stream));
let packet = ReplyPacket{
id: id,
message: reply_message,
};
try!(replies.send(ReceiverMessage::Packet(packet)));
}
sender: Sender<T>,
}
enum ReceiverMessage<Reply> {
Handle(ReplyHandle<Reply>),
Packet(ReplyPacket<Reply>),
Handle(Handle<Reply>),
Packet(Packet<Reply>),
Shutdown,
}
fn receiver<Reply>(messages: Receiver<ReceiverMessage<Reply>>) -> Result<()>
{
let mut ready_handles: HashMap<u64, ReplyHandle<Reply>> = HashMap::new();
let mut ready_packets: HashMap<u64, ReplyPacket<Reply>> = HashMap::new();
fn receiver<Reply>(messages: Receiver<ReceiverMessage<Reply>>) -> Result<()> {
let mut ready_handles: HashMap<u64, Handle<Reply>> = HashMap::new();
for message in messages.into_iter() {
match message {
ReceiverMessage::Handle(handle) => {
if let Some(packet) = ready_packets.remove(&handle.id) {
try!(handle.reply.send(packet.message));
} else {
ready_handles.insert(handle.id, handle);
}
ready_handles.insert(handle.id, handle);
},
ReceiverMessage::Packet(packet) => {
if let Some(handle) = ready_handles.remove(&packet.id) {
try!(handle.reply.send(packet.message));
} else {
ready_packets.insert(packet.id, packet);
}
ReceiverMessage::Packet(Packet::Shutdown) => break,
ReceiverMessage::Packet(Packet::Message(id, message)) => {
let handle = ready_handles.remove(&id).unwrap();
try!(handle.sender.send(message));
}
ReceiverMessage::Shutdown => break,
}
}
Ok(())
}
fn message_writer<Request>(
mut stream: TcpStream,
requests: Receiver<RequestHandle<Request>>) -> Result<()>
where Request: serde::ser::Serialize
fn reader<Reply>(stream: TcpStream, tx: SyncSender<ReceiverMessage<Reply>>)
where Reply: serde::Deserialize
{
for request_handle in requests.into_iter() {
try!(serde_json::to_writer(&mut stream, &request_handle.id));
try!(serde_json::to_writer(&mut stream, &request_handle.request));
use serde_json::Error::SyntaxError;
use serde_json::ErrorCode::EOFWhileParsingValue;
let mut de = serde_json::Deserializer::new(stream.bytes());
loop {
match Packet::deserialize(&mut de) {
Ok(packet) =>{
println!("send!");
tx.send(ReceiverMessage::Packet(packet)).unwrap();
},
// TODO: This shutdown logic is janky.. What's the right way to do this?
Err(SyntaxError(EOFWhileParsingValue, _, _)) => break,
Err(err) => panic!("unexpected error while parsing!: {:?}", err),
}
}
Ok(())
}
struct Client<Request, Reply> {
next_id: Mutex<u64>,
writer_tx: Sender<RequestHandle<Request>>,
handles_tx: Sender<ReceiverMessage<Reply>>,
fn increment(cur_id: &mut u64) -> u64 {
let id = *cur_id;
*cur_id += 1;
id
}
impl<Request, Reply> Client<Request, Reply>
where Request: serde::ser::Serialize + Clone + Send + 'static,
Reply: serde::de::Deserialize + Send + 'static
struct SyncedClientState<Reply> {
next_id: u64,
stream: TcpStream,
handles_tx: SyncSender<ReceiverMessage<Reply>>,
}
pub struct Client<Reply> {
synced_state: Mutex<SyncedClientState<Reply>>,
reader_guard: thread::JoinHandle<()>,
}
impl<Reply> Client<Reply>
where Reply: serde::de::Deserialize + Send + 'static
{
fn new(stream: TcpStream) -> Result<Self> {
let write_stream = try!(stream.try_clone());
let (requests_tx, requests_rx) = channel();
let (handles_tx, receiver_rx) = channel();
let replies_tx = handles_tx.clone();
thread::spawn(move || message_writer(write_stream, requests_rx).unwrap());
thread::spawn(move || message_reader(stream, replies_tx).unwrap());
thread::spawn(move || receiver(receiver_rx).unwrap());
pub fn new(stream: TcpStream) -> Result<Self> {
let (handles_tx, receiver_rx) = sync_channel(0);
let read_stream = try!(stream.try_clone());
try!(read_stream.set_read_timeout(Some(time::Duration::from_millis(50))));
let reader_handles_tx = handles_tx.clone();
let guard = thread::spawn(move || reader(read_stream, reader_handles_tx));
thread::spawn(move || receiver(receiver_rx));
Ok(Client{
next_id: Mutex::new(0),
writer_tx: requests_tx,
handles_tx: handles_tx,
synced_state: Mutex::new(SyncedClientState{
next_id: 0,
stream: stream,
handles_tx: handles_tx,
}),
reader_guard: guard,
})
}
fn get_next_id(&self) -> u64 {
let mut id = self.next_id.lock().unwrap();
*id += 1;
*id
pub fn rpc<Request>(&self, request: &Request) -> Result<Reply>
where Request: serde::ser::Serialize + Clone + Send + 'static
{
let (tx, rx) = channel();
let mut state = self.synced_state.lock().unwrap();
let id = increment(&mut state.next_id);
try!(state.handles_tx.send(ReceiverMessage::Handle(Handle{
id: id,
sender: tx,
})));
let packet = Packet::Message(id, request.clone());
try!(serde_json::to_writer(&mut state.stream, &packet));
Ok(rx.recv().unwrap())
}
fn rpc(&self, request: &Request) -> Result<Reply> {
let (tx, rx) = channel();
let id = self.get_next_id();
try!(self.writer_tx.send(RequestHandle{
id: id,
request: request.clone(),
}));
try!(self.handles_tx.send(ReceiverMessage::Handle(ReplyHandle{
id: id,
reply: tx,
})));
Ok(rx.recv().unwrap())
pub fn join<Request: serde::Serialize>(self) -> Result<()> {
let mut state = self.synced_state.lock().unwrap();
let packet: Packet<Request> = Packet::Shutdown;
try!(serde_json::to_writer(&mut state.stream, &packet));
try!(state.stream.shutdown(net::Shutdown::Both));
self.reader_guard.join().unwrap();
Ok(())
}
}
/*
#[cfg(test)]
mod test {
use adamrpc::*;
use super::*;
use std::io;
use std::net::{TcpStream, TcpListener, SocketAddr};
use std::str::FromStr;
use std::sync::{Arc, Mutex, Barrier};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::thread;
const port: AtomicUsize = AtomicUsize::new(10000);
fn pair() -> (TcpStream, TcpListener) {
let addr = format!("127.0.0.1:{}", port.fetch_add(1, Ordering::SeqCst));
println!("what the fuck {}", &addr);
// Do this one first so that we don't get connection refused :)
let listener = TcpListener::bind(&*addr).unwrap();
(TcpStream::connect(&*addr).unwrap(), listener)
}
#[derive(Debug, PartialEq, Serialize, Deserialize, Clone)]
enum Request {
Increment
}
#[derive(Debug, PartialEq, Serialize, Deserialize)]
enum Reply {
Increment(u64)
}
struct Server {
counter: Mutex<u64>,
}
impl Serve<Request, Reply> for Server {
fn serve(&self, _: &Request) -> io::Result<Reply> {
let mut counter = self.counter.lock().unwrap();
let reply = Reply::Increment(*counter);
*counter += 1;
Ok(reply)
}
}
impl Server {
fn new() -> Server {
Server{counter: Mutex::new(0)}
}
fn count(&self) -> u64 {
*self.counter.lock().unwrap()
}
}
#[test]
fn test() {
let listener = TcpListener::bind("127.0.0.1:9000").expect("listener");
let server =
let stream = TcpStream::connect
let (client_stream, server_streams) = pair();
let server = Arc::new(Server::new());
let thread_server = server.clone();
let guard = thread::spawn(move || serve(server_streams, thread_server));
let client = Client::new(client_stream).unwrap();
assert_eq!(Reply::Increment(0), client.rpc(&Request::Increment).unwrap());
assert_eq!(1, server.count());
assert_eq!(Reply::Increment(1), client.rpc(&Request::Increment).unwrap());
assert_eq!(2, server.count());
client.join::<Request>().unwrap();
guard.join();
}
struct BarrierServer {
barrier: Barrier,
inner: Server,
}
impl Serve<Request, Reply> for BarrierServer {
fn serve(&self, request: &Request) -> io::Result<Reply> {
self.barrier.wait();
let reply = try!(self.inner.serve(request));
Ok(reply)
}
}
impl BarrierServer {
fn new(n: usize) -> BarrierServer {
BarrierServer{barrier: Barrier::new(n), inner: Server::new()}
}
fn count(&self) -> u64 {
self.inner.count()
}
}
#[test]
fn test_concurrent() {
let (client_stream, server_streams) = pair();
let server = Arc::new(BarrierServer::new(10));
let thread_server = server.clone();
let guard = thread::spawn(move || serve(server_streams, thread_server));
let client: Arc<Client<Reply>> = Arc::new(Client::new(client_stream).unwrap());
let mut join_handles = vec![];
for _ in 0..10 {
let my_client = client.clone();
join_handles.push(thread::spawn(move || my_client.rpc(&Request::Increment).unwrap()));
}
for handle in join_handles.into_iter() {
handle.join();
}
assert_eq!(10, server.count());
let client = match Arc::try_unwrap(client) {
Err(_) => panic!("couldn't unwrap arc"),
Ok(c) => c,
};
client.join::<Request>().unwrap();
guard.join();
}
}
*/