Rewrite pubsub example to have the subscriber connect to the publisher.

Fixes https://github.com/google/tarpc/issues/313
This commit is contained in:
Tim Kuehn
2020-07-28 22:06:44 -07:00
parent 3ebc3b5845
commit ebd245a93d

View File

@@ -4,20 +4,25 @@
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
use futures::{future, prelude::*};
use futures::{
future::{self, AbortHandle},
prelude::*,
};
use log::info;
use publisher::Publisher as _;
use std::{
collections::HashMap,
io,
net::SocketAddr,
sync::{Arc, Mutex},
time::Duration,
};
use subscriber::Subscriber as _;
use tarpc::{
client, context,
server::{self, Handler},
serde_transport::tcp,
server::{self, Channel},
};
use tokio::net::ToSocketAddrs;
use tokio_serde::formats::Json;
pub mod subscriber {
@@ -28,90 +33,130 @@ pub mod subscriber {
}
pub mod publisher {
use std::net::SocketAddr;
#[tarpc::service]
pub trait Publisher {
async fn broadcast(message: String);
async fn subscribe(id: u32, address: SocketAddr) -> Result<(), String>;
async fn unsubscribe(id: u32);
}
}
#[derive(Clone, Debug)]
struct Subscriber {
id: u32,
local_addr: SocketAddr,
}
#[tarpc::server]
impl subscriber::Subscriber for Subscriber {
async fn receive(self, _: context::Context, message: String) {
eprintln!("{} received message: {}", self.id, message);
info!("{} received message: {}", self.local_addr, message);
}
}
struct SubscriberHandle(AbortHandle);
impl Drop for SubscriberHandle {
fn drop(&mut self) {
self.0.abort();
}
}
impl Subscriber {
async fn listen(id: u32, config: server::Config) -> io::Result<SocketAddr> {
let incoming = tarpc::serde_transport::tcp::listen("localhost:0", Json::default)
.await?
.filter_map(|r| future::ready(r.ok()));
let addr = incoming.get_ref().local_addr();
tokio::spawn(
server::new(config)
.incoming(incoming)
.take(1)
.respond_with(Subscriber { id }.serve()),
async fn connect(publisher_addr: impl ToSocketAddrs) -> io::Result<SubscriberHandle> {
let publisher = tcp::connect(publisher_addr, Json::default()).await?;
let local_addr = publisher.local_addr()?;
let (handler, abort_handle) = future::abortable(
server::BaseChannel::with_defaults(publisher)
.respond_with(Subscriber { local_addr }.serve())
.execute(),
);
Ok(addr)
tokio::spawn(handler);
Ok(SubscriberHandle(abort_handle))
}
}
#[derive(Clone, Debug)]
struct Publisher {
clients: Arc<Mutex<HashMap<u32, subscriber::SubscriberClient>>>,
clients: Arc<Mutex<HashMap<SocketAddr, subscriber::SubscriberClient>>>,
}
struct PublisherAddrs {
publisher: SocketAddr,
subscriptions: SocketAddr,
}
impl Publisher {
fn new() -> Publisher {
Publisher {
clients: Arc::new(Mutex::new(HashMap::new())),
}
async fn start(self) -> io::Result<PublisherAddrs> {
let mut connecting_publishers = tcp::listen("localhost:0", Json::default).await?;
let publisher_addrs = PublisherAddrs {
publisher: connecting_publishers.local_addr(),
subscriptions: Self::start_subscription_manager(self.clients.clone()).await?,
};
tokio::spawn(async move {
// Because this is just an example, we know there will only be one publisher. In more
// realistic code, this would be a loop to continually accept new publisher
// connections.
let publisher = connecting_publishers.next().await.unwrap().unwrap();
info!("[{}] publisher connected.", publisher.peer_addr().unwrap());
server::BaseChannel::with_defaults(publisher)
.respond_with(self.serve())
.execute()
.await
});
Ok(publisher_addrs)
}
async fn start_subscription_manager(
clients: Arc<Mutex<HashMap<SocketAddr, subscriber::SubscriberClient>>>,
) -> io::Result<SocketAddr> {
let mut connecting_subscribers = tcp::listen("localhost:0", Json::default)
.await?
.filter_map(|r| future::ready(r.ok()));
let new_subscriber_addr = connecting_subscribers.get_ref().local_addr();
tokio::spawn(async move {
while let Some(conn) = connecting_subscribers.next().await {
let subscriber_addr = conn.peer_addr().unwrap();
info!("[{}] subscriber connected.", subscriber_addr);
let tarpc::client::NewClient {
client: subscriber,
dispatch,
} = subscriber::SubscriberClient::new(client::Config::default(), conn);
clients.lock().unwrap().insert(subscriber_addr, subscriber);
let dropped_clients = clients.clone();
tokio::spawn(async move {
match dispatch.await {
Ok(()) => info!("[{:?}] subscriber connection closed", subscriber_addr),
Err(e) => info!(
"[{:?}] subscriber connection broken: {}",
subscriber_addr, e
),
}
dropped_clients.lock().unwrap().remove(&subscriber_addr);
});
}
});
Ok(new_subscriber_addr)
}
}
#[tarpc::server]
impl publisher::Publisher for Publisher {
async fn broadcast(self, _: context::Context, message: String) {
info!("received message to broadcast.");
let mut clients = self.clients.lock().unwrap().clone();
let mut publications = Vec::new();
for client in clients.values_mut() {
// Ignore failing subscribers. In a real pubsub,
// you'd want to continually retry until subscribers
// ack.
let _ = client.receive(context::current(), message.clone()).await;
}
}
async fn subscribe(self, _: context::Context, id: u32, addr: SocketAddr) -> Result<(), String> {
let conn = tarpc::serde_transport::tcp::connect(addr, Json::default())
.await
.map_err(|e| e.to_string())?;
let subscriber = subscriber::SubscriberClient::new(client::Config::default(), conn)
.spawn()
.map_err(|e| e.to_string())?;
eprintln!("Subscribing {}.", id);
self.clients.lock().unwrap().insert(id, subscriber);
Ok(())
}
async fn unsubscribe(self, _: context::Context, id: u32) {
eprintln!("Unsubscribing {}", id);
let mut clients = self.clients.lock().unwrap();
if clients.remove(&id).is_none() {
eprintln!(
"Client {} not found. Existings clients: {:?}",
id, &*clients
);
publications.push(client.receive(context::current(), message.clone()));
}
// Ignore failing subscribers. In a real pubsub, you'd want to continually retry until
// subscribers ack. Of course, a lot would be different in a real pubsub :)
future::join_all(publications).await;
}
}
@@ -119,50 +164,26 @@ impl publisher::Publisher for Publisher {
async fn main() -> io::Result<()> {
env_logger::init();
let transport = tarpc::serde_transport::tcp::listen("localhost:0", Json::default)
.await?
.filter_map(|r| future::ready(r.ok()));
let publisher_addr = transport.get_ref().local_addr();
tokio::spawn(
transport
.take(1)
.map(server::BaseChannel::with_defaults)
.respond_with(Publisher::new().serve()),
);
let clients = Arc::new(Mutex::new(HashMap::new()));
let addrs = Publisher { clients }.start().await?;
let subscriber1 = Subscriber::listen(0, server::Config::default()).await?;
let subscriber2 = Subscriber::listen(1, server::Config::default()).await?;
let mut publisher = publisher::PublisherClient::new(
client::Config::default(),
tcp::connect(addrs.publisher, Json::default()).await?,
)
.spawn()?;
let publisher_conn = tarpc::serde_transport::tcp::connect(publisher_addr, Json::default());
let publisher_conn = publisher_conn.await?;
let mut publisher =
publisher::PublisherClient::new(client::Config::default(), publisher_conn).spawn()?;
let _subscriber0 = Subscriber::connect(addrs.subscriptions).await?;
publisher
.broadcast(context::current(), "hello to one".to_string())
.await?;
if let Err(e) = publisher
.subscribe(context::current(), 0, subscriber1)
.await?
{
eprintln!("Couldn't subscribe subscriber 0: {}", e);
}
if let Err(e) = publisher
.subscribe(context::current(), 1, subscriber2)
.await?
{
eprintln!("Couldn't subscribe subscriber 1: {}", e);
}
println!("Broadcasting...");
let _subscriber1 = Subscriber::connect(addrs.subscriptions).await?;
publisher
.broadcast(context::current(), "hello to all".to_string())
.await?;
publisher.unsubscribe(context::current(), 1).await?;
publisher
.broadcast(context::current(), "hi again".to_string())
.await?;
drop(publisher);
tokio::time::delay_for(Duration::from_millis(100)).await;
println!("Done.");
info!("done.");
Ok(())
}