From dbf7113cf38b24b3b44b96b07bc5e77318307101 Mon Sep 17 00:00:00 2001 From: Adam Wright Date: Fri, 8 Jan 2016 00:07:53 -0800 Subject: [PATCH] WIP, doesn't compile :( --- Cargo.toml | 1 + src/lib.rs | 218 +++++++++++++++++++-------------------- src/multi_tcp/Cargo.toml | 6 ++ src/multi_tcp/src/lib.rs | 135 ++++++++++++++++++++++++ 4 files changed, 246 insertions(+), 114 deletions(-) create mode 100644 src/multi_tcp/Cargo.toml create mode 100644 src/multi_tcp/src/lib.rs diff --git a/Cargo.toml b/Cargo.toml index 31a32ab..2ff0216 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,3 +7,4 @@ authors = ["Adam Wright "] serde = "*" serde_json = "*" serde_macros = "*" +multi_tcp = { path = "src/multi_tcp" } diff --git a/src/lib.rs b/src/lib.rs index 7d169cf..ad0343c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,13 +1,13 @@ #![feature(custom_derive, plugin)] #![plugin(serde_macros)] +extern crate multi_tcp; extern crate serde; extern crate serde_json; use std::io; use std::convert; use std::collections::HashMap; -use std::error::Error as StdError; use std::net::{ TcpListener, TcpStream, @@ -18,7 +18,9 @@ use std::sync::{ }; use std::sync::mpsc::{ channel, + sync_channel, Sender, + SyncSender, Receiver, }; use std::thread; @@ -45,32 +47,40 @@ impl convert::From for Error { } impl convert::From> for Error { - fn from(err: sync::mpsc::SendError) -> Error { + fn from(_: sync::mpsc::SendError) -> Error { Error::Sender } } pub type Result = std::result::Result; -pub fn handle_conn(mut conn: TcpStream, f: F) -> Result<()> +pub fn handle_conn( + mut conn: TcpStream, + f: F) -> Result<()> where Request: serde::de::Deserialize, - Response: serde::ser::Serialize, - F: Fn(&Request) -> Result + Reply: serde::ser::Serialize, + F: 'static + Serve { let request: Request = try!(serde_json::from_reader(&mut conn)); - let response = try!(f(&request)); + let response = try!(f.serve(&request)); try!(serde_json::to_writer(&mut conn, &response)); Ok(()) } -pub fn serve(listener: TcpListener) -> Error { +pub fn serve(listener: TcpListener, f: F) -> Error + where Request: serde::de::Deserialize, + Reply: serde::ser::Serialize, + F: 'static + Serve, +{ for conn in listener.incoming() { let conn = match conn { Err(err) => return convert::From::from(err), Ok(c) => c, }; + println!("received connection"); + let f = f.clone(); thread::spawn(move || { - if let Err(err) = handle_conn(conn, |a| handle_impl(a)) { + if let Err(err) = handle_conn(conn, f) { println!("error handling connection: {:?}", err); } }); @@ -78,130 +88,74 @@ pub fn serve(listener: TcpListener) -> Error { Error::Impossible } +pub trait Serve : Sync + Send + Clone { + fn serve(&self, request: &Request) -> io::Result; +} + #[derive(Serialize, Deserialize)] struct Packet { - seq: u64, + id: u64, message: T, } -// Generated code - -#[derive(Serialize, Deserialize)] -struct A; -#[derive(Serialize, Deserialize)] -struct B; - -fn handle_impl(a: &A) -> Result { - Ok(B) -} - -struct InnerClient { - stream: TcpStream, - seq: u64, - outstanding_messages: HashMap>, -} - -struct RPC { +struct Handle { id: u64, - request: Request, - reply: Sender, -} - -struct RequestHandle { - id: u64, - request: Request, -} - -struct ReplyHandle { - id: u64, - reply: Sender, -} - -struct ReplyPacket { - id: u64, - message: Reply, -} - -fn message_reader( - mut stream: TcpStream, - replies: Sender>) -> Result<()> - where Reply: serde::de::Deserialize -{ - loop { - let id = try!(serde_json::from_reader(&mut stream)); - let reply_message = try!(serde_json::from_reader(&mut stream)); - let packet = ReplyPacket{ - id: id, - message: reply_message, - }; - try!(replies.send(ReceiverMessage::Packet(packet))); - } + sender: Sender, } enum ReceiverMessage { - Handle(ReplyHandle), - Packet(ReplyPacket), + Handle(Handle), + Packet(Packet), } -fn receiver(messages: Receiver>) -> Result<()> -{ - let mut ready_handles: HashMap> = HashMap::new(); - let mut ready_packets: HashMap> = HashMap::new(); +fn receiver(messages: Receiver>) -> Result<()> { + let mut ready_handles: HashMap> = HashMap::new(); for message in messages.into_iter() { match message { ReceiverMessage::Handle(handle) => { - if let Some(packet) = ready_packets.remove(&handle.id) { - try!(handle.reply.send(packet.message)); - } else { - ready_handles.insert(handle.id, handle); - } + ready_handles.insert(handle.id, handle); }, ReceiverMessage::Packet(packet) => { - if let Some(handle) = ready_handles.remove(&packet.id) { - try!(handle.reply.send(packet.message)); - } else { - ready_packets.insert(packet.id, packet); - } + let handle = ready_handles.remove(&packet.id).unwrap(); + try!(handle.sender.send(packet.message)); } - } } Ok(()) } -fn message_writer( - mut stream: TcpStream, - requests: Receiver>) -> Result<()> - where Request: serde::ser::Serialize -{ - for request_handle in requests.into_iter() { - try!(serde_json::to_writer(&mut stream, &request_handle.id)); - try!(serde_json::to_writer(&mut stream, &request_handle.request)); - } - Ok(()) -} - -struct Client { +pub struct Client { next_id: Mutex, - writer_tx: Sender>, - handles_tx: Sender>, + writer: multi_tcp::MultiStream, serde_json::Error>, + handles_tx: SyncSender>, } impl Client where Request: serde::ser::Serialize + Clone + Send + 'static, Reply: serde::de::Deserialize + Send + 'static { - fn new(stream: TcpStream) -> Result { - let write_stream = try!(stream.try_clone()); - let (requests_tx, requests_rx) = channel(); - let (handles_tx, receiver_rx) = channel(); - let replies_tx = handles_tx.clone(); - thread::spawn(move || message_writer(write_stream, requests_rx).unwrap()); - thread::spawn(move || message_reader(stream, replies_tx).unwrap()); - thread::spawn(move || receiver(receiver_rx).unwrap()); + pub fn new(stream: TcpStream) -> Result { + let (handles_tx, receiver_rx) = sync_channel(0); + let writer = multi_tcp::MultiStream::with_sync_sender( + stream, + |stream, packet: &Packet| { + try!(serde_json::to_writer(stream, &packet.id)); + try!(serde_json::to_writer(stream, &packet.message)); + Ok(()) + }, + |stream| { + let id = try!(serde_json::from_reader(stream)); + let reply = try!(serde_json::from_reader(stream)); + Ok(ReceiverMessage::Packet(Packet{ + id: id, + message: reply, + })) + }, + handles_tx.clone()); + thread::spawn(move || receiver(receiver_rx)); Ok(Client{ next_id: Mutex::new(0), - writer_tx: requests_tx, + writer: writer, handles_tx: handles_tx, }) } @@ -212,31 +166,67 @@ impl Client *id } - fn rpc(&self, request: &Request) -> Result { + pub fn rpc(&self, request: &Request) -> Result { let (tx, rx) = channel(); let id = self.get_next_id(); - try!(self.writer_tx.send(RequestHandle{ + println!("indicate that we're weaiting"); + try!(self.handles_tx.send(ReceiverMessage::Handle(Handle{ id: id, - request: request.clone(), - })); - try!(self.handles_tx.send(ReceiverMessage::Handle(ReplyHandle{ - id: id, - reply: tx, + sender: tx, }))); + println!("write the request to the wire"); + try!(self.writer.write(Packet{ + id: id, + message: request.clone(), + })); + println!("wait for the response"); Ok(rx.recv().unwrap()) } } -/* #[cfg(test)] mod test { - use adamrpc::*; + use super::*; + use std::thread; + use std::net::{TcpStream, TcpListener}; + use std::io; + + fn pair() -> (TcpStream, TcpListener) { + let addr = "127.0.0.1:9000"; + // Do this one first so that we don't get connection refused :) + let listener = TcpListener::bind(addr).unwrap(); + (TcpStream::connect(addr).unwrap(), listener) + } + + #[derive(Debug, PartialEq, Serialize, Deserialize, Clone)] + enum Request { + Increment + } + + #[derive(Debug, PartialEq, Serialize, Deserialize)] + enum Reply { + Increment + } + + #[derive(Clone)] + struct Server; + + impl Serve for Server { + fn serve(&self, _: &Request) -> io::Result { + Ok(Reply::Increment) + } + } #[test] fn test() { - let listener = TcpListener::bind("127.0.0.1:9000").expect("listener"); - let server = - let stream = TcpStream::connect + let (client_stream, server_streams) = pair(); + println!("starting server!"); + thread::spawn(|| { + serve(server_streams, Server) + }); + println!("making client"); + let client: Client = Client::new(client_stream).unwrap(); + println!("hi there"); + assert_eq!(Reply::Increment, client.rpc(&Request::Increment).unwrap()); } } -*/ diff --git a/src/multi_tcp/Cargo.toml b/src/multi_tcp/Cargo.toml new file mode 100644 index 0000000..2fac3ad --- /dev/null +++ b/src/multi_tcp/Cargo.toml @@ -0,0 +1,6 @@ +[package] +name = "multi_tcp" +version = "0.1.0" +authors = ["Adam Wright "] + +[dependencies] diff --git a/src/multi_tcp/src/lib.rs b/src/multi_tcp/src/lib.rs new file mode 100644 index 0000000..4bccdde --- /dev/null +++ b/src/multi_tcp/src/lib.rs @@ -0,0 +1,135 @@ +use std::fmt; +use std::net::TcpStream; +use std::thread; +use std::sync::mpsc::{ + channel, + sync_channel, + Sender, + SyncSender, + Receiver, +}; + +fn read(mut stream: TcpStream, decode: F, tx: SyncSender) + where F: Send + 'static + Fn(&mut TcpStream) -> Result, + T: Send + 'static, + E: fmt::Debug + Send + 'static +{ + loop { + let t = decode(&mut stream).unwrap(); + if let Err(_) = tx.send(t) { + break; + } + } +} + +struct SendHelper { + value: T, + result: Sender>, +} + +fn write(mut stream: TcpStream, encode: F) -> Sender> + where F: Send + 'static + Fn(&mut TcpStream, &T) -> Result<(), E>, + T: Send + 'static, + E: Send + 'static +{ + let (tx, rx) = channel(); + thread::spawn(move || { + loop { + let helper: SendHelper = match rx.recv() { + Ok(h) => h, + Err(_) => { + break; + } + }; + helper.result.send(encode(&mut stream, &helper.value)).unwrap(); + } + }); + tx +} + +pub struct MultiStream { + tx: Sender>, +} + +impl MultiStream + where Request: Send + 'static, + E: fmt::Debug + Send + 'static +{ + pub fn new( + stream: TcpStream, + encode: F, + decode: G) -> (Self, Receiver) + where Reply: Send + 'static, + F: Send + 'static + Fn(&mut TcpStream, &Request) -> Result<(), E>, + G: Send + 'static + Fn(&mut TcpStream) -> Result + { + let read_stream = stream.try_clone().unwrap(); + let ms = MultiStream{tx: write(stream, encode)}; + let (reply_tx, reply_rx) = sync_channel(0); + thread::spawn(move || read(read_stream, decode, reply_tx)); + (ms, reply_rx) + } + + pub fn with_sync_sender( + stream: TcpStream, + encode: F, + decode: G, + reply_tx: SyncSender) -> Self + where Reply: Send + 'static, + F: Send + 'static + Fn(&mut TcpStream, &Request) -> Result<(), E>, + G: Send + 'static + Fn(&mut TcpStream) -> Result + { + let read_stream = stream.try_clone().unwrap(); + thread::spawn(move || read(read_stream, decode, reply_tx)); + MultiStream{tx: write(stream, encode)} + } + + + pub fn write(&self, value: Request) -> Result<(), E> { + let my_tx = self.tx.clone(); + let (reply_tx, reply_rx) = channel(); + let helper = SendHelper{ + value: value, + result: reply_tx, + }; + my_tx.send(helper).unwrap(); + reply_rx.recv().unwrap() + } +} + +#[cfg(test)] +mod test { + use super::MultiStream; + use std::net::{TcpStream, TcpListener}; + use std::sync::mpsc::Receiver; + use std::io::{Write, Read}; + + fn pair() -> (TcpStream, Receiver) { + let addr = "127.0.0.1:9000"; + let recv_stream = listen(TcpListener::bind(addr).unwrap()); + (TcpStream::connect(addr).unwrap(), recv_stream) + } + + fn write_byte(stream: &mut TcpStream, v: u8) -> Result<(), ()> { + stream.write(&[v]).unwrap(); + Ok(()) + } + + fn read_byte(stream: &mut TcpStream) -> Result { + let mut buf = [0u8]; + stream.read_exact(&mut buf[..]).unwrap(); + Ok(buf[0]) + } + + #[test] + fn test_thing() { + let (stream, listener) = pair(); + let (ms, reader) : (MultiStream, Receiver) = + MultiStream::new(stream, |s, v| write_byte(s, *v), |s| read_byte(s)); + ms.write(5).expect("writing 5"); + let mut srv_stream = listener.accept().unwrap().0; + assert_eq!(5, read_byte(&mut srv_stream).expect("read 5")); + write_byte(&mut srv_stream, 10).expect("write 10"); + assert_eq!(10, reader.recv().expect("reading 10")); + } +}