Modify serve_async to expose the bound address

serve_async was taking a SocketAddr, and then binding to it. This is fine if
your'e always sure of the exact address you want to bind to, but in some cases
you don't know.

One such case is when you want the OS to assign you and ephemeral port number,
like we do in our tests. In this case, you pass 0.0.0.0:0 as the address, and
then call bind. After that, you don't know which address the listener bound to,
so we can't make the subsequent call to TcpStream::connect without getting a
weird error.

This is fixed by the ServeHandle object exposing a local_addr() method, which
returns the address that the listener bound to.
This commit is contained in:
Adam Wright
2016-01-10 02:29:06 -08:00
parent fae09e3fed
commit 56bd362fb1
2 changed files with 29 additions and 49 deletions

View File

@@ -81,34 +81,40 @@ fn handle_conn<F, Request, Reply>(stream: TcpStream, f: F) -> Result<()>
}
pub struct Shutdown {
pub struct ServeHandle {
tx: Sender<()>,
join_handle: JoinHandle<()>,
addr: SocketAddr,
}
impl Shutdown {
impl ServeHandle {
pub fn wait(self) {
self.join_handle.join().unwrap();
}
pub fn local_addr(&self) -> &SocketAddr {
&self.addr
}
pub fn shutdown(self) {
self.tx.send(()).expect(&line!().to_string());
TcpStream::connect(&self.addr).unwrap();
self.join_handle.join().expect(&line!().to_string());
if let Ok(_) = TcpStream::connect(self.addr) {
self.join_handle.join().expect(&line!().to_string());
} else {
warn!("Best effort shutdown of serve thread failed");
}
}
}
pub fn serve_async<A, F, Request, Reply>(addr: A, f: F) -> io::Result<Shutdown>
pub fn serve_async<A, F, Request, Reply>(addr: A, f: F) -> io::Result<ServeHandle>
where A: ToSocketAddrs,
Request: 'static + fmt::Debug + Send + serde::de::Deserialize + fmt::Debug + serde::ser::Serialize,
Request: 'static + fmt::Debug + Send + serde::de::Deserialize + serde::ser::Serialize,
Reply: 'static + fmt::Debug + serde::ser::Serialize,
F: 'static + Clone + Serve<Request, Reply>,
F: 'static + Clone + Send + Serve<Request, Reply>,
{
let addr = addr.to_socket_addrs().unwrap().next().unwrap();
let listener = try!(TcpListener::bind(&addr));
let addr = try!(listener.local_addr());
info!("Spinning up server on {:?}", addr);
let listener = try!(TcpListener::bind(addr.clone()));
let (die_tx, die_rx) = channel();
let join_handle = thread::spawn(move || {
for conn in listener.incoming() {
@@ -135,7 +141,7 @@ pub fn serve_async<A, F, Request, Reply>(addr: A, f: F) -> io::Result<Shutdown>
});
}
});
Ok(Shutdown {
Ok(ServeHandle {
tx: die_tx,
join_handle: join_handle,
addr: addr.clone(),
@@ -260,21 +266,11 @@ impl<Request, Reply> Drop for Client<Request, Reply>
#[cfg(test)]
mod test {
use serde;
use super::*;
use std::fmt;
use std::net::{TcpStream, SocketAddr, ToSocketAddrs};
use std::net::TcpStream;
use std::sync::{Arc, Mutex, Barrier};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::thread;
const PORT: AtomicUsize = AtomicUsize::new(10000);
fn next_addr() -> SocketAddr {
let addr = format!("127.0.0.1:{}", PORT.fetch_add(1, Ordering::SeqCst));
addr.to_socket_addrs().unwrap().next().unwrap()
}
#[derive(Debug, PartialEq, Serialize, Deserialize, Clone)]
enum Request {
Increment,
@@ -308,38 +304,22 @@ mod test {
}
}
fn serve_on_any_addr<F, Request, Reply>(server: F) -> (SocketAddr, Shutdown)
where Request: 'static + fmt::Debug + Send + serde::de::Deserialize + fmt::Debug + serde::ser::Serialize,
Reply: 'static + fmt::Debug + Send + serde::ser::Serialize,
F: 'static + Clone + Serve<Request, Reply>
{
let mut addr;
let mut shutdown;
while let &Err(_) = {
addr = next_addr();
shutdown = serve_async(&addr, server.clone());
&shutdown
} {
}
(addr, shutdown.unwrap())
}
#[test]
fn test_handle() {
let server = Arc::new(Server::new());
let (addr, shutdown) = serve_on_any_addr(server.clone());
let client_stream = TcpStream::connect(&addr).unwrap();
let serve_handle = serve_async("0.0.0.0:0", server.clone()).unwrap();
let client_stream = TcpStream::connect(serve_handle.local_addr()).unwrap();
let client: Client<Request, Reply> = Client::new(client_stream)
.expect(&line!().to_string());
drop(client);
shutdown.shutdown();
serve_handle.shutdown();
}
#[test]
fn test() {
let server = Arc::new(Server::new());
let (addr, shutdown) = serve_on_any_addr(server.clone());
let client_stream = TcpStream::connect(&addr).unwrap();
let serve_handle = serve_async("0.0.0.0:0", server.clone()).unwrap();
let client_stream = TcpStream::connect(serve_handle.local_addr()).unwrap();
let client = Client::new(client_stream).unwrap();
assert_eq!(Reply::Increment(0),
client.rpc(&Request::Increment).unwrap());
@@ -348,7 +328,7 @@ mod test {
client.rpc(&Request::Increment).unwrap());
assert_eq!(2, server.count());
drop(client);
shutdown.shutdown();
serve_handle.shutdown();
}
struct BarrierServer {
@@ -379,8 +359,8 @@ mod test {
#[test]
fn test_concurrent() {
let server = Arc::new(BarrierServer::new(10));
let (addr, shutdown) = serve_on_any_addr(server.clone());
let client_stream = TcpStream::connect(&addr).unwrap();
let serve_handle = serve_async("0.0.0.0:0", server.clone()).unwrap();
let client_stream = TcpStream::connect(serve_handle.local_addr()).unwrap();
let client: Arc<Client<Request, Reply>> = Arc::new(Client::new(client_stream).unwrap());
let mut join_handles = vec![];
for _ in 0..10 {
@@ -396,6 +376,6 @@ mod test {
Ok(c) => c,
};
drop(client);
shutdown.shutdown();
serve_handle.shutdown();
}
}

View File

@@ -54,7 +54,7 @@ macro_rules! rpc_service { ($server:ident:
use std::sync::Arc;
use tarpc::{
self,
Shutdown,
ServeHandle,
serve_async,
};
@@ -151,7 +151,7 @@ macro_rules! rpc_service { ($server:ident:
}
#[doc="Start a running service."]
pub fn serve<A, S>(addr: A, service: S) -> Result<Shutdown>
pub fn serve<A, S>(addr: A, service: S) -> Result<ServeHandle>
where A: ToSocketAddrs,
S: 'static + Service
{