From dbf7113cf38b24b3b44b96b07bc5e77318307101 Mon Sep 17 00:00:00 2001 From: Adam Wright Date: Fri, 8 Jan 2016 00:07:53 -0800 Subject: [PATCH 1/6] WIP, doesn't compile :( --- Cargo.toml | 1 + src/lib.rs | 218 +++++++++++++++++++-------------------- src/multi_tcp/Cargo.toml | 6 ++ src/multi_tcp/src/lib.rs | 135 ++++++++++++++++++++++++ 4 files changed, 246 insertions(+), 114 deletions(-) create mode 100644 src/multi_tcp/Cargo.toml create mode 100644 src/multi_tcp/src/lib.rs diff --git a/Cargo.toml b/Cargo.toml index 31a32ab..2ff0216 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,3 +7,4 @@ authors = ["Adam Wright "] serde = "*" serde_json = "*" serde_macros = "*" +multi_tcp = { path = "src/multi_tcp" } diff --git a/src/lib.rs b/src/lib.rs index 7d169cf..ad0343c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,13 +1,13 @@ #![feature(custom_derive, plugin)] #![plugin(serde_macros)] +extern crate multi_tcp; extern crate serde; extern crate serde_json; use std::io; use std::convert; use std::collections::HashMap; -use std::error::Error as StdError; use std::net::{ TcpListener, TcpStream, @@ -18,7 +18,9 @@ use std::sync::{ }; use std::sync::mpsc::{ channel, + sync_channel, Sender, + SyncSender, Receiver, }; use std::thread; @@ -45,32 +47,40 @@ impl convert::From for Error { } impl convert::From> for Error { - fn from(err: sync::mpsc::SendError) -> Error { + fn from(_: sync::mpsc::SendError) -> Error { Error::Sender } } pub type Result = std::result::Result; -pub fn handle_conn(mut conn: TcpStream, f: F) -> Result<()> +pub fn handle_conn( + mut conn: TcpStream, + f: F) -> Result<()> where Request: serde::de::Deserialize, - Response: serde::ser::Serialize, - F: Fn(&Request) -> Result + Reply: serde::ser::Serialize, + F: 'static + Serve { let request: Request = try!(serde_json::from_reader(&mut conn)); - let response = try!(f(&request)); + let response = try!(f.serve(&request)); try!(serde_json::to_writer(&mut conn, &response)); Ok(()) } -pub fn serve(listener: TcpListener) -> Error { +pub fn serve(listener: TcpListener, f: F) -> Error + where Request: serde::de::Deserialize, + Reply: serde::ser::Serialize, + F: 'static + Serve, +{ for conn in listener.incoming() { let conn = match conn { Err(err) => return convert::From::from(err), Ok(c) => c, }; + println!("received connection"); + let f = f.clone(); thread::spawn(move || { - if let Err(err) = handle_conn(conn, |a| handle_impl(a)) { + if let Err(err) = handle_conn(conn, f) { println!("error handling connection: {:?}", err); } }); @@ -78,130 +88,74 @@ pub fn serve(listener: TcpListener) -> Error { Error::Impossible } +pub trait Serve : Sync + Send + Clone { + fn serve(&self, request: &Request) -> io::Result; +} + #[derive(Serialize, Deserialize)] struct Packet { - seq: u64, + id: u64, message: T, } -// Generated code - -#[derive(Serialize, Deserialize)] -struct A; -#[derive(Serialize, Deserialize)] -struct B; - -fn handle_impl(a: &A) -> Result { - Ok(B) -} - -struct InnerClient { - stream: TcpStream, - seq: u64, - outstanding_messages: HashMap>, -} - -struct RPC { +struct Handle { id: u64, - request: Request, - reply: Sender, -} - -struct RequestHandle { - id: u64, - request: Request, -} - -struct ReplyHandle { - id: u64, - reply: Sender, -} - -struct ReplyPacket { - id: u64, - message: Reply, -} - -fn message_reader( - mut stream: TcpStream, - replies: Sender>) -> Result<()> - where Reply: serde::de::Deserialize -{ - loop { - let id = try!(serde_json::from_reader(&mut stream)); - let reply_message = try!(serde_json::from_reader(&mut stream)); - let packet = ReplyPacket{ - id: id, - message: reply_message, - }; - try!(replies.send(ReceiverMessage::Packet(packet))); - } + sender: Sender, } enum ReceiverMessage { - Handle(ReplyHandle), - Packet(ReplyPacket), + Handle(Handle), + Packet(Packet), } -fn receiver(messages: Receiver>) -> Result<()> -{ - let mut ready_handles: HashMap> = HashMap::new(); - let mut ready_packets: HashMap> = HashMap::new(); +fn receiver(messages: Receiver>) -> Result<()> { + let mut ready_handles: HashMap> = HashMap::new(); for message in messages.into_iter() { match message { ReceiverMessage::Handle(handle) => { - if let Some(packet) = ready_packets.remove(&handle.id) { - try!(handle.reply.send(packet.message)); - } else { - ready_handles.insert(handle.id, handle); - } + ready_handles.insert(handle.id, handle); }, ReceiverMessage::Packet(packet) => { - if let Some(handle) = ready_handles.remove(&packet.id) { - try!(handle.reply.send(packet.message)); - } else { - ready_packets.insert(packet.id, packet); - } + let handle = ready_handles.remove(&packet.id).unwrap(); + try!(handle.sender.send(packet.message)); } - } } Ok(()) } -fn message_writer( - mut stream: TcpStream, - requests: Receiver>) -> Result<()> - where Request: serde::ser::Serialize -{ - for request_handle in requests.into_iter() { - try!(serde_json::to_writer(&mut stream, &request_handle.id)); - try!(serde_json::to_writer(&mut stream, &request_handle.request)); - } - Ok(()) -} - -struct Client { +pub struct Client { next_id: Mutex, - writer_tx: Sender>, - handles_tx: Sender>, + writer: multi_tcp::MultiStream, serde_json::Error>, + handles_tx: SyncSender>, } impl Client where Request: serde::ser::Serialize + Clone + Send + 'static, Reply: serde::de::Deserialize + Send + 'static { - fn new(stream: TcpStream) -> Result { - let write_stream = try!(stream.try_clone()); - let (requests_tx, requests_rx) = channel(); - let (handles_tx, receiver_rx) = channel(); - let replies_tx = handles_tx.clone(); - thread::spawn(move || message_writer(write_stream, requests_rx).unwrap()); - thread::spawn(move || message_reader(stream, replies_tx).unwrap()); - thread::spawn(move || receiver(receiver_rx).unwrap()); + pub fn new(stream: TcpStream) -> Result { + let (handles_tx, receiver_rx) = sync_channel(0); + let writer = multi_tcp::MultiStream::with_sync_sender( + stream, + |stream, packet: &Packet| { + try!(serde_json::to_writer(stream, &packet.id)); + try!(serde_json::to_writer(stream, &packet.message)); + Ok(()) + }, + |stream| { + let id = try!(serde_json::from_reader(stream)); + let reply = try!(serde_json::from_reader(stream)); + Ok(ReceiverMessage::Packet(Packet{ + id: id, + message: reply, + })) + }, + handles_tx.clone()); + thread::spawn(move || receiver(receiver_rx)); Ok(Client{ next_id: Mutex::new(0), - writer_tx: requests_tx, + writer: writer, handles_tx: handles_tx, }) } @@ -212,31 +166,67 @@ impl Client *id } - fn rpc(&self, request: &Request) -> Result { + pub fn rpc(&self, request: &Request) -> Result { let (tx, rx) = channel(); let id = self.get_next_id(); - try!(self.writer_tx.send(RequestHandle{ + println!("indicate that we're weaiting"); + try!(self.handles_tx.send(ReceiverMessage::Handle(Handle{ id: id, - request: request.clone(), - })); - try!(self.handles_tx.send(ReceiverMessage::Handle(ReplyHandle{ - id: id, - reply: tx, + sender: tx, }))); + println!("write the request to the wire"); + try!(self.writer.write(Packet{ + id: id, + message: request.clone(), + })); + println!("wait for the response"); Ok(rx.recv().unwrap()) } } -/* #[cfg(test)] mod test { - use adamrpc::*; + use super::*; + use std::thread; + use std::net::{TcpStream, TcpListener}; + use std::io; + + fn pair() -> (TcpStream, TcpListener) { + let addr = "127.0.0.1:9000"; + // Do this one first so that we don't get connection refused :) + let listener = TcpListener::bind(addr).unwrap(); + (TcpStream::connect(addr).unwrap(), listener) + } + + #[derive(Debug, PartialEq, Serialize, Deserialize, Clone)] + enum Request { + Increment + } + + #[derive(Debug, PartialEq, Serialize, Deserialize)] + enum Reply { + Increment + } + + #[derive(Clone)] + struct Server; + + impl Serve for Server { + fn serve(&self, _: &Request) -> io::Result { + Ok(Reply::Increment) + } + } #[test] fn test() { - let listener = TcpListener::bind("127.0.0.1:9000").expect("listener"); - let server = - let stream = TcpStream::connect + let (client_stream, server_streams) = pair(); + println!("starting server!"); + thread::spawn(|| { + serve(server_streams, Server) + }); + println!("making client"); + let client: Client = Client::new(client_stream).unwrap(); + println!("hi there"); + assert_eq!(Reply::Increment, client.rpc(&Request::Increment).unwrap()); } } -*/ diff --git a/src/multi_tcp/Cargo.toml b/src/multi_tcp/Cargo.toml new file mode 100644 index 0000000..2fac3ad --- /dev/null +++ b/src/multi_tcp/Cargo.toml @@ -0,0 +1,6 @@ +[package] +name = "multi_tcp" +version = "0.1.0" +authors = ["Adam Wright "] + +[dependencies] diff --git a/src/multi_tcp/src/lib.rs b/src/multi_tcp/src/lib.rs new file mode 100644 index 0000000..4bccdde --- /dev/null +++ b/src/multi_tcp/src/lib.rs @@ -0,0 +1,135 @@ +use std::fmt; +use std::net::TcpStream; +use std::thread; +use std::sync::mpsc::{ + channel, + sync_channel, + Sender, + SyncSender, + Receiver, +}; + +fn read(mut stream: TcpStream, decode: F, tx: SyncSender) + where F: Send + 'static + Fn(&mut TcpStream) -> Result, + T: Send + 'static, + E: fmt::Debug + Send + 'static +{ + loop { + let t = decode(&mut stream).unwrap(); + if let Err(_) = tx.send(t) { + break; + } + } +} + +struct SendHelper { + value: T, + result: Sender>, +} + +fn write(mut stream: TcpStream, encode: F) -> Sender> + where F: Send + 'static + Fn(&mut TcpStream, &T) -> Result<(), E>, + T: Send + 'static, + E: Send + 'static +{ + let (tx, rx) = channel(); + thread::spawn(move || { + loop { + let helper: SendHelper = match rx.recv() { + Ok(h) => h, + Err(_) => { + break; + } + }; + helper.result.send(encode(&mut stream, &helper.value)).unwrap(); + } + }); + tx +} + +pub struct MultiStream { + tx: Sender>, +} + +impl MultiStream + where Request: Send + 'static, + E: fmt::Debug + Send + 'static +{ + pub fn new( + stream: TcpStream, + encode: F, + decode: G) -> (Self, Receiver) + where Reply: Send + 'static, + F: Send + 'static + Fn(&mut TcpStream, &Request) -> Result<(), E>, + G: Send + 'static + Fn(&mut TcpStream) -> Result + { + let read_stream = stream.try_clone().unwrap(); + let ms = MultiStream{tx: write(stream, encode)}; + let (reply_tx, reply_rx) = sync_channel(0); + thread::spawn(move || read(read_stream, decode, reply_tx)); + (ms, reply_rx) + } + + pub fn with_sync_sender( + stream: TcpStream, + encode: F, + decode: G, + reply_tx: SyncSender) -> Self + where Reply: Send + 'static, + F: Send + 'static + Fn(&mut TcpStream, &Request) -> Result<(), E>, + G: Send + 'static + Fn(&mut TcpStream) -> Result + { + let read_stream = stream.try_clone().unwrap(); + thread::spawn(move || read(read_stream, decode, reply_tx)); + MultiStream{tx: write(stream, encode)} + } + + + pub fn write(&self, value: Request) -> Result<(), E> { + let my_tx = self.tx.clone(); + let (reply_tx, reply_rx) = channel(); + let helper = SendHelper{ + value: value, + result: reply_tx, + }; + my_tx.send(helper).unwrap(); + reply_rx.recv().unwrap() + } +} + +#[cfg(test)] +mod test { + use super::MultiStream; + use std::net::{TcpStream, TcpListener}; + use std::sync::mpsc::Receiver; + use std::io::{Write, Read}; + + fn pair() -> (TcpStream, Receiver) { + let addr = "127.0.0.1:9000"; + let recv_stream = listen(TcpListener::bind(addr).unwrap()); + (TcpStream::connect(addr).unwrap(), recv_stream) + } + + fn write_byte(stream: &mut TcpStream, v: u8) -> Result<(), ()> { + stream.write(&[v]).unwrap(); + Ok(()) + } + + fn read_byte(stream: &mut TcpStream) -> Result { + let mut buf = [0u8]; + stream.read_exact(&mut buf[..]).unwrap(); + Ok(buf[0]) + } + + #[test] + fn test_thing() { + let (stream, listener) = pair(); + let (ms, reader) : (MultiStream, Receiver) = + MultiStream::new(stream, |s, v| write_byte(s, *v), |s| read_byte(s)); + ms.write(5).expect("writing 5"); + let mut srv_stream = listener.accept().unwrap().0; + assert_eq!(5, read_byte(&mut srv_stream).expect("read 5")); + write_byte(&mut srv_stream, 10).expect("write 10"); + assert_eq!(10, reader.recv().expect("reading 10")); + } +} From 579d3909e54d714fa4a0f0caeb2cc83fe375ea8c Mon Sep 17 00:00:00 2001 From: Adam Wright Date: Fri, 8 Jan 2016 02:39:40 -0800 Subject: [PATCH 2/6] I made le test pass --- src/lib.rs | 64 +++++++++++++++++++--------------------- src/multi_tcp/src/lib.rs | 5 ++-- 2 files changed, 34 insertions(+), 35 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index ad0343c..36c14ea 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,7 +5,9 @@ extern crate multi_tcp; extern crate serde; extern crate serde_json; -use std::io; +use serde::Deserialize; +use std::fmt; +use std::io::{self, Read}; use std::convert; use std::collections::HashMap; use std::net::{ @@ -36,7 +38,10 @@ pub enum Error { impl convert::From for Error { fn from(err: serde_json::Error) -> Error { - Error::Json(err) + match err { + serde_json::Error::IoError(err) => Error::Io(err), + err => Error::Json(err), + } } } @@ -55,21 +60,27 @@ impl convert::From> for Error { pub type Result = std::result::Result; pub fn handle_conn( - mut conn: TcpStream, + mut stream: TcpStream, f: F) -> Result<()> - where Request: serde::de::Deserialize, - Reply: serde::ser::Serialize, + where Request: fmt::Debug + serde::de::Deserialize, + Reply: fmt::Debug + serde::ser::Serialize, F: 'static + Serve { - let request: Request = try!(serde_json::from_reader(&mut conn)); - let response = try!(f.serve(&request)); - try!(serde_json::to_writer(&mut conn, &response)); + let read_stream = try!(stream.try_clone()); + let mut de = serde_json::Deserializer::new(read_stream.bytes()); + let request_packet: Packet = try!(Packet::deserialize(&mut de)); + let reply = try!(f.serve(&request_packet.message)); + let reply_packet = Packet{ + id: request_packet.id, + message: reply, + }; + try!(serde_json::to_writer(&mut stream, &reply_packet)); Ok(()) } pub fn serve(listener: TcpListener, f: F) -> Error - where Request: serde::de::Deserialize, - Reply: serde::ser::Serialize, + where Request: fmt::Debug + serde::de::Deserialize, + Reply: fmt::Debug + serde::ser::Serialize, F: 'static + Serve, { for conn in listener.incoming() { @@ -77,7 +88,6 @@ pub fn serve(listener: TcpListener, f: F) -> Error Err(err) => return convert::From::from(err), Ok(c) => c, }; - println!("received connection"); let f = f.clone(); thread::spawn(move || { if let Err(err) = handle_conn(conn, f) { @@ -92,7 +102,7 @@ pub trait Serve : Sync + Send + Clone { fn serve(&self, request: &Request) -> io::Result; } -#[derive(Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize)] struct Packet { id: u64, message: T, @@ -126,7 +136,7 @@ fn receiver(messages: Receiver>) -> Result<()> { pub struct Client { next_id: Mutex, - writer: multi_tcp::MultiStream, serde_json::Error>, + writer: multi_tcp::MultiStream, Error>, handles_tx: SyncSender>, } @@ -136,20 +146,16 @@ impl Client { pub fn new(stream: TcpStream) -> Result { let (handles_tx, receiver_rx) = sync_channel(0); - let writer = multi_tcp::MultiStream::with_sync_sender( + let writer: multi_tcp::MultiStream, Error> + = multi_tcp::MultiStream::with_sync_sender( stream, - |stream, packet: &Packet| { - try!(serde_json::to_writer(stream, &packet.id)); - try!(serde_json::to_writer(stream, &packet.message)); + |stream: &mut TcpStream, packet: &Packet| { + try!(serde_json::to_writer(stream, packet)); Ok(()) }, - |stream| { - let id = try!(serde_json::from_reader(stream)); - let reply = try!(serde_json::from_reader(stream)); - Ok(ReceiverMessage::Packet(Packet{ - id: id, - message: reply, - })) + |mut stream| { + let packet = try!(serde_json::from_reader(&mut stream)); + Ok(ReceiverMessage::Packet(packet)) }, handles_tx.clone()); thread::spawn(move || receiver(receiver_rx)); @@ -169,17 +175,14 @@ impl Client pub fn rpc(&self, request: &Request) -> Result { let (tx, rx) = channel(); let id = self.get_next_id(); - println!("indicate that we're weaiting"); try!(self.handles_tx.send(ReceiverMessage::Handle(Handle{ id: id, sender: tx, }))); - println!("write the request to the wire"); try!(self.writer.write(Packet{ id: id, message: request.clone(), })); - println!("wait for the response"); Ok(rx.recv().unwrap()) } } @@ -220,13 +223,8 @@ mod test { #[test] fn test() { let (client_stream, server_streams) = pair(); - println!("starting server!"); - thread::spawn(|| { - serve(server_streams, Server) - }); - println!("making client"); + thread::spawn(|| serve(server_streams, Server)); let client: Client = Client::new(client_stream).unwrap(); - println!("hi there"); assert_eq!(Reply::Increment, client.rpc(&Request::Increment).unwrap()); } } diff --git a/src/multi_tcp/src/lib.rs b/src/multi_tcp/src/lib.rs index 4bccdde..7b8183d 100644 --- a/src/multi_tcp/src/lib.rs +++ b/src/multi_tcp/src/lib.rs @@ -15,7 +15,7 @@ fn read(mut stream: TcpStream, decode: F, tx: SyncSender) E: fmt::Debug + Send + 'static { loop { - let t = decode(&mut stream).unwrap(); + let t = decode(&mut stream).expect("I couldn't do the thing"); if let Err(_) = tx.send(t) { break; } @@ -41,7 +41,8 @@ fn write(mut stream: TcpStream, encode: F) -> Sender> break; } }; - helper.result.send(encode(&mut stream, &helper.value)).unwrap(); + helper.result.send(encode(&mut stream, &helper.value)) + .expect("died trying to send the result to the helper"); } }); tx From 44b3765d7027ba1226a0e08a483c861a7ddd1e05 Mon Sep 17 00:00:00 2001 From: Adam Wright Date: Fri, 8 Jan 2016 04:15:10 -0800 Subject: [PATCH 3/6] Removed a bunch of over-engineered code --- Cargo.toml | 1 - src/lib.rs | 84 +++++++++++++++--------- src/multi_tcp/Cargo.toml | 6 -- src/multi_tcp/src/lib.rs | 136 --------------------------------------- 4 files changed, 52 insertions(+), 175 deletions(-) delete mode 100644 src/multi_tcp/Cargo.toml delete mode 100644 src/multi_tcp/src/lib.rs diff --git a/Cargo.toml b/Cargo.toml index 2ff0216..31a32ab 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,4 +7,3 @@ authors = ["Adam Wright "] serde = "*" serde_json = "*" serde_macros = "*" -multi_tcp = { path = "src/multi_tcp" } diff --git a/src/lib.rs b/src/lib.rs index 36c14ea..f3e9af0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,7 +1,6 @@ #![feature(custom_derive, plugin)] #![plugin(serde_macros)] -extern crate multi_tcp; extern crate serde; extern crate serde_json; @@ -134,57 +133,78 @@ fn receiver(messages: Receiver>) -> Result<()> { Ok(()) } -pub struct Client { - next_id: Mutex, - writer: multi_tcp::MultiStream, Error>, - handles_tx: SyncSender>, +fn reader(mut stream: TcpStream, decode: F, tx: SyncSender) + where F: Send + 'static + Fn(&mut TcpStream) -> Result, + T: Send + 'static +{ + loop { + let t = decode(&mut stream).expect("I couldn't do the thing"); + if let Err(_) = tx.send(t) { + break; + } + } } -impl Client - where Request: serde::ser::Serialize + Clone + Send + 'static, - Reply: serde::de::Deserialize + Send + 'static +fn increment(cur_id: &mut u64) -> u64 { + let id = *cur_id; + *cur_id += 1; + id +} + +struct SyncedClientState{ + next_id: u64, + stream: TcpStream, +} + +pub struct Client { + synced_state: Mutex, + handles_tx: SyncSender>, + reader_guard: thread::JoinHandle<()>, +} + +impl Client + where Reply: serde::de::Deserialize + Send + 'static { pub fn new(stream: TcpStream) -> Result { let (handles_tx, receiver_rx) = sync_channel(0); - let writer: multi_tcp::MultiStream, Error> - = multi_tcp::MultiStream::with_sync_sender( - stream, - |stream: &mut TcpStream, packet: &Packet| { - try!(serde_json::to_writer(stream, packet)); - Ok(()) - }, - |mut stream| { - let packet = try!(serde_json::from_reader(&mut stream)); - Ok(ReceiverMessage::Packet(packet)) - }, - handles_tx.clone()); + let decode = |mut stream: &mut TcpStream| { + let packet = try!(serde_json::from_reader(&mut stream)); + Ok(ReceiverMessage::Packet(packet)) + }; + let read_stream = try!(stream.try_clone()); + let reader_handles_tx = handles_tx.clone(); + let guard = thread::spawn(move || reader(read_stream, decode, reader_handles_tx)); thread::spawn(move || receiver(receiver_rx)); Ok(Client{ - next_id: Mutex::new(0), - writer: writer, + synced_state: Mutex::new(SyncedClientState{ + next_id: 0, + stream: stream, + }), + reader_guard: guard, handles_tx: handles_tx, }) } - fn get_next_id(&self) -> u64 { - let mut id = self.next_id.lock().unwrap(); - *id += 1; - *id - } - - pub fn rpc(&self, request: &Request) -> Result { + pub fn rpc(&self, request: &Request) -> Result + where Request: serde::ser::Serialize + Clone + Send + 'static + { let (tx, rx) = channel(); - let id = self.get_next_id(); + let mut state = self.synced_state.lock().unwrap(); + let id = increment(&mut state.next_id); try!(self.handles_tx.send(ReceiverMessage::Handle(Handle{ id: id, sender: tx, }))); - try!(self.writer.write(Packet{ + try!(serde_json::to_writer(&mut state.stream, &Packet{ id: id, message: request.clone(), })); Ok(rx.recv().unwrap()) } + + pub fn join(self) { + self.reader_guard.join().unwrap(); + } } #[cfg(test)] @@ -224,7 +244,7 @@ mod test { fn test() { let (client_stream, server_streams) = pair(); thread::spawn(|| serve(server_streams, Server)); - let client: Client = Client::new(client_stream).unwrap(); + let client = Client::new(client_stream).unwrap(); assert_eq!(Reply::Increment, client.rpc(&Request::Increment).unwrap()); } } diff --git a/src/multi_tcp/Cargo.toml b/src/multi_tcp/Cargo.toml deleted file mode 100644 index 2fac3ad..0000000 --- a/src/multi_tcp/Cargo.toml +++ /dev/null @@ -1,6 +0,0 @@ -[package] -name = "multi_tcp" -version = "0.1.0" -authors = ["Adam Wright "] - -[dependencies] diff --git a/src/multi_tcp/src/lib.rs b/src/multi_tcp/src/lib.rs deleted file mode 100644 index 7b8183d..0000000 --- a/src/multi_tcp/src/lib.rs +++ /dev/null @@ -1,136 +0,0 @@ -use std::fmt; -use std::net::TcpStream; -use std::thread; -use std::sync::mpsc::{ - channel, - sync_channel, - Sender, - SyncSender, - Receiver, -}; - -fn read(mut stream: TcpStream, decode: F, tx: SyncSender) - where F: Send + 'static + Fn(&mut TcpStream) -> Result, - T: Send + 'static, - E: fmt::Debug + Send + 'static -{ - loop { - let t = decode(&mut stream).expect("I couldn't do the thing"); - if let Err(_) = tx.send(t) { - break; - } - } -} - -struct SendHelper { - value: T, - result: Sender>, -} - -fn write(mut stream: TcpStream, encode: F) -> Sender> - where F: Send + 'static + Fn(&mut TcpStream, &T) -> Result<(), E>, - T: Send + 'static, - E: Send + 'static -{ - let (tx, rx) = channel(); - thread::spawn(move || { - loop { - let helper: SendHelper = match rx.recv() { - Ok(h) => h, - Err(_) => { - break; - } - }; - helper.result.send(encode(&mut stream, &helper.value)) - .expect("died trying to send the result to the helper"); - } - }); - tx -} - -pub struct MultiStream { - tx: Sender>, -} - -impl MultiStream - where Request: Send + 'static, - E: fmt::Debug + Send + 'static -{ - pub fn new( - stream: TcpStream, - encode: F, - decode: G) -> (Self, Receiver) - where Reply: Send + 'static, - F: Send + 'static + Fn(&mut TcpStream, &Request) -> Result<(), E>, - G: Send + 'static + Fn(&mut TcpStream) -> Result - { - let read_stream = stream.try_clone().unwrap(); - let ms = MultiStream{tx: write(stream, encode)}; - let (reply_tx, reply_rx) = sync_channel(0); - thread::spawn(move || read(read_stream, decode, reply_tx)); - (ms, reply_rx) - } - - pub fn with_sync_sender( - stream: TcpStream, - encode: F, - decode: G, - reply_tx: SyncSender) -> Self - where Reply: Send + 'static, - F: Send + 'static + Fn(&mut TcpStream, &Request) -> Result<(), E>, - G: Send + 'static + Fn(&mut TcpStream) -> Result - { - let read_stream = stream.try_clone().unwrap(); - thread::spawn(move || read(read_stream, decode, reply_tx)); - MultiStream{tx: write(stream, encode)} - } - - - pub fn write(&self, value: Request) -> Result<(), E> { - let my_tx = self.tx.clone(); - let (reply_tx, reply_rx) = channel(); - let helper = SendHelper{ - value: value, - result: reply_tx, - }; - my_tx.send(helper).unwrap(); - reply_rx.recv().unwrap() - } -} - -#[cfg(test)] -mod test { - use super::MultiStream; - use std::net::{TcpStream, TcpListener}; - use std::sync::mpsc::Receiver; - use std::io::{Write, Read}; - - fn pair() -> (TcpStream, Receiver) { - let addr = "127.0.0.1:9000"; - let recv_stream = listen(TcpListener::bind(addr).unwrap()); - (TcpStream::connect(addr).unwrap(), recv_stream) - } - - fn write_byte(stream: &mut TcpStream, v: u8) -> Result<(), ()> { - stream.write(&[v]).unwrap(); - Ok(()) - } - - fn read_byte(stream: &mut TcpStream) -> Result { - let mut buf = [0u8]; - stream.read_exact(&mut buf[..]).unwrap(); - Ok(buf[0]) - } - - #[test] - fn test_thing() { - let (stream, listener) = pair(); - let (ms, reader) : (MultiStream, Receiver) = - MultiStream::new(stream, |s, v| write_byte(s, *v), |s| read_byte(s)); - ms.write(5).expect("writing 5"); - let mut srv_stream = listener.accept().unwrap().0; - assert_eq!(5, read_byte(&mut srv_stream).expect("read 5")); - write_byte(&mut srv_stream, 10).expect("write 10"); - assert_eq!(10, reader.recv().expect("reading 10")); - } -} From c62d66839dc7be342509c093d7879995267ad95c Mon Sep 17 00:00:00 2001 From: Adam Wright Date: Fri, 8 Jan 2016 04:31:51 -0800 Subject: [PATCH 4/6] Join the client, and update reader to handle EOF --- src/lib.rs | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index f3e9af0..222e609 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -137,10 +137,13 @@ fn reader(mut stream: TcpStream, decode: F, tx: SyncSender) where F: Send + 'static + Fn(&mut TcpStream) -> Result, T: Send + 'static { + use serde_json::Error::SyntaxError; + use serde_json::ErrorCode::EOFWhileParsingValue; loop { - let t = decode(&mut stream).expect("I couldn't do the thing"); - if let Err(_) = tx.send(t) { - break; + match decode(&mut stream) { + Ok(t) => tx.send(t).unwrap(), + Err(Error::Json(SyntaxError(EOFWhileParsingValue, _, _))) => break, + Err(err) => panic!("unexpected error while parsing!: {:?}", err), } } } @@ -246,5 +249,6 @@ mod test { thread::spawn(|| serve(server_streams, Server)); let client = Client::new(client_stream).unwrap(); assert_eq!(Reply::Increment, client.rpc(&Request::Increment).unwrap()); + client.join(); } } From eac0e56be772f6fa8ee67425e8332b8c0623b522 Mon Sep 17 00:00:00 2001 From: Adam Wright Date: Fri, 8 Jan 2016 05:49:00 -0800 Subject: [PATCH 5/6] A number of improvements - Support non clonable Serve objects by wrapping in an Arc - Support multiple RPCs per connection - Support cleanish shutdown --- src/lib.rs | 119 +++++++++++++++++++++++++++++++++-------------------- 1 file changed, 75 insertions(+), 44 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 222e609..b81ed00 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -10,12 +10,14 @@ use std::io::{self, Read}; use std::convert; use std::collections::HashMap; use std::net::{ + self, TcpListener, TcpStream, }; use std::sync::{ self, Mutex, + Arc, }; use std::sync::mpsc::{ channel, @@ -24,6 +26,7 @@ use std::sync::mpsc::{ SyncSender, Receiver, }; +use std::time; use std::thread; #[derive(Debug)] @@ -58,26 +61,30 @@ impl convert::From> for Error { pub type Result = std::result::Result; -pub fn handle_conn( - mut stream: TcpStream, - f: F) -> Result<()> +pub fn handle_conn(mut stream: TcpStream, f: Arc) -> Result<()> where Request: fmt::Debug + serde::de::Deserialize, Reply: fmt::Debug + serde::ser::Serialize, - F: 'static + Serve + F: Serve { let read_stream = try!(stream.try_clone()); let mut de = serde_json::Deserializer::new(read_stream.bytes()); - let request_packet: Packet = try!(Packet::deserialize(&mut de)); - let reply = try!(f.serve(&request_packet.message)); - let reply_packet = Packet{ - id: request_packet.id, - message: reply, - }; - try!(serde_json::to_writer(&mut stream, &reply_packet)); + loop { + println!("read"); + let request_packet: Packet = try!(Packet::deserialize(&mut de)); + match request_packet { + Packet::Shutdown => break, + Packet::Message(id, message) => { + let reply = try!(f.serve(&message)); + let reply_packet = Packet::Message(id, reply); + println!("write"); + try!(serde_json::to_writer(&mut stream, &reply_packet)); + }, + } + } Ok(()) } -pub fn serve(listener: TcpListener, f: F) -> Error +pub fn serve(listener: TcpListener, f: Arc) -> Error where Request: fmt::Debug + serde::de::Deserialize, Reply: fmt::Debug + serde::ser::Serialize, F: 'static + Serve, @@ -97,14 +104,14 @@ pub fn serve(listener: TcpListener, f: F) -> Error Error::Impossible } -pub trait Serve : Sync + Send + Clone { +pub trait Serve: Send + Sync { fn serve(&self, request: &Request) -> io::Result; } #[derive(Debug, Clone, Serialize, Deserialize)] -struct Packet { - id: u64, - message: T, +enum Packet { + Message(u64, T), + Shutdown, } struct Handle { @@ -115,6 +122,7 @@ struct Handle { enum ReceiverMessage { Handle(Handle), Packet(Packet), + Shutdown, } fn receiver(messages: Receiver>) -> Result<()> { @@ -124,25 +132,32 @@ fn receiver(messages: Receiver>) -> Result<()> { ReceiverMessage::Handle(handle) => { ready_handles.insert(handle.id, handle); }, - ReceiverMessage::Packet(packet) => { - let handle = ready_handles.remove(&packet.id).unwrap(); - try!(handle.sender.send(packet.message)); + ReceiverMessage::Packet(Packet::Shutdown) => break, + ReceiverMessage::Packet(Packet::Message(id, message)) => { + let handle = ready_handles.remove(&id).unwrap(); + try!(handle.sender.send(message)); } + ReceiverMessage::Shutdown => break, } } Ok(()) } -fn reader(mut stream: TcpStream, decode: F, tx: SyncSender) - where F: Send + 'static + Fn(&mut TcpStream) -> Result, - T: Send + 'static +fn reader(stream: TcpStream, tx: SyncSender>) + where Reply: serde::Deserialize { use serde_json::Error::SyntaxError; use serde_json::ErrorCode::EOFWhileParsingValue; + let mut de = serde_json::Deserializer::new(stream.bytes()); loop { - match decode(&mut stream) { - Ok(t) => tx.send(t).unwrap(), - Err(Error::Json(SyntaxError(EOFWhileParsingValue, _, _))) => break, + match Packet::deserialize(&mut de) { + Ok(packet) =>{ + println!("send!"); + tx.send(ReceiverMessage::Packet(packet)).unwrap(); + }, + // TODO: This shutdown logic is janky.. What's the right way to do this? + Err(SyntaxError(EOFWhileParsingValue, _, _)) => break, + Err(SyntaxError(ExpectedValue, _, _)) => break, Err(err) => panic!("unexpected error while parsing!: {:?}", err), } } @@ -170,13 +185,10 @@ impl Client { pub fn new(stream: TcpStream) -> Result { let (handles_tx, receiver_rx) = sync_channel(0); - let decode = |mut stream: &mut TcpStream| { - let packet = try!(serde_json::from_reader(&mut stream)); - Ok(ReceiverMessage::Packet(packet)) - }; let read_stream = try!(stream.try_clone()); + try!(read_stream.set_read_timeout(Some(time::Duration::from_millis(50)))); let reader_handles_tx = handles_tx.clone(); - let guard = thread::spawn(move || reader(read_stream, decode, reader_handles_tx)); + let guard = thread::spawn(move || reader(read_stream, reader_handles_tx)); thread::spawn(move || receiver(receiver_rx)); Ok(Client{ synced_state: Mutex::new(SyncedClientState{ @@ -198,24 +210,28 @@ impl Client id: id, sender: tx, }))); - try!(serde_json::to_writer(&mut state.stream, &Packet{ - id: id, - message: request.clone(), - })); + let packet = Packet::Message(id, request.clone()); + try!(serde_json::to_writer(&mut state.stream, &packet)); Ok(rx.recv().unwrap()) } - pub fn join(self) { + pub fn join(self) -> Result<()> { + let mut state = self.synced_state.lock().unwrap(); + let packet: Packet = Packet::Shutdown; + try!(serde_json::to_writer(&mut state.stream, &packet)); + try!(state.stream.shutdown(net::Shutdown::Both)); self.reader_guard.join().unwrap(); + Ok(()) } } #[cfg(test)] mod test { use super::*; - use std::thread; - use std::net::{TcpStream, TcpListener}; use std::io; + use std::net::{TcpStream, TcpListener}; + use std::sync::{Arc, Mutex}; + use std::thread; fn pair() -> (TcpStream, TcpListener) { let addr = "127.0.0.1:9000"; @@ -231,24 +247,39 @@ mod test { #[derive(Debug, PartialEq, Serialize, Deserialize)] enum Reply { - Increment + Increment(u64) } - #[derive(Clone)] - struct Server; + struct Server { + counter: Mutex, + } impl Serve for Server { fn serve(&self, _: &Request) -> io::Result { - Ok(Reply::Increment) + let mut counter = self.counter.lock().unwrap(); + let reply = Reply::Increment(*counter); + *counter += 1; + Ok(reply) + } + } + + impl Server { + fn count(&self) -> u64 { + *self.counter.lock().unwrap() } } #[test] fn test() { let (client_stream, server_streams) = pair(); - thread::spawn(|| serve(server_streams, Server)); + let server = Arc::new(Server{counter: Mutex::new(0)}); + let thread_server = server.clone(); + thread::spawn(move || serve(server_streams, thread_server)); let client = Client::new(client_stream).unwrap(); - assert_eq!(Reply::Increment, client.rpc(&Request::Increment).unwrap()); - client.join(); + assert_eq!(Reply::Increment(0), client.rpc(&Request::Increment).unwrap()); + assert_eq!(1, server.count()); + assert_eq!(Reply::Increment(1), client.rpc(&Request::Increment).unwrap()); + assert_eq!(2, server.count()); + client.join::().unwrap(); } } From 3a3e2d1e4dcdae90cabaf32bbfbf760b813433df Mon Sep 17 00:00:00 2001 From: Adam Wright Date: Fri, 8 Jan 2016 19:56:02 -0800 Subject: [PATCH 6/6] Really have a non-clone thing? tests are a mess though --- src/lib.rs | 83 +++++++++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 70 insertions(+), 13 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index b81ed00..0571d95 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,4 @@ +#![feature(const_fn)] #![feature(custom_derive, plugin)] #![plugin(serde_macros)] @@ -157,7 +158,6 @@ fn reader(stream: TcpStream, tx: SyncSender>) }, // TODO: This shutdown logic is janky.. What's the right way to do this? Err(SyntaxError(EOFWhileParsingValue, _, _)) => break, - Err(SyntaxError(ExpectedValue, _, _)) => break, Err(err) => panic!("unexpected error while parsing!: {:?}", err), } } @@ -169,14 +169,14 @@ fn increment(cur_id: &mut u64) -> u64 { id } -struct SyncedClientState{ +struct SyncedClientState { next_id: u64, stream: TcpStream, + handles_tx: SyncSender>, } pub struct Client { - synced_state: Mutex, - handles_tx: SyncSender>, + synced_state: Mutex>, reader_guard: thread::JoinHandle<()>, } @@ -194,9 +194,9 @@ impl Client synced_state: Mutex::new(SyncedClientState{ next_id: 0, stream: stream, + handles_tx: handles_tx, }), reader_guard: guard, - handles_tx: handles_tx, }) } @@ -206,7 +206,7 @@ impl Client let (tx, rx) = channel(); let mut state = self.synced_state.lock().unwrap(); let id = increment(&mut state.next_id); - try!(self.handles_tx.send(ReceiverMessage::Handle(Handle{ + try!(state.handles_tx.send(ReceiverMessage::Handle(Handle{ id: id, sender: tx, }))); @@ -229,15 +229,20 @@ impl Client mod test { use super::*; use std::io; - use std::net::{TcpStream, TcpListener}; - use std::sync::{Arc, Mutex}; + use std::net::{TcpStream, TcpListener, SocketAddr}; + use std::str::FromStr; + use std::sync::{Arc, Mutex, Barrier}; + use std::sync::atomic::{AtomicUsize, Ordering}; use std::thread; + const port: AtomicUsize = AtomicUsize::new(10000); + fn pair() -> (TcpStream, TcpListener) { - let addr = "127.0.0.1:9000"; + let addr = format!("127.0.0.1:{}", port.fetch_add(1, Ordering::SeqCst)); + println!("what the fuck {}", &addr); // Do this one first so that we don't get connection refused :) - let listener = TcpListener::bind(addr).unwrap(); - (TcpStream::connect(addr).unwrap(), listener) + let listener = TcpListener::bind(&*addr).unwrap(); + (TcpStream::connect(&*addr).unwrap(), listener) } #[derive(Debug, PartialEq, Serialize, Deserialize, Clone)] @@ -264,6 +269,10 @@ mod test { } impl Server { + fn new() -> Server { + Server{counter: Mutex::new(0)} + } + fn count(&self) -> u64 { *self.counter.lock().unwrap() } @@ -272,14 +281,62 @@ mod test { #[test] fn test() { let (client_stream, server_streams) = pair(); - let server = Arc::new(Server{counter: Mutex::new(0)}); + let server = Arc::new(Server::new()); let thread_server = server.clone(); - thread::spawn(move || serve(server_streams, thread_server)); + let guard = thread::spawn(move || serve(server_streams, thread_server)); let client = Client::new(client_stream).unwrap(); assert_eq!(Reply::Increment(0), client.rpc(&Request::Increment).unwrap()); assert_eq!(1, server.count()); assert_eq!(Reply::Increment(1), client.rpc(&Request::Increment).unwrap()); assert_eq!(2, server.count()); client.join::().unwrap(); + guard.join(); + } + + struct BarrierServer { + barrier: Barrier, + inner: Server, + } + + impl Serve for BarrierServer { + fn serve(&self, request: &Request) -> io::Result { + self.barrier.wait(); + let reply = try!(self.inner.serve(request)); + Ok(reply) + } + } + + impl BarrierServer { + fn new(n: usize) -> BarrierServer { + BarrierServer{barrier: Barrier::new(n), inner: Server::new()} + } + + fn count(&self) -> u64 { + self.inner.count() + } + } + + #[test] + fn test_concurrent() { + let (client_stream, server_streams) = pair(); + let server = Arc::new(BarrierServer::new(10)); + let thread_server = server.clone(); + let guard = thread::spawn(move || serve(server_streams, thread_server)); + let client: Arc> = Arc::new(Client::new(client_stream).unwrap()); + let mut join_handles = vec![]; + for _ in 0..10 { + let my_client = client.clone(); + join_handles.push(thread::spawn(move || my_client.rpc(&Request::Increment).unwrap())); + } + for handle in join_handles.into_iter() { + handle.join(); + } + assert_eq!(10, server.count()); + let client = match Arc::try_unwrap(client) { + Err(_) => panic!("couldn't unwrap arc"), + Ok(c) => c, + }; + client.join::().unwrap(); + guard.join(); } }