From 3a3e2d1e4dcdae90cabaf32bbfbf760b813433df Mon Sep 17 00:00:00 2001 From: Adam Wright Date: Fri, 8 Jan 2016 19:56:02 -0800 Subject: [PATCH] Really have a non-clone thing? tests are a mess though --- src/lib.rs | 83 +++++++++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 70 insertions(+), 13 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index b81ed00..0571d95 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,4 @@ +#![feature(const_fn)] #![feature(custom_derive, plugin)] #![plugin(serde_macros)] @@ -157,7 +158,6 @@ fn reader(stream: TcpStream, tx: SyncSender>) }, // TODO: This shutdown logic is janky.. What's the right way to do this? Err(SyntaxError(EOFWhileParsingValue, _, _)) => break, - Err(SyntaxError(ExpectedValue, _, _)) => break, Err(err) => panic!("unexpected error while parsing!: {:?}", err), } } @@ -169,14 +169,14 @@ fn increment(cur_id: &mut u64) -> u64 { id } -struct SyncedClientState{ +struct SyncedClientState { next_id: u64, stream: TcpStream, + handles_tx: SyncSender>, } pub struct Client { - synced_state: Mutex, - handles_tx: SyncSender>, + synced_state: Mutex>, reader_guard: thread::JoinHandle<()>, } @@ -194,9 +194,9 @@ impl Client synced_state: Mutex::new(SyncedClientState{ next_id: 0, stream: stream, + handles_tx: handles_tx, }), reader_guard: guard, - handles_tx: handles_tx, }) } @@ -206,7 +206,7 @@ impl Client let (tx, rx) = channel(); let mut state = self.synced_state.lock().unwrap(); let id = increment(&mut state.next_id); - try!(self.handles_tx.send(ReceiverMessage::Handle(Handle{ + try!(state.handles_tx.send(ReceiverMessage::Handle(Handle{ id: id, sender: tx, }))); @@ -229,15 +229,20 @@ impl Client mod test { use super::*; use std::io; - use std::net::{TcpStream, TcpListener}; - use std::sync::{Arc, Mutex}; + use std::net::{TcpStream, TcpListener, SocketAddr}; + use std::str::FromStr; + use std::sync::{Arc, Mutex, Barrier}; + use std::sync::atomic::{AtomicUsize, Ordering}; use std::thread; + const port: AtomicUsize = AtomicUsize::new(10000); + fn pair() -> (TcpStream, TcpListener) { - let addr = "127.0.0.1:9000"; + let addr = format!("127.0.0.1:{}", port.fetch_add(1, Ordering::SeqCst)); + println!("what the fuck {}", &addr); // Do this one first so that we don't get connection refused :) - let listener = TcpListener::bind(addr).unwrap(); - (TcpStream::connect(addr).unwrap(), listener) + let listener = TcpListener::bind(&*addr).unwrap(); + (TcpStream::connect(&*addr).unwrap(), listener) } #[derive(Debug, PartialEq, Serialize, Deserialize, Clone)] @@ -264,6 +269,10 @@ mod test { } impl Server { + fn new() -> Server { + Server{counter: Mutex::new(0)} + } + fn count(&self) -> u64 { *self.counter.lock().unwrap() } @@ -272,14 +281,62 @@ mod test { #[test] fn test() { let (client_stream, server_streams) = pair(); - let server = Arc::new(Server{counter: Mutex::new(0)}); + let server = Arc::new(Server::new()); let thread_server = server.clone(); - thread::spawn(move || serve(server_streams, thread_server)); + let guard = thread::spawn(move || serve(server_streams, thread_server)); let client = Client::new(client_stream).unwrap(); assert_eq!(Reply::Increment(0), client.rpc(&Request::Increment).unwrap()); assert_eq!(1, server.count()); assert_eq!(Reply::Increment(1), client.rpc(&Request::Increment).unwrap()); assert_eq!(2, server.count()); client.join::().unwrap(); + guard.join(); + } + + struct BarrierServer { + barrier: Barrier, + inner: Server, + } + + impl Serve for BarrierServer { + fn serve(&self, request: &Request) -> io::Result { + self.barrier.wait(); + let reply = try!(self.inner.serve(request)); + Ok(reply) + } + } + + impl BarrierServer { + fn new(n: usize) -> BarrierServer { + BarrierServer{barrier: Barrier::new(n), inner: Server::new()} + } + + fn count(&self) -> u64 { + self.inner.count() + } + } + + #[test] + fn test_concurrent() { + let (client_stream, server_streams) = pair(); + let server = Arc::new(BarrierServer::new(10)); + let thread_server = server.clone(); + let guard = thread::spawn(move || serve(server_streams, thread_server)); + let client: Arc> = Arc::new(Client::new(client_stream).unwrap()); + let mut join_handles = vec![]; + for _ in 0..10 { + let my_client = client.clone(); + join_handles.push(thread::spawn(move || my_client.rpc(&Request::Increment).unwrap())); + } + for handle in join_handles.into_iter() { + handle.join(); + } + assert_eq!(10, server.count()); + let client = match Arc::try_unwrap(client) { + Err(_) => panic!("couldn't unwrap arc"), + Ok(c) => c, + }; + client.join::().unwrap(); + guard.join(); } }