diff --git a/config/heisskleber/mqtt.yaml b/config/heisskleber/mqtt.yaml index 72b7bd0..7a97d0c 100644 --- a/config/heisskleber/mqtt.yaml +++ b/config/heisskleber/mqtt.yaml @@ -1,4 +1,4 @@ -broker: "10.47.36.1" +host: "10.47.36.1" user: "" password: "" port: 1883 diff --git a/config/heisskleber/zmq.yaml b/config/heisskleber/zmq.yaml index 5e7299d..a7ef27a 100644 --- a/config/heisskleber/zmq.yaml +++ b/config/heisskleber/zmq.yaml @@ -1,4 +1,4 @@ -protocol : "tcp" # ipc protocol -interface: "127.0.0.1" # the interface to bind to -publisher_port : 5555 # port used by primary producers -subscriber_port: 5556 # port used by primary consumers +protocol: "tcp" # ipc protocol +host: "127.0.0.1" # the interface to bind to +publisher_port: 5555 # port used by primary producers +subscriber_port: 5556 # port used by primary consumers diff --git a/heisskleber/config/__init__.py b/heisskleber/config/__init__.py index 8710d76..aba7128 100644 --- a/heisskleber/config/__init__.py +++ b/heisskleber/config/__init__.py @@ -1,4 +1,4 @@ from .config import BaseConf -from .parse import load_config +from .parse import ConfigType, load_config -__all__ = ["load_config", "BaseConf"] +__all__ = ["load_config", "BaseConf", "ConfigType"] diff --git a/heisskleber/console/sink.py b/heisskleber/console/sink.py index 3246ece..c6d88f7 100644 --- a/heisskleber/console/sink.py +++ b/heisskleber/console/sink.py @@ -16,6 +16,15 @@ class ConsoleSink(Sink): else: print(verbose_topic + str(data)) + def __repr__(self) -> str: + return f"{self.__class__.__name__}(pretty={self.pretty}, verbose={self.verbose})" + + def start(self) -> None: + pass + + def stop(self) -> None: + pass + class AsyncConsoleSink(AsyncSink): def __init__(self, pretty: bool = False, verbose: bool = False) -> None: @@ -29,6 +38,15 @@ class AsyncConsoleSink(AsyncSink): else: print(verbose_topic + str(data)) + def __repr__(self) -> str: + return f"{self.__class__.__name__}(pretty={self.pretty}, verbose={self.verbose})" + + def start(self) -> None: + pass + + def stop(self) -> None: + pass + if __name__ == "__main__": sink = ConsoleSink() diff --git a/heisskleber/console/source.py b/heisskleber/console/source.py index 48f6242..ae3910a 100644 --- a/heisskleber/console/source.py +++ b/heisskleber/console/source.py @@ -1,34 +1,98 @@ +import asyncio import json import sys import time from queue import SimpleQueue from threading import Thread -from heisskleber.core.types import Serializable, Source +from heisskleber.core.types import AsyncSource, Serializable, Source class ConsoleSource(Source): - def __init__(self, topic: str | list[str] | tuple[str] = "console") -> None: - self.topic = "console" + def __init__(self, topic: str = "console") -> None: + self.topic = topic self.queue = SimpleQueue() - self.listener_daemon = Thread(target=self.listener_task, daemon=True) - self.listener_daemon.start() self.pack = json.loads + self.thread: Thread | None = None def listener_task(self): while True: - data = sys.stdin.readline() - payload = self.pack(data) - self.queue.put(payload) + try: + data = sys.stdin.readline() + payload = self.pack(data) + self.queue.put(payload) + except json.decoder.JSONDecodeError: + print("Invalid JSON") + continue + except ValueError: + break + print("listener task finished") def receive(self) -> tuple[str, dict[str, Serializable]]: + if not self.thread: + self.start() + data = self.queue.get() return self.topic, data + def __repr__(self) -> str: + return f"{self.__class__.__name__}(topic={self.topic})" + + def start(self) -> None: + self.thread = Thread(target=self.listener_task, daemon=True) + self.thread.start() + + def stop(self) -> None: + if self.thread: + sys.stdin.close() + self.thread.join() + + +class AsyncConsoleSource(AsyncSource): + def __init__(self, topic: str = "console") -> None: + self.topic = topic + self.queue: asyncio.Queue[dict[str, Serializable]] = asyncio.Queue(maxsize=10) + self.pack = json.loads + self.task: asyncio.Task[None] | None = None + + async def listener_task(self): + while True: + data = sys.stdin.readline() + payload = self.pack(data) + await self.queue.put(payload) + + async def receive(self) -> tuple[str, dict[str, Serializable]]: + if not self.task: + self.start() + + data = await self.queue.get() + return self.topic, data + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(topic={self.topic})" + + def start(self) -> None: + self.task = asyncio.create_task(self.listener_task()) + + def stop(self) -> None: + if self.task: + self.task.cancel() + if __name__ == "__main__": console_source = ConsoleSource() + console_source.start() - while True: - print(console_source.receive()) - time.sleep(1) + print("Listening to console input.") + + count = 0 + + try: + while True: + print(console_source.receive()) + time.sleep(1) + count += 1 + print(count) + except KeyboardInterrupt: + print("Stopped") + sys.exit(0) diff --git a/heisskleber/core/types.py b/heisskleber/core/types.py index 3009dfc..edfc0c8 100644 --- a/heisskleber/core/types.py +++ b/heisskleber/core/types.py @@ -3,7 +3,7 @@ from __future__ import annotations from abc import ABC, abstractmethod from typing import Any, Callable, Union -from heisskleber.config.config import BaseConf +from heisskleber.config import BaseConf Serializable = Union[str, int, float] @@ -23,12 +23,30 @@ class Sink(ABC): pass @abstractmethod - def send(self, data: dict[str, Any], topic: str) -> None: + def send(self, data: dict[str, Serializable], topic: str) -> None: """ Send data via the implemented output stream. """ pass + @abstractmethod + def __repr__(self) -> str: + pass + + @abstractmethod + def start(self) -> None: + """ + Start any background processes and tasks. + """ + pass + + @abstractmethod + def stop(self) -> None: + """ + Stop any background processes and tasks. + """ + pass + class Source(ABC): """ @@ -53,25 +71,21 @@ class Source(ABC): """ pass - -class AsyncSubscriber(ABC): - """ - AsyncSubscriber interface - """ + @abstractmethod + def __repr__(self) -> str: + pass @abstractmethod - def __init__(self, config: Any, topic: str | list[str]) -> None: + def start(self) -> None: """ - Initialize the subscriber with a topic and a configuration object. + Start any background processes and tasks. """ pass @abstractmethod - async def receive(self) -> tuple[str, dict[str, Serializable]]: + def stop(self) -> None: """ - Blocking function to receive data from the implemented input stream. - - Data is returned as a tuple of (topic, data). + Stop any background processes and tasks. """ pass @@ -97,6 +111,24 @@ class AsyncSource(ABC): """ pass + @abstractmethod + def __repr__(self) -> str: + pass + + @abstractmethod + def start(self) -> None: + """ + Start any background processes and tasks. + """ + pass + + @abstractmethod + def stop(self) -> None: + """ + Stop any background processes and tasks. + """ + pass + class AsyncSink(ABC): """ @@ -118,3 +150,21 @@ class AsyncSink(ABC): Send data via the implemented output stream. """ pass + + @abstractmethod + def __repr__(self) -> str: + pass + + @abstractmethod + def start(self) -> None: + """ + Start any background processes and tasks. + """ + pass + + @abstractmethod + def stop(self) -> None: + """ + Stop any background processes and tasks. + """ + pass diff --git a/heisskleber/mqtt/config.py b/heisskleber/mqtt/config.py index 004d955..b52e4f0 100644 --- a/heisskleber/mqtt/config.py +++ b/heisskleber/mqtt/config.py @@ -9,7 +9,7 @@ class MqttConf(BaseConf): MQTT configuration class. """ - broker: str = "localhost" + host: str = "localhost" user: str = "" password: str = "" port: int = 1883 diff --git a/heisskleber/mqtt/mqtt_base.py b/heisskleber/mqtt/mqtt_base.py index 696f88d..2699fd9 100644 --- a/heisskleber/mqtt/mqtt_base.py +++ b/heisskleber/mqtt/mqtt_base.py @@ -34,8 +34,14 @@ class MqttBase: def __init__(self, config: MqttConf) -> None: self.config = config + self.client: mqtt_client | None = None + + def start(self) -> None: self.connect() - self.client.loop_start() + + def stop(self) -> None: + if self.client: + self.client.loop_stop() def connect(self) -> None: self.client = mqtt_client() @@ -52,7 +58,8 @@ class MqttBase: # the default certification authority of the system is used. self.client.tls_set(tls_version=ssl.PROTOCOL_TLS_CLIENT) - self.client.connect(self.config.broker, self.config.port) + self.client.connect(self.config.host, self.config.port) + self.client.loop_start() @staticmethod def _raise_if_thread_died() -> None: @@ -63,7 +70,7 @@ class MqttBase: # MQTT callbacks def _on_connect(self, client, userdata, flags, return_code) -> None: if return_code == 0: - print(f"MQTT node connected to {self.config.broker}:{self.config.port}") + print(f"MQTT node connected to {self.config.host}:{self.config.port}") else: print("Connection failed!") if self.config.verbose: @@ -84,4 +91,4 @@ class MqttBase: print(f"Received message: {message.payload!s}, topic: {message.topic}, qos: {message.qos}") def __del__(self) -> None: - self.client.loop_stop() + self.stop() diff --git a/heisskleber/mqtt/publisher.py b/heisskleber/mqtt/publisher.py index 18c8158..23899fb 100644 --- a/heisskleber/mqtt/publisher.py +++ b/heisskleber/mqtt/publisher.py @@ -1,9 +1,7 @@ from __future__ import annotations -from typing import Any - from heisskleber.core.packer import get_packer -from heisskleber.core.types import Sink +from heisskleber.core.types import Serializable, Sink from .config import MqttConf from .mqtt_base import MqttBase @@ -21,14 +19,26 @@ class MqttPublisher(MqttBase, Sink): super().__init__(config) self.pack = get_packer(config.packstyle) - def send(self, data: dict[str, Any], topic: str) -> None: + def send(self, data: dict[str, Serializable], topic: str) -> None: """ Takes python dictionary, serializes it according to the packstyle and sends it to the broker. Publishing is asynchronous """ + if not self.client.is_connected(): + self.start() + self._raise_if_thread_died() payload = self.pack(data) self.client.publish(topic, payload, qos=self.config.qos, retain=self.config.retain) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(broker={self.config.host}, port={self.config.port})" + + def start(self) -> None: + super().start() + + def stop(self) -> None: + super().stop() diff --git a/heisskleber/mqtt/publisher_async.py b/heisskleber/mqtt/publisher_async.py index 2657d03..bc48e72 100644 --- a/heisskleber/mqtt/publisher_async.py +++ b/heisskleber/mqtt/publisher_async.py @@ -1,7 +1,4 @@ -from __future__ import annotations - -import asyncio -from asyncio import Queue, Task, create_task +from asyncio import Queue, Task, create_task, sleep import aiomqtt @@ -22,8 +19,8 @@ class AsyncMqttPublisher(AsyncSink): def __init__(self, config: MqttConf) -> None: self.config = config self.pack = get_packer(config.packstyle) - self._send_queue: Queue[tuple[dict[str, Serializable], str]] = Queue(maxsize=config.max_saved_messages) - self._sender_task: Task[None] = create_task(self.send_work()) + self._send_queue: Queue[tuple[dict[str, Serializable], str]] = Queue() + self._sender_task: Task[None] | None = None async def send(self, data: dict[str, Serializable], topic: str) -> None: """ @@ -32,6 +29,8 @@ class AsyncMqttPublisher(AsyncSink): Publishing is asynchronous """ + if not self._sender_task: + self.start() await self._send_queue.put((data, topic)) @@ -45,7 +44,7 @@ class AsyncMqttPublisher(AsyncSink): while True: try: async with aiomqtt.Client( - hostname=self.config.broker, + hostname=self.config.host, port=self.config.port, username=self.config.user, password=self.config.password, @@ -57,4 +56,15 @@ class AsyncMqttPublisher(AsyncSink): await client.publish(topic, payload) except aiomqtt.MqttError: print("Connection to MQTT broker failed. Retrying in 5 seconds") - await asyncio.sleep(5) + await sleep(5) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(broker={self.config.host}, port={self.config.port})" + + def start(self) -> None: + self._sender_task = create_task(self.send_work()) + + def stop(self) -> None: + if self._sender_task: + self._sender_task.cancel() + self._sender_task = None diff --git a/heisskleber/mqtt/subscriber.py b/heisskleber/mqtt/subscriber.py index 6a17937..7093b97 100644 --- a/heisskleber/mqtt/subscriber.py +++ b/heisskleber/mqtt/subscriber.py @@ -22,9 +22,8 @@ class MqttSubscriber(MqttBase, Source): def __init__(self, config: MqttConf, topics: str | list[str]) -> None: super().__init__(config) + self.topics = topics self._message_queue: SimpleQueue[MQTTMessage] = SimpleQueue() - self.subscribe(topics) - self.client.on_message = self._on_message self.unpack = get_unpacker(config.packstyle) def subscribe(self, topics: str | list[str] | tuple[str]) -> None: @@ -47,14 +46,29 @@ class MqttSubscriber(MqttBase, Source): Messages are saved in a stack, if no message is available, this function blocks. Returns: - tuple(topic: bytes, message: dict): the message received + tuple(topic: str, message: dict): the message received """ + if not self.client: + self.start() + self._raise_if_thread_died() mqtt_message = self._message_queue.get(block=True, timeout=self.config.timeout_s) message_returned = self.unpack(mqtt_message.payload.decode()) return (mqtt_message.topic, message_returned) + def __repr__(self) -> str: + return f"{self.__class__.__name__}(broker={self.config.host}, port={self.config.port})" + + def start(self) -> None: + super().start() + self.subscribe(self.topics) + self.client.on_message = self._on_message + self.is_connected = True + + def stop(self) -> None: + super().stop() + # callback to add incoming messages onto stack def _on_message(self, client, userdata, message) -> None: self._message_queue.put(message) diff --git a/heisskleber/mqtt/subscriber_async.py b/heisskleber/mqtt/subscriber_async.py index 901db09..5bb8da5 100644 --- a/heisskleber/mqtt/subscriber_async.py +++ b/heisskleber/mqtt/subscriber_async.py @@ -16,7 +16,7 @@ class AsyncMqttSubscriber(AsyncSource): def __init__(self, config: MqttConf, topic: str | list[str]) -> None: self.config: MqttConf = config self.client = Client( - hostname=self.config.broker, + hostname=self.config.host, port=self.config.port, username=self.config.user, password=self.config.password, @@ -24,17 +24,32 @@ class AsyncMqttSubscriber(AsyncSource): self.topics = topic self.unpack = get_unpacker(self.config.packstyle) self.message_queue: Queue[Message] = Queue(self.config.max_saved_messages) - self._listener_task: Task = create_task(self.create_listener()) + self._listener_task: Task[None] | None = None - """ - Await the newest message in the queue and return Tuple - """ + def __repr__(self) -> str: + return f"{self.__class__.__name__}(broker={self.config.host}, port={self.config.port})" + + def start(self) -> None: + self._listener_task = create_task(self.run()) + + def stop(self) -> None: + if self._listener_task: + self._listener_task.cancel() + self._listener_task = None async def receive(self) -> tuple[str, dict[str, Serializable]]: - mqtt_message: Message = await self.message_queue.get() + """ + Await the newest message in the queue and return Tuple + """ + if not self._listener_task: + self.start() + mqtt_message = await self.message_queue.get() return self._handle_message(mqtt_message) - async def create_listener(self): + async def run(self): + """ + Handle the connection to MQTT broker and run the message loop. + """ while True: try: async with self.client: @@ -45,11 +60,10 @@ class AsyncMqttSubscriber(AsyncSource): print("Connection to MQTT failed. Retrying...") await sleep(1) - """ - Listen to incoming messages asynchronously and put them into a queue - """ - async def _listen_mqtt_loop(self) -> None: + """ + Listen to incoming messages asynchronously and put them into a queue + """ async with self.client.messages() as messages: # async with self.client.filtered_messages(self.topics) as messages: async for message in messages: diff --git a/heisskleber/run/cli.py b/heisskleber/run/cli.py index 3fa9d8c..9b97fa1 100644 --- a/heisskleber/run/cli.py +++ b/heisskleber/run/cli.py @@ -1,12 +1,13 @@ import argparse import sys -from typing import Union +from typing import Callable, Union from heisskleber.config import load_config from heisskleber.console.sink import ConsoleSink from heisskleber.core.factories import _registered_sources - -TopicType = Union[str, list[str]] +from heisskleber.mqtt import MqttSubscriber +from heisskleber.udp import UdpSubscriber +from heisskleber.zmq import ZmqSubscriber def parse_args() -> argparse.Namespace: @@ -33,14 +34,12 @@ def parse_args() -> argparse.Namespace: "-H", "--host", type=str, - default="localhost", help="Host or broker for MQTT, zmq and UDP.", ) parser.add_argument( "-P", "--port", type=int, - default=1883, help="Port or serial interface for MQTT, zmq and UDP.", ) parser.add_argument("-v", "--verbose", action="store_true") @@ -49,7 +48,19 @@ def parse_args() -> argparse.Namespace: return parser.parse_args() -def run() -> None: +def keyboardexit(func) -> Callable: + def wrapper(*args, **kwargs) -> Union[None, int]: + try: + return func(*args, **kwargs) + except KeyboardInterrupt: + print("Exiting...") + sys.exit(0) + + return wrapper + + +@keyboardexit +def main() -> None: args = parse_args() sink = ConsoleSink(pretty=args.pretty, verbose=args.verbose) @@ -58,36 +69,27 @@ def run() -> None: try: config = load_config(conf_cls(), args.type, read_commandline=False) except FileNotFoundError: - print(f"Using default config for {args.type}.") + print(f"No config file found for {args.type}, using default values and user input.") config = conf_cls() - if args.port: - config.port = args.port - - if args.host: - if args.type == "mqtt": - config.broker = args.host - elif args.type == "zmq": - config.interface = args.host - elif args.type == "udp": - config.ip = args.host - - if args.type == "zmq" and args.topic == "#": - args.topic = "" - source = sub_cls(config, args.topic) + if isinstance(source, (MqttSubscriber, UdpSubscriber)): + source.config.host = args.host or source.config.host + source.config.port = args.port or source.config.port + elif isinstance(source, ZmqSubscriber): + source.config.host = args.host or source.config.host + source.config.subscriber_port = args.port or source.config.subscriber_port + args.topic = "" if args.topic == "#" else args.topic + elif isinstance(source, UdpSubscriber): + source.config.port = args.port or source.config.port + + source.start() + sink.start() + while True: topic, data = source.receive() sink.send(data, topic) -def main() -> None: - try: - run() - except KeyboardInterrupt: - print("Exiting...") - sys.exit(0) - - if __name__ == "__main__": main() diff --git a/heisskleber/serial/publisher.py b/heisskleber/serial/publisher.py index bb6bfb9..52c6f38 100644 --- a/heisskleber/serial/publisher.py +++ b/heisskleber/serial/publisher.py @@ -1,5 +1,6 @@ from __future__ import annotations +import sys from typing import Callable, Optional import serial @@ -11,6 +12,7 @@ from .config import SerialConf class SerialPublisher(Sink): + serial_connection: serial.Serial """ Publisher for serial devices. Can be used everywhere that a flucto style publishing connection is required. @@ -30,19 +32,36 @@ class SerialPublisher(Sink): ): self.config = config self.pack = pack_func if pack_func else get_packer("serial") - self._connect() + self.is_connected = False + + def start(self) -> None: + """ + Start the serial connection. + """ + try: + self.serial_connection = serial.Serial( + port=self.config.port, + baudrate=self.config.baudrate, + bytesize=self.config.bytesize, + parity=serial.PARITY_NONE, + stopbits=serial.STOPBITS_ONE, + ) + except serial.SerialException: + print(f"Failed to connect to serial device at port {self.config.port}") + sys.exit(1) - def _connect(self) -> None: - self.serial: serial.Serial = serial.Serial( - port=self.config.port, - baudrate=self.config.baudrate, - bytesize=self.config.bytesize, - parity=serial.PARITY_NONE, - stopbits=serial.STOPBITS_ONE, - ) print(f"Successfully connected to serial device at port {self.config.port}") + self.is_connected = True - def send(self, message: dict[str, Serializable], topic: str) -> None: + def stop(self) -> None: + """ + Stop the serial connection. + """ + if hasattr(self, "serial_connection") and self.serial_connection.is_open: + self.serial_connection.flush() + self.serial_connection.close() + + def send(self, data: dict[str, Serializable], topic: str) -> None: """ Takes python dictionary, serializes it according to the packstyle and sends it to the broker. @@ -52,16 +71,17 @@ class SerialPublisher(Sink): message : dict object to be serialized and sent via the serial connection. Usually a dict. """ - payload = self.pack(message) - self.serial.write(payload.encode(self.config.encoding)) - self.serial.flush() + if not self.is_connected: + self.start() + + payload = self.pack(data) + self.serial_connection.write(payload.encode(self.config.encoding)) + self.serial_connection.flush() if self.config.verbose: print(f"{topic}: {payload}") + def __repr__(self) -> str: + return f"SerialPublisher(port={self.config.port}, baudrate={self.config.baudrate}, bytezize={self.config.bytesize}, encoding={self.config.encoding})" + def __del__(self) -> None: - if not hasattr(self, "serial"): - return - if not self.serial.is_open: - return - self.serial.flush() - self.serial.close() + self.stop() diff --git a/heisskleber/serial/subscriber.py b/heisskleber/serial/subscriber.py index 976a4b5..db3d962 100644 --- a/heisskleber/serial/subscriber.py +++ b/heisskleber/serial/subscriber.py @@ -1,7 +1,6 @@ -from __future__ import annotations - +import sys from collections.abc import Generator -from typing import Callable, Optional +from typing import Callable import serial @@ -11,6 +10,7 @@ from .config import SerialConf class SerialSubscriber(Source): + serial_connection: serial.Serial """ Subscriber for serial devices. Connects to a serial port and reads from it. @@ -30,22 +30,38 @@ class SerialSubscriber(Source): self, config: SerialConf, topic: str | None = None, - custom_unpack: Optional[Callable] = None, # noqa: UP007 + custom_unpack: Callable | None = None, ): self.config = config self.topic = topic self.unpack = custom_unpack if custom_unpack else lambda x: x # types: ignore - self._connect() + self.is_connected = False - def _connect(self): - self.serial: serial.Serial = serial.Serial( - port=self.config.port, - baudrate=self.config.baudrate, - bytesize=self.config.bytesize, - parity=serial.PARITY_NONE, - stopbits=serial.STOPBITS_ONE, - ) + def start(self) -> None: + """ + Start the serial connection. + """ + try: + self.serial_connection = serial.Serial( + port=self.config.port, + baudrate=self.config.baudrate, + bytesize=self.config.bytesize, + parity=serial.PARITY_NONE, + stopbits=serial.STOPBITS_ONE, + ) + except serial.SerialException: + print(f"Failed to connect to serial device at port {self.config.port}") + sys.exit(1) print(f"Successfully connected to serial device at port {self.config.port}") + self.is_connected = True + + def stop(self) -> None: + """ + Stop the serial connection. + """ + if hasattr(self, "serial_connection") and self.serial_connection.is_open: + self.serial_connection.flush() + self.serial_connection.close() def receive(self) -> tuple[str, dict]: """ @@ -57,6 +73,9 @@ class SerialSubscriber(Source): topic is a placeholder to adhere to the Subscriber interface payload is a dictionary containing the data from the serial port """ + if not self.is_connected: + self.start() + # message is a string message = next(self.read_serial_port()) # payload is a dictionary @@ -76,7 +95,7 @@ class SerialSubscriber(Source): buffer = "" while True: try: - buffer = self.serial.readline().decode(self.config.encoding, "ignore") + buffer = self.serial_connection.readline().decode(self.config.encoding, "ignore") yield buffer except UnicodeError as e: if self.config.verbose: @@ -84,10 +103,8 @@ class SerialSubscriber(Source): print(e) continue + def __repr__(self) -> str: + return f"SerialPublisher(port={self.config.port}, baudrate={self.config.baudrate}, bytezize={self.config.bytesize}, encoding={self.config.encoding})" + def __del__(self) -> None: - if not hasattr(self, "serial"): - return - if not self.serial.is_open: - return - self.serial.flush() - self.serial.close() + self.stop() diff --git a/heisskleber/stream/joint.py b/heisskleber/stream/joint.py index a7bacac..ed8102b 100644 --- a/heisskleber/stream/joint.py +++ b/heisskleber/stream/joint.py @@ -25,18 +25,31 @@ class Joint: self.output_queue: asyncio.Queue[dict[str, Serializable]] = asyncio.Queue() self.initialized = asyncio.Event() self.initalize_task = asyncio.create_task(self.sync()) - self.output_task = asyncio.create_task(self.output_work()) self.combined_dict: dict[str, Serializable] = {} + self.task: asyncio.Task[None] | None = None - """ - Main interaction coroutine: Get next value out of the queue. - """ + def __repr__(self) -> str: + return f"""Joint(resample_rate={self.conf.resample_rate}, + sources={len(self.resamplers)} of type(s): {{r.__class__.__name__ for r in self.resamplers}})""" + + def start(self) -> None: + self.task = asyncio.create_task(self.output_work()) + + def stop(self) -> None: + if self.task: + self.task.cancel() async def receive(self) -> dict[str, Any]: + """ + Main interaction coroutine: Get next value out of the queue. + """ + if not self.task: + self.start() output = await self.output_queue.get() return output async def sync(self) -> None: + """Synchronize the resamplers by pulling data from each until the timestamp is aligned. Retains first matching data.""" print("Starting sync") datas = await asyncio.gather(*[source.receive() for source in self.resamplers]) print("Got data") diff --git a/heisskleber/udp/config.py b/heisskleber/udp/config.py index f3e7210..4cd1b09 100644 --- a/heisskleber/udp/config.py +++ b/heisskleber/udp/config.py @@ -10,5 +10,6 @@ class UdpConf(BaseConf): """ port: int = 1234 - ip: str = "127.0.0.1" + host: str = "127.0.0.1" packer: str = "json" + max_queue_size: int = 1000 diff --git a/heisskleber/udp/publisher.py b/heisskleber/udp/publisher.py index 6332e3b..2f46e79 100644 --- a/heisskleber/udp/publisher.py +++ b/heisskleber/udp/publisher.py @@ -1,4 +1,5 @@ import socket +import sys from heisskleber.core.packer import get_packer from heisskleber.core.types import Serializable, Sink @@ -8,15 +9,30 @@ from heisskleber.udp.config import UdpConf class UdpPublisher(Sink): def __init__(self, config: UdpConf) -> None: self.config = config - self.ip = self.config.ip + self.ip = self.config.host self.port = self.config.port - self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) self.pack = get_packer(self.config.packer) + self.is_connected = False - def send(self, message: dict[str, Serializable], topic: str) -> None: - message["topic"] = topic - payload = self.pack(message).encode("utf-8") + def start(self) -> None: + try: + self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + except OSError as e: + print(f"failed to create socket: {e}") + sys.exit(-1) + else: + self.is_connected = True + + def stop(self) -> None: + self.socket.close() + self.is_connected = True + + def send(self, data: dict[str, Serializable], topic: str) -> None: + if not self.is_connected: + self.start() + data["topic"] = topic + payload = self.pack(data).encode("utf-8") self.socket.sendto(payload, (self.ip, self.port)) - def __del__(self) -> None: - self.socket.close() + def __repr__(self) -> str: + return f"{self.__class__.__name__}(host={self.config.host}, port={self.config.port})" diff --git a/heisskleber/udp/subscriber.py b/heisskleber/udp/subscriber.py index 840824e..82ab1d3 100644 --- a/heisskleber/udp/subscriber.py +++ b/heisskleber/udp/subscriber.py @@ -1,6 +1,7 @@ import socket +import sys import threading -from queue import SimpleQueue +from queue import Queue from heisskleber.core.packer import get_unpacker from heisskleber.core.types import Serializable, Source @@ -11,15 +12,30 @@ class UdpSubscriber(Source): def __init__(self, config: UdpConf, topic: str | None = None): self.config = config self.topic = topic - self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - self.socket.bind((self.config.ip, self.config.port)) self.unpacker = get_unpacker(self.config.packer) - self._queue: SimpleQueue[tuple[str, dict[str, Serializable]]] = SimpleQueue() + self._queue: Queue[tuple[str, dict[str, Serializable]]] = Queue(maxsize=self.config.max_queue_size) self._running = threading.Event() + + def start(self) -> None: + try: + self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + except OSError as e: + print(f"failed to create socket: {e}") + sys.exit(-1) + self.socket.bind((self.config.host, self.config.port)) self._running.set() - self._thread: threading.Thread | None = None + self._thread = threading.Thread(target=self._loop, daemon=True) + self._thread.start() + + def stop(self) -> None: + self._running.clear() + # if self._thread is not None: + # self._thread.join() + self.socket.close() def receive(self) -> tuple[str, dict[str, Serializable]]: + if not self._running.is_set(): + self.start() return self._queue.get() def _loop(self) -> None: @@ -33,15 +49,5 @@ class UdpSubscriber(Source): error_message = f"Error in UDP listener loop: {e}" print(error_message) - def start_loop(self) -> None: - self._thread = threading.Thread(target=self._loop, daemon=True) - self._thread.start() - - def stop_loop(self) -> None: - self._running.clear() - if self._thread is not None: - self._thread.join() - self.socket.close() - - def __del__(self) -> None: - self.stop_loop() + def __repr__(self) -> str: + return f"{self.__class__.__name__}(host={self.config.host}, port={self.config.port})" diff --git a/heisskleber/zmq/config.py b/heisskleber/zmq/config.py index 77a46bf..a93a164 100644 --- a/heisskleber/zmq/config.py +++ b/heisskleber/zmq/config.py @@ -6,15 +6,15 @@ from heisskleber.config import BaseConf @dataclass class ZmqConf(BaseConf): protocol: str = "tcp" - interface: str = "127.0.0.1" + host: str = "127.0.0.1" publisher_port: int = 5555 subscriber_port: int = 5556 packstyle: str = "json" @property def publisher_address(self) -> str: - return f"{self.protocol}://{self.interface}:{self.publisher_port}" + return f"{self.protocol}://{self.host}:{self.publisher_port}" @property def subscriber_address(self) -> str: - return f"{self.protocol}://{self.interface}:{self.subscriber_port}" + return f"{self.protocol}://{self.host}:{self.subscriber_port}" diff --git a/heisskleber/zmq/publisher.py b/heisskleber/zmq/publisher.py index 1a754ae..22abffc 100644 --- a/heisskleber/zmq/publisher.py +++ b/heisskleber/zmq/publisher.py @@ -1,4 +1,5 @@ import sys +from typing import Callable import zmq import zmq.asyncio @@ -10,45 +11,46 @@ from .config import ZmqConf class ZmqPublisher(Sink): + """ + Publisher that sends messages to a ZMQ PUB socket. + + Attributes: + ----------- + pack : Callable + The packer function to use for serializing the data. + + Methods: + -------- + send(data : dict, topic : str): + Send the data with the given topic. + + start(): + Connect to the socket. + + stop(): + Close the socket. + """ + def __init__(self, config: ZmqConf): self.config = config - self.context = zmq.Context.instance() self.socket = self.context.socket(zmq.PUB) - self.pack = get_packer(config.packstyle) - self.connect() - - def connect(self) -> None: - try: - if self.config.verbose: - print(f"connecting to {self.config.publisher_address}") - self.socket.connect(self.config.publisher_address) - except Exception as e: - print(f"failed to bind to zeromq socket: {e}") - sys.exit(-1) + self.is_connected = False def send(self, data: dict[str, Serializable], topic: str) -> None: + """ + Take the data as a dict, serialize it with the given packer and send it to the zmq socket. + """ + if not self.is_connected: + self.start() payload = self.pack(data) if self.config.verbose: print(f"sending message {payload} to topic {topic}") self.socket.send_multipart([topic.encode(), payload.encode()]) - def __del__(self): - self.socket.close() - - -class ZmqAsyncPublisher(AsyncSink): - def __init__(self, config: ZmqConf): - self.config = config - - self.context = zmq.asyncio.Context.instance() - self.socket: zmq.asyncio.Socket = self.context.socket(zmq.PUB) - - self.pack = get_packer(config.packstyle) - self.connect() - - def connect(self) -> None: + def start(self) -> None: + """Connect to the zmq socket.""" try: if self.config.verbose: print(f"connecting to {self.config.publisher_address}") @@ -56,12 +58,72 @@ class ZmqAsyncPublisher(AsyncSink): except Exception as e: print(f"failed to bind to zeromq socket: {e}") sys.exit(-1) + else: + self.is_connected = True + + def stop(self) -> None: + self.socket.close() + self.is_connected = False + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(host={self.config.publisher_address}, port={self.config.publisher_port})" + + +class ZmqAsyncPublisher(AsyncSink): + """ + Async publisher that sends messages to a ZMQ PUB socket. + + Attributes: + ----------- + pack : Callable + The packer function to use for serializing the data. + + Methods: + -------- + send(data : dict, topic : str): + Send the data with the given topic. + + start(): + Connect to the socket. + + stop(): + Close the socket. + """ + + def __init__(self, config: ZmqConf): + self.config = config + self.context = zmq.asyncio.Context.instance() + self.socket: zmq.asyncio.Socket = self.context.socket(zmq.PUB) + self.pack: Callable = get_packer(config.packstyle) + self.is_connected = False async def send(self, data: dict[str, Serializable], topic: str) -> None: + """ + Take the data as a dict, serialize it with the given packer and send it to the zmq socket. + """ + if not self.is_connected: + self.start() payload = self.pack(data) if self.config.verbose: print(f"sending message {payload} to topic {topic}") await self.socket.send_multipart([topic.encode(), payload.encode()]) - def __del__(self): + def start(self) -> None: + """Connect to the zmq socket.""" + try: + if self.config.verbose: + print(f"connecting to {self.config.publisher_address}") + self.socket.connect(self.config.publisher_address) + except Exception as e: + print(f"failed to bind to zeromq socket: {e}") + sys.exit(-1) + else: + self.is_connected = True + + def stop(self) -> None: + """Close the zmq socket.""" self.socket.close() + self.is_connected = False + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(host={self.config.publisher_address}, port={self.config.publisher_port})" diff --git a/heisskleber/zmq/subscriber.py b/heisskleber/zmq/subscriber.py index 0297f41..fb53a5c 100644 --- a/heisskleber/zmq/subscriber.py +++ b/heisskleber/zmq/subscriber.py @@ -12,34 +12,43 @@ from .config import ZmqConf class ZmqSubscriber(Source): - def __init__(self, config: ZmqConf, topic: str): - self.config = config + """ + Source that subscribes to one or many topics from a zmq broker and receives messages via the receive() function. + Attributes: + ----------- + unpack : Callable + The unpacker function to use for deserializing the data. + + Methods: + -------- + receive() -> tuple[str, dict]: + Send the data with the given topic. + + start(): + Connect to the socket. + + stop(): + Close the socket. + """ + + def __init__(self, config: ZmqConf, topic: str | list[str]): + """ + Constructs new ZmqAsyncSubscriber instance. + + Parameters: + ----------- + config : ZmqConf + The configuration dataclass object for the zmq connection. + topic : str + The topic or list of topics to subscribe to. + """ + self.config = config + self.topic = topic self.context = zmq.Context.instance() self.socket = self.context.socket(zmq.SUB) - self.connect() - self.subscribe(topic) - self.unpack = get_unpacker(config.packstyle) - - def connect(self): - try: - # print(f"Connecting to { self.config.consumer_connection }") - self.socket.connect(self.config.subscriber_address) - except Exception as e: - print(f"failed to bind to zeromq socket: {e}") - sys.exit(-1) - - def _subscribe_single_topic(self, topic: str): - self.socket.setsockopt(zmq.SUBSCRIBE, topic.encode()) - - def subscribe(self, topic: str | list[str] | tuple[str]): - # Accepts single topic or list of topics - if isinstance(topic, (list, tuple)): - for t in topic: - self._subscribe_single_topic(t) - else: - self._subscribe_single_topic(topic) + self.is_connected = False def receive(self) -> tuple[str, dict]: """ @@ -48,34 +57,26 @@ class ZmqSubscriber(Source): Returns: tuple(topic: str, message: dict): the message received """ + if not self.is_connected: + self.start() (topic, payload) = self.socket.recv_multipart() message = self.unpack(payload.decode()) topic = topic.decode() return (topic, message) - def __del__(self): - self.socket.close() - - -class ZmqAsyncSubscriber(AsyncSource): - def __init__(self, config: ZmqConf, topic: str): - self.config = config - self.context = zmq.asyncio.Context.instance() - self.socket: zmq.asyncio.Socket = self.context.socket(zmq.SUB) - self.connect() - self.subscribe(topic) - - self.unpack = get_unpacker(config.packstyle) - - def connect(self): + def start(self): try: self.socket.connect(self.config.subscriber_address) + self.subscribe(self.topic) except Exception as e: print(f"failed to bind to zeromq socket: {e}") sys.exit(-1) + else: + self.is_connected = True - def _subscribe_single_topic(self, topic: str): - self.socket.setsockopt(zmq.SUBSCRIBE, topic.encode()) + def stop(self): + self.socket.close() + self.is_connected = False def subscribe(self, topic: str | list[str] | tuple[str]): # Accepts single topic or list of topics @@ -85,6 +86,52 @@ class ZmqAsyncSubscriber(AsyncSource): else: self._subscribe_single_topic(topic) + def _subscribe_single_topic(self, topic: str): + self.socket.setsockopt(zmq.SUBSCRIBE, topic.encode()) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(host={self.config.subscriber_address}, port={self.config.subscriber_port})" + + +class ZmqAsyncSubscriber(AsyncSource): + """ + Async source that subscribes to one or many topics from a zmq broker and receives messages via the receive() function. + + Attributes: + ----------- + unpack : Callable + The unpacker function to use for deserializing the data. + + Methods: + -------- + receive() -> tuple[str, dict]: + Send the data with the given topic. + + start(): + Connect to the socket. + + stop(): + Close the socket. + """ + + def __init__(self, config: ZmqConf, topic: str | list[str]): + """ + Constructs new ZmqAsyncSubscriber instance. + + Parameters: + ----------- + config : ZmqConf + The configuration dataclass object for the zmq connection. + topic : str + The topic or list of topics to subscribe to. + """ + self.config = config + self.topic = topic + self.context = zmq.asyncio.Context.instance() + self.socket: zmq.asyncio.Socket = self.context.socket(zmq.SUB) + self.unpack = get_unpacker(config.packstyle) + self.is_connected = True + async def receive(self) -> tuple[str, dict]: """ reads a message from the zmq bus and returns it @@ -92,10 +139,43 @@ class ZmqAsyncSubscriber(AsyncSource): Returns: tuple(topic: str, message: dict): the message received """ + if not self.is_connected: + self.start() (topic, payload) = await self.socket.recv_multipart() message = self.unpack(payload.decode()) topic = topic.decode() return (topic, message) - def __del__(self): + def start(self): + """Connect to the zmq socket.""" + try: + self.socket.connect(self.config.subscriber_address) + except Exception as e: + print(f"failed to bind to zeromq socket: {e}") + sys.exit(-1) + else: + self.is_connected = True + self.subscribe(self.topic) + + def stop(self): + """Close the zmq socket.""" self.socket.close() + self.is_connected = False + + def subscribe(self, topic: str | list[str] | tuple[str]): + """ + Subscribes to the given topic(s) on the zmq socket. + + Accepts single topic or list of topics. + """ + if isinstance(topic, (list, tuple)): + for t in topic: + self._subscribe_single_topic(t) + else: + self._subscribe_single_topic(topic) + + def _subscribe_single_topic(self, topic: str): + self.socket.setsockopt(zmq.SUBSCRIBE, topic.encode()) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(host={self.config.subscriber_address}, port={self.config.subscriber_port})" diff --git a/run/udp-listener.py b/run/udp-listener.py index 30754a6..9e5e0fe 100644 --- a/run/udp-listener.py +++ b/run/udp-listener.py @@ -1,8 +1,8 @@ -from heisskleber.udp import UdpSubscriber, UdpConf +from heisskleber.udp import UdpConf, UdpSubscriber def main() -> None: - conf = UdpConf(ip="192.168.137.1", port=6600) + conf = UdpConf(host="192.168.137.1", port=6600) subscriber = UdpSubscriber(conf) while True: @@ -12,4 +12,3 @@ def main() -> None: if __name__ == "__main__": main() - diff --git a/tests/integration/async_streamer.py b/tests/integration/async_streamer.py index 25f5227..cde1761 100644 --- a/tests/integration/async_streamer.py +++ b/tests/integration/async_streamer.py @@ -10,7 +10,7 @@ async def main(): topic1 = "topic1" topic2 = "topic2" - config = MqttConf(broker="localhost", port=1883, user="", password="") # not a real password + config = MqttConf(host="localhost", port=1883, user="", password="") # not a real password sub1 = AsyncMqttSubscriber(config, topic1) sub2 = AsyncMqttSubscriber(config, topic2) diff --git a/tests/integration/integration_joint.py b/tests/integration/integration_joint.py index 8d2dbb3..84caa9d 100644 --- a/tests/integration/integration_joint.py +++ b/tests/integration/integration_joint.py @@ -7,7 +7,7 @@ from heisskleber.stream import Joint, Resampler, ResamplerConf async def main(): topics = ["topic0", "topic1", "topic2", "topic3"] - config = MqttConf(broker="localhost", port=1883, user="", password="") # not a real password + config = MqttConf(host="localhost", port=1883, user="", password="") # not a real password subs = [AsyncMqttSubscriber(config, topic=topic) for topic in topics] resampler_config = ResamplerConf(resample_rate=1000) diff --git a/tests/integration/mqtt_async.py b/tests/integration/mqtt_async.py index 64b7988..c847a43 100644 --- a/tests/integration/mqtt_async.py +++ b/tests/integration/mqtt_async.py @@ -4,7 +4,7 @@ from heisskleber.mqtt import AsyncMqttSubscriber, MqttConf async def main(): - conf = MqttConf(broker="localhost", port=1883, user="", password="") + conf = MqttConf(host="localhost", port=1883, user="", password="") sub = AsyncMqttSubscriber(conf, topic="#") # async for topic, message in sub: # print(message) diff --git a/tests/integration/mqtt_pub.py b/tests/integration/mqtt_pub.py index 32d63b8..678e8e2 100644 --- a/tests/integration/mqtt_pub.py +++ b/tests/integration/mqtt_pub.py @@ -20,7 +20,7 @@ async def send_every_n_miliseconds(frequency, value, pub, topic): async def main2(): - config = MqttConf(broker="localhost", port=1883, user="", password="") + config = MqttConf(host="localhost", port=1883, user="", password="") pubs = [AsyncMqttPublisher(config) for i in range(5)] tasks = [] diff --git a/tests/integration/mqtt_stream.py b/tests/integration/mqtt_stream.py index 6d59dec..5ac2308 100644 --- a/tests/integration/mqtt_stream.py +++ b/tests/integration/mqtt_stream.py @@ -5,7 +5,7 @@ from heisskleber.stream import Resampler, ResamplerConf async def main(): - conf = MqttConf(broker="localhost", port=1883, user="", password="") + conf = MqttConf(host="localhost", port=1883, user="", password="") sub = AsyncMqttSubscriber(conf, topic="#") resampler = Resampler(ResamplerConf(), sub) diff --git a/tests/integration/mqtt_sub.py b/tests/integration/mqtt_sub.py index f3a58d5..55bce37 100644 --- a/tests/integration/mqtt_sub.py +++ b/tests/integration/mqtt_sub.py @@ -4,7 +4,7 @@ from heisskleber.mqtt import AsyncMqttSubscriber, MqttConf async def main(): - config = MqttConf(broker="localhost", port=1883, user="", password="") + config = MqttConf(host="localhost", port=1883, user="", password="") sub = AsyncMqttSubscriber(config, topic="#") diff --git a/tests/integration/sync_streamer.py b/tests/integration/sync_streamer.py index a1677a8..eb39f24 100644 --- a/tests/integration/sync_streamer.py +++ b/tests/integration/sync_streamer.py @@ -13,7 +13,7 @@ def main(): # topic2 = "topic2" config = MqttConf( - broker="localhost", port=1883, user="", password="" + host="localhost", port=1883, user="", password="" ) # , not a real password port=1883, user="", password="") sub1 = MqttSubscriber(config, topic1) # sub2 = MqttSubscriber(config, topic2) diff --git a/tests/resources/zmq.yaml b/tests/resources/zmq.yaml index 5e7299d..a7ef27a 100644 --- a/tests/resources/zmq.yaml +++ b/tests/resources/zmq.yaml @@ -1,4 +1,4 @@ -protocol : "tcp" # ipc protocol -interface: "127.0.0.1" # the interface to bind to -publisher_port : 5555 # port used by primary producers -subscriber_port: 5556 # port used by primary consumers +protocol: "tcp" # ipc protocol +host: "127.0.0.1" # the interface to bind to +publisher_port: 5555 # port used by primary producers +subscriber_port: 5556 # port used by primary consumers diff --git a/tests/test_console_sink.py b/tests/test_console_sink.py index fc83874..f1a2e43 100644 --- a/tests/test_console_sink.py +++ b/tests/test_console_sink.py @@ -39,6 +39,16 @@ def test_console_sink_pretty_verbose(capsys) -> None: assert captured.out == 'test:\t{\n "key": 3\n}\n' +def test_console_repr() -> None: + sink = ConsoleSink() + assert repr(sink) == "ConsoleSink(pretty=False, verbose=False)" + + +def test_async_console_repr() -> None: + sink = AsyncConsoleSink() + assert repr(sink) == "AsyncConsoleSink(pretty=False, verbose=False)" + + @pytest.mark.asyncio async def test_async_console_sink(capsys) -> None: sink = AsyncConsoleSink() diff --git a/tests/test_mqtt.py b/tests/test_mqtt.py index 8edb070..10f31cd 100644 --- a/tests/test_mqtt.py +++ b/tests/test_mqtt.py @@ -14,7 +14,7 @@ from heisskleber.mqtt.subscriber import MqttSubscriber @pytest.fixture def mock_mqtt_conf() -> MqttConf: return MqttConf( - broker="localhost", + host="localhost", port=1883, user="user", password="passwd", # noqa: S106, this is a test password @@ -40,12 +40,14 @@ def mock_queue(): def test_mqtt_base_intialization(mock_mqtt_client, mock_mqtt_conf): """Test that the intialization of the mqtt client is as expected.""" base = MqttBase(config=mock_mqtt_conf) + base.start() mock_mqtt_client.assert_called_once() mock_mqtt_client.return_value.loop_start.assert_called_once() mock_client_instance = mock_mqtt_client.return_value mock_client_instance.username_pw_set.assert_called_with(mock_mqtt_conf.user, mock_mqtt_conf.password) - mock_client_instance.connect.assert_called_with(mock_mqtt_conf.broker, mock_mqtt_conf.port) + mock_client_instance.connect.assert_called_with(mock_mqtt_conf.host, mock_mqtt_conf.port) + assert base.client assert base.client.on_connect == base._on_connect assert base.client.on_disconnect == base._on_disconnect assert base.client.on_publish == base._on_publish @@ -56,7 +58,7 @@ def test_mqtt_base_on_connect(mock_mqtt_client, mock_mqtt_conf, capsys): base = MqttBase(config=mock_mqtt_conf) base._on_connect(None, None, {}, 0) captured = capsys.readouterr() - assert f"MQTT node connected to {mock_mqtt_conf.broker}:{mock_mqtt_conf.port}" in captured.out + assert f"MQTT node connected to {mock_mqtt_conf.host}:{mock_mqtt_conf.port}" in captured.out def test_mqtt_base_on_disconnect_with_error(mock_mqtt_client, mock_mqtt_conf, capsys): @@ -71,7 +73,8 @@ def test_mqtt_base_on_disconnect_with_error(mock_mqtt_client, mock_mqtt_conf, ca def test_mqtt_subscribes_single_topic(mock_mqtt_client, mock_mqtt_conf): """Test that the mqtt client subscribes to a single topic.""" - _ = MqttSubscriber(topics="singleTopic", config=mock_mqtt_conf) + sub = MqttSubscriber(topics="singleTopic", config=mock_mqtt_conf) + sub.start() actual_calls = mock_mqtt_client.return_value.subscribe.call_args_list assert actual_calls == [call("singleTopic", mock_mqtt_conf.qos)] @@ -82,7 +85,8 @@ def test_mqtt_subscribes_multiple_topics(mock_mqtt_client, mock_mqtt_conf): I would love to do this via parametrization, but the call argument is built differently for single size lists and longer lists. """ - _ = MqttSubscriber(topics=["multiple1", "multiple2"], config=mock_mqtt_conf) + sub = MqttSubscriber(topics=["multiple1", "multiple2"], config=mock_mqtt_conf) + sub.start() actual_calls = mock_mqtt_client.return_value.subscribe.call_args_list assert actual_calls == [ @@ -92,7 +96,8 @@ def test_mqtt_subscribes_multiple_topics(mock_mqtt_client, mock_mqtt_conf): def test_mqtt_subscribes_multiple_topics_tuple(mock_mqtt_client, mock_mqtt_conf): """Test that the mqtt client subscribes to multiple topics passed as tuple.""" - _ = MqttSubscriber(topics=("multiple1", "multiple2"), config=mock_mqtt_conf) + sub = MqttSubscriber(topics=("multiple1", "multiple2"), config=mock_mqtt_conf) + sub.start() actual_calls = mock_mqtt_client.return_value.subscribe.call_args_list assert actual_calls == [ diff --git a/tests/test_serial.py b/tests/test_serial.py index 76042ee..1536ac5 100644 --- a/tests/test_serial.py +++ b/tests/test_serial.py @@ -29,10 +29,11 @@ def mock_serial_device_publisher(): def test_serial_subscriber_initialization(mock_serial_device_subscriber, serial_conf): """Test that the SerialSubscriber class initializes correctly. Mocks the serial.Serial class to avoid opening a serial port.""" - _ = SerialSubscriber( + sub = SerialSubscriber( config=serial_conf, topic="", ) + sub.start() mock_serial_device_subscriber.assert_called_with( port=serial_conf.port, baudrate=serial_conf.baudrate, @@ -45,6 +46,7 @@ def test_serial_subscriber_initialization(mock_serial_device_subscriber, serial_ def test_serial_subscriber_receive(mock_serial_device_subscriber, serial_conf): """Test that the SerialSubscriber class calls readline and unpack as expected.""" subscriber = SerialSubscriber(config=serial_conf, topic="") + subscriber.start() # Set up the readline return value mock_serial_instance = mock_serial_device_subscriber.return_value @@ -69,6 +71,7 @@ def test_serial_subscriber_converts_bytes_to_str(): """Test that the SerialSubscriber class converts bytes to str as expected.""" with patch("heisskleber.serial.subscriber.serial.Serial") as mock_serial: subscriber = SerialSubscriber(config=SerialConf(), topic="", custom_unpack=lambda x: x) + subscriber.start() # Set the readline method to raise UnicodeError mock_serial_instance = mock_serial.return_value @@ -86,6 +89,7 @@ def test_serial_publisher_initialization(mock_serial_device_publisher, serial_co """Test that the SerialPublisher class initializes correctly. Mocks the serial.Serial class to avoid opening a serial port.""" publisher = SerialPublisher(config=serial_conf) + publisher.start() mock_serial_device_publisher.assert_called_with( port=serial_conf.port, baudrate=serial_conf.baudrate, @@ -93,7 +97,7 @@ def test_serial_publisher_initialization(mock_serial_device_publisher, serial_co parity=serial.PARITY_NONE, stopbits=serial.STOPBITS_ONE, ) - assert publisher.serial + assert publisher.serial_connection def test_serial_publisher_send(mock_serial_device_publisher, serial_conf): diff --git a/tests/test_streamer.py b/tests/test_streamer.py index e7b350b..9049aca 100644 --- a/tests/test_streamer.py +++ b/tests/test_streamer.py @@ -67,11 +67,12 @@ async def test_resampler_multiple_modes(mock_subscriber): ] ) - config = ResamplerConf(resample_rate=1000) # Fill in your MQTT configuration + config = ResamplerConf(resample_rate=1000) resampler = Resampler(config, mock_subscriber) # Test the resample method resampled_data = [await resampler.receive() for _ in range(3)] + resampler.stop() assert resampled_data[0] == {"epoch": 0.0, "data": 1.5} assert resampled_data[1] == {"epoch": 1.0, "data": 3.5} @@ -89,11 +90,12 @@ async def test_resampler_upsampling(mock_subscriber): ] ) - config = ResamplerConf(resample_rate=250) # Fill in your MQTT configuration + config = ResamplerConf(resample_rate=250) resampler = Resampler(config, mock_subscriber) # Test the resample method resampled_data = [await resampler.receive() for _ in range(7)] + resampler.stop() assert resampled_data[0] == {"epoch": 0.0, "data": 1.0} assert resampled_data[1] == {"epoch": 0.25, "data": 1.25} diff --git a/tests/test_udp.py b/tests/test_udp.py index 09133a0..6060629 100644 --- a/tests/test_udp.py +++ b/tests/test_udp.py @@ -17,19 +17,22 @@ def mock_socket(): @pytest.fixture def mock_conf(): - return UdpConf(ip="127.0.0.1", port=12345, packer="json") + return UdpConf(host="127.0.0.1", port=12345, packer="json") def test_connects_to_socket(mock_socket, mock_conf) -> None: - _ = UdpPublisher(mock_conf) + pub = UdpPublisher(mock_conf) + pub.start() # constructor was called mock_socket.assert_called_with(socket.AF_INET, socket.SOCK_DGRAM) + pub.stop() def test_closes_socket(mock_socket, mock_conf) -> None: pub = UdpPublisher(mock_conf) - del pub + pub.start() + pub.stop() # instace was closed mock_socket.return_value.close.assert_called() @@ -45,8 +48,9 @@ def test_packs_and_sends_message(mock_socket, mock_conf) -> None: mock_socket.return_value.sendto.assert_called_with( b'{"key": "val", "intkey": 1, "floatkey": 1.0, "topic": "test"}', - (str(mock_conf.ip), mock_conf.port), + (str(mock_conf.host), mock_conf.port), ) + pub.stop() def test_subscriber_receives_message_from_queue(mock_conf) -> None: @@ -59,13 +63,16 @@ def test_subscriber_receives_message_from_queue(mock_conf) -> None: topic, data = sub.receive() assert test_topic == topic assert test_data == data + sub.stop() @pytest.fixture def udp_sub(mock_conf): sub = UdpSubscriber(mock_conf) - sub.start_loop() + sub.config.port = 12346 # explicitly set port to avoid conflicts + sub.start() yield sub + sub.stop() def test_sends_message_between_pub_and_sub(udp_sub, mock_conf): diff --git a/tests/test_zmq.py b/tests/test_zmq.py index c6e3c7b..7323bfa 100644 --- a/tests/test_zmq.py +++ b/tests/test_zmq.py @@ -33,9 +33,9 @@ def start_broker(): def test_config_parses_correctly(): - conf = ZmqConf(protocol="tcp", interface="localhost", publisher_port=5555, subscriber_port=5556) + conf = ZmqConf(protocol="tcp", host="localhost", publisher_port=5555, subscriber_port=5556) assert conf.protocol == "tcp" - assert conf.interface == "localhost" + assert conf.host == "localhost" assert conf.publisher_port == 5555 assert conf.subscriber_port == 5556 @@ -44,13 +44,13 @@ def test_config_parses_correctly(): def test_instantiate_subscriber(): - conf = ZmqConf(protocol="tcp", interface="localhost", publisher_port=5555, subscriber_port=5556) + conf = ZmqConf(protocol="tcp", host="localhost", publisher_port=5555, subscriber_port=5556) sub = ZmqSubscriber(conf, "test") assert sub.config == conf def test_instantiate_publisher(): - conf = ZmqConf(protocol="tcp", interface="localhost", publisher_port=5555, subscriber_port=5556) + conf = ZmqConf(protocol="tcp", host="localhost", publisher_port=5555, subscriber_port=5556) pub = ZmqPublisher(conf) assert pub.config == conf @@ -59,7 +59,9 @@ def test_send_receive(start_broker): print("test_send_receive") topic = "test" source = get_source("zmq", topic) + source.start() sink = get_sink("zmq") + sink.start() time.sleep(1) # this is crucial, otherwise the source might hang for i in range(10): message = {"m": i} diff --git a/tests/zmq/test_zmq_asyncio.py b/tests/zmq/test_zmq_asyncio.py index 1b9761b..f8948fc 100644 --- a/tests/zmq/test_zmq_asyncio.py +++ b/tests/zmq/test_zmq_asyncio.py @@ -33,13 +33,13 @@ def start_broker() -> Generator[Process, None, None]: def test_instantiate_subscriber() -> None: - conf = ZmqConf(protocol="tcp", interface="localhost", publisher_port=5555, subscriber_port=5556) + conf = ZmqConf(protocol="tcp", host="localhost", publisher_port=5555, subscriber_port=5556) sub = ZmqAsyncSubscriber(conf, "test") assert sub.config == conf def test_instantiate_publisher() -> None: - conf = ZmqConf(protocol="tcp", interface="localhost", publisher_port=5555, subscriber_port=5556) + conf = ZmqConf(protocol="tcp", host="localhost", publisher_port=5555, subscriber_port=5556) pub = ZmqPublisher(conf) assert pub.config == conf @@ -48,9 +48,11 @@ def test_instantiate_publisher() -> None: async def test_send_receive(start_broker) -> None: print("test_send_receive") topic = "test" - conf = ZmqConf(protocol="tcp", interface="localhost", publisher_port=5555, subscriber_port=5556) + conf = ZmqConf(protocol="tcp", host="localhost", publisher_port=5555, subscriber_port=5556) source = ZmqAsyncSubscriber(conf, topic) sink = ZmqAsyncPublisher(conf) + source.start() + sink.start() time.sleep(1) # this is crucial, otherwise the source might hang for i in range(10): message = {"m": i}