From ebd245a93dd02b88539df06fbc0c2c2e27944ff0 Mon Sep 17 00:00:00 2001 From: Tim Kuehn Date: Tue, 28 Jul 2020 22:06:44 -0700 Subject: [PATCH] Rewrite pubsub example to have the subscriber connect to the publisher. Fixes https://github.com/google/tarpc/issues/313 --- tarpc/examples/pubsub.rs | 199 ++++++++++++++++++++++----------------- 1 file changed, 110 insertions(+), 89 deletions(-) diff --git a/tarpc/examples/pubsub.rs b/tarpc/examples/pubsub.rs index 532b845..00b208b 100644 --- a/tarpc/examples/pubsub.rs +++ b/tarpc/examples/pubsub.rs @@ -4,20 +4,25 @@ // license that can be found in the LICENSE file or at // https://opensource.org/licenses/MIT. -use futures::{future, prelude::*}; +use futures::{ + future::{self, AbortHandle}, + prelude::*, +}; +use log::info; use publisher::Publisher as _; use std::{ collections::HashMap, io, net::SocketAddr, sync::{Arc, Mutex}, - time::Duration, }; use subscriber::Subscriber as _; use tarpc::{ client, context, - server::{self, Handler}, + serde_transport::tcp, + server::{self, Channel}, }; +use tokio::net::ToSocketAddrs; use tokio_serde::formats::Json; pub mod subscriber { @@ -28,90 +33,130 @@ pub mod subscriber { } pub mod publisher { - use std::net::SocketAddr; - #[tarpc::service] pub trait Publisher { async fn broadcast(message: String); - async fn subscribe(id: u32, address: SocketAddr) -> Result<(), String>; - async fn unsubscribe(id: u32); } } #[derive(Clone, Debug)] struct Subscriber { - id: u32, + local_addr: SocketAddr, } #[tarpc::server] impl subscriber::Subscriber for Subscriber { async fn receive(self, _: context::Context, message: String) { - eprintln!("{} received message: {}", self.id, message); + info!("{} received message: {}", self.local_addr, message); + } +} + +struct SubscriberHandle(AbortHandle); + +impl Drop for SubscriberHandle { + fn drop(&mut self) { + self.0.abort(); } } impl Subscriber { - async fn listen(id: u32, config: server::Config) -> io::Result { - let incoming = tarpc::serde_transport::tcp::listen("localhost:0", Json::default) - .await? - .filter_map(|r| future::ready(r.ok())); - let addr = incoming.get_ref().local_addr(); - tokio::spawn( - server::new(config) - .incoming(incoming) - .take(1) - .respond_with(Subscriber { id }.serve()), + async fn connect(publisher_addr: impl ToSocketAddrs) -> io::Result { + let publisher = tcp::connect(publisher_addr, Json::default()).await?; + let local_addr = publisher.local_addr()?; + let (handler, abort_handle) = future::abortable( + server::BaseChannel::with_defaults(publisher) + .respond_with(Subscriber { local_addr }.serve()) + .execute(), ); - Ok(addr) + tokio::spawn(handler); + Ok(SubscriberHandle(abort_handle)) } } #[derive(Clone, Debug)] struct Publisher { - clients: Arc>>, + clients: Arc>>, +} + +struct PublisherAddrs { + publisher: SocketAddr, + subscriptions: SocketAddr, } impl Publisher { - fn new() -> Publisher { - Publisher { - clients: Arc::new(Mutex::new(HashMap::new())), - } + async fn start(self) -> io::Result { + let mut connecting_publishers = tcp::listen("localhost:0", Json::default).await?; + + let publisher_addrs = PublisherAddrs { + publisher: connecting_publishers.local_addr(), + subscriptions: Self::start_subscription_manager(self.clients.clone()).await?, + }; + + tokio::spawn(async move { + // Because this is just an example, we know there will only be one publisher. In more + // realistic code, this would be a loop to continually accept new publisher + // connections. + let publisher = connecting_publishers.next().await.unwrap().unwrap(); + info!("[{}] publisher connected.", publisher.peer_addr().unwrap()); + + server::BaseChannel::with_defaults(publisher) + .respond_with(self.serve()) + .execute() + .await + }); + + Ok(publisher_addrs) + } + + async fn start_subscription_manager( + clients: Arc>>, + ) -> io::Result { + let mut connecting_subscribers = tcp::listen("localhost:0", Json::default) + .await? + .filter_map(|r| future::ready(r.ok())); + let new_subscriber_addr = connecting_subscribers.get_ref().local_addr(); + + tokio::spawn(async move { + while let Some(conn) = connecting_subscribers.next().await { + let subscriber_addr = conn.peer_addr().unwrap(); + info!("[{}] subscriber connected.", subscriber_addr); + + let tarpc::client::NewClient { + client: subscriber, + dispatch, + } = subscriber::SubscriberClient::new(client::Config::default(), conn); + clients.lock().unwrap().insert(subscriber_addr, subscriber); + + let dropped_clients = clients.clone(); + tokio::spawn(async move { + match dispatch.await { + Ok(()) => info!("[{:?}] subscriber connection closed", subscriber_addr), + Err(e) => info!( + "[{:?}] subscriber connection broken: {}", + subscriber_addr, e + ), + } + dropped_clients.lock().unwrap().remove(&subscriber_addr); + }); + } + }); + + Ok(new_subscriber_addr) } } #[tarpc::server] impl publisher::Publisher for Publisher { async fn broadcast(self, _: context::Context, message: String) { + info!("received message to broadcast."); let mut clients = self.clients.lock().unwrap().clone(); + let mut publications = Vec::new(); for client in clients.values_mut() { - // Ignore failing subscribers. In a real pubsub, - // you'd want to continually retry until subscribers - // ack. - let _ = client.receive(context::current(), message.clone()).await; - } - } - - async fn subscribe(self, _: context::Context, id: u32, addr: SocketAddr) -> Result<(), String> { - let conn = tarpc::serde_transport::tcp::connect(addr, Json::default()) - .await - .map_err(|e| e.to_string())?; - let subscriber = subscriber::SubscriberClient::new(client::Config::default(), conn) - .spawn() - .map_err(|e| e.to_string())?; - eprintln!("Subscribing {}.", id); - self.clients.lock().unwrap().insert(id, subscriber); - Ok(()) - } - - async fn unsubscribe(self, _: context::Context, id: u32) { - eprintln!("Unsubscribing {}", id); - let mut clients = self.clients.lock().unwrap(); - if clients.remove(&id).is_none() { - eprintln!( - "Client {} not found. Existings clients: {:?}", - id, &*clients - ); + publications.push(client.receive(context::current(), message.clone())); } + // Ignore failing subscribers. In a real pubsub, you'd want to continually retry until + // subscribers ack. Of course, a lot would be different in a real pubsub :) + future::join_all(publications).await; } } @@ -119,50 +164,26 @@ impl publisher::Publisher for Publisher { async fn main() -> io::Result<()> { env_logger::init(); - let transport = tarpc::serde_transport::tcp::listen("localhost:0", Json::default) - .await? - .filter_map(|r| future::ready(r.ok())); - let publisher_addr = transport.get_ref().local_addr(); - tokio::spawn( - transport - .take(1) - .map(server::BaseChannel::with_defaults) - .respond_with(Publisher::new().serve()), - ); + let clients = Arc::new(Mutex::new(HashMap::new())); + let addrs = Publisher { clients }.start().await?; - let subscriber1 = Subscriber::listen(0, server::Config::default()).await?; - let subscriber2 = Subscriber::listen(1, server::Config::default()).await?; + let mut publisher = publisher::PublisherClient::new( + client::Config::default(), + tcp::connect(addrs.publisher, Json::default()).await?, + ) + .spawn()?; - let publisher_conn = tarpc::serde_transport::tcp::connect(publisher_addr, Json::default()); - let publisher_conn = publisher_conn.await?; - let mut publisher = - publisher::PublisherClient::new(client::Config::default(), publisher_conn).spawn()?; + let _subscriber0 = Subscriber::connect(addrs.subscriptions).await?; + publisher + .broadcast(context::current(), "hello to one".to_string()) + .await?; - if let Err(e) = publisher - .subscribe(context::current(), 0, subscriber1) - .await? - { - eprintln!("Couldn't subscribe subscriber 0: {}", e); - } - if let Err(e) = publisher - .subscribe(context::current(), 1, subscriber2) - .await? - { - eprintln!("Couldn't subscribe subscriber 1: {}", e); - } - - println!("Broadcasting..."); + let _subscriber1 = Subscriber::connect(addrs.subscriptions).await?; publisher .broadcast(context::current(), "hello to all".to_string()) .await?; - publisher.unsubscribe(context::current(), 1).await?; - publisher - .broadcast(context::current(), "hi again".to_string()) - .await?; - drop(publisher); - tokio::time::delay_for(Duration::from_millis(100)).await; - println!("Done."); + info!("done."); Ok(()) }