mirror of
https://github.com/OMGeeky/flucto-heisskleber.git
synced 2026-02-13 21:18:09 +01:00
Refactor/background tasks (#75)
* Add start, stop and __repr__ to sink and source types. * Merge conflicts on mqtt async pub and resampler. * Add start() and stop() functions to udp and zmq. Change tests accordingly. * Rename broker, ip, interface to common config name "host". * Updated "host" entry in config files. * Add lazyload to mqtt-source.
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
broker: "10.47.36.1"
|
||||
host: "10.47.36.1"
|
||||
user: ""
|
||||
password: ""
|
||||
port: 1883
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
protocol : "tcp" # ipc protocol
|
||||
interface: "127.0.0.1" # the interface to bind to
|
||||
publisher_port : 5555 # port used by primary producers
|
||||
subscriber_port: 5556 # port used by primary consumers
|
||||
protocol: "tcp" # ipc protocol
|
||||
host: "127.0.0.1" # the interface to bind to
|
||||
publisher_port: 5555 # port used by primary producers
|
||||
subscriber_port: 5556 # port used by primary consumers
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from .config import BaseConf
|
||||
from .parse import load_config
|
||||
from .parse import ConfigType, load_config
|
||||
|
||||
__all__ = ["load_config", "BaseConf"]
|
||||
__all__ = ["load_config", "BaseConf", "ConfigType"]
|
||||
|
||||
@@ -16,6 +16,15 @@ class ConsoleSink(Sink):
|
||||
else:
|
||||
print(verbose_topic + str(data))
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}(pretty={self.pretty}, verbose={self.verbose})"
|
||||
|
||||
def start(self) -> None:
|
||||
pass
|
||||
|
||||
def stop(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class AsyncConsoleSink(AsyncSink):
|
||||
def __init__(self, pretty: bool = False, verbose: bool = False) -> None:
|
||||
@@ -29,6 +38,15 @@ class AsyncConsoleSink(AsyncSink):
|
||||
else:
|
||||
print(verbose_topic + str(data))
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}(pretty={self.pretty}, verbose={self.verbose})"
|
||||
|
||||
def start(self) -> None:
|
||||
pass
|
||||
|
||||
def stop(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sink = ConsoleSink()
|
||||
|
||||
@@ -1,34 +1,98 @@
|
||||
import asyncio
|
||||
import json
|
||||
import sys
|
||||
import time
|
||||
from queue import SimpleQueue
|
||||
from threading import Thread
|
||||
|
||||
from heisskleber.core.types import Serializable, Source
|
||||
from heisskleber.core.types import AsyncSource, Serializable, Source
|
||||
|
||||
|
||||
class ConsoleSource(Source):
|
||||
def __init__(self, topic: str | list[str] | tuple[str] = "console") -> None:
|
||||
self.topic = "console"
|
||||
def __init__(self, topic: str = "console") -> None:
|
||||
self.topic = topic
|
||||
self.queue = SimpleQueue()
|
||||
self.listener_daemon = Thread(target=self.listener_task, daemon=True)
|
||||
self.listener_daemon.start()
|
||||
self.pack = json.loads
|
||||
self.thread: Thread | None = None
|
||||
|
||||
def listener_task(self):
|
||||
while True:
|
||||
data = sys.stdin.readline()
|
||||
payload = self.pack(data)
|
||||
self.queue.put(payload)
|
||||
try:
|
||||
data = sys.stdin.readline()
|
||||
payload = self.pack(data)
|
||||
self.queue.put(payload)
|
||||
except json.decoder.JSONDecodeError:
|
||||
print("Invalid JSON")
|
||||
continue
|
||||
except ValueError:
|
||||
break
|
||||
print("listener task finished")
|
||||
|
||||
def receive(self) -> tuple[str, dict[str, Serializable]]:
|
||||
if not self.thread:
|
||||
self.start()
|
||||
|
||||
data = self.queue.get()
|
||||
return self.topic, data
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}(topic={self.topic})"
|
||||
|
||||
def start(self) -> None:
|
||||
self.thread = Thread(target=self.listener_task, daemon=True)
|
||||
self.thread.start()
|
||||
|
||||
def stop(self) -> None:
|
||||
if self.thread:
|
||||
sys.stdin.close()
|
||||
self.thread.join()
|
||||
|
||||
|
||||
class AsyncConsoleSource(AsyncSource):
|
||||
def __init__(self, topic: str = "console") -> None:
|
||||
self.topic = topic
|
||||
self.queue: asyncio.Queue[dict[str, Serializable]] = asyncio.Queue(maxsize=10)
|
||||
self.pack = json.loads
|
||||
self.task: asyncio.Task[None] | None = None
|
||||
|
||||
async def listener_task(self):
|
||||
while True:
|
||||
data = sys.stdin.readline()
|
||||
payload = self.pack(data)
|
||||
await self.queue.put(payload)
|
||||
|
||||
async def receive(self) -> tuple[str, dict[str, Serializable]]:
|
||||
if not self.task:
|
||||
self.start()
|
||||
|
||||
data = await self.queue.get()
|
||||
return self.topic, data
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}(topic={self.topic})"
|
||||
|
||||
def start(self) -> None:
|
||||
self.task = asyncio.create_task(self.listener_task())
|
||||
|
||||
def stop(self) -> None:
|
||||
if self.task:
|
||||
self.task.cancel()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
console_source = ConsoleSource()
|
||||
console_source.start()
|
||||
|
||||
while True:
|
||||
print(console_source.receive())
|
||||
time.sleep(1)
|
||||
print("Listening to console input.")
|
||||
|
||||
count = 0
|
||||
|
||||
try:
|
||||
while True:
|
||||
print(console_source.receive())
|
||||
time.sleep(1)
|
||||
count += 1
|
||||
print(count)
|
||||
except KeyboardInterrupt:
|
||||
print("Stopped")
|
||||
sys.exit(0)
|
||||
|
||||
@@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, Union
|
||||
|
||||
from heisskleber.config.config import BaseConf
|
||||
from heisskleber.config import BaseConf
|
||||
|
||||
Serializable = Union[str, int, float]
|
||||
|
||||
@@ -23,12 +23,30 @@ class Sink(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def send(self, data: dict[str, Any], topic: str) -> None:
|
||||
def send(self, data: dict[str, Serializable], topic: str) -> None:
|
||||
"""
|
||||
Send data via the implemented output stream.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def __repr__(self) -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def start(self) -> None:
|
||||
"""
|
||||
Start any background processes and tasks.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def stop(self) -> None:
|
||||
"""
|
||||
Stop any background processes and tasks.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class Source(ABC):
|
||||
"""
|
||||
@@ -53,25 +71,21 @@ class Source(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class AsyncSubscriber(ABC):
|
||||
"""
|
||||
AsyncSubscriber interface
|
||||
"""
|
||||
@abstractmethod
|
||||
def __repr__(self) -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, config: Any, topic: str | list[str]) -> None:
|
||||
def start(self) -> None:
|
||||
"""
|
||||
Initialize the subscriber with a topic and a configuration object.
|
||||
Start any background processes and tasks.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def receive(self) -> tuple[str, dict[str, Serializable]]:
|
||||
def stop(self) -> None:
|
||||
"""
|
||||
Blocking function to receive data from the implemented input stream.
|
||||
|
||||
Data is returned as a tuple of (topic, data).
|
||||
Stop any background processes and tasks.
|
||||
"""
|
||||
pass
|
||||
|
||||
@@ -97,6 +111,24 @@ class AsyncSource(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def __repr__(self) -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def start(self) -> None:
|
||||
"""
|
||||
Start any background processes and tasks.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def stop(self) -> None:
|
||||
"""
|
||||
Stop any background processes and tasks.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class AsyncSink(ABC):
|
||||
"""
|
||||
@@ -118,3 +150,21 @@ class AsyncSink(ABC):
|
||||
Send data via the implemented output stream.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def __repr__(self) -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def start(self) -> None:
|
||||
"""
|
||||
Start any background processes and tasks.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def stop(self) -> None:
|
||||
"""
|
||||
Stop any background processes and tasks.
|
||||
"""
|
||||
pass
|
||||
|
||||
@@ -9,7 +9,7 @@ class MqttConf(BaseConf):
|
||||
MQTT configuration class.
|
||||
"""
|
||||
|
||||
broker: str = "localhost"
|
||||
host: str = "localhost"
|
||||
user: str = ""
|
||||
password: str = ""
|
||||
port: int = 1883
|
||||
|
||||
@@ -34,8 +34,14 @@ class MqttBase:
|
||||
|
||||
def __init__(self, config: MqttConf) -> None:
|
||||
self.config = config
|
||||
self.client: mqtt_client | None = None
|
||||
|
||||
def start(self) -> None:
|
||||
self.connect()
|
||||
self.client.loop_start()
|
||||
|
||||
def stop(self) -> None:
|
||||
if self.client:
|
||||
self.client.loop_stop()
|
||||
|
||||
def connect(self) -> None:
|
||||
self.client = mqtt_client()
|
||||
@@ -52,7 +58,8 @@ class MqttBase:
|
||||
# the default certification authority of the system is used.
|
||||
self.client.tls_set(tls_version=ssl.PROTOCOL_TLS_CLIENT)
|
||||
|
||||
self.client.connect(self.config.broker, self.config.port)
|
||||
self.client.connect(self.config.host, self.config.port)
|
||||
self.client.loop_start()
|
||||
|
||||
@staticmethod
|
||||
def _raise_if_thread_died() -> None:
|
||||
@@ -63,7 +70,7 @@ class MqttBase:
|
||||
# MQTT callbacks
|
||||
def _on_connect(self, client, userdata, flags, return_code) -> None:
|
||||
if return_code == 0:
|
||||
print(f"MQTT node connected to {self.config.broker}:{self.config.port}")
|
||||
print(f"MQTT node connected to {self.config.host}:{self.config.port}")
|
||||
else:
|
||||
print("Connection failed!")
|
||||
if self.config.verbose:
|
||||
@@ -84,4 +91,4 @@ class MqttBase:
|
||||
print(f"Received message: {message.payload!s}, topic: {message.topic}, qos: {message.qos}")
|
||||
|
||||
def __del__(self) -> None:
|
||||
self.client.loop_stop()
|
||||
self.stop()
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from heisskleber.core.packer import get_packer
|
||||
from heisskleber.core.types import Sink
|
||||
from heisskleber.core.types import Serializable, Sink
|
||||
|
||||
from .config import MqttConf
|
||||
from .mqtt_base import MqttBase
|
||||
@@ -21,14 +19,26 @@ class MqttPublisher(MqttBase, Sink):
|
||||
super().__init__(config)
|
||||
self.pack = get_packer(config.packstyle)
|
||||
|
||||
def send(self, data: dict[str, Any], topic: str) -> None:
|
||||
def send(self, data: dict[str, Serializable], topic: str) -> None:
|
||||
"""
|
||||
Takes python dictionary, serializes it according to the packstyle
|
||||
and sends it to the broker.
|
||||
|
||||
Publishing is asynchronous
|
||||
"""
|
||||
if not self.client.is_connected():
|
||||
self.start()
|
||||
|
||||
self._raise_if_thread_died()
|
||||
|
||||
payload = self.pack(data)
|
||||
self.client.publish(topic, payload, qos=self.config.qos, retain=self.config.retain)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}(broker={self.config.host}, port={self.config.port})"
|
||||
|
||||
def start(self) -> None:
|
||||
super().start()
|
||||
|
||||
def stop(self) -> None:
|
||||
super().stop()
|
||||
|
||||
@@ -1,7 +1,4 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from asyncio import Queue, Task, create_task
|
||||
from asyncio import Queue, Task, create_task, sleep
|
||||
|
||||
import aiomqtt
|
||||
|
||||
@@ -22,8 +19,8 @@ class AsyncMqttPublisher(AsyncSink):
|
||||
def __init__(self, config: MqttConf) -> None:
|
||||
self.config = config
|
||||
self.pack = get_packer(config.packstyle)
|
||||
self._send_queue: Queue[tuple[dict[str, Serializable], str]] = Queue(maxsize=config.max_saved_messages)
|
||||
self._sender_task: Task[None] = create_task(self.send_work())
|
||||
self._send_queue: Queue[tuple[dict[str, Serializable], str]] = Queue()
|
||||
self._sender_task: Task[None] | None = None
|
||||
|
||||
async def send(self, data: dict[str, Serializable], topic: str) -> None:
|
||||
"""
|
||||
@@ -32,6 +29,8 @@ class AsyncMqttPublisher(AsyncSink):
|
||||
|
||||
Publishing is asynchronous
|
||||
"""
|
||||
if not self._sender_task:
|
||||
self.start()
|
||||
|
||||
await self._send_queue.put((data, topic))
|
||||
|
||||
@@ -45,7 +44,7 @@ class AsyncMqttPublisher(AsyncSink):
|
||||
while True:
|
||||
try:
|
||||
async with aiomqtt.Client(
|
||||
hostname=self.config.broker,
|
||||
hostname=self.config.host,
|
||||
port=self.config.port,
|
||||
username=self.config.user,
|
||||
password=self.config.password,
|
||||
@@ -57,4 +56,15 @@ class AsyncMqttPublisher(AsyncSink):
|
||||
await client.publish(topic, payload)
|
||||
except aiomqtt.MqttError:
|
||||
print("Connection to MQTT broker failed. Retrying in 5 seconds")
|
||||
await asyncio.sleep(5)
|
||||
await sleep(5)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}(broker={self.config.host}, port={self.config.port})"
|
||||
|
||||
def start(self) -> None:
|
||||
self._sender_task = create_task(self.send_work())
|
||||
|
||||
def stop(self) -> None:
|
||||
if self._sender_task:
|
||||
self._sender_task.cancel()
|
||||
self._sender_task = None
|
||||
|
||||
@@ -22,9 +22,8 @@ class MqttSubscriber(MqttBase, Source):
|
||||
|
||||
def __init__(self, config: MqttConf, topics: str | list[str]) -> None:
|
||||
super().__init__(config)
|
||||
self.topics = topics
|
||||
self._message_queue: SimpleQueue[MQTTMessage] = SimpleQueue()
|
||||
self.subscribe(topics)
|
||||
self.client.on_message = self._on_message
|
||||
self.unpack = get_unpacker(config.packstyle)
|
||||
|
||||
def subscribe(self, topics: str | list[str] | tuple[str]) -> None:
|
||||
@@ -47,14 +46,29 @@ class MqttSubscriber(MqttBase, Source):
|
||||
Messages are saved in a stack, if no message is available, this function blocks.
|
||||
|
||||
Returns:
|
||||
tuple(topic: bytes, message: dict): the message received
|
||||
tuple(topic: str, message: dict): the message received
|
||||
"""
|
||||
if not self.client:
|
||||
self.start()
|
||||
|
||||
self._raise_if_thread_died()
|
||||
mqtt_message = self._message_queue.get(block=True, timeout=self.config.timeout_s)
|
||||
|
||||
message_returned = self.unpack(mqtt_message.payload.decode())
|
||||
return (mqtt_message.topic, message_returned)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}(broker={self.config.host}, port={self.config.port})"
|
||||
|
||||
def start(self) -> None:
|
||||
super().start()
|
||||
self.subscribe(self.topics)
|
||||
self.client.on_message = self._on_message
|
||||
self.is_connected = True
|
||||
|
||||
def stop(self) -> None:
|
||||
super().stop()
|
||||
|
||||
# callback to add incoming messages onto stack
|
||||
def _on_message(self, client, userdata, message) -> None:
|
||||
self._message_queue.put(message)
|
||||
|
||||
@@ -16,7 +16,7 @@ class AsyncMqttSubscriber(AsyncSource):
|
||||
def __init__(self, config: MqttConf, topic: str | list[str]) -> None:
|
||||
self.config: MqttConf = config
|
||||
self.client = Client(
|
||||
hostname=self.config.broker,
|
||||
hostname=self.config.host,
|
||||
port=self.config.port,
|
||||
username=self.config.user,
|
||||
password=self.config.password,
|
||||
@@ -24,17 +24,32 @@ class AsyncMqttSubscriber(AsyncSource):
|
||||
self.topics = topic
|
||||
self.unpack = get_unpacker(self.config.packstyle)
|
||||
self.message_queue: Queue[Message] = Queue(self.config.max_saved_messages)
|
||||
self._listener_task: Task = create_task(self.create_listener())
|
||||
self._listener_task: Task[None] | None = None
|
||||
|
||||
"""
|
||||
Await the newest message in the queue and return Tuple
|
||||
"""
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}(broker={self.config.host}, port={self.config.port})"
|
||||
|
||||
def start(self) -> None:
|
||||
self._listener_task = create_task(self.run())
|
||||
|
||||
def stop(self) -> None:
|
||||
if self._listener_task:
|
||||
self._listener_task.cancel()
|
||||
self._listener_task = None
|
||||
|
||||
async def receive(self) -> tuple[str, dict[str, Serializable]]:
|
||||
mqtt_message: Message = await self.message_queue.get()
|
||||
"""
|
||||
Await the newest message in the queue and return Tuple
|
||||
"""
|
||||
if not self._listener_task:
|
||||
self.start()
|
||||
mqtt_message = await self.message_queue.get()
|
||||
return self._handle_message(mqtt_message)
|
||||
|
||||
async def create_listener(self):
|
||||
async def run(self):
|
||||
"""
|
||||
Handle the connection to MQTT broker and run the message loop.
|
||||
"""
|
||||
while True:
|
||||
try:
|
||||
async with self.client:
|
||||
@@ -45,11 +60,10 @@ class AsyncMqttSubscriber(AsyncSource):
|
||||
print("Connection to MQTT failed. Retrying...")
|
||||
await sleep(1)
|
||||
|
||||
"""
|
||||
Listen to incoming messages asynchronously and put them into a queue
|
||||
"""
|
||||
|
||||
async def _listen_mqtt_loop(self) -> None:
|
||||
"""
|
||||
Listen to incoming messages asynchronously and put them into a queue
|
||||
"""
|
||||
async with self.client.messages() as messages:
|
||||
# async with self.client.filtered_messages(self.topics) as messages:
|
||||
async for message in messages:
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
import argparse
|
||||
import sys
|
||||
from typing import Union
|
||||
from typing import Callable, Union
|
||||
|
||||
from heisskleber.config import load_config
|
||||
from heisskleber.console.sink import ConsoleSink
|
||||
from heisskleber.core.factories import _registered_sources
|
||||
|
||||
TopicType = Union[str, list[str]]
|
||||
from heisskleber.mqtt import MqttSubscriber
|
||||
from heisskleber.udp import UdpSubscriber
|
||||
from heisskleber.zmq import ZmqSubscriber
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
@@ -33,14 +34,12 @@ def parse_args() -> argparse.Namespace:
|
||||
"-H",
|
||||
"--host",
|
||||
type=str,
|
||||
default="localhost",
|
||||
help="Host or broker for MQTT, zmq and UDP.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-P",
|
||||
"--port",
|
||||
type=int,
|
||||
default=1883,
|
||||
help="Port or serial interface for MQTT, zmq and UDP.",
|
||||
)
|
||||
parser.add_argument("-v", "--verbose", action="store_true")
|
||||
@@ -49,7 +48,19 @@ def parse_args() -> argparse.Namespace:
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def run() -> None:
|
||||
def keyboardexit(func) -> Callable:
|
||||
def wrapper(*args, **kwargs) -> Union[None, int]:
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except KeyboardInterrupt:
|
||||
print("Exiting...")
|
||||
sys.exit(0)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@keyboardexit
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
sink = ConsoleSink(pretty=args.pretty, verbose=args.verbose)
|
||||
|
||||
@@ -58,36 +69,27 @@ def run() -> None:
|
||||
try:
|
||||
config = load_config(conf_cls(), args.type, read_commandline=False)
|
||||
except FileNotFoundError:
|
||||
print(f"Using default config for {args.type}.")
|
||||
print(f"No config file found for {args.type}, using default values and user input.")
|
||||
config = conf_cls()
|
||||
|
||||
if args.port:
|
||||
config.port = args.port
|
||||
|
||||
if args.host:
|
||||
if args.type == "mqtt":
|
||||
config.broker = args.host
|
||||
elif args.type == "zmq":
|
||||
config.interface = args.host
|
||||
elif args.type == "udp":
|
||||
config.ip = args.host
|
||||
|
||||
if args.type == "zmq" and args.topic == "#":
|
||||
args.topic = ""
|
||||
|
||||
source = sub_cls(config, args.topic)
|
||||
if isinstance(source, (MqttSubscriber, UdpSubscriber)):
|
||||
source.config.host = args.host or source.config.host
|
||||
source.config.port = args.port or source.config.port
|
||||
elif isinstance(source, ZmqSubscriber):
|
||||
source.config.host = args.host or source.config.host
|
||||
source.config.subscriber_port = args.port or source.config.subscriber_port
|
||||
args.topic = "" if args.topic == "#" else args.topic
|
||||
elif isinstance(source, UdpSubscriber):
|
||||
source.config.port = args.port or source.config.port
|
||||
|
||||
source.start()
|
||||
sink.start()
|
||||
|
||||
while True:
|
||||
topic, data = source.receive()
|
||||
sink.send(data, topic)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
try:
|
||||
run()
|
||||
except KeyboardInterrupt:
|
||||
print("Exiting...")
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from typing import Callable, Optional
|
||||
|
||||
import serial
|
||||
@@ -11,6 +12,7 @@ from .config import SerialConf
|
||||
|
||||
|
||||
class SerialPublisher(Sink):
|
||||
serial_connection: serial.Serial
|
||||
"""
|
||||
Publisher for serial devices.
|
||||
Can be used everywhere that a flucto style publishing connection is required.
|
||||
@@ -30,19 +32,36 @@ class SerialPublisher(Sink):
|
||||
):
|
||||
self.config = config
|
||||
self.pack = pack_func if pack_func else get_packer("serial")
|
||||
self._connect()
|
||||
self.is_connected = False
|
||||
|
||||
def start(self) -> None:
|
||||
"""
|
||||
Start the serial connection.
|
||||
"""
|
||||
try:
|
||||
self.serial_connection = serial.Serial(
|
||||
port=self.config.port,
|
||||
baudrate=self.config.baudrate,
|
||||
bytesize=self.config.bytesize,
|
||||
parity=serial.PARITY_NONE,
|
||||
stopbits=serial.STOPBITS_ONE,
|
||||
)
|
||||
except serial.SerialException:
|
||||
print(f"Failed to connect to serial device at port {self.config.port}")
|
||||
sys.exit(1)
|
||||
|
||||
def _connect(self) -> None:
|
||||
self.serial: serial.Serial = serial.Serial(
|
||||
port=self.config.port,
|
||||
baudrate=self.config.baudrate,
|
||||
bytesize=self.config.bytesize,
|
||||
parity=serial.PARITY_NONE,
|
||||
stopbits=serial.STOPBITS_ONE,
|
||||
)
|
||||
print(f"Successfully connected to serial device at port {self.config.port}")
|
||||
self.is_connected = True
|
||||
|
||||
def send(self, message: dict[str, Serializable], topic: str) -> None:
|
||||
def stop(self) -> None:
|
||||
"""
|
||||
Stop the serial connection.
|
||||
"""
|
||||
if hasattr(self, "serial_connection") and self.serial_connection.is_open:
|
||||
self.serial_connection.flush()
|
||||
self.serial_connection.close()
|
||||
|
||||
def send(self, data: dict[str, Serializable], topic: str) -> None:
|
||||
"""
|
||||
Takes python dictionary, serializes it according to the packstyle
|
||||
and sends it to the broker.
|
||||
@@ -52,16 +71,17 @@ class SerialPublisher(Sink):
|
||||
message : dict
|
||||
object to be serialized and sent via the serial connection. Usually a dict.
|
||||
"""
|
||||
payload = self.pack(message)
|
||||
self.serial.write(payload.encode(self.config.encoding))
|
||||
self.serial.flush()
|
||||
if not self.is_connected:
|
||||
self.start()
|
||||
|
||||
payload = self.pack(data)
|
||||
self.serial_connection.write(payload.encode(self.config.encoding))
|
||||
self.serial_connection.flush()
|
||||
if self.config.verbose:
|
||||
print(f"{topic}: {payload}")
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"SerialPublisher(port={self.config.port}, baudrate={self.config.baudrate}, bytezize={self.config.bytesize}, encoding={self.config.encoding})"
|
||||
|
||||
def __del__(self) -> None:
|
||||
if not hasattr(self, "serial"):
|
||||
return
|
||||
if not self.serial.is_open:
|
||||
return
|
||||
self.serial.flush()
|
||||
self.serial.close()
|
||||
self.stop()
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from collections.abc import Generator
|
||||
from typing import Callable, Optional
|
||||
from typing import Callable
|
||||
|
||||
import serial
|
||||
|
||||
@@ -11,6 +10,7 @@ from .config import SerialConf
|
||||
|
||||
|
||||
class SerialSubscriber(Source):
|
||||
serial_connection: serial.Serial
|
||||
"""
|
||||
Subscriber for serial devices. Connects to a serial port and reads from it.
|
||||
|
||||
@@ -30,22 +30,38 @@ class SerialSubscriber(Source):
|
||||
self,
|
||||
config: SerialConf,
|
||||
topic: str | None = None,
|
||||
custom_unpack: Optional[Callable] = None, # noqa: UP007
|
||||
custom_unpack: Callable | None = None,
|
||||
):
|
||||
self.config = config
|
||||
self.topic = topic
|
||||
self.unpack = custom_unpack if custom_unpack else lambda x: x # types: ignore
|
||||
self._connect()
|
||||
self.is_connected = False
|
||||
|
||||
def _connect(self):
|
||||
self.serial: serial.Serial = serial.Serial(
|
||||
port=self.config.port,
|
||||
baudrate=self.config.baudrate,
|
||||
bytesize=self.config.bytesize,
|
||||
parity=serial.PARITY_NONE,
|
||||
stopbits=serial.STOPBITS_ONE,
|
||||
)
|
||||
def start(self) -> None:
|
||||
"""
|
||||
Start the serial connection.
|
||||
"""
|
||||
try:
|
||||
self.serial_connection = serial.Serial(
|
||||
port=self.config.port,
|
||||
baudrate=self.config.baudrate,
|
||||
bytesize=self.config.bytesize,
|
||||
parity=serial.PARITY_NONE,
|
||||
stopbits=serial.STOPBITS_ONE,
|
||||
)
|
||||
except serial.SerialException:
|
||||
print(f"Failed to connect to serial device at port {self.config.port}")
|
||||
sys.exit(1)
|
||||
print(f"Successfully connected to serial device at port {self.config.port}")
|
||||
self.is_connected = True
|
||||
|
||||
def stop(self) -> None:
|
||||
"""
|
||||
Stop the serial connection.
|
||||
"""
|
||||
if hasattr(self, "serial_connection") and self.serial_connection.is_open:
|
||||
self.serial_connection.flush()
|
||||
self.serial_connection.close()
|
||||
|
||||
def receive(self) -> tuple[str, dict]:
|
||||
"""
|
||||
@@ -57,6 +73,9 @@ class SerialSubscriber(Source):
|
||||
topic is a placeholder to adhere to the Subscriber interface
|
||||
payload is a dictionary containing the data from the serial port
|
||||
"""
|
||||
if not self.is_connected:
|
||||
self.start()
|
||||
|
||||
# message is a string
|
||||
message = next(self.read_serial_port())
|
||||
# payload is a dictionary
|
||||
@@ -76,7 +95,7 @@ class SerialSubscriber(Source):
|
||||
buffer = ""
|
||||
while True:
|
||||
try:
|
||||
buffer = self.serial.readline().decode(self.config.encoding, "ignore")
|
||||
buffer = self.serial_connection.readline().decode(self.config.encoding, "ignore")
|
||||
yield buffer
|
||||
except UnicodeError as e:
|
||||
if self.config.verbose:
|
||||
@@ -84,10 +103,8 @@ class SerialSubscriber(Source):
|
||||
print(e)
|
||||
continue
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"SerialPublisher(port={self.config.port}, baudrate={self.config.baudrate}, bytezize={self.config.bytesize}, encoding={self.config.encoding})"
|
||||
|
||||
def __del__(self) -> None:
|
||||
if not hasattr(self, "serial"):
|
||||
return
|
||||
if not self.serial.is_open:
|
||||
return
|
||||
self.serial.flush()
|
||||
self.serial.close()
|
||||
self.stop()
|
||||
|
||||
@@ -25,18 +25,31 @@ class Joint:
|
||||
self.output_queue: asyncio.Queue[dict[str, Serializable]] = asyncio.Queue()
|
||||
self.initialized = asyncio.Event()
|
||||
self.initalize_task = asyncio.create_task(self.sync())
|
||||
self.output_task = asyncio.create_task(self.output_work())
|
||||
self.combined_dict: dict[str, Serializable] = {}
|
||||
self.task: asyncio.Task[None] | None = None
|
||||
|
||||
"""
|
||||
Main interaction coroutine: Get next value out of the queue.
|
||||
"""
|
||||
def __repr__(self) -> str:
|
||||
return f"""Joint(resample_rate={self.conf.resample_rate},
|
||||
sources={len(self.resamplers)} of type(s): {{r.__class__.__name__ for r in self.resamplers}})"""
|
||||
|
||||
def start(self) -> None:
|
||||
self.task = asyncio.create_task(self.output_work())
|
||||
|
||||
def stop(self) -> None:
|
||||
if self.task:
|
||||
self.task.cancel()
|
||||
|
||||
async def receive(self) -> dict[str, Any]:
|
||||
"""
|
||||
Main interaction coroutine: Get next value out of the queue.
|
||||
"""
|
||||
if not self.task:
|
||||
self.start()
|
||||
output = await self.output_queue.get()
|
||||
return output
|
||||
|
||||
async def sync(self) -> None:
|
||||
"""Synchronize the resamplers by pulling data from each until the timestamp is aligned. Retains first matching data."""
|
||||
print("Starting sync")
|
||||
datas = await asyncio.gather(*[source.receive() for source in self.resamplers])
|
||||
print("Got data")
|
||||
|
||||
@@ -10,5 +10,6 @@ class UdpConf(BaseConf):
|
||||
"""
|
||||
|
||||
port: int = 1234
|
||||
ip: str = "127.0.0.1"
|
||||
host: str = "127.0.0.1"
|
||||
packer: str = "json"
|
||||
max_queue_size: int = 1000
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import socket
|
||||
import sys
|
||||
|
||||
from heisskleber.core.packer import get_packer
|
||||
from heisskleber.core.types import Serializable, Sink
|
||||
@@ -8,15 +9,30 @@ from heisskleber.udp.config import UdpConf
|
||||
class UdpPublisher(Sink):
|
||||
def __init__(self, config: UdpConf) -> None:
|
||||
self.config = config
|
||||
self.ip = self.config.ip
|
||||
self.ip = self.config.host
|
||||
self.port = self.config.port
|
||||
self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
self.pack = get_packer(self.config.packer)
|
||||
self.is_connected = False
|
||||
|
||||
def send(self, message: dict[str, Serializable], topic: str) -> None:
|
||||
message["topic"] = topic
|
||||
payload = self.pack(message).encode("utf-8")
|
||||
def start(self) -> None:
|
||||
try:
|
||||
self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
except OSError as e:
|
||||
print(f"failed to create socket: {e}")
|
||||
sys.exit(-1)
|
||||
else:
|
||||
self.is_connected = True
|
||||
|
||||
def stop(self) -> None:
|
||||
self.socket.close()
|
||||
self.is_connected = True
|
||||
|
||||
def send(self, data: dict[str, Serializable], topic: str) -> None:
|
||||
if not self.is_connected:
|
||||
self.start()
|
||||
data["topic"] = topic
|
||||
payload = self.pack(data).encode("utf-8")
|
||||
self.socket.sendto(payload, (self.ip, self.port))
|
||||
|
||||
def __del__(self) -> None:
|
||||
self.socket.close()
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}(host={self.config.host}, port={self.config.port})"
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import socket
|
||||
import sys
|
||||
import threading
|
||||
from queue import SimpleQueue
|
||||
from queue import Queue
|
||||
|
||||
from heisskleber.core.packer import get_unpacker
|
||||
from heisskleber.core.types import Serializable, Source
|
||||
@@ -11,15 +12,30 @@ class UdpSubscriber(Source):
|
||||
def __init__(self, config: UdpConf, topic: str | None = None):
|
||||
self.config = config
|
||||
self.topic = topic
|
||||
self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
self.socket.bind((self.config.ip, self.config.port))
|
||||
self.unpacker = get_unpacker(self.config.packer)
|
||||
self._queue: SimpleQueue[tuple[str, dict[str, Serializable]]] = SimpleQueue()
|
||||
self._queue: Queue[tuple[str, dict[str, Serializable]]] = Queue(maxsize=self.config.max_queue_size)
|
||||
self._running = threading.Event()
|
||||
|
||||
def start(self) -> None:
|
||||
try:
|
||||
self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
except OSError as e:
|
||||
print(f"failed to create socket: {e}")
|
||||
sys.exit(-1)
|
||||
self.socket.bind((self.config.host, self.config.port))
|
||||
self._running.set()
|
||||
self._thread: threading.Thread | None = None
|
||||
self._thread = threading.Thread(target=self._loop, daemon=True)
|
||||
self._thread.start()
|
||||
|
||||
def stop(self) -> None:
|
||||
self._running.clear()
|
||||
# if self._thread is not None:
|
||||
# self._thread.join()
|
||||
self.socket.close()
|
||||
|
||||
def receive(self) -> tuple[str, dict[str, Serializable]]:
|
||||
if not self._running.is_set():
|
||||
self.start()
|
||||
return self._queue.get()
|
||||
|
||||
def _loop(self) -> None:
|
||||
@@ -33,15 +49,5 @@ class UdpSubscriber(Source):
|
||||
error_message = f"Error in UDP listener loop: {e}"
|
||||
print(error_message)
|
||||
|
||||
def start_loop(self) -> None:
|
||||
self._thread = threading.Thread(target=self._loop, daemon=True)
|
||||
self._thread.start()
|
||||
|
||||
def stop_loop(self) -> None:
|
||||
self._running.clear()
|
||||
if self._thread is not None:
|
||||
self._thread.join()
|
||||
self.socket.close()
|
||||
|
||||
def __del__(self) -> None:
|
||||
self.stop_loop()
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}(host={self.config.host}, port={self.config.port})"
|
||||
|
||||
@@ -6,15 +6,15 @@ from heisskleber.config import BaseConf
|
||||
@dataclass
|
||||
class ZmqConf(BaseConf):
|
||||
protocol: str = "tcp"
|
||||
interface: str = "127.0.0.1"
|
||||
host: str = "127.0.0.1"
|
||||
publisher_port: int = 5555
|
||||
subscriber_port: int = 5556
|
||||
packstyle: str = "json"
|
||||
|
||||
@property
|
||||
def publisher_address(self) -> str:
|
||||
return f"{self.protocol}://{self.interface}:{self.publisher_port}"
|
||||
return f"{self.protocol}://{self.host}:{self.publisher_port}"
|
||||
|
||||
@property
|
||||
def subscriber_address(self) -> str:
|
||||
return f"{self.protocol}://{self.interface}:{self.subscriber_port}"
|
||||
return f"{self.protocol}://{self.host}:{self.subscriber_port}"
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import sys
|
||||
from typing import Callable
|
||||
|
||||
import zmq
|
||||
import zmq.asyncio
|
||||
@@ -10,45 +11,46 @@ from .config import ZmqConf
|
||||
|
||||
|
||||
class ZmqPublisher(Sink):
|
||||
"""
|
||||
Publisher that sends messages to a ZMQ PUB socket.
|
||||
|
||||
Attributes:
|
||||
-----------
|
||||
pack : Callable
|
||||
The packer function to use for serializing the data.
|
||||
|
||||
Methods:
|
||||
--------
|
||||
send(data : dict, topic : str):
|
||||
Send the data with the given topic.
|
||||
|
||||
start():
|
||||
Connect to the socket.
|
||||
|
||||
stop():
|
||||
Close the socket.
|
||||
"""
|
||||
|
||||
def __init__(self, config: ZmqConf):
|
||||
self.config = config
|
||||
|
||||
self.context = zmq.Context.instance()
|
||||
self.socket = self.context.socket(zmq.PUB)
|
||||
|
||||
self.pack = get_packer(config.packstyle)
|
||||
self.connect()
|
||||
|
||||
def connect(self) -> None:
|
||||
try:
|
||||
if self.config.verbose:
|
||||
print(f"connecting to {self.config.publisher_address}")
|
||||
self.socket.connect(self.config.publisher_address)
|
||||
except Exception as e:
|
||||
print(f"failed to bind to zeromq socket: {e}")
|
||||
sys.exit(-1)
|
||||
self.is_connected = False
|
||||
|
||||
def send(self, data: dict[str, Serializable], topic: str) -> None:
|
||||
"""
|
||||
Take the data as a dict, serialize it with the given packer and send it to the zmq socket.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
self.start()
|
||||
payload = self.pack(data)
|
||||
if self.config.verbose:
|
||||
print(f"sending message {payload} to topic {topic}")
|
||||
self.socket.send_multipart([topic.encode(), payload.encode()])
|
||||
|
||||
def __del__(self):
|
||||
self.socket.close()
|
||||
|
||||
|
||||
class ZmqAsyncPublisher(AsyncSink):
|
||||
def __init__(self, config: ZmqConf):
|
||||
self.config = config
|
||||
|
||||
self.context = zmq.asyncio.Context.instance()
|
||||
self.socket: zmq.asyncio.Socket = self.context.socket(zmq.PUB)
|
||||
|
||||
self.pack = get_packer(config.packstyle)
|
||||
self.connect()
|
||||
|
||||
def connect(self) -> None:
|
||||
def start(self) -> None:
|
||||
"""Connect to the zmq socket."""
|
||||
try:
|
||||
if self.config.verbose:
|
||||
print(f"connecting to {self.config.publisher_address}")
|
||||
@@ -56,12 +58,72 @@ class ZmqAsyncPublisher(AsyncSink):
|
||||
except Exception as e:
|
||||
print(f"failed to bind to zeromq socket: {e}")
|
||||
sys.exit(-1)
|
||||
else:
|
||||
self.is_connected = True
|
||||
|
||||
def stop(self) -> None:
|
||||
self.socket.close()
|
||||
self.is_connected = False
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}(host={self.config.publisher_address}, port={self.config.publisher_port})"
|
||||
|
||||
|
||||
class ZmqAsyncPublisher(AsyncSink):
|
||||
"""
|
||||
Async publisher that sends messages to a ZMQ PUB socket.
|
||||
|
||||
Attributes:
|
||||
-----------
|
||||
pack : Callable
|
||||
The packer function to use for serializing the data.
|
||||
|
||||
Methods:
|
||||
--------
|
||||
send(data : dict, topic : str):
|
||||
Send the data with the given topic.
|
||||
|
||||
start():
|
||||
Connect to the socket.
|
||||
|
||||
stop():
|
||||
Close the socket.
|
||||
"""
|
||||
|
||||
def __init__(self, config: ZmqConf):
|
||||
self.config = config
|
||||
self.context = zmq.asyncio.Context.instance()
|
||||
self.socket: zmq.asyncio.Socket = self.context.socket(zmq.PUB)
|
||||
self.pack: Callable = get_packer(config.packstyle)
|
||||
self.is_connected = False
|
||||
|
||||
async def send(self, data: dict[str, Serializable], topic: str) -> None:
|
||||
"""
|
||||
Take the data as a dict, serialize it with the given packer and send it to the zmq socket.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
self.start()
|
||||
payload = self.pack(data)
|
||||
if self.config.verbose:
|
||||
print(f"sending message {payload} to topic {topic}")
|
||||
await self.socket.send_multipart([topic.encode(), payload.encode()])
|
||||
|
||||
def __del__(self):
|
||||
def start(self) -> None:
|
||||
"""Connect to the zmq socket."""
|
||||
try:
|
||||
if self.config.verbose:
|
||||
print(f"connecting to {self.config.publisher_address}")
|
||||
self.socket.connect(self.config.publisher_address)
|
||||
except Exception as e:
|
||||
print(f"failed to bind to zeromq socket: {e}")
|
||||
sys.exit(-1)
|
||||
else:
|
||||
self.is_connected = True
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Close the zmq socket."""
|
||||
self.socket.close()
|
||||
self.is_connected = False
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}(host={self.config.publisher_address}, port={self.config.publisher_port})"
|
||||
|
||||
@@ -12,34 +12,43 @@ from .config import ZmqConf
|
||||
|
||||
|
||||
class ZmqSubscriber(Source):
|
||||
def __init__(self, config: ZmqConf, topic: str):
|
||||
self.config = config
|
||||
"""
|
||||
Source that subscribes to one or many topics from a zmq broker and receives messages via the receive() function.
|
||||
|
||||
Attributes:
|
||||
-----------
|
||||
unpack : Callable
|
||||
The unpacker function to use for deserializing the data.
|
||||
|
||||
Methods:
|
||||
--------
|
||||
receive() -> tuple[str, dict]:
|
||||
Send the data with the given topic.
|
||||
|
||||
start():
|
||||
Connect to the socket.
|
||||
|
||||
stop():
|
||||
Close the socket.
|
||||
"""
|
||||
|
||||
def __init__(self, config: ZmqConf, topic: str | list[str]):
|
||||
"""
|
||||
Constructs new ZmqAsyncSubscriber instance.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
config : ZmqConf
|
||||
The configuration dataclass object for the zmq connection.
|
||||
topic : str
|
||||
The topic or list of topics to subscribe to.
|
||||
"""
|
||||
self.config = config
|
||||
self.topic = topic
|
||||
self.context = zmq.Context.instance()
|
||||
self.socket = self.context.socket(zmq.SUB)
|
||||
self.connect()
|
||||
self.subscribe(topic)
|
||||
|
||||
self.unpack = get_unpacker(config.packstyle)
|
||||
|
||||
def connect(self):
|
||||
try:
|
||||
# print(f"Connecting to { self.config.consumer_connection }")
|
||||
self.socket.connect(self.config.subscriber_address)
|
||||
except Exception as e:
|
||||
print(f"failed to bind to zeromq socket: {e}")
|
||||
sys.exit(-1)
|
||||
|
||||
def _subscribe_single_topic(self, topic: str):
|
||||
self.socket.setsockopt(zmq.SUBSCRIBE, topic.encode())
|
||||
|
||||
def subscribe(self, topic: str | list[str] | tuple[str]):
|
||||
# Accepts single topic or list of topics
|
||||
if isinstance(topic, (list, tuple)):
|
||||
for t in topic:
|
||||
self._subscribe_single_topic(t)
|
||||
else:
|
||||
self._subscribe_single_topic(topic)
|
||||
self.is_connected = False
|
||||
|
||||
def receive(self) -> tuple[str, dict]:
|
||||
"""
|
||||
@@ -48,34 +57,26 @@ class ZmqSubscriber(Source):
|
||||
Returns:
|
||||
tuple(topic: str, message: dict): the message received
|
||||
"""
|
||||
if not self.is_connected:
|
||||
self.start()
|
||||
(topic, payload) = self.socket.recv_multipart()
|
||||
message = self.unpack(payload.decode())
|
||||
topic = topic.decode()
|
||||
return (topic, message)
|
||||
|
||||
def __del__(self):
|
||||
self.socket.close()
|
||||
|
||||
|
||||
class ZmqAsyncSubscriber(AsyncSource):
|
||||
def __init__(self, config: ZmqConf, topic: str):
|
||||
self.config = config
|
||||
self.context = zmq.asyncio.Context.instance()
|
||||
self.socket: zmq.asyncio.Socket = self.context.socket(zmq.SUB)
|
||||
self.connect()
|
||||
self.subscribe(topic)
|
||||
|
||||
self.unpack = get_unpacker(config.packstyle)
|
||||
|
||||
def connect(self):
|
||||
def start(self):
|
||||
try:
|
||||
self.socket.connect(self.config.subscriber_address)
|
||||
self.subscribe(self.topic)
|
||||
except Exception as e:
|
||||
print(f"failed to bind to zeromq socket: {e}")
|
||||
sys.exit(-1)
|
||||
else:
|
||||
self.is_connected = True
|
||||
|
||||
def _subscribe_single_topic(self, topic: str):
|
||||
self.socket.setsockopt(zmq.SUBSCRIBE, topic.encode())
|
||||
def stop(self):
|
||||
self.socket.close()
|
||||
self.is_connected = False
|
||||
|
||||
def subscribe(self, topic: str | list[str] | tuple[str]):
|
||||
# Accepts single topic or list of topics
|
||||
@@ -85,6 +86,52 @@ class ZmqAsyncSubscriber(AsyncSource):
|
||||
else:
|
||||
self._subscribe_single_topic(topic)
|
||||
|
||||
def _subscribe_single_topic(self, topic: str):
|
||||
self.socket.setsockopt(zmq.SUBSCRIBE, topic.encode())
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}(host={self.config.subscriber_address}, port={self.config.subscriber_port})"
|
||||
|
||||
|
||||
class ZmqAsyncSubscriber(AsyncSource):
|
||||
"""
|
||||
Async source that subscribes to one or many topics from a zmq broker and receives messages via the receive() function.
|
||||
|
||||
Attributes:
|
||||
-----------
|
||||
unpack : Callable
|
||||
The unpacker function to use for deserializing the data.
|
||||
|
||||
Methods:
|
||||
--------
|
||||
receive() -> tuple[str, dict]:
|
||||
Send the data with the given topic.
|
||||
|
||||
start():
|
||||
Connect to the socket.
|
||||
|
||||
stop():
|
||||
Close the socket.
|
||||
"""
|
||||
|
||||
def __init__(self, config: ZmqConf, topic: str | list[str]):
|
||||
"""
|
||||
Constructs new ZmqAsyncSubscriber instance.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
config : ZmqConf
|
||||
The configuration dataclass object for the zmq connection.
|
||||
topic : str
|
||||
The topic or list of topics to subscribe to.
|
||||
"""
|
||||
self.config = config
|
||||
self.topic = topic
|
||||
self.context = zmq.asyncio.Context.instance()
|
||||
self.socket: zmq.asyncio.Socket = self.context.socket(zmq.SUB)
|
||||
self.unpack = get_unpacker(config.packstyle)
|
||||
self.is_connected = True
|
||||
|
||||
async def receive(self) -> tuple[str, dict]:
|
||||
"""
|
||||
reads a message from the zmq bus and returns it
|
||||
@@ -92,10 +139,43 @@ class ZmqAsyncSubscriber(AsyncSource):
|
||||
Returns:
|
||||
tuple(topic: str, message: dict): the message received
|
||||
"""
|
||||
if not self.is_connected:
|
||||
self.start()
|
||||
(topic, payload) = await self.socket.recv_multipart()
|
||||
message = self.unpack(payload.decode())
|
||||
topic = topic.decode()
|
||||
return (topic, message)
|
||||
|
||||
def __del__(self):
|
||||
def start(self):
|
||||
"""Connect to the zmq socket."""
|
||||
try:
|
||||
self.socket.connect(self.config.subscriber_address)
|
||||
except Exception as e:
|
||||
print(f"failed to bind to zeromq socket: {e}")
|
||||
sys.exit(-1)
|
||||
else:
|
||||
self.is_connected = True
|
||||
self.subscribe(self.topic)
|
||||
|
||||
def stop(self):
|
||||
"""Close the zmq socket."""
|
||||
self.socket.close()
|
||||
self.is_connected = False
|
||||
|
||||
def subscribe(self, topic: str | list[str] | tuple[str]):
|
||||
"""
|
||||
Subscribes to the given topic(s) on the zmq socket.
|
||||
|
||||
Accepts single topic or list of topics.
|
||||
"""
|
||||
if isinstance(topic, (list, tuple)):
|
||||
for t in topic:
|
||||
self._subscribe_single_topic(t)
|
||||
else:
|
||||
self._subscribe_single_topic(topic)
|
||||
|
||||
def _subscribe_single_topic(self, topic: str):
|
||||
self.socket.setsockopt(zmq.SUBSCRIBE, topic.encode())
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}(host={self.config.subscriber_address}, port={self.config.subscriber_port})"
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from heisskleber.udp import UdpSubscriber, UdpConf
|
||||
from heisskleber.udp import UdpConf, UdpSubscriber
|
||||
|
||||
|
||||
def main() -> None:
|
||||
conf = UdpConf(ip="192.168.137.1", port=6600)
|
||||
conf = UdpConf(host="192.168.137.1", port=6600)
|
||||
subscriber = UdpSubscriber(conf)
|
||||
|
||||
while True:
|
||||
@@ -12,4 +12,3 @@ def main() -> None:
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ async def main():
|
||||
topic1 = "topic1"
|
||||
topic2 = "topic2"
|
||||
|
||||
config = MqttConf(broker="localhost", port=1883, user="", password="") # not a real password
|
||||
config = MqttConf(host="localhost", port=1883, user="", password="") # not a real password
|
||||
sub1 = AsyncMqttSubscriber(config, topic1)
|
||||
sub2 = AsyncMqttSubscriber(config, topic2)
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ from heisskleber.stream import Joint, Resampler, ResamplerConf
|
||||
async def main():
|
||||
topics = ["topic0", "topic1", "topic2", "topic3"]
|
||||
|
||||
config = MqttConf(broker="localhost", port=1883, user="", password="") # not a real password
|
||||
config = MqttConf(host="localhost", port=1883, user="", password="") # not a real password
|
||||
subs = [AsyncMqttSubscriber(config, topic=topic) for topic in topics]
|
||||
|
||||
resampler_config = ResamplerConf(resample_rate=1000)
|
||||
|
||||
@@ -4,7 +4,7 @@ from heisskleber.mqtt import AsyncMqttSubscriber, MqttConf
|
||||
|
||||
|
||||
async def main():
|
||||
conf = MqttConf(broker="localhost", port=1883, user="", password="")
|
||||
conf = MqttConf(host="localhost", port=1883, user="", password="")
|
||||
sub = AsyncMqttSubscriber(conf, topic="#")
|
||||
# async for topic, message in sub:
|
||||
# print(message)
|
||||
|
||||
@@ -20,7 +20,7 @@ async def send_every_n_miliseconds(frequency, value, pub, topic):
|
||||
|
||||
|
||||
async def main2():
|
||||
config = MqttConf(broker="localhost", port=1883, user="", password="")
|
||||
config = MqttConf(host="localhost", port=1883, user="", password="")
|
||||
|
||||
pubs = [AsyncMqttPublisher(config) for i in range(5)]
|
||||
tasks = []
|
||||
|
||||
@@ -5,7 +5,7 @@ from heisskleber.stream import Resampler, ResamplerConf
|
||||
|
||||
|
||||
async def main():
|
||||
conf = MqttConf(broker="localhost", port=1883, user="", password="")
|
||||
conf = MqttConf(host="localhost", port=1883, user="", password="")
|
||||
sub = AsyncMqttSubscriber(conf, topic="#")
|
||||
|
||||
resampler = Resampler(ResamplerConf(), sub)
|
||||
|
||||
@@ -4,7 +4,7 @@ from heisskleber.mqtt import AsyncMqttSubscriber, MqttConf
|
||||
|
||||
|
||||
async def main():
|
||||
config = MqttConf(broker="localhost", port=1883, user="", password="")
|
||||
config = MqttConf(host="localhost", port=1883, user="", password="")
|
||||
|
||||
sub = AsyncMqttSubscriber(config, topic="#")
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ def main():
|
||||
# topic2 = "topic2"
|
||||
|
||||
config = MqttConf(
|
||||
broker="localhost", port=1883, user="", password=""
|
||||
host="localhost", port=1883, user="", password=""
|
||||
) # , not a real password port=1883, user="", password="")
|
||||
sub1 = MqttSubscriber(config, topic1)
|
||||
# sub2 = MqttSubscriber(config, topic2)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
protocol : "tcp" # ipc protocol
|
||||
interface: "127.0.0.1" # the interface to bind to
|
||||
publisher_port : 5555 # port used by primary producers
|
||||
subscriber_port: 5556 # port used by primary consumers
|
||||
protocol: "tcp" # ipc protocol
|
||||
host: "127.0.0.1" # the interface to bind to
|
||||
publisher_port: 5555 # port used by primary producers
|
||||
subscriber_port: 5556 # port used by primary consumers
|
||||
|
||||
@@ -39,6 +39,16 @@ def test_console_sink_pretty_verbose(capsys) -> None:
|
||||
assert captured.out == 'test:\t{\n "key": 3\n}\n'
|
||||
|
||||
|
||||
def test_console_repr() -> None:
|
||||
sink = ConsoleSink()
|
||||
assert repr(sink) == "ConsoleSink(pretty=False, verbose=False)"
|
||||
|
||||
|
||||
def test_async_console_repr() -> None:
|
||||
sink = AsyncConsoleSink()
|
||||
assert repr(sink) == "AsyncConsoleSink(pretty=False, verbose=False)"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_console_sink(capsys) -> None:
|
||||
sink = AsyncConsoleSink()
|
||||
|
||||
@@ -14,7 +14,7 @@ from heisskleber.mqtt.subscriber import MqttSubscriber
|
||||
@pytest.fixture
|
||||
def mock_mqtt_conf() -> MqttConf:
|
||||
return MqttConf(
|
||||
broker="localhost",
|
||||
host="localhost",
|
||||
port=1883,
|
||||
user="user",
|
||||
password="passwd", # noqa: S106, this is a test password
|
||||
@@ -40,12 +40,14 @@ def mock_queue():
|
||||
def test_mqtt_base_intialization(mock_mqtt_client, mock_mqtt_conf):
|
||||
"""Test that the intialization of the mqtt client is as expected."""
|
||||
base = MqttBase(config=mock_mqtt_conf)
|
||||
base.start()
|
||||
|
||||
mock_mqtt_client.assert_called_once()
|
||||
mock_mqtt_client.return_value.loop_start.assert_called_once()
|
||||
mock_client_instance = mock_mqtt_client.return_value
|
||||
mock_client_instance.username_pw_set.assert_called_with(mock_mqtt_conf.user, mock_mqtt_conf.password)
|
||||
mock_client_instance.connect.assert_called_with(mock_mqtt_conf.broker, mock_mqtt_conf.port)
|
||||
mock_client_instance.connect.assert_called_with(mock_mqtt_conf.host, mock_mqtt_conf.port)
|
||||
assert base.client
|
||||
assert base.client.on_connect == base._on_connect
|
||||
assert base.client.on_disconnect == base._on_disconnect
|
||||
assert base.client.on_publish == base._on_publish
|
||||
@@ -56,7 +58,7 @@ def test_mqtt_base_on_connect(mock_mqtt_client, mock_mqtt_conf, capsys):
|
||||
base = MqttBase(config=mock_mqtt_conf)
|
||||
base._on_connect(None, None, {}, 0)
|
||||
captured = capsys.readouterr()
|
||||
assert f"MQTT node connected to {mock_mqtt_conf.broker}:{mock_mqtt_conf.port}" in captured.out
|
||||
assert f"MQTT node connected to {mock_mqtt_conf.host}:{mock_mqtt_conf.port}" in captured.out
|
||||
|
||||
|
||||
def test_mqtt_base_on_disconnect_with_error(mock_mqtt_client, mock_mqtt_conf, capsys):
|
||||
@@ -71,7 +73,8 @@ def test_mqtt_base_on_disconnect_with_error(mock_mqtt_client, mock_mqtt_conf, ca
|
||||
|
||||
def test_mqtt_subscribes_single_topic(mock_mqtt_client, mock_mqtt_conf):
|
||||
"""Test that the mqtt client subscribes to a single topic."""
|
||||
_ = MqttSubscriber(topics="singleTopic", config=mock_mqtt_conf)
|
||||
sub = MqttSubscriber(topics="singleTopic", config=mock_mqtt_conf)
|
||||
sub.start()
|
||||
|
||||
actual_calls = mock_mqtt_client.return_value.subscribe.call_args_list
|
||||
assert actual_calls == [call("singleTopic", mock_mqtt_conf.qos)]
|
||||
@@ -82,7 +85,8 @@ def test_mqtt_subscribes_multiple_topics(mock_mqtt_client, mock_mqtt_conf):
|
||||
|
||||
I would love to do this via parametrization, but the call argument is built differently for single size lists and longer lists.
|
||||
"""
|
||||
_ = MqttSubscriber(topics=["multiple1", "multiple2"], config=mock_mqtt_conf)
|
||||
sub = MqttSubscriber(topics=["multiple1", "multiple2"], config=mock_mqtt_conf)
|
||||
sub.start()
|
||||
|
||||
actual_calls = mock_mqtt_client.return_value.subscribe.call_args_list
|
||||
assert actual_calls == [
|
||||
@@ -92,7 +96,8 @@ def test_mqtt_subscribes_multiple_topics(mock_mqtt_client, mock_mqtt_conf):
|
||||
|
||||
def test_mqtt_subscribes_multiple_topics_tuple(mock_mqtt_client, mock_mqtt_conf):
|
||||
"""Test that the mqtt client subscribes to multiple topics passed as tuple."""
|
||||
_ = MqttSubscriber(topics=("multiple1", "multiple2"), config=mock_mqtt_conf)
|
||||
sub = MqttSubscriber(topics=("multiple1", "multiple2"), config=mock_mqtt_conf)
|
||||
sub.start()
|
||||
|
||||
actual_calls = mock_mqtt_client.return_value.subscribe.call_args_list
|
||||
assert actual_calls == [
|
||||
|
||||
@@ -29,10 +29,11 @@ def mock_serial_device_publisher():
|
||||
def test_serial_subscriber_initialization(mock_serial_device_subscriber, serial_conf):
|
||||
"""Test that the SerialSubscriber class initializes correctly.
|
||||
Mocks the serial.Serial class to avoid opening a serial port."""
|
||||
_ = SerialSubscriber(
|
||||
sub = SerialSubscriber(
|
||||
config=serial_conf,
|
||||
topic="",
|
||||
)
|
||||
sub.start()
|
||||
mock_serial_device_subscriber.assert_called_with(
|
||||
port=serial_conf.port,
|
||||
baudrate=serial_conf.baudrate,
|
||||
@@ -45,6 +46,7 @@ def test_serial_subscriber_initialization(mock_serial_device_subscriber, serial_
|
||||
def test_serial_subscriber_receive(mock_serial_device_subscriber, serial_conf):
|
||||
"""Test that the SerialSubscriber class calls readline and unpack as expected."""
|
||||
subscriber = SerialSubscriber(config=serial_conf, topic="")
|
||||
subscriber.start()
|
||||
|
||||
# Set up the readline return value
|
||||
mock_serial_instance = mock_serial_device_subscriber.return_value
|
||||
@@ -69,6 +71,7 @@ def test_serial_subscriber_converts_bytes_to_str():
|
||||
"""Test that the SerialSubscriber class converts bytes to str as expected."""
|
||||
with patch("heisskleber.serial.subscriber.serial.Serial") as mock_serial:
|
||||
subscriber = SerialSubscriber(config=SerialConf(), topic="", custom_unpack=lambda x: x)
|
||||
subscriber.start()
|
||||
|
||||
# Set the readline method to raise UnicodeError
|
||||
mock_serial_instance = mock_serial.return_value
|
||||
@@ -86,6 +89,7 @@ def test_serial_publisher_initialization(mock_serial_device_publisher, serial_co
|
||||
"""Test that the SerialPublisher class initializes correctly.
|
||||
Mocks the serial.Serial class to avoid opening a serial port."""
|
||||
publisher = SerialPublisher(config=serial_conf)
|
||||
publisher.start()
|
||||
mock_serial_device_publisher.assert_called_with(
|
||||
port=serial_conf.port,
|
||||
baudrate=serial_conf.baudrate,
|
||||
@@ -93,7 +97,7 @@ def test_serial_publisher_initialization(mock_serial_device_publisher, serial_co
|
||||
parity=serial.PARITY_NONE,
|
||||
stopbits=serial.STOPBITS_ONE,
|
||||
)
|
||||
assert publisher.serial
|
||||
assert publisher.serial_connection
|
||||
|
||||
|
||||
def test_serial_publisher_send(mock_serial_device_publisher, serial_conf):
|
||||
|
||||
@@ -67,11 +67,12 @@ async def test_resampler_multiple_modes(mock_subscriber):
|
||||
]
|
||||
)
|
||||
|
||||
config = ResamplerConf(resample_rate=1000) # Fill in your MQTT configuration
|
||||
config = ResamplerConf(resample_rate=1000)
|
||||
resampler = Resampler(config, mock_subscriber)
|
||||
|
||||
# Test the resample method
|
||||
resampled_data = [await resampler.receive() for _ in range(3)]
|
||||
resampler.stop()
|
||||
|
||||
assert resampled_data[0] == {"epoch": 0.0, "data": 1.5}
|
||||
assert resampled_data[1] == {"epoch": 1.0, "data": 3.5}
|
||||
@@ -89,11 +90,12 @@ async def test_resampler_upsampling(mock_subscriber):
|
||||
]
|
||||
)
|
||||
|
||||
config = ResamplerConf(resample_rate=250) # Fill in your MQTT configuration
|
||||
config = ResamplerConf(resample_rate=250)
|
||||
resampler = Resampler(config, mock_subscriber)
|
||||
|
||||
# Test the resample method
|
||||
resampled_data = [await resampler.receive() for _ in range(7)]
|
||||
resampler.stop()
|
||||
|
||||
assert resampled_data[0] == {"epoch": 0.0, "data": 1.0}
|
||||
assert resampled_data[1] == {"epoch": 0.25, "data": 1.25}
|
||||
|
||||
@@ -17,19 +17,22 @@ def mock_socket():
|
||||
|
||||
@pytest.fixture
|
||||
def mock_conf():
|
||||
return UdpConf(ip="127.0.0.1", port=12345, packer="json")
|
||||
return UdpConf(host="127.0.0.1", port=12345, packer="json")
|
||||
|
||||
|
||||
def test_connects_to_socket(mock_socket, mock_conf) -> None:
|
||||
_ = UdpPublisher(mock_conf)
|
||||
pub = UdpPublisher(mock_conf)
|
||||
pub.start()
|
||||
|
||||
# constructor was called
|
||||
mock_socket.assert_called_with(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
pub.stop()
|
||||
|
||||
|
||||
def test_closes_socket(mock_socket, mock_conf) -> None:
|
||||
pub = UdpPublisher(mock_conf)
|
||||
del pub
|
||||
pub.start()
|
||||
pub.stop()
|
||||
|
||||
# instace was closed
|
||||
mock_socket.return_value.close.assert_called()
|
||||
@@ -45,8 +48,9 @@ def test_packs_and_sends_message(mock_socket, mock_conf) -> None:
|
||||
|
||||
mock_socket.return_value.sendto.assert_called_with(
|
||||
b'{"key": "val", "intkey": 1, "floatkey": 1.0, "topic": "test"}',
|
||||
(str(mock_conf.ip), mock_conf.port),
|
||||
(str(mock_conf.host), mock_conf.port),
|
||||
)
|
||||
pub.stop()
|
||||
|
||||
|
||||
def test_subscriber_receives_message_from_queue(mock_conf) -> None:
|
||||
@@ -59,13 +63,16 @@ def test_subscriber_receives_message_from_queue(mock_conf) -> None:
|
||||
topic, data = sub.receive()
|
||||
assert test_topic == topic
|
||||
assert test_data == data
|
||||
sub.stop()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def udp_sub(mock_conf):
|
||||
sub = UdpSubscriber(mock_conf)
|
||||
sub.start_loop()
|
||||
sub.config.port = 12346 # explicitly set port to avoid conflicts
|
||||
sub.start()
|
||||
yield sub
|
||||
sub.stop()
|
||||
|
||||
|
||||
def test_sends_message_between_pub_and_sub(udp_sub, mock_conf):
|
||||
|
||||
@@ -33,9 +33,9 @@ def start_broker():
|
||||
|
||||
|
||||
def test_config_parses_correctly():
|
||||
conf = ZmqConf(protocol="tcp", interface="localhost", publisher_port=5555, subscriber_port=5556)
|
||||
conf = ZmqConf(protocol="tcp", host="localhost", publisher_port=5555, subscriber_port=5556)
|
||||
assert conf.protocol == "tcp"
|
||||
assert conf.interface == "localhost"
|
||||
assert conf.host == "localhost"
|
||||
assert conf.publisher_port == 5555
|
||||
assert conf.subscriber_port == 5556
|
||||
|
||||
@@ -44,13 +44,13 @@ def test_config_parses_correctly():
|
||||
|
||||
|
||||
def test_instantiate_subscriber():
|
||||
conf = ZmqConf(protocol="tcp", interface="localhost", publisher_port=5555, subscriber_port=5556)
|
||||
conf = ZmqConf(protocol="tcp", host="localhost", publisher_port=5555, subscriber_port=5556)
|
||||
sub = ZmqSubscriber(conf, "test")
|
||||
assert sub.config == conf
|
||||
|
||||
|
||||
def test_instantiate_publisher():
|
||||
conf = ZmqConf(protocol="tcp", interface="localhost", publisher_port=5555, subscriber_port=5556)
|
||||
conf = ZmqConf(protocol="tcp", host="localhost", publisher_port=5555, subscriber_port=5556)
|
||||
pub = ZmqPublisher(conf)
|
||||
assert pub.config == conf
|
||||
|
||||
@@ -59,7 +59,9 @@ def test_send_receive(start_broker):
|
||||
print("test_send_receive")
|
||||
topic = "test"
|
||||
source = get_source("zmq", topic)
|
||||
source.start()
|
||||
sink = get_sink("zmq")
|
||||
sink.start()
|
||||
time.sleep(1) # this is crucial, otherwise the source might hang
|
||||
for i in range(10):
|
||||
message = {"m": i}
|
||||
|
||||
@@ -33,13 +33,13 @@ def start_broker() -> Generator[Process, None, None]:
|
||||
|
||||
|
||||
def test_instantiate_subscriber() -> None:
|
||||
conf = ZmqConf(protocol="tcp", interface="localhost", publisher_port=5555, subscriber_port=5556)
|
||||
conf = ZmqConf(protocol="tcp", host="localhost", publisher_port=5555, subscriber_port=5556)
|
||||
sub = ZmqAsyncSubscriber(conf, "test")
|
||||
assert sub.config == conf
|
||||
|
||||
|
||||
def test_instantiate_publisher() -> None:
|
||||
conf = ZmqConf(protocol="tcp", interface="localhost", publisher_port=5555, subscriber_port=5556)
|
||||
conf = ZmqConf(protocol="tcp", host="localhost", publisher_port=5555, subscriber_port=5556)
|
||||
pub = ZmqPublisher(conf)
|
||||
assert pub.config == conf
|
||||
|
||||
@@ -48,9 +48,11 @@ def test_instantiate_publisher() -> None:
|
||||
async def test_send_receive(start_broker) -> None:
|
||||
print("test_send_receive")
|
||||
topic = "test"
|
||||
conf = ZmqConf(protocol="tcp", interface="localhost", publisher_port=5555, subscriber_port=5556)
|
||||
conf = ZmqConf(protocol="tcp", host="localhost", publisher_port=5555, subscriber_port=5556)
|
||||
source = ZmqAsyncSubscriber(conf, topic)
|
||||
sink = ZmqAsyncPublisher(conf)
|
||||
source.start()
|
||||
sink.start()
|
||||
time.sleep(1) # this is crucial, otherwise the source might hang
|
||||
for i in range(10):
|
||||
message = {"m": i}
|
||||
|
||||
Reference in New Issue
Block a user