mirror of
https://github.com/OMGeeky/tarpc.git
synced 2025-12-28 07:12:05 +01:00
Really have a non-clone thing? tests are a mess though
This commit is contained in:
83
src/lib.rs
83
src/lib.rs
@@ -1,3 +1,4 @@
|
||||
#![feature(const_fn)]
|
||||
#![feature(custom_derive, plugin)]
|
||||
#![plugin(serde_macros)]
|
||||
|
||||
@@ -157,7 +158,6 @@ fn reader<Reply>(stream: TcpStream, tx: SyncSender<ReceiverMessage<Reply>>)
|
||||
},
|
||||
// TODO: This shutdown logic is janky.. What's the right way to do this?
|
||||
Err(SyntaxError(EOFWhileParsingValue, _, _)) => break,
|
||||
Err(SyntaxError(ExpectedValue, _, _)) => break,
|
||||
Err(err) => panic!("unexpected error while parsing!: {:?}", err),
|
||||
}
|
||||
}
|
||||
@@ -169,14 +169,14 @@ fn increment(cur_id: &mut u64) -> u64 {
|
||||
id
|
||||
}
|
||||
|
||||
struct SyncedClientState{
|
||||
struct SyncedClientState<Reply> {
|
||||
next_id: u64,
|
||||
stream: TcpStream,
|
||||
handles_tx: SyncSender<ReceiverMessage<Reply>>,
|
||||
}
|
||||
|
||||
pub struct Client<Reply> {
|
||||
synced_state: Mutex<SyncedClientState>,
|
||||
handles_tx: SyncSender<ReceiverMessage<Reply>>,
|
||||
synced_state: Mutex<SyncedClientState<Reply>>,
|
||||
reader_guard: thread::JoinHandle<()>,
|
||||
}
|
||||
|
||||
@@ -194,9 +194,9 @@ impl<Reply> Client<Reply>
|
||||
synced_state: Mutex::new(SyncedClientState{
|
||||
next_id: 0,
|
||||
stream: stream,
|
||||
handles_tx: handles_tx,
|
||||
}),
|
||||
reader_guard: guard,
|
||||
handles_tx: handles_tx,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -206,7 +206,7 @@ impl<Reply> Client<Reply>
|
||||
let (tx, rx) = channel();
|
||||
let mut state = self.synced_state.lock().unwrap();
|
||||
let id = increment(&mut state.next_id);
|
||||
try!(self.handles_tx.send(ReceiverMessage::Handle(Handle{
|
||||
try!(state.handles_tx.send(ReceiverMessage::Handle(Handle{
|
||||
id: id,
|
||||
sender: tx,
|
||||
})));
|
||||
@@ -229,15 +229,20 @@ impl<Reply> Client<Reply>
|
||||
mod test {
|
||||
use super::*;
|
||||
use std::io;
|
||||
use std::net::{TcpStream, TcpListener};
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::net::{TcpStream, TcpListener, SocketAddr};
|
||||
use std::str::FromStr;
|
||||
use std::sync::{Arc, Mutex, Barrier};
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
use std::thread;
|
||||
|
||||
const port: AtomicUsize = AtomicUsize::new(10000);
|
||||
|
||||
fn pair() -> (TcpStream, TcpListener) {
|
||||
let addr = "127.0.0.1:9000";
|
||||
let addr = format!("127.0.0.1:{}", port.fetch_add(1, Ordering::SeqCst));
|
||||
println!("what the fuck {}", &addr);
|
||||
// Do this one first so that we don't get connection refused :)
|
||||
let listener = TcpListener::bind(addr).unwrap();
|
||||
(TcpStream::connect(addr).unwrap(), listener)
|
||||
let listener = TcpListener::bind(&*addr).unwrap();
|
||||
(TcpStream::connect(&*addr).unwrap(), listener)
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Serialize, Deserialize, Clone)]
|
||||
@@ -264,6 +269,10 @@ mod test {
|
||||
}
|
||||
|
||||
impl Server {
|
||||
fn new() -> Server {
|
||||
Server{counter: Mutex::new(0)}
|
||||
}
|
||||
|
||||
fn count(&self) -> u64 {
|
||||
*self.counter.lock().unwrap()
|
||||
}
|
||||
@@ -272,14 +281,62 @@ mod test {
|
||||
#[test]
|
||||
fn test() {
|
||||
let (client_stream, server_streams) = pair();
|
||||
let server = Arc::new(Server{counter: Mutex::new(0)});
|
||||
let server = Arc::new(Server::new());
|
||||
let thread_server = server.clone();
|
||||
thread::spawn(move || serve(server_streams, thread_server));
|
||||
let guard = thread::spawn(move || serve(server_streams, thread_server));
|
||||
let client = Client::new(client_stream).unwrap();
|
||||
assert_eq!(Reply::Increment(0), client.rpc(&Request::Increment).unwrap());
|
||||
assert_eq!(1, server.count());
|
||||
assert_eq!(Reply::Increment(1), client.rpc(&Request::Increment).unwrap());
|
||||
assert_eq!(2, server.count());
|
||||
client.join::<Request>().unwrap();
|
||||
guard.join();
|
||||
}
|
||||
|
||||
struct BarrierServer {
|
||||
barrier: Barrier,
|
||||
inner: Server,
|
||||
}
|
||||
|
||||
impl Serve<Request, Reply> for BarrierServer {
|
||||
fn serve(&self, request: &Request) -> io::Result<Reply> {
|
||||
self.barrier.wait();
|
||||
let reply = try!(self.inner.serve(request));
|
||||
Ok(reply)
|
||||
}
|
||||
}
|
||||
|
||||
impl BarrierServer {
|
||||
fn new(n: usize) -> BarrierServer {
|
||||
BarrierServer{barrier: Barrier::new(n), inner: Server::new()}
|
||||
}
|
||||
|
||||
fn count(&self) -> u64 {
|
||||
self.inner.count()
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_concurrent() {
|
||||
let (client_stream, server_streams) = pair();
|
||||
let server = Arc::new(BarrierServer::new(10));
|
||||
let thread_server = server.clone();
|
||||
let guard = thread::spawn(move || serve(server_streams, thread_server));
|
||||
let client: Arc<Client<Reply>> = Arc::new(Client::new(client_stream).unwrap());
|
||||
let mut join_handles = vec![];
|
||||
for _ in 0..10 {
|
||||
let my_client = client.clone();
|
||||
join_handles.push(thread::spawn(move || my_client.rpc(&Request::Increment).unwrap()));
|
||||
}
|
||||
for handle in join_handles.into_iter() {
|
||||
handle.join();
|
||||
}
|
||||
assert_eq!(10, server.count());
|
||||
let client = match Arc::try_unwrap(client) {
|
||||
Err(_) => panic!("couldn't unwrap arc"),
|
||||
Ok(c) => c,
|
||||
};
|
||||
client.join::<Request>().unwrap();
|
||||
guard.join();
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user