From a67145724305c1c32d26073276a82eb7e5062576 Mon Sep 17 00:00:00 2001 From: Tim Kuehn Date: Wed, 29 Jul 2020 22:51:04 -0700 Subject: [PATCH] Add topics to PubSub example --- tarpc/examples/pubsub.rs | 167 +++++++++++++++++++++++++++++---------- 1 file changed, 125 insertions(+), 42 deletions(-) diff --git a/tarpc/examples/pubsub.rs b/tarpc/examples/pubsub.rs index a4d487d..2bdcb4e 100644 --- a/tarpc/examples/pubsub.rs +++ b/tarpc/examples/pubsub.rs @@ -13,7 +13,7 @@ /// Subscriber service. /// /// - Publishers connect to the server on the "publisher"port, and once connected, they send -/// messages to the server to be broadcast via a Publisher service. +/// messages to the server to be publish via a Publisher service. /// /// Subscriber Publisher PubSub Server /// T1 | | | @@ -27,8 +27,8 @@ /// T9 |-----(Receive OK)------------------------------------------------->| /// T10 | | | /// T11 | |<--------------(Publish OK)------| - use futures::{ + channel::oneshot, future::{self, AbortHandle}, prelude::*, }; @@ -38,7 +38,8 @@ use std::{ collections::HashMap, io, net::SocketAddr, - sync::{Arc, Mutex}, + sync::{Arc, Mutex, RwLock}, + time::Duration, }; use subscriber::Subscriber as _; use tarpc::{ @@ -52,26 +53,36 @@ use tokio_serde::formats::Json; pub mod subscriber { #[tarpc::service] pub trait Subscriber { - async fn receive(message: String); + async fn topics() -> Vec; + async fn receive(topic: String, message: String); } } pub mod publisher { + #[tarpc::service] pub trait Publisher { - async fn broadcast(message: String); + async fn publish(topic: String, message: String); } } #[derive(Clone, Debug)] struct Subscriber { local_addr: SocketAddr, + topics: Vec, } #[tarpc::server] impl subscriber::Subscriber for Subscriber { - async fn receive(self, _: context::Context, message: String) { - info!("[{}] received message: {}", self.local_addr, message); + async fn topics(self, _: context::Context) -> Vec { + self.topics.clone() + } + + async fn receive(self, _: context::Context, topic: String, message: String) { + info!( + "[{}] received message on topic '{}': {}", + self.local_addr, topic, message + ); } } @@ -84,12 +95,15 @@ impl Drop for SubscriberHandle { } impl Subscriber { - async fn connect(publisher_addr: impl ToSocketAddrs) -> io::Result { + async fn connect( + publisher_addr: impl ToSocketAddrs, + topics: Vec, + ) -> io::Result { let publisher = tcp::connect(publisher_addr, Json::default()).await?; let local_addr = publisher.local_addr()?; let (handler, abort_handle) = future::abortable( server::BaseChannel::with_defaults(publisher) - .respond_with(Subscriber { local_addr }.serve()) + .respond_with(Subscriber { local_addr, topics }.serve()) .execute(), ); tokio::spawn(async move { @@ -101,9 +115,16 @@ impl Subscriber { } } +#[derive(Debug)] +struct Subscription { + subscriber: subscriber::SubscriberClient, + topics: Vec, +} + #[derive(Clone, Debug)] struct Publisher { - clients: Arc>>, + clients: Arc>>, + subscriptions: Arc>>>, } struct PublisherAddrs { @@ -117,7 +138,7 @@ impl Publisher { let publisher_addrs = PublisherAddrs { publisher: connecting_publishers.local_addr(), - subscriptions: Self::start_subscription_manager(self.clients.clone()).await?, + subscriptions: self.clone().start_subscription_manager().await?, }; info!("[{}] listening for publishers.", publisher_addrs.publisher); @@ -137,9 +158,7 @@ impl Publisher { Ok(publisher_addrs) } - async fn start_subscription_manager( - clients: Arc>>, - ) -> io::Result { + async fn start_subscription_manager(self) -> io::Result { let mut connecting_subscribers = tcp::listen("localhost:0", Json::default) .await? .filter_map(|r| future::ready(r.ok())); @@ -152,12 +171,19 @@ impl Publisher { info!("[{}] subscriber connected.", subscriber_addr); let tarpc::client::NewClient { - client: subscriber, + client: mut subscriber, dispatch, } = subscriber::SubscriberClient::new(client::Config::default(), conn); - clients.lock().unwrap().insert(subscriber_addr, subscriber); + self.clients.lock().unwrap().insert( + subscriber_addr, + Subscription { + subscriber: subscriber.clone(), + topics: Vec::new(), + }, + ); - let dropped_clients = clients.clone(); + let (ready_tx, ready) = oneshot::channel(); + let me = self.clone(); tokio::spawn(async move { match dispatch.await { Ok(()) => info!("[{:?}] subscriber connection closed", subscriber_addr), @@ -166,8 +192,33 @@ impl Publisher { subscriber_addr, e ), } - dropped_clients.lock().unwrap().remove(&subscriber_addr); + // Don't clean up the subscriber until initialization is done. + let _ = ready.await; + if let Some(subscription) = me.clients.lock().unwrap().remove(&subscriber_addr) + { + let mut subscriptions = me.subscriptions.write().unwrap(); + for topic in subscription.topics { + let subscribers = subscriptions.get_mut(&topic).unwrap(); + subscribers.remove(&subscriber_addr); + if subscribers.is_empty() { + subscriptions.remove(&topic); + } + } + } }); + + // Populate the topics + if let Ok(topics) = subscriber.topics(context::current()).await { + let mut subscriptions = self.subscriptions.write().unwrap(); + for topic in topics { + subscriptions + .entry(topic) + .or_insert_with(HashMap::new) + .insert(subscriber_addr, subscriber.clone()); + } + } + // Signal that initialization is done. + ready_tx.send(()).unwrap(); } }); @@ -177,25 +228,51 @@ impl Publisher { #[tarpc::server] impl publisher::Publisher for Publisher { - async fn broadcast(self, _: context::Context, message: String) { - info!("received message to broadcast."); - let mut clients = self.clients.lock().unwrap().clone(); - let mut publications = Vec::new(); - for client in clients.values_mut() { - publications.push(client.receive(context::current(), message.clone())); - } - // Ignore failing subscribers. In a real pubsub, you'd want to continually retry until - // subscribers ack. Of course, a lot would be different in a real pubsub :) - future::join_all(publications).await; + async fn publish(self, _: context::Context, topic: String, message: String) { + info!("received message to publish."); + let mut subscribers = match self.subscriptions.read().unwrap().get(&topic) { + None => return, + Some(subscriptions) => subscriptions.clone(), + }; + tokio::spawn(async move { + let mut publications = Vec::new(); + for client in subscribers.values_mut() { + publications.push(client.receive( + context::current(), + topic.clone(), + message.clone(), + )); + } + // Ignore failing subscribers. In a real pubsub, you'd want to continually retry until + // subscribers ack. Of course, a lot would be different in a real pubsub :) + future::join_all(publications).await; + }); } } #[tokio::main] -async fn main() -> io::Result<()> { +async fn main() -> anyhow::Result<()> { env_logger::init(); let clients = Arc::new(Mutex::new(HashMap::new())); - let addrs = Publisher { clients }.start().await?; + let addrs = Publisher { + clients, + subscriptions: Arc::new(RwLock::new(HashMap::new())), + } + .start() + .await?; + + let _subscriber0 = Subscriber::connect( + addrs.subscriptions, + vec!["calculus".into(), "cool shorts".into()], + ) + .await?; + + let _subscriber1 = Subscriber::connect( + addrs.subscriptions, + vec!["cool shorts".into(), "history".into()], + ) + .await?; let mut publisher = publisher::PublisherClient::new( client::Config::default(), @@ -203,22 +280,28 @@ async fn main() -> io::Result<()> { ) .spawn()?; - let _subscriber0 = Subscriber::connect(addrs.subscriptions).await?; - publisher - .broadcast(context::current(), "hello to one".to_string()) - .await?; - - let _subscriber1 = Subscriber::connect(addrs.subscriptions).await?; - publisher - .broadcast(context::current(), "hello to all".to_string()) - .await?; - - drop(_subscriber0); + // Wait a moment for subscribers to get set up. + tokio::time::delay_for(Duration::from_millis(25)).await; publisher - .broadcast(context::current(), "hello to who?".to_string()) + .publish(context::current(), "calculus".into(), "sqrt(2)".into()) .await?; + publisher + .publish( + context::current(), + "cool shorts".into(), + "hello to all".into(), + ) + .await?; + + publisher + .publish(context::current(), "history".into(), "napoleon".to_string()) + .await?; + + // Wait a moment for the last publication. + tokio::time::delay_for(Duration::from_millis(25)).await; + info!("done."); Ok(())