diff --git a/src/lib.rs b/src/lib.rs index 7d169cf..0571d95 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,26 +1,33 @@ +#![feature(const_fn)] #![feature(custom_derive, plugin)] #![plugin(serde_macros)] extern crate serde; extern crate serde_json; -use std::io; +use serde::Deserialize; +use std::fmt; +use std::io::{self, Read}; use std::convert; use std::collections::HashMap; -use std::error::Error as StdError; use std::net::{ + self, TcpListener, TcpStream, }; use std::sync::{ self, Mutex, + Arc, }; use std::sync::mpsc::{ channel, + sync_channel, Sender, + SyncSender, Receiver, }; +use std::time; use std::thread; #[derive(Debug)] @@ -34,7 +41,10 @@ pub enum Error { impl convert::From for Error { fn from(err: serde_json::Error) -> Error { - Error::Json(err) + match err { + serde_json::Error::IoError(err) => Error::Io(err), + err => Error::Json(err), + } } } @@ -45,32 +55,49 @@ impl convert::From for Error { } impl convert::From> for Error { - fn from(err: sync::mpsc::SendError) -> Error { + fn from(_: sync::mpsc::SendError) -> Error { Error::Sender } } pub type Result = std::result::Result; -pub fn handle_conn(mut conn: TcpStream, f: F) -> Result<()> - where Request: serde::de::Deserialize, - Response: serde::ser::Serialize, - F: Fn(&Request) -> Result +pub fn handle_conn(mut stream: TcpStream, f: Arc) -> Result<()> + where Request: fmt::Debug + serde::de::Deserialize, + Reply: fmt::Debug + serde::ser::Serialize, + F: Serve { - let request: Request = try!(serde_json::from_reader(&mut conn)); - let response = try!(f(&request)); - try!(serde_json::to_writer(&mut conn, &response)); + let read_stream = try!(stream.try_clone()); + let mut de = serde_json::Deserializer::new(read_stream.bytes()); + loop { + println!("read"); + let request_packet: Packet = try!(Packet::deserialize(&mut de)); + match request_packet { + Packet::Shutdown => break, + Packet::Message(id, message) => { + let reply = try!(f.serve(&message)); + let reply_packet = Packet::Message(id, reply); + println!("write"); + try!(serde_json::to_writer(&mut stream, &reply_packet)); + }, + } + } Ok(()) } -pub fn serve(listener: TcpListener) -> Error { +pub fn serve(listener: TcpListener, f: Arc) -> Error + where Request: fmt::Debug + serde::de::Deserialize, + Reply: fmt::Debug + serde::ser::Serialize, + F: 'static + Serve, +{ for conn in listener.incoming() { let conn = match conn { Err(err) => return convert::From::from(err), Ok(c) => c, }; + let f = f.clone(); thread::spawn(move || { - if let Err(err) = handle_conn(conn, |a| handle_impl(a)) { + if let Err(err) = handle_conn(conn, f) { println!("error handling connection: {:?}", err); } }); @@ -78,165 +105,238 @@ pub fn serve(listener: TcpListener) -> Error { Error::Impossible } -#[derive(Serialize, Deserialize)] -struct Packet { - seq: u64, - message: T, +pub trait Serve: Send + Sync { + fn serve(&self, request: &Request) -> io::Result; } -// Generated code - -#[derive(Serialize, Deserialize)] -struct A; -#[derive(Serialize, Deserialize)] -struct B; - -fn handle_impl(a: &A) -> Result { - Ok(B) +#[derive(Debug, Clone, Serialize, Deserialize)] +enum Packet { + Message(u64, T), + Shutdown, } -struct InnerClient { - stream: TcpStream, - seq: u64, - outstanding_messages: HashMap>, -} - -struct RPC { +struct Handle { id: u64, - request: Request, - reply: Sender, -} - -struct RequestHandle { - id: u64, - request: Request, -} - -struct ReplyHandle { - id: u64, - reply: Sender, -} - -struct ReplyPacket { - id: u64, - message: Reply, -} - -fn message_reader( - mut stream: TcpStream, - replies: Sender>) -> Result<()> - where Reply: serde::de::Deserialize -{ - loop { - let id = try!(serde_json::from_reader(&mut stream)); - let reply_message = try!(serde_json::from_reader(&mut stream)); - let packet = ReplyPacket{ - id: id, - message: reply_message, - }; - try!(replies.send(ReceiverMessage::Packet(packet))); - } + sender: Sender, } enum ReceiverMessage { - Handle(ReplyHandle), - Packet(ReplyPacket), + Handle(Handle), + Packet(Packet), + Shutdown, } -fn receiver(messages: Receiver>) -> Result<()> -{ - let mut ready_handles: HashMap> = HashMap::new(); - let mut ready_packets: HashMap> = HashMap::new(); +fn receiver(messages: Receiver>) -> Result<()> { + let mut ready_handles: HashMap> = HashMap::new(); for message in messages.into_iter() { match message { ReceiverMessage::Handle(handle) => { - if let Some(packet) = ready_packets.remove(&handle.id) { - try!(handle.reply.send(packet.message)); - } else { - ready_handles.insert(handle.id, handle); - } + ready_handles.insert(handle.id, handle); }, - ReceiverMessage::Packet(packet) => { - if let Some(handle) = ready_handles.remove(&packet.id) { - try!(handle.reply.send(packet.message)); - } else { - ready_packets.insert(packet.id, packet); - } + ReceiverMessage::Packet(Packet::Shutdown) => break, + ReceiverMessage::Packet(Packet::Message(id, message)) => { + let handle = ready_handles.remove(&id).unwrap(); + try!(handle.sender.send(message)); } - + ReceiverMessage::Shutdown => break, } } Ok(()) } -fn message_writer( - mut stream: TcpStream, - requests: Receiver>) -> Result<()> - where Request: serde::ser::Serialize +fn reader(stream: TcpStream, tx: SyncSender>) + where Reply: serde::Deserialize { - for request_handle in requests.into_iter() { - try!(serde_json::to_writer(&mut stream, &request_handle.id)); - try!(serde_json::to_writer(&mut stream, &request_handle.request)); + use serde_json::Error::SyntaxError; + use serde_json::ErrorCode::EOFWhileParsingValue; + let mut de = serde_json::Deserializer::new(stream.bytes()); + loop { + match Packet::deserialize(&mut de) { + Ok(packet) =>{ + println!("send!"); + tx.send(ReceiverMessage::Packet(packet)).unwrap(); + }, + // TODO: This shutdown logic is janky.. What's the right way to do this? + Err(SyntaxError(EOFWhileParsingValue, _, _)) => break, + Err(err) => panic!("unexpected error while parsing!: {:?}", err), + } } - Ok(()) } -struct Client { - next_id: Mutex, - writer_tx: Sender>, - handles_tx: Sender>, +fn increment(cur_id: &mut u64) -> u64 { + let id = *cur_id; + *cur_id += 1; + id } -impl Client - where Request: serde::ser::Serialize + Clone + Send + 'static, - Reply: serde::de::Deserialize + Send + 'static +struct SyncedClientState { + next_id: u64, + stream: TcpStream, + handles_tx: SyncSender>, +} + +pub struct Client { + synced_state: Mutex>, + reader_guard: thread::JoinHandle<()>, +} + +impl Client + where Reply: serde::de::Deserialize + Send + 'static { - fn new(stream: TcpStream) -> Result { - let write_stream = try!(stream.try_clone()); - let (requests_tx, requests_rx) = channel(); - let (handles_tx, receiver_rx) = channel(); - let replies_tx = handles_tx.clone(); - thread::spawn(move || message_writer(write_stream, requests_rx).unwrap()); - thread::spawn(move || message_reader(stream, replies_tx).unwrap()); - thread::spawn(move || receiver(receiver_rx).unwrap()); + pub fn new(stream: TcpStream) -> Result { + let (handles_tx, receiver_rx) = sync_channel(0); + let read_stream = try!(stream.try_clone()); + try!(read_stream.set_read_timeout(Some(time::Duration::from_millis(50)))); + let reader_handles_tx = handles_tx.clone(); + let guard = thread::spawn(move || reader(read_stream, reader_handles_tx)); + thread::spawn(move || receiver(receiver_rx)); Ok(Client{ - next_id: Mutex::new(0), - writer_tx: requests_tx, - handles_tx: handles_tx, + synced_state: Mutex::new(SyncedClientState{ + next_id: 0, + stream: stream, + handles_tx: handles_tx, + }), + reader_guard: guard, }) } - fn get_next_id(&self) -> u64 { - let mut id = self.next_id.lock().unwrap(); - *id += 1; - *id + pub fn rpc(&self, request: &Request) -> Result + where Request: serde::ser::Serialize + Clone + Send + 'static + { + let (tx, rx) = channel(); + let mut state = self.synced_state.lock().unwrap(); + let id = increment(&mut state.next_id); + try!(state.handles_tx.send(ReceiverMessage::Handle(Handle{ + id: id, + sender: tx, + }))); + let packet = Packet::Message(id, request.clone()); + try!(serde_json::to_writer(&mut state.stream, &packet)); + Ok(rx.recv().unwrap()) } - fn rpc(&self, request: &Request) -> Result { - let (tx, rx) = channel(); - let id = self.get_next_id(); - try!(self.writer_tx.send(RequestHandle{ - id: id, - request: request.clone(), - })); - try!(self.handles_tx.send(ReceiverMessage::Handle(ReplyHandle{ - id: id, - reply: tx, - }))); - Ok(rx.recv().unwrap()) + pub fn join(self) -> Result<()> { + let mut state = self.synced_state.lock().unwrap(); + let packet: Packet = Packet::Shutdown; + try!(serde_json::to_writer(&mut state.stream, &packet)); + try!(state.stream.shutdown(net::Shutdown::Both)); + self.reader_guard.join().unwrap(); + Ok(()) } } -/* #[cfg(test)] mod test { - use adamrpc::*; + use super::*; + use std::io; + use std::net::{TcpStream, TcpListener, SocketAddr}; + use std::str::FromStr; + use std::sync::{Arc, Mutex, Barrier}; + use std::sync::atomic::{AtomicUsize, Ordering}; + use std::thread; + + const port: AtomicUsize = AtomicUsize::new(10000); + + fn pair() -> (TcpStream, TcpListener) { + let addr = format!("127.0.0.1:{}", port.fetch_add(1, Ordering::SeqCst)); + println!("what the fuck {}", &addr); + // Do this one first so that we don't get connection refused :) + let listener = TcpListener::bind(&*addr).unwrap(); + (TcpStream::connect(&*addr).unwrap(), listener) + } + + #[derive(Debug, PartialEq, Serialize, Deserialize, Clone)] + enum Request { + Increment + } + + #[derive(Debug, PartialEq, Serialize, Deserialize)] + enum Reply { + Increment(u64) + } + + struct Server { + counter: Mutex, + } + + impl Serve for Server { + fn serve(&self, _: &Request) -> io::Result { + let mut counter = self.counter.lock().unwrap(); + let reply = Reply::Increment(*counter); + *counter += 1; + Ok(reply) + } + } + + impl Server { + fn new() -> Server { + Server{counter: Mutex::new(0)} + } + + fn count(&self) -> u64 { + *self.counter.lock().unwrap() + } + } #[test] fn test() { - let listener = TcpListener::bind("127.0.0.1:9000").expect("listener"); - let server = - let stream = TcpStream::connect + let (client_stream, server_streams) = pair(); + let server = Arc::new(Server::new()); + let thread_server = server.clone(); + let guard = thread::spawn(move || serve(server_streams, thread_server)); + let client = Client::new(client_stream).unwrap(); + assert_eq!(Reply::Increment(0), client.rpc(&Request::Increment).unwrap()); + assert_eq!(1, server.count()); + assert_eq!(Reply::Increment(1), client.rpc(&Request::Increment).unwrap()); + assert_eq!(2, server.count()); + client.join::().unwrap(); + guard.join(); + } + + struct BarrierServer { + barrier: Barrier, + inner: Server, + } + + impl Serve for BarrierServer { + fn serve(&self, request: &Request) -> io::Result { + self.barrier.wait(); + let reply = try!(self.inner.serve(request)); + Ok(reply) + } + } + + impl BarrierServer { + fn new(n: usize) -> BarrierServer { + BarrierServer{barrier: Barrier::new(n), inner: Server::new()} + } + + fn count(&self) -> u64 { + self.inner.count() + } + } + + #[test] + fn test_concurrent() { + let (client_stream, server_streams) = pair(); + let server = Arc::new(BarrierServer::new(10)); + let thread_server = server.clone(); + let guard = thread::spawn(move || serve(server_streams, thread_server)); + let client: Arc> = Arc::new(Client::new(client_stream).unwrap()); + let mut join_handles = vec![]; + for _ in 0..10 { + let my_client = client.clone(); + join_handles.push(thread::spawn(move || my_client.rpc(&Request::Increment).unwrap())); + } + for handle in join_handles.into_iter() { + handle.join(); + } + assert_eq!(10, server.count()); + let client = match Arc::try_unwrap(client) { + Err(_) => panic!("couldn't unwrap arc"), + Ok(c) => c, + }; + client.join::().unwrap(); + guard.join(); } } -*/