mirror of
https://github.com/OMGeeky/tarpc.git
synced 2025-12-26 17:02:32 +01:00
Rewrite pubsub example to have the subscriber connect to the publisher.
Fixes https://github.com/google/tarpc/issues/313
This commit is contained in:
@@ -4,20 +4,25 @@
|
||||
// license that can be found in the LICENSE file or at
|
||||
// https://opensource.org/licenses/MIT.
|
||||
|
||||
use futures::{future, prelude::*};
|
||||
use futures::{
|
||||
future::{self, AbortHandle},
|
||||
prelude::*,
|
||||
};
|
||||
use log::info;
|
||||
use publisher::Publisher as _;
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
io,
|
||||
net::SocketAddr,
|
||||
sync::{Arc, Mutex},
|
||||
time::Duration,
|
||||
};
|
||||
use subscriber::Subscriber as _;
|
||||
use tarpc::{
|
||||
client, context,
|
||||
server::{self, Handler},
|
||||
serde_transport::tcp,
|
||||
server::{self, Channel},
|
||||
};
|
||||
use tokio::net::ToSocketAddrs;
|
||||
use tokio_serde::formats::Json;
|
||||
|
||||
pub mod subscriber {
|
||||
@@ -28,90 +33,130 @@ pub mod subscriber {
|
||||
}
|
||||
|
||||
pub mod publisher {
|
||||
use std::net::SocketAddr;
|
||||
|
||||
#[tarpc::service]
|
||||
pub trait Publisher {
|
||||
async fn broadcast(message: String);
|
||||
async fn subscribe(id: u32, address: SocketAddr) -> Result<(), String>;
|
||||
async fn unsubscribe(id: u32);
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct Subscriber {
|
||||
id: u32,
|
||||
local_addr: SocketAddr,
|
||||
}
|
||||
|
||||
#[tarpc::server]
|
||||
impl subscriber::Subscriber for Subscriber {
|
||||
async fn receive(self, _: context::Context, message: String) {
|
||||
eprintln!("{} received message: {}", self.id, message);
|
||||
info!("{} received message: {}", self.local_addr, message);
|
||||
}
|
||||
}
|
||||
|
||||
struct SubscriberHandle(AbortHandle);
|
||||
|
||||
impl Drop for SubscriberHandle {
|
||||
fn drop(&mut self) {
|
||||
self.0.abort();
|
||||
}
|
||||
}
|
||||
|
||||
impl Subscriber {
|
||||
async fn listen(id: u32, config: server::Config) -> io::Result<SocketAddr> {
|
||||
let incoming = tarpc::serde_transport::tcp::listen("localhost:0", Json::default)
|
||||
.await?
|
||||
.filter_map(|r| future::ready(r.ok()));
|
||||
let addr = incoming.get_ref().local_addr();
|
||||
tokio::spawn(
|
||||
server::new(config)
|
||||
.incoming(incoming)
|
||||
.take(1)
|
||||
.respond_with(Subscriber { id }.serve()),
|
||||
async fn connect(publisher_addr: impl ToSocketAddrs) -> io::Result<SubscriberHandle> {
|
||||
let publisher = tcp::connect(publisher_addr, Json::default()).await?;
|
||||
let local_addr = publisher.local_addr()?;
|
||||
let (handler, abort_handle) = future::abortable(
|
||||
server::BaseChannel::with_defaults(publisher)
|
||||
.respond_with(Subscriber { local_addr }.serve())
|
||||
.execute(),
|
||||
);
|
||||
Ok(addr)
|
||||
tokio::spawn(handler);
|
||||
Ok(SubscriberHandle(abort_handle))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct Publisher {
|
||||
clients: Arc<Mutex<HashMap<u32, subscriber::SubscriberClient>>>,
|
||||
clients: Arc<Mutex<HashMap<SocketAddr, subscriber::SubscriberClient>>>,
|
||||
}
|
||||
|
||||
struct PublisherAddrs {
|
||||
publisher: SocketAddr,
|
||||
subscriptions: SocketAddr,
|
||||
}
|
||||
|
||||
impl Publisher {
|
||||
fn new() -> Publisher {
|
||||
Publisher {
|
||||
clients: Arc::new(Mutex::new(HashMap::new())),
|
||||
}
|
||||
async fn start(self) -> io::Result<PublisherAddrs> {
|
||||
let mut connecting_publishers = tcp::listen("localhost:0", Json::default).await?;
|
||||
|
||||
let publisher_addrs = PublisherAddrs {
|
||||
publisher: connecting_publishers.local_addr(),
|
||||
subscriptions: Self::start_subscription_manager(self.clients.clone()).await?,
|
||||
};
|
||||
|
||||
tokio::spawn(async move {
|
||||
// Because this is just an example, we know there will only be one publisher. In more
|
||||
// realistic code, this would be a loop to continually accept new publisher
|
||||
// connections.
|
||||
let publisher = connecting_publishers.next().await.unwrap().unwrap();
|
||||
info!("[{}] publisher connected.", publisher.peer_addr().unwrap());
|
||||
|
||||
server::BaseChannel::with_defaults(publisher)
|
||||
.respond_with(self.serve())
|
||||
.execute()
|
||||
.await
|
||||
});
|
||||
|
||||
Ok(publisher_addrs)
|
||||
}
|
||||
|
||||
async fn start_subscription_manager(
|
||||
clients: Arc<Mutex<HashMap<SocketAddr, subscriber::SubscriberClient>>>,
|
||||
) -> io::Result<SocketAddr> {
|
||||
let mut connecting_subscribers = tcp::listen("localhost:0", Json::default)
|
||||
.await?
|
||||
.filter_map(|r| future::ready(r.ok()));
|
||||
let new_subscriber_addr = connecting_subscribers.get_ref().local_addr();
|
||||
|
||||
tokio::spawn(async move {
|
||||
while let Some(conn) = connecting_subscribers.next().await {
|
||||
let subscriber_addr = conn.peer_addr().unwrap();
|
||||
info!("[{}] subscriber connected.", subscriber_addr);
|
||||
|
||||
let tarpc::client::NewClient {
|
||||
client: subscriber,
|
||||
dispatch,
|
||||
} = subscriber::SubscriberClient::new(client::Config::default(), conn);
|
||||
clients.lock().unwrap().insert(subscriber_addr, subscriber);
|
||||
|
||||
let dropped_clients = clients.clone();
|
||||
tokio::spawn(async move {
|
||||
match dispatch.await {
|
||||
Ok(()) => info!("[{:?}] subscriber connection closed", subscriber_addr),
|
||||
Err(e) => info!(
|
||||
"[{:?}] subscriber connection broken: {}",
|
||||
subscriber_addr, e
|
||||
),
|
||||
}
|
||||
dropped_clients.lock().unwrap().remove(&subscriber_addr);
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
Ok(new_subscriber_addr)
|
||||
}
|
||||
}
|
||||
|
||||
#[tarpc::server]
|
||||
impl publisher::Publisher for Publisher {
|
||||
async fn broadcast(self, _: context::Context, message: String) {
|
||||
info!("received message to broadcast.");
|
||||
let mut clients = self.clients.lock().unwrap().clone();
|
||||
let mut publications = Vec::new();
|
||||
for client in clients.values_mut() {
|
||||
// Ignore failing subscribers. In a real pubsub,
|
||||
// you'd want to continually retry until subscribers
|
||||
// ack.
|
||||
let _ = client.receive(context::current(), message.clone()).await;
|
||||
}
|
||||
}
|
||||
|
||||
async fn subscribe(self, _: context::Context, id: u32, addr: SocketAddr) -> Result<(), String> {
|
||||
let conn = tarpc::serde_transport::tcp::connect(addr, Json::default())
|
||||
.await
|
||||
.map_err(|e| e.to_string())?;
|
||||
let subscriber = subscriber::SubscriberClient::new(client::Config::default(), conn)
|
||||
.spawn()
|
||||
.map_err(|e| e.to_string())?;
|
||||
eprintln!("Subscribing {}.", id);
|
||||
self.clients.lock().unwrap().insert(id, subscriber);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn unsubscribe(self, _: context::Context, id: u32) {
|
||||
eprintln!("Unsubscribing {}", id);
|
||||
let mut clients = self.clients.lock().unwrap();
|
||||
if clients.remove(&id).is_none() {
|
||||
eprintln!(
|
||||
"Client {} not found. Existings clients: {:?}",
|
||||
id, &*clients
|
||||
);
|
||||
publications.push(client.receive(context::current(), message.clone()));
|
||||
}
|
||||
// Ignore failing subscribers. In a real pubsub, you'd want to continually retry until
|
||||
// subscribers ack. Of course, a lot would be different in a real pubsub :)
|
||||
future::join_all(publications).await;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -119,50 +164,26 @@ impl publisher::Publisher for Publisher {
|
||||
async fn main() -> io::Result<()> {
|
||||
env_logger::init();
|
||||
|
||||
let transport = tarpc::serde_transport::tcp::listen("localhost:0", Json::default)
|
||||
.await?
|
||||
.filter_map(|r| future::ready(r.ok()));
|
||||
let publisher_addr = transport.get_ref().local_addr();
|
||||
tokio::spawn(
|
||||
transport
|
||||
.take(1)
|
||||
.map(server::BaseChannel::with_defaults)
|
||||
.respond_with(Publisher::new().serve()),
|
||||
);
|
||||
let clients = Arc::new(Mutex::new(HashMap::new()));
|
||||
let addrs = Publisher { clients }.start().await?;
|
||||
|
||||
let subscriber1 = Subscriber::listen(0, server::Config::default()).await?;
|
||||
let subscriber2 = Subscriber::listen(1, server::Config::default()).await?;
|
||||
let mut publisher = publisher::PublisherClient::new(
|
||||
client::Config::default(),
|
||||
tcp::connect(addrs.publisher, Json::default()).await?,
|
||||
)
|
||||
.spawn()?;
|
||||
|
||||
let publisher_conn = tarpc::serde_transport::tcp::connect(publisher_addr, Json::default());
|
||||
let publisher_conn = publisher_conn.await?;
|
||||
let mut publisher =
|
||||
publisher::PublisherClient::new(client::Config::default(), publisher_conn).spawn()?;
|
||||
let _subscriber0 = Subscriber::connect(addrs.subscriptions).await?;
|
||||
publisher
|
||||
.broadcast(context::current(), "hello to one".to_string())
|
||||
.await?;
|
||||
|
||||
if let Err(e) = publisher
|
||||
.subscribe(context::current(), 0, subscriber1)
|
||||
.await?
|
||||
{
|
||||
eprintln!("Couldn't subscribe subscriber 0: {}", e);
|
||||
}
|
||||
if let Err(e) = publisher
|
||||
.subscribe(context::current(), 1, subscriber2)
|
||||
.await?
|
||||
{
|
||||
eprintln!("Couldn't subscribe subscriber 1: {}", e);
|
||||
}
|
||||
|
||||
println!("Broadcasting...");
|
||||
let _subscriber1 = Subscriber::connect(addrs.subscriptions).await?;
|
||||
publisher
|
||||
.broadcast(context::current(), "hello to all".to_string())
|
||||
.await?;
|
||||
publisher.unsubscribe(context::current(), 1).await?;
|
||||
publisher
|
||||
.broadcast(context::current(), "hi again".to_string())
|
||||
.await?;
|
||||
drop(publisher);
|
||||
|
||||
tokio::time::delay_for(Duration::from_millis(100)).await;
|
||||
println!("Done.");
|
||||
info!("done.");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user