diff --git a/tarpc/examples/pubsub.rs b/tarpc/examples/pubsub.rs index 2bdcb4e..af65265 100644 --- a/tarpc/examples/pubsub.rs +++ b/tarpc/examples/pubsub.rs @@ -27,6 +27,7 @@ /// T9 |-----(Receive OK)------------------------------------------------->| /// T10 | | | /// T11 | |<--------------(Publish OK)------| +use anyhow::anyhow; use futures::{ channel::oneshot, future::{self, AbortHandle}, @@ -39,7 +40,6 @@ use std::{ io, net::SocketAddr, sync::{Arc, Mutex, RwLock}, - time::Duration, }; use subscriber::Subscriber as _; use tarpc::{ @@ -59,7 +59,6 @@ pub mod subscriber { } pub mod publisher { - #[tarpc::service] pub trait Publisher { async fn publish(topic: String, message: String); @@ -98,14 +97,22 @@ impl Subscriber { async fn connect( publisher_addr: impl ToSocketAddrs, topics: Vec, - ) -> io::Result { + ) -> anyhow::Result { let publisher = tcp::connect(publisher_addr, Json::default()).await?; let local_addr = publisher.local_addr()?; - let (handler, abort_handle) = future::abortable( - server::BaseChannel::with_defaults(publisher) - .respond_with(Subscriber { local_addr, topics }.serve()) - .execute(), - ); + let mut handler = server::BaseChannel::with_defaults(publisher) + .respond_with(Subscriber { local_addr, topics }.serve()); + // The first request is for the topics being subscriibed to. + match handler.next().await { + Some(init_topics) => init_topics?.await, + None => { + return Err(anyhow!( + "[{}] Server never initialized the subscriber.", + local_addr + )) + } + }; + let (handler, abort_handle) = future::abortable(handler.execute()); tokio::spawn(async move { match handler.await { Ok(()) | Err(future::Aborted) => info!("[{}] subscriber shutdown.", local_addr), @@ -168,7 +175,6 @@ impl Publisher { tokio::spawn(async move { while let Some(conn) = connecting_subscribers.next().await { let subscriber_addr = conn.peer_addr().unwrap(); - info!("[{}] subscriber connected.", subscriber_addr); let tarpc::client::NewClient { client: mut subscriber, @@ -183,32 +189,12 @@ impl Publisher { ); let (ready_tx, ready) = oneshot::channel(); - let me = self.clone(); - tokio::spawn(async move { - match dispatch.await { - Ok(()) => info!("[{:?}] subscriber connection closed", subscriber_addr), - Err(e) => info!( - "[{:?}] subscriber connection broken: {:?}", - subscriber_addr, e - ), - } - // Don't clean up the subscriber until initialization is done. - let _ = ready.await; - if let Some(subscription) = me.clients.lock().unwrap().remove(&subscriber_addr) - { - let mut subscriptions = me.subscriptions.write().unwrap(); - for topic in subscription.topics { - let subscribers = subscriptions.get_mut(&topic).unwrap(); - subscribers.remove(&subscriber_addr); - if subscribers.is_empty() { - subscriptions.remove(&topic); - } - } - } - }); + self.clone() + .start_subscriber_gc(subscriber_addr, dispatch, ready); // Populate the topics if let Ok(topics) = subscriber.topics(context::current()).await { + info!("[{}] subscribed to topics: {:?}", subscriber_addr, topics); let mut subscriptions = self.subscriptions.write().unwrap(); for topic in topics { subscriptions @@ -224,6 +210,35 @@ impl Publisher { Ok(new_subscriber_addr) } + + fn start_subscriber_gc( + self, + subscriber_addr: SocketAddr, + client_dispatch: impl Future> + Send + 'static, + subscriber_ready: oneshot::Receiver<()>, + ) { + tokio::spawn(async move { + match client_dispatch.await { + Ok(()) => info!("[{:?}] subscriber connection closed", subscriber_addr), + Err(e) => info!( + "[{:?}] subscriber connection broken: {:?}", + subscriber_addr, e + ), + } + // Don't clean up the subscriber until initialization is done. + let _ = subscriber_ready.await; + if let Some(subscription) = self.clients.lock().unwrap().remove(&subscriber_addr) { + let mut subscriptions = self.subscriptions.write().unwrap(); + for topic in subscription.topics { + let subscribers = subscriptions.get_mut(&topic).unwrap(); + subscribers.remove(&subscriber_addr); + if subscribers.is_empty() { + subscriptions.remove(&topic); + } + } + } + }); + } } #[tarpc::server] @@ -234,19 +249,13 @@ impl publisher::Publisher for Publisher { None => return, Some(subscriptions) => subscriptions.clone(), }; - tokio::spawn(async move { - let mut publications = Vec::new(); - for client in subscribers.values_mut() { - publications.push(client.receive( - context::current(), - topic.clone(), - message.clone(), - )); - } - // Ignore failing subscribers. In a real pubsub, you'd want to continually retry until - // subscribers ack. Of course, a lot would be different in a real pubsub :) - future::join_all(publications).await; - }); + let mut publications = Vec::new(); + for client in subscribers.values_mut() { + publications.push(client.receive(context::current(), topic.clone(), message.clone())); + } + // Ignore failing subscribers. In a real pubsub, you'd want to continually retry until + // subscribers ack. Of course, a lot would be different in a real pubsub :) + future::join_all(publications).await; } } @@ -280,9 +289,6 @@ async fn main() -> anyhow::Result<()> { ) .spawn()?; - // Wait a moment for subscribers to get set up. - tokio::time::delay_for(Duration::from_millis(25)).await; - publisher .publish(context::current(), "calculus".into(), "sqrt(2)".into()) .await?; @@ -299,9 +305,6 @@ async fn main() -> anyhow::Result<()> { .publish(context::current(), "history".into(), "napoleon".to_string()) .await?; - // Wait a moment for the last publication. - tokio::time::delay_for(Duration::from_millis(25)).await; - info!("done."); Ok(())