diff --git a/Makefile b/Makefile index 78a4553..32b7eb7 100644 --- a/Makefile +++ b/Makefile @@ -47,7 +47,7 @@ docs-test: ## Test if documentation can be built without warnings or errors .PHONY: docs docs: ## Build and serve the documentation - @poetry run mkdocs serve + @poetry run python3 -m sphinx docs docs/_build -b html .PHONY: help help: diff --git a/docs/reference.md b/docs/reference.md index 20ff891..c2e4fe4 100644 --- a/docs/reference.md +++ b/docs/reference.md @@ -3,22 +3,16 @@ ## Network ```{eval-rst} -.. automodule:: heisskleber.network +.. automodule:: heisskleber :members: -.. automodule:: heisskleber.network.mqtt -.. autoclass:: MqttPublisher -.. autoclass:: MqttSubscriber -.. automodule:: heisskleber.network.zmq +.. automodule:: heisskleber.mqtt + :members: +.. automodule:: heisskleber.zmq + :members: +.. automodule:: heisskleber.serial + :members: +.. automodule:: heisskleber.udp :members: -.. autoclass:: ZmqPublisher -.. autoclass:: ZmqSubscriber -``` - -### Broker - -```{eval-rst} -.. automodule:: heisskleber.broker - :members: ``` ## Config @@ -26,15 +20,5 @@ ### Loading configs ```{eval-rst} .. automodule:: heisskleber.config - :members: load_config -``` - -### Config types - -Configs are extended dataclasses, which inherit from the BaseConf class. -```{eval-rst} -.. autoclass:: heisskleber.config.BaseConf -.. autoclass:: heisskleber.network.mqtt.config.MqttConf -.. autoclass:: heisskleber.network.zmq.config.ZmqConf -.. autoclass:: heisskleber.network.serial.config.SerialConf + :members: ``` diff --git a/heisskleber/__init__.py b/heisskleber/__init__.py index 2cb7032..0379734 100644 --- a/heisskleber/__init__.py +++ b/heisskleber/__init__.py @@ -1,5 +1,6 @@ """Heisskleber.""" -from .network import get_publisher, get_subscriber +from .core.factories import get_publisher, get_subscriber +from .core.types import Publisher, Subscriber -__all__ = ["get_publisher", "get_subscriber"] -__version__ = "0.1.0" +__all__ = ["get_publisher", "get_subscriber", "Publisher", "Subscriber"] +__version__ = "0.2.0" diff --git a/heisskleber/broker/__init__.py b/heisskleber/broker/__init__.py index 961abcd..17f19ef 100644 --- a/heisskleber/broker/__init__.py +++ b/heisskleber/broker/__init__.py @@ -1,3 +1,3 @@ -from .msb_broker import msb_broker as start_zmq_broker +from .msb_broker import zmq_broker as start_zmq_broker __all__ = ["start_zmq_broker"] diff --git a/heisskleber/broker/msb_broker.py b/heisskleber/broker/zmq_broker.py similarity index 71% rename from heisskleber/broker/msb_broker.py rename to heisskleber/broker/zmq_broker.py index db4394e..ff271b3 100644 --- a/heisskleber/broker/msb_broker.py +++ b/heisskleber/broker/zmq_broker.py @@ -2,9 +2,10 @@ import signal import sys import zmq +from zmq import Socket from heisskleber.config import load_config -from heisskleber.network.zmq.config import ZmqConf as BrokerConf +from heisskleber.zmq.config import ZmqConf as BrokerConf def signal_handler(sig, frame): @@ -16,29 +17,31 @@ class BrokerBindingError(Exception): pass -def bind_socket(socket, address, socket_type, verbose=False): +def bind_socket(socket: Socket, address: str, socket_type: str, verbose=False) -> None: """Bind a ZMQ socket and handle errors.""" if verbose: print(f"creating {socket_type} socket") try: socket.bind(address) except Exception as err: - raise BrokerBindingError(f"failed to bind to {socket_type}: {err}") from err + error_message = f"failed to bind to {socket_type}: {err}" + raise BrokerBindingError(error_message) from err if verbose: print(f"successfully bound to {socket_type} socket: {address}") -def create_proxy(xpub, xsub, verbose=False): +def create_proxy(xpub: Socket, xsub: Socket, verbose=False) -> None: """Create a ZMQ proxy to connect XPUB and XSUB sockets.""" if verbose: print("creating proxy") try: zmq.proxy(xpub, xsub) except Exception as err: - raise BrokerBindingError(f"failed to create proxy: {err}") from err + error_message = f"failed to create proxy: {err}" + raise BrokerBindingError(error_message) from err -def msb_broker(config: BrokerConf) -> None: +def zmq_broker(config: BrokerConf) -> None: """Start a zmq broker. Binds to a publisher and subscriber port, allowing many to many connections.""" @@ -60,4 +63,4 @@ def main() -> None: """Start a zmq broker, with a user specified configuration.""" signal.signal(signal.SIGINT, signal_handler) broker_config = load_config(BrokerConf(), "zmq") - msb_broker(broker_config) + zmq_broker(broker_config) diff --git a/heisskleber/network/config.py b/heisskleber/config.py similarity index 100% rename from heisskleber/network/config.py rename to heisskleber/config.py diff --git a/heisskleber/config/__init__.py b/heisskleber/config/__init__.py index c554645..8710d76 100644 --- a/heisskleber/config/__init__.py +++ b/heisskleber/config/__init__.py @@ -1,4 +1,4 @@ -from heisskleber.config.MSBConfig import BaseConf -from heisskleber.config.parse import load_config +from .config import BaseConf +from .parse import load_config __all__ = ["load_config", "BaseConf"] diff --git a/heisskleber/config/cmdline.py b/heisskleber/config/cmdline.py index 5a39f58..b71588c 100644 --- a/heisskleber/config/cmdline.py +++ b/heisskleber/config/cmdline.py @@ -2,18 +2,7 @@ import argparse class KeyValue(argparse.Action): - # Constructor calling - """ - def __call__( self , parser, namespace, values : list, option_string = None): - setattr(namespace, self.dest, dict()) - for value in values: - # split it into key and value - key, value = value.split('=') - # assign into dictionary - getattr(namespace, self.dest)[key] = value - """ - - def __call__(self, parser, args, values, option_string=None): + def __call__(self, parser, args, values, option_string=None) -> None: try: params = dict(x.split("=") for x in values) except ValueError as ex: diff --git a/heisskleber/config/MSBConfig.py b/heisskleber/config/config.py similarity index 79% rename from heisskleber/config/MSBConfig.py rename to heisskleber/config/config.py index 7d05478..69cc2d3 100644 --- a/heisskleber/config/MSBConfig.py +++ b/heisskleber/config/config.py @@ -1,6 +1,7 @@ import socket import warnings from dataclasses import dataclass +from typing import Any @dataclass @@ -12,18 +13,18 @@ class BaseConf: verbose: bool = False print_stdout: bool = False - def __setitem__(self, key, value): + def __setitem__(self, key: str, value: Any) -> None: if hasattr(self, key): self.__setattr__(key, value) else: warnings.warn(UserWarning(f"no such class member: {key}"), stacklevel=2) - def __getitem__(self, key): + def __getitem__(self, key: str) -> Any: if hasattr(self, key): return getattr(self, key) else: warnings.warn(UserWarning(f"no such class member: {key}"), stacklevel=2) @property - def serial_number(self): + def serial_number(self) -> str: return socket.gethostname().upper() diff --git a/heisskleber/config/parse.py b/heisskleber/config/parse.py index f6cfd2b..0d776df 100644 --- a/heisskleber/config/parse.py +++ b/heisskleber/config/parse.py @@ -8,7 +8,9 @@ import yaml from heisskleber.config import BaseConf from heisskleber.config.cmdline import get_cmdline -ConfigType = TypeVar("ConfigType", bound=BaseConf) +ConfigType = TypeVar( + "ConfigType", bound=BaseConf +) # https://stackoverflow.com/a/46227137 , https://docs.python.org/3/library/typing.html#typing.TypeVar def get_msb_config_filepath(config_filename: str = "heisskleber.conf") -> str: @@ -17,10 +19,10 @@ def get_msb_config_filepath(config_filename: str = "heisskleber.conf") -> str: config_filepath = os.path.join(os.environ["MSB_CONFIG_DIR"], config_subpath) except Exception as e: print(f"could no get MSB_CONFIG from PATH: {e}") - sys.exit() # TODO use 1 or the error str as exit value + sys.exit(1) if not os.path.isfile(config_filepath): print(f"not a file: {config_filepath}!") - sys.exit() + sys.exit(1) return config_filepath @@ -33,16 +35,12 @@ def update_config(config: ConfigType, config_dict: dict) -> ConfigType: for config_key, config_value in config_dict.items(): # get expected type of element from config_object: if not hasattr(config, config_key): - warnings.warn( - f"no such configuration parameter: {config_key}, skipping", stacklevel=2 - ) + error_msg = f"no such configuration parameter: {config_key}, skipping" + warnings.warn(error_msg, stacklevel=2) continue cast_func = type(config[config_key]) try: - if config_key == "topic": - config[config_key] = config_value.encode("utf-8") - else: - config[config_key] = cast_func(config_value) + config[config_key] = cast_func(config_value) except Exception as e: print( f"failed to cast {config_value} to {type(config[config_key])}: {e}. skipping" @@ -51,11 +49,6 @@ def update_config(config: ConfigType, config_dict: dict) -> ConfigType: return config -ConfigType = TypeVar( - "ConfigType", bound=BaseConf -) # https://stackoverflow.com/a/46227137 , https://docs.python.org/3/library/typing.html#typing.TypeVar - - def load_config( config: ConfigType, config_filename: str, read_commandline: bool = True ) -> ConfigType: diff --git a/heisskleber/config/zeromq.py b/heisskleber/config/zeromq.py deleted file mode 100644 index 571440a..0000000 --- a/heisskleber/config/zeromq.py +++ /dev/null @@ -1,40 +0,0 @@ -import sys - -import zmq - - -def open_zmq_sub_socket(connect_to: str, topic=b""): - ctx = zmq.Context() - zmq_socket = ctx.socket(zmq.SUB) - try: - zmq_socket.connect(connect_to) - except Exception as e: - print(f"failed to bind to zeromq socket: {e}") - sys.exit(-1) - zmq_socket.setsockopt(zmq.SUBSCRIBE, topic) - return zmq_socket - - -def open_zmq_pub_socket(connect_to: str): - ctx = zmq.Context() - zmq_socket = ctx.socket(zmq.PUB) - try: - zmq_socket.connect(connect_to) - except Exception as e: - print(f"failed to bind to zeromq socket: {e}") - sys.exit(-1) - return zmq_socket - - -def get_zmq_xpub_socketstring(msb_config: dict) -> str: - zmq_config = msb_config["zeromq"] - return ( - f"{zmq_config['protocol']}://{zmq_config['address']}:{zmq_config['xpub_port']}" - ) - - -def get_zmq_xsub_socketstring(msb_config: dict) -> str: - zmq_config = msb_config["zeromq"] - return ( - f"{zmq_config['protocol']}://{zmq_config['address']}:{zmq_config['xsub_port']}" - ) diff --git a/heisskleber/network/influxdb/__init__.py b/heisskleber/core/__init__.py similarity index 100% rename from heisskleber/network/influxdb/__init__.py rename to heisskleber/core/__init__.py diff --git a/heisskleber/network/pubsub/factories.py b/heisskleber/core/factories.py similarity index 62% rename from heisskleber/network/pubsub/factories.py rename to heisskleber/core/factories.py index a7dfdf9..195ba5f 100644 --- a/heisskleber/network/pubsub/factories.py +++ b/heisskleber/core/factories.py @@ -1,24 +1,28 @@ import os -from heisskleber.config import load_config -from heisskleber.network.mqtt import MqttConf, MqttPublisher, MqttSubscriber -from heisskleber.network.serial import SerialConf, SerialPublisher, SerialSubscriber -from heisskleber.network.zmq import ZmqConf, ZmqPublisher, ZmqSubscriber +from heisskleber.config import BaseConf, load_config +from heisskleber.core.types import Publisher, Subscriber +from heisskleber.mqtt import MqttConf, MqttPublisher, MqttSubscriber +from heisskleber.serial import SerialConf, SerialPublisher, SerialSubscriber +from heisskleber.udp import UdpConf, UdpPublisher, UdpSubscriber +from heisskleber.zmq import ZmqConf, ZmqPublisher, ZmqSubscriber -_registered_publishers = { +_registered_publishers: dict[str, tuple[type[Publisher], type[BaseConf]]] = { "zmq": (ZmqPublisher, ZmqConf), "mqtt": (MqttPublisher, MqttConf), "serial": (SerialPublisher, SerialConf), + "udp": (UdpPublisher, UdpConf), } -_registered_subscribers = { +_registered_subscribers: dict[str, tuple[type[Subscriber], type[BaseConf]]] = { "zmq": (ZmqSubscriber, ZmqConf), "mqtt": (MqttSubscriber, MqttConf), "serial": (SerialSubscriber, SerialConf), + "udp": (UdpSubscriber, UdpConf), } -def get_publisher(name: str): +def get_publisher(name: str) -> Publisher: if name not in _registered_publishers: error_message = f"{name} is not a registered Publisher." raise KeyError(error_message) @@ -35,7 +39,7 @@ def get_publisher(name: str): return pub_cls(config) -def get_subscriber(name: str, topic): +def get_subscriber(name: str, topic: str | list[str]) -> Subscriber: if name not in _registered_publishers: error_message = f"{name} is not a registered Subscriber." raise KeyError(error_message) diff --git a/heisskleber/network/packer.py b/heisskleber/core/packer.py similarity index 58% rename from heisskleber/network/packer.py rename to heisskleber/core/packer.py index 1924ba7..fcf5a73 100644 --- a/heisskleber/network/packer.py +++ b/heisskleber/core/packer.py @@ -1,10 +1,12 @@ """Packer and unpacker for network data.""" import json import pickle -from typing import Callable +from typing import Any, Callable + +from .types import Serializable -def get_packer(style) -> Callable[[dict], str]: +def get_packer(style: str) -> Callable[[dict[str, Serializable]], str]: """Return a packer function for the given style. Packer func serializes a given dict.""" @@ -14,7 +16,7 @@ def get_packer(style) -> Callable[[dict], str]: return _packstyles["default"] -def get_unpacker(style) -> Callable[[str], dict]: +def get_unpacker(style: str) -> Callable[[str], dict[str, Serializable]]: """Return an unpacker function for the given style. Unpacker func deserializes a string.""" @@ -24,21 +26,19 @@ def get_unpacker(style) -> Callable[[str], dict]: return _unpackstyles["default"] -def serialpacker(data: dict) -> str: +def serialpacker(data: dict[str, Any]) -> str: return ",".join([str(v) for v in data.values()]) -_packstyles = { +_packstyles: dict[str, Callable[[dict[str, Serializable]], str]] = { "default": json.dumps, "json": json.dumps, - "pickle": pickle.dumps, + "pickle": pickle.dumps, # type: ignore "serial": serialpacker, - "raw": lambda x: x, } -_unpackstyles = { +_unpackstyles: dict[str, Callable[[str], dict[str, Serializable]]] = { "default": json.loads, "json": json.loads, - "pickle": pickle.loads, - "raw": lambda x: x, + "pickle": pickle.loads, # type: ignore } diff --git a/heisskleber/core/types.py b/heisskleber/core/types.py new file mode 100644 index 0000000..28c5883 --- /dev/null +++ b/heisskleber/core/types.py @@ -0,0 +1,52 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Any, Callable, Union + +Serializable = Union[str, int, float] + + +class Publisher(ABC): + """ + Publisher interface. + """ + + pack: Callable[[dict[str, Serializable]], str] + + @abstractmethod + def __init__(self, config: Any) -> None: + """ + Initialize the publisher with a configuration object. + """ + pass + + @abstractmethod + def send(self, data: dict[str, Any], topic: str) -> None: + """ + Send data via the implemented output stream. + """ + pass + + +class Subscriber(ABC): + """ + Subscriber interface + """ + + unpack: Callable[[str], dict[str, Serializable]] + + @abstractmethod + def __init__(self, config: Any, topic: str | list[str]) -> None: + """ + Initialize the subscriber with a topic and a configuration object. + """ + pass + + @abstractmethod + def receive(self) -> tuple[str, dict[str, Serializable]]: + """ + Blocking function to receive data from the implemented input stream. + + Data is returned as a tuple of (topic, data). + """ + pass diff --git a/heisskleber/network/pubsub/__init__.py b/heisskleber/influxdb/__init__.py similarity index 100% rename from heisskleber/network/pubsub/__init__.py rename to heisskleber/influxdb/__init__.py diff --git a/heisskleber/network/influxdb/config.py b/heisskleber/influxdb/config.py similarity index 100% rename from heisskleber/network/influxdb/config.py rename to heisskleber/influxdb/config.py diff --git a/heisskleber/network/influxdb/subscriber.py b/heisskleber/influxdb/subscriber.py similarity index 89% rename from heisskleber/network/influxdb/subscriber.py rename to heisskleber/influxdb/subscriber.py index a49433a..748b39e 100644 --- a/heisskleber/network/influxdb/subscriber.py +++ b/heisskleber/influxdb/subscriber.py @@ -1,6 +1,8 @@ import pandas as pd from influxdb_client import InfluxDBClient +from heisskleber.core.types import Subscriber + from .config import InfluxDBConf @@ -28,11 +30,10 @@ def build_query(options: dict) -> str: return query -class Influx_Subscriber: +class Influx_Subscriber(Subscriber): def __init__(self, config: InfluxDBConf, query: str): self.config = config self.query = query - self.df: pd.DataFrame = None self.client: InfluxDBClient = InfluxDBClient( url=self.config.url, @@ -45,13 +46,15 @@ class Influx_Subscriber: self._run_query() self.index = 0 - def receive(self) -> dict: + def receive(self) -> tuple[str, dict]: row = self.df.iloc[self.index].to_dict() self.index += 1 return "influx", row def _run_query(self): - self.df = self.reader.query_data_frame(self.query, org=self.config.org) + self.df: pd.DataFrame = self.reader.query_data_frame( + self.query, org=self.config.org + ) self.df["epoch"] = pd.to_numeric(self.df["_time"]) / 1e9 self.df.drop( columns=[ diff --git a/heisskleber/network/influxdb/writer.py b/heisskleber/influxdb/writer.py similarity index 100% rename from heisskleber/network/influxdb/writer.py rename to heisskleber/influxdb/writer.py diff --git a/heisskleber/mqtt/__init__.py b/heisskleber/mqtt/__init__.py new file mode 100644 index 0000000..9e08600 --- /dev/null +++ b/heisskleber/mqtt/__init__.py @@ -0,0 +1,5 @@ +from .config import MqttConf +from .publisher import MqttPublisher +from .subscriber import MqttSubscriber + +__all__ = ["MqttConf", "MqttPublisher", "MqttSubscriber"] diff --git a/heisskleber/network/mqtt/config.py b/heisskleber/mqtt/config.py similarity index 89% rename from heisskleber/network/mqtt/config.py rename to heisskleber/mqtt/config.py index c716cb3..004d955 100644 --- a/heisskleber/network/mqtt/config.py +++ b/heisskleber/mqtt/config.py @@ -16,7 +16,7 @@ class MqttConf(BaseConf): ssl: bool = False qos: int = 0 retain: bool = False - topics: list[bytes] = field(default_factory=list) + topics: list[str] = field(default_factory=list) mapping: str = "/msb/" packstyle: str = "json" max_saved_messages: int = 100 diff --git a/heisskleber/network/mqtt/forwarder.py b/heisskleber/mqtt/forwarder.py similarity index 54% rename from heisskleber/network/mqtt/forwarder.py rename to heisskleber/mqtt/forwarder.py index e3a959a..370afd3 100644 --- a/heisskleber/network/mqtt/forwarder.py +++ b/heisskleber/mqtt/forwarder.py @@ -1,22 +1,23 @@ +from heisskleber import get_publisher, get_subscriber from heisskleber.config import load_config -from heisskleber.network import get_publisher, get_subscriber from .config import MqttConf -def map_topic(zmq_topic, mapping): - return mapping + zmq_topic.decode() +def map_topic(zmq_topic: str, mapping: str) -> str: + return mapping + zmq_topic -def main(): +def main() -> None: config: MqttConf = load_config(MqttConf(), "mqtt") sub = get_subscriber("zmq", config.topics) pub = get_publisher("mqtt") - sub.unpack = pub.pack = lambda x: x + pub.pack = lambda x: x # type: ignore + sub.unpack = lambda x: x # type: ignore while True: (zmq_topic, data) = sub.receive() mqtt_topic = map_topic(zmq_topic, config.mapping) - pub.send(mqtt_topic, data) + pub.send(data, mqtt_topic) diff --git a/heisskleber/network/mqtt/mqtt_base.py b/heisskleber/mqtt/mqtt_base.py similarity index 83% rename from heisskleber/network/mqtt/mqtt_base.py rename to heisskleber/mqtt/mqtt_base.py index ed9fb2a..c5d5dcc 100644 --- a/heisskleber/network/mqtt/mqtt_base.py +++ b/heisskleber/mqtt/mqtt_base.py @@ -25,19 +25,19 @@ def _set_thread_died_excepthook(args, /): threading.excepthook = _set_thread_died_excepthook -class MQTT_Base: +class MqttBase: """ Wrapper around eclipse paho mqtt client. Handles connection and callbacks. Callbacks may be overwritten in subclasses. """ - def __init__(self, config: MqttConf): + def __init__(self, config: MqttConf) -> None: self.config = config self.connect() self.client.loop_start() - def connect(self): + def connect(self) -> None: self.client = mqtt_client() self.client.username_pw_set(self.config.user, self.config.password) @@ -55,13 +55,13 @@ class MQTT_Base: self.client.connect(self.config.broker, self.config.port) @staticmethod - def _raise_if_thread_died(): + def _raise_if_thread_died() -> None: global _thread_died if _thread_died.is_set(): raise ThreadDiedError() # MQTT callbacks - def _on_connect(self, client, userdata, flags, return_code): + def _on_connect(self, client, userdata, flags, return_code) -> None: if return_code == 0: print(f"MQTT node connected to {self.config.broker}:{self.config.port}") else: @@ -69,21 +69,21 @@ class MQTT_Base: if self.config.verbose: print(flags) - def _on_disconnect(self, client, userdata, return_code): + def _on_disconnect(self, client, userdata, return_code) -> None: print(f"Disconnected from broker with return code {return_code}") if return_code != 0: print("Killing this service") sys.exit(-1) - def _on_publish(self, client, userdata, message_id): + def _on_publish(self, client, userdata, message_id) -> None: if self.config.verbose: print(f"Published message with id {message_id}, qos={self.config.qos}") - def _on_message(self, client, userdata, message): + def _on_message(self, client, userdata, message) -> None: if self.config.verbose: print( f"Received message: {message.payload!s}, topic: {message.topic}, qos: {message.qos}" ) - def __del__(self): + def __del__(self) -> None: self.client.loop_stop() diff --git a/heisskleber/mqtt/publisher.py b/heisskleber/mqtt/publisher.py new file mode 100644 index 0000000..3a533f3 --- /dev/null +++ b/heisskleber/mqtt/publisher.py @@ -0,0 +1,36 @@ +from __future__ import annotations + +from typing import Any + +from heisskleber.core.packer import get_packer +from heisskleber.core.types import Publisher + +from .config import MqttConf +from .mqtt_base import MqttBase + + +class MqttPublisher(MqttBase, Publisher): + """ + MQTT publisher class. + Can be used everywhere that a flucto style publishing connection is required. + + Network message loop is handled in a separated thread. + """ + + def __init__(self, config: MqttConf) -> None: + super().__init__(config) + self.pack = get_packer(config.packstyle) + + def send(self, data: dict[str, Any], topic: str) -> None: + """ + Takes python dictionary, serializes it according to the packstyle + and sends it to the broker. + + Publishing is asynchronous + """ + self._raise_if_thread_died() + + payload = self.pack(data) + self.client.publish( + topic, payload, qos=self.config.qos, retain=self.config.retain + ) diff --git a/heisskleber/mqtt/subscriber.py b/heisskleber/mqtt/subscriber.py new file mode 100644 index 0000000..88328ab --- /dev/null +++ b/heisskleber/mqtt/subscriber.py @@ -0,0 +1,66 @@ +from __future__ import annotations + +from queue import SimpleQueue +from typing import Any + +from paho.mqtt.client import MQTTMessage + +from heisskleber.core.packer import get_unpacker +from heisskleber.core.types import Subscriber + +from .config import MqttConf +from .mqtt_base import MqttBase + + +class MqttSubscriber(MqttBase, Subscriber): + """ + MQTT subscriber, wraps around ecplipse's paho mqtt client. + Network message loop is handled in a separated thread. + + Incoming messages are saved as a stack when not processed via the receive() function. + """ + + def __init__(self, config: MqttConf, topics: str | list[str]) -> None: + super().__init__(config) + self._message_queue: SimpleQueue[MQTTMessage] = SimpleQueue() + self.subscribe(topics) + self.client.on_message = self._on_message + self.unpack = get_unpacker(config.packstyle) + + def subscribe(self, topics: str | list[str] | tuple[str]) -> None: + """ + Subscribe to one or multiple topics + """ + if isinstance(topics, (list, tuple)): + # if subscribing to multiple topics, use a list of tuples + subscription_list = [(topic, self.config.qos) for topic in topics] + self.client.subscribe(subscription_list) + else: + self.client.subscribe(topics, self.config.qos) + if self.config.verbose: + print(f"Subscribed to: {topics}") + + def receive(self) -> tuple[str, dict[str, Any]]: + """ + Reads a message from mqtt and returns it + + Messages are saved in a stack, if no message is available, this function blocks. + + Returns: + tuple(topic: bytes, message: dict): the message received + """ + self._raise_if_thread_died() + mqtt_message = self._message_queue.get( + block=True, timeout=self.config.timeout_s + ) + + message_returned = self.unpack(mqtt_message.payload.decode()) + return (mqtt_message.topic, message_returned) + + # callback to add incoming messages onto stack + def _on_message(self, client, userdata, message) -> None: + self._message_queue.put(message) + + if self.config.verbose: + print(f"Topic: {message.topic}") + print(f"MQTT message: {message.payload.decode()}") diff --git a/heisskleber/network/__init__.py b/heisskleber/network/__init__.py deleted file mode 100644 index 31256a0..0000000 --- a/heisskleber/network/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .pubsub.factories import get_publisher, get_subscriber # noqa: F401 diff --git a/heisskleber/network/mqtt/__init__.py b/heisskleber/network/mqtt/__init__.py deleted file mode 100644 index e71a954..0000000 --- a/heisskleber/network/mqtt/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .config import MqttConf # noqa: F401 -from .publisher import MqttPublisher # noqa: F401 -from .subscriber import MqttSubscriber # noqa: F401 diff --git a/heisskleber/network/mqtt/msb_mqtt.py b/heisskleber/network/mqtt/msb_mqtt.py deleted file mode 100644 index 344add1..0000000 --- a/heisskleber/network/mqtt/msb_mqtt.py +++ /dev/null @@ -1,18 +0,0 @@ -from heisskleber.config import load_config -from heisskleber.network import get_publisher, get_subscriber - -from .config import MqttConf -from .forwarder import ZMQ_to_MQTT_Forwarder - - -def main(): - config = load_config(MqttConf(), "mqtt") - for topic in config.topics: - print(f"Subscribing to {topic}") - - zmq_sub = get_subscriber("zmq", list(config.topics)) - mqtt_pub = get_publisher("mqtt") - forwarder = ZMQ_to_MQTT_Forwarder(config, subscriber=zmq_sub, publisher=mqtt_pub) - - # Wait for zmq messages, publish as mqtt message - forwarder.zmq_to_mqtt_loop() diff --git a/heisskleber/network/mqtt/publisher.py b/heisskleber/network/mqtt/publisher.py deleted file mode 100644 index 4acf4a9..0000000 --- a/heisskleber/network/mqtt/publisher.py +++ /dev/null @@ -1,63 +0,0 @@ -from __future__ import annotations - -from heisskleber.config import load_config -from heisskleber.network.packer import get_packer -from heisskleber.network.pubsub.types import Publisher - -from .config import MqttConf -from .mqtt_base import MQTT_Base - - -class MqttPublisher(MQTT_Base, Publisher): - """ - MQTT publisher class. - Can be used everywhere that a flucto style publishing connection is required. - - Network message loop is handled in a separated thread. - """ - - def __init__(self, config: MqttConf): - super().__init__(config) - self.pack = get_packer(config.packstyle) - - def send(self, topic: str | bytes, data: dict): - """ - Takes python dictionary, serializes it according to the packstyle - and sends it to the broker. - - Publishing is asynchronous - """ - self._raise_if_thread_died() - if isinstance(topic, bytes): - topic = topic.decode() - - payload = self.pack(data) - self.client.publish( - topic, payload, qos=self.config.qos, retain=self.config.retain - ) - - -def get_mqtt_publisher() -> MqttPublisher: - """ - Generate mqtt publisher with configuration from yaml file, - falls back to default values if no config is found - """ - import os - - if "MSB_CONFIG_DIR" in os.environ: - print("loading mqtt config") - config = load_config(MqttConf(), "mqtt", read_commandline=False) - else: - print("using default mqtt config") - config = MqttConf() - return MqttPublisher(config) - - -def get_default_publisher() -> MqttPublisher: - """ - Generate mqtt publisher with configuration from yaml file, - falls back to default values if no config is found - - Deprecated, use get_mqtt_publisher() instead - """ - return get_mqtt_publisher() diff --git a/heisskleber/network/mqtt/subscriber.py b/heisskleber/network/mqtt/subscriber.py deleted file mode 100644 index e684bff..0000000 --- a/heisskleber/network/mqtt/subscriber.py +++ /dev/null @@ -1,104 +0,0 @@ -from __future__ import annotations - -from queue import SimpleQueue - -from heisskleber.config import load_config -from heisskleber.network.packer import get_unpacker -from heisskleber.network.pubsub.types import Subscriber - -from .config import MqttConf -from .mqtt_base import MQTT_Base - - -class MqttSubscriber(MQTT_Base, Subscriber): - """ - MQTT subscriber, wraps around ecplipse's paho mqtt client. - Network message loop is handled in a separated thread. - - Incoming messages are saved as a stack when not processed via the receive() function. - """ - - def __init__(self, topics, config: MqttConf): - super().__init__(config) - self._message_queue = SimpleQueue() - self.subscribe(topics) - self.client.on_message = self._on_message - self.unpack = get_unpacker(config.packstyle) - - def _subscribe_single_topic(self, topic: bytes | str): - if isinstance(topic, bytes): - topic = topic.decode() - if self.config.verbose: - print(f"Subscribed to: {topic}") - self.client.subscribe(topic, self.config.qos) - - def _subscribe_multiple_topics(self, topics: list[bytes] | list[str]): - topics = [ - topic.decode() if isinstance(topic, bytes) else topic for topic in topics - ] - subscription_list = [(topic, self.config.qos) for topic in topics] - if self.config.verbose: - print(f"Subscribed to: {topics}") - self.client.subscribe(subscription_list) - - def subscribe(self, topics): - """ - Subscribe to one or multiple topics - """ - # if subscribing to multiple topics, use a list of tuples - if isinstance(topics, (list, tuple)): - self._subscribe_multiple_topics(topics) - else: - self.client.subscribe(topics, self.config.qos) - - def receive(self) -> tuple[bytes, dict]: - """ - Reads a message from mqtt and returns it - - Messages are saved in a stack, if no message is available, this function blocks. - - Returns: - tuple(topic: bytes, message: dict): the message received - """ - self._raise_if_thread_died() - mqtt_message = self._message_queue.get( - block=True, timeout=self.config.timeout_s - ) - - topic = mqtt_message.topic.encode("utf-8") - message_returned = self.unpack(mqtt_message.payload.decode()) - return (topic, message_returned) - - # callback to add incoming messages onto stack - def _on_message(self, client, userdata, message): - self._message_queue.put(message) - - if self.config.verbose: - print(f"Topic: {message.topic}") - print(f"MQTT message: {message.payload.decode()}") - - -def get_mqtt_subscriber(topic: bytes | str) -> MqttSubscriber: - """ - Generate mqtt subscriber with configuration from yaml file, - falls back to default values if no config is found - """ - import os - - if "MSB_CONFIG_DIR" in os.environ: - print("loading mqtt config") - config = load_config(MqttConf(), "mqtt", read_commandline=False) - else: - print("using default mqtt config") - config = MqttConf() - return MqttSubscriber(topic, config) - - -def get_default_subscriber(topic: bytes | str) -> MqttSubscriber: - """ - Generate mqtt subscriber with configuration from yaml file, - falls back to default values if no config is found - - Deprecated, use get_mqtt_subscriber(topic) instead. - """ - return get_mqtt_subscriber(topic) diff --git a/heisskleber/network/pubsub/types.py b/heisskleber/network/pubsub/types.py deleted file mode 100644 index 8fe5184..0000000 --- a/heisskleber/network/pubsub/types.py +++ /dev/null @@ -1,31 +0,0 @@ -from __future__ import annotations - -from abc import ABC, abstractmethod - - -class Publisher(ABC): - """ - Publisher interface. - """ - - @abstractmethod - def send(self, topic: str | bytes, data: dict): - """ - Send data via the implemented output stream. - """ - pass - - -class Subscriber(ABC): - """ - Subscriber interface - """ - - @abstractmethod - def receive(self) -> tuple[bytes, dict]: - """ - Blocking function to receive data from the implemented input stream. - - Data is returned as a tuple of (topic, data). - """ - pass diff --git a/heisskleber/network/serial/__init__.py b/heisskleber/network/serial/__init__.py deleted file mode 100644 index 12cd737..0000000 --- a/heisskleber/network/serial/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .config import SerialConf # noqa: F401 -from .publisher import SerialPublisher # noqa: F401 -from .subscriber import SerialSubscriber # noqa: F401 diff --git a/heisskleber/network/types.py b/heisskleber/network/types.py deleted file mode 100644 index d701017..0000000 --- a/heisskleber/network/types.py +++ /dev/null @@ -1 +0,0 @@ -from .pubsub.types import Publisher, Subscriber # noqa: F401 diff --git a/heisskleber/network/udp/__init__.py b/heisskleber/network/udp/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/heisskleber/network/udp/publisher.py b/heisskleber/network/udp/publisher.py deleted file mode 100644 index 5499a69..0000000 --- a/heisskleber/network/udp/publisher.py +++ /dev/null @@ -1,49 +0,0 @@ -import socket - -from heisskleber.network.packer import get_packer -from heisskleber.network.pubsub.types import Publisher -from heisskleber.network.udp.config import UDPConf - - -class UDP_Publisher(Publisher): - def __init__(self, config): - self.config = config - self.ip = self.config.ip - self.port = self.config.port - self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - self.packer = get_packer(self.config.packer) - - def send(self, topic, message): - payload = self.packer(message) - payload = payload.encode("utf-8") - self.socket.sendto(payload, (self.ip, self.port)) - - def __del__(self): - self.socket.close() - - -def udp_sender(): - target_ip = "127.0.0.1" # Replace this with the receiver's IP address - target_port = 12345 # Replace this with the receiver's port number - - message = "Hello, UDP Receiver!" - - # Create a UDP socket - udp_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - - try: - # Send the message to the receiver - udp_socket.sendto(message.encode("utf-8"), (target_ip, target_port)) - print("Message sent successfully!") - except Exception as e: - print("Error occurred while sending the message:", str(e)) - finally: - udp_socket.close() - - -if __name__ == "__main__": - conf = UDPConf(ip="192.168.1.122", port=12345) - pub = UDP_Publisher(conf) - - pub.send("test", {"test": "test"}) - # pub.send("test", "Hi from pub") diff --git a/heisskleber/network/udp/subscriber.py b/heisskleber/network/udp/subscriber.py deleted file mode 100644 index ffaab08..0000000 --- a/heisskleber/network/udp/subscriber.py +++ /dev/null @@ -1,34 +0,0 @@ -import socket - -from heisskleber.network.packer import get_unpacker -from heisskleber.network.pubsub.types import Subscriber -from heisskleber.network.udp.config import UDPConf - - -class UDP_Subscriber(Subscriber): - def __init__(self, config, topic=None): - self.config = config - self.ip = self.config.ip - self.port = self.config.port - self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - self.socket.bind((self.ip, self.port)) - self.unpacker = get_unpacker(self.config.packer) - - def receive(self): - payload, addr = self.socket.recvfrom(1024) - return addr, self.unpacker(payload.decode("utf-8")) - - def listen_loop(self): - while True: - addr, data = self.receive() - print(type(data)) - print(data) - - def __del__(self): - self.socket.close() - - -if __name__ == "__main__": - conf = UDPConf(ip="192.168.1.122", port=12345) - sub = UDP_Subscriber(conf) - sub.listen_loop() diff --git a/heisskleber/network/zmq/__init__.py b/heisskleber/network/zmq/__init__.py deleted file mode 100644 index e66b53b..0000000 --- a/heisskleber/network/zmq/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .config import ZmqConf # noqa: F401 -from .publisher import ZmqPublisher # noqa: F401 -from .subscriber import ZmqSubscriber # noqa: F401 diff --git a/heisskleber/serial/__init__.py b/heisskleber/serial/__init__.py new file mode 100644 index 0000000..7921657 --- /dev/null +++ b/heisskleber/serial/__init__.py @@ -0,0 +1,5 @@ +from .config import SerialConf +from .publisher import SerialPublisher +from .subscriber import SerialSubscriber + +__all__ = ["SerialConf", "SerialPublisher", "SerialSubscriber"] diff --git a/heisskleber/network/serial/config.py b/heisskleber/serial/config.py similarity index 100% rename from heisskleber/network/serial/config.py rename to heisskleber/serial/config.py diff --git a/heisskleber/network/serial/forwarder.py b/heisskleber/serial/forwarder.py similarity index 62% rename from heisskleber/network/serial/forwarder.py rename to heisskleber/serial/forwarder.py index 098b661..0f385e0 100644 --- a/heisskleber/network/serial/forwarder.py +++ b/heisskleber/serial/forwarder.py @@ -1,10 +1,10 @@ -from heisskleber.network.types import Subscriber +from heisskleber.core.types import Subscriber from .publisher import SerialPublisher class SerialForwarder: - def __init__(self, subscriber: Subscriber, publisher: SerialPublisher): + def __init__(self, subscriber: Subscriber, publisher: SerialPublisher) -> None: self.sub = subscriber self.pub = publisher @@ -12,19 +12,20 @@ class SerialForwarder: Wait for message and forward """ - def forward_message(self): + def forward_message(self) -> None: # collected = {} # for sub in self.sub: # topic, data = sub.receive() # collected.update(data) - _, collected = self.sub.receive() + topic, data = self.sub.receive() - self.pub.send(collected) + # We send the topic and let the publisher decide what to do with it + self.pub.send(data, topic) """ Enter loop and continuously forward messages """ - def sub_pub_loop(self): + def sub_pub_loop(self) -> None: while True: self.forward_message() diff --git a/heisskleber/network/serial/publisher.py b/heisskleber/serial/publisher.py similarity index 70% rename from heisskleber/network/serial/publisher.py rename to heisskleber/serial/publisher.py index 7f67138..16b058f 100644 --- a/heisskleber/network/serial/publisher.py +++ b/heisskleber/serial/publisher.py @@ -1,11 +1,11 @@ from __future__ import annotations -from types import FunctionType +from typing import Callable, Optional import serial -from heisskleber.network.packer import get_packer -from heisskleber.network.pubsub.types import Publisher +from heisskleber.core.packer import get_packer +from heisskleber.core.types import Publisher, Serializable from .config import SerialConf @@ -23,12 +23,16 @@ class SerialPublisher(Publisher): Function to translate from a dict to a serialized string. """ - def __init__(self, config: SerialConf, pack_func: FunctionType | None = None): + def __init__( + self, + config: SerialConf, + pack_func: Optional[Callable] = None, # noqa: UP007 + ): self.config = config - self.packer = pack_func if pack_func else get_packer("serial") + self.pack = pack_func if pack_func else get_packer("serial") self._connect() - def _connect(self): + def _connect(self) -> None: self.serial: serial.Serial = serial.Serial( port=self.config.port, baudrate=self.config.baudrate, @@ -38,25 +42,23 @@ class SerialPublisher(Publisher): ) print(f"Successfully connected to serial device at port {self.config.port}") - def send(self, message: object): + def send(self, message: dict[str, Serializable], topic: str) -> None: """ Takes python dictionary, serializes it according to the packstyle and sends it to the broker. - Please note that this does not adhere to the interface, as there is no topic. - Parameters ---------- - message : object + message : dict object to be serialized and sent via the serial connection. Usually a dict. """ - payload = self.packer(message) + payload = self.pack(message) self.serial.write(payload.encode(self.config.encoding)) self.serial.flush() if self.config.verbose: - print(payload) + print(f"{topic}: {payload}") - def __del__(self): + def __del__(self) -> None: if not hasattr(self, "serial"): return if not self.serial.is_open: diff --git a/heisskleber/network/serial/subscriber.py b/heisskleber/serial/subscriber.py similarity index 70% rename from heisskleber/network/serial/subscriber.py rename to heisskleber/serial/subscriber.py index d6c203f..c441356 100644 --- a/heisskleber/network/serial/subscriber.py +++ b/heisskleber/serial/subscriber.py @@ -1,10 +1,11 @@ from __future__ import annotations +from collections.abc import Generator from typing import Callable, Optional import serial -from heisskleber.network.pubsub.types import Subscriber +from heisskleber.core.types import Subscriber from .config import SerialConf @@ -27,12 +28,13 @@ class SerialSubscriber(Subscriber): def __init__( self, - topics, config: SerialConf, - unpack_func: Optional[Callable] = None, # noqa: UP007 + topic: str | None = None, + custom_unpack: Optional[Callable] = None, # noqa: UP007 ): self.config = config - self.unpack = unpack_func if unpack_func else lambda x: x + self.topic = topic + self.unpack = custom_unpack if custom_unpack else lambda x: x # types: ignore self._connect() def _connect(self): @@ -45,7 +47,7 @@ class SerialSubscriber(Subscriber): ) print(f"Successfully connected to serial device at port {self.config.port}") - def receive(self) -> dict: + def receive(self) -> tuple[str, dict]: """ Wait for data to arrive on the serial port and return it. @@ -62,29 +64,30 @@ class SerialSubscriber(Subscriber): # port is a placeholder for topic return self.config.port, payload - def read_serial_port(self) -> str: + def read_serial_port(self) -> Generator[str, None, None]: + """ + Generator function reading from the serial port. + + Returns + ------- + :return: Generator[str, None, None] + Generator yielding strings read from the serial port + """ buffer = "" while True: try: - buffer = self.serial.readline().decode() + buffer = self.serial.readline().decode(self.config.encoding, "ignore") yield buffer except UnicodeError as e: if self.config.verbose: - print(f"Could not decode: {message}") + print(f"Could not decode: {buffer!r}") print(e) continue - def __del__(self): + def __del__(self) -> None: if not hasattr(self, "serial"): return if not self.serial.is_open: return self.serial.flush() self.serial.close() - - -if __name__ == "__main__": - config = SerialConf() - serial_reader = SerialSubscriber(config) - for message in serial_reader.receive(): - print(message) diff --git a/heisskleber/udp/__init__.py b/heisskleber/udp/__init__.py new file mode 100644 index 0000000..74c187d --- /dev/null +++ b/heisskleber/udp/__init__.py @@ -0,0 +1,5 @@ +from .config import UdpConf +from .publisher import UdpPublisher +from .subscriber import UdpSubscriber + +__all__ = ["UdpSubscriber", "UdpPublisher", "UdpConf"] diff --git a/heisskleber/network/udp/config.py b/heisskleber/udp/config.py similarity index 88% rename from heisskleber/network/udp/config.py rename to heisskleber/udp/config.py index 51c9ad9..f3e7210 100644 --- a/heisskleber/network/udp/config.py +++ b/heisskleber/udp/config.py @@ -4,7 +4,7 @@ from heisskleber.config import BaseConf @dataclass -class UDPConf(BaseConf): +class UdpConf(BaseConf): """ UDP configuration. """ diff --git a/heisskleber/udp/publisher.py b/heisskleber/udp/publisher.py new file mode 100644 index 0000000..161c766 --- /dev/null +++ b/heisskleber/udp/publisher.py @@ -0,0 +1,22 @@ +import socket + +from heisskleber.core.packer import get_packer +from heisskleber.core.types import Publisher, Serializable +from heisskleber.udp.config import UdpConf + + +class UdpPublisher(Publisher): + def __init__(self, config: UdpConf) -> None: + self.config = config + self.ip = self.config.ip + self.port = self.config.port + self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + self.pack = get_packer(self.config.packer) + + def send(self, message: dict[str, Serializable], topic: str) -> None: + message["topic"] = topic + payload = self.pack(message).encode("utf-8") + self.socket.sendto(payload, (self.ip, self.port)) + + def __del__(self) -> None: + self.socket.close() diff --git a/heisskleber/udp/subscriber.py b/heisskleber/udp/subscriber.py new file mode 100644 index 0000000..7a8a60f --- /dev/null +++ b/heisskleber/udp/subscriber.py @@ -0,0 +1,47 @@ +import socket +import threading +from queue import SimpleQueue + +from heisskleber.core.packer import get_unpacker +from heisskleber.core.types import Serializable, Subscriber +from heisskleber.udp.config import UdpConf + + +class UdpSubscriber(Subscriber): + def __init__(self, config: UdpConf, topic: str | None = None): + self.config = config + self.topic = topic + self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + self.socket.bind((self.config.ip, self.config.port)) + self.unpacker = get_unpacker(self.config.packer) + self._queue: SimpleQueue[tuple[str, dict[str, Serializable]]] = SimpleQueue() + self._running = threading.Event() + self._running.set() + self._thread: threading.Thread | None = None + + def receive(self) -> tuple[str, dict[str, Serializable]]: + return self._queue.get() + + def _loop(self) -> None: + while self._running.is_set(): + try: + payload, _ = self.socket.recvfrom(1024) + data = self.unpacker(payload.decode("utf-8")) + topic: str = str(data.pop("topic")) if "topic" in data else "" + self._queue.put((topic, data)) + except Exception as e: + error_message = f"Error in UDP listener loop: {e}" + print(error_message) + + def start_loop(self) -> None: + self._thread = threading.Thread(target=self._loop, daemon=True) + self._thread.start() + + def stop_loop(self) -> None: + self._running.clear() + if self._thread is not None: + self._thread.join() + self.socket.close() + + def __del__(self) -> None: + self.stop_loop() diff --git a/heisskleber/zmq/__init__.py b/heisskleber/zmq/__init__.py new file mode 100644 index 0000000..bbdfd1e --- /dev/null +++ b/heisskleber/zmq/__init__.py @@ -0,0 +1,5 @@ +from .config import ZmqConf +from .publisher import ZmqPublisher +from .subscriber import ZmqSubscriber + +__all__ = ["ZmqConf", "ZmqPublisher", "ZmqSubscriber"] diff --git a/heisskleber/network/zmq/config.py b/heisskleber/zmq/config.py similarity index 84% rename from heisskleber/network/zmq/config.py rename to heisskleber/zmq/config.py index 2243d2d..77a46bf 100644 --- a/heisskleber/network/zmq/config.py +++ b/heisskleber/zmq/config.py @@ -12,9 +12,9 @@ class ZmqConf(BaseConf): packstyle: str = "json" @property - def publisher_address(self): + def publisher_address(self) -> str: return f"{self.protocol}://{self.interface}:{self.publisher_port}" @property - def subscriber_address(self): + def subscriber_address(self) -> str: return f"{self.protocol}://{self.interface}:{self.subscriber_port}" diff --git a/heisskleber/network/zmq/publisher.py b/heisskleber/zmq/publisher.py similarity index 68% rename from heisskleber/network/zmq/publisher.py rename to heisskleber/zmq/publisher.py index 3d50841..164a4fd 100644 --- a/heisskleber/network/zmq/publisher.py +++ b/heisskleber/zmq/publisher.py @@ -2,8 +2,8 @@ import sys import zmq -from heisskleber.network.packer import get_packer -from heisskleber.network.pubsub.types import Publisher +from heisskleber.core.packer import get_packer +from heisskleber.core.types import Publisher, Serializable from .config import ZmqConf @@ -25,9 +25,9 @@ class ZmqPublisher(Publisher): print(f"failed to bind to zeromq socket: {e}") sys.exit(-1) - def send(self, topic: str, data: dict) -> None: - data = self.pack(data) - self.socket.send_multipart([topic, data.encode("utf-8")]) + def send(self, data: dict[str, Serializable], topic: str) -> None: + payload = self.pack(data) + self.socket.send_multipart([topic.encode(), payload.encode()]) def __del__(self): self.socket.close() diff --git a/heisskleber/network/zmq/subscriber.py b/heisskleber/zmq/subscriber.py similarity index 61% rename from heisskleber/network/zmq/subscriber.py rename to heisskleber/zmq/subscriber.py index 245a9cb..47cec13 100644 --- a/heisskleber/network/zmq/subscriber.py +++ b/heisskleber/zmq/subscriber.py @@ -4,14 +4,14 @@ import sys import zmq -from heisskleber.network.packer import get_unpacker -from heisskleber.network.pubsub.types import Subscriber +from heisskleber.core.packer import get_unpacker +from heisskleber.core.types import Subscriber from .config import ZmqConf class ZmqSubscriber(Subscriber): - def __init__(self, topic: bytes | str | list[bytes] | list[str], config: ZmqConf): + def __init__(self, config: ZmqConf, topic: str): self.config = config self.context = zmq.Context.instance() @@ -29,12 +29,10 @@ class ZmqSubscriber(Subscriber): print(f"failed to bind to zeromq socket: {e}") sys.exit(-1) - def _subscribe_single_topic(self, topic: bytes | str): - if isinstance(topic, str): - topic = topic.encode() - self.socket.setsockopt(zmq.SUBSCRIBE, topic) + def _subscribe_single_topic(self, topic: str): + self.socket.setsockopt(zmq.SUBSCRIBE, topic.encode()) - def subscribe(self, topic: bytes | str | list[bytes] | list[str]): + def subscribe(self, topic: str | list[str] | tuple[str]): # Accepts single topic or list of topics if isinstance(topic, (list, tuple)): for t in topic: @@ -42,15 +40,16 @@ class ZmqSubscriber(Subscriber): else: self._subscribe_single_topic(topic) - def receive(self) -> tuple[bytes, dict]: + def receive(self) -> tuple[str, dict]: """ reads a message from the zmq bus and returns it Returns: - tuple(topic: bytes, message: dict): the message received + tuple(topic: str, message: dict): the message received """ - (topic, message) = self.socket.recv_multipart() - message = self.unpack(message.decode()) + (topic, payload) = self.socket.recv_multipart() + message = self.unpack(payload.decode()) + topic = topic.decode() return (topic, message) def __del__(self): diff --git a/poetry.lock b/poetry.lock index 82ea010..9eaa0c5 100644 --- a/poetry.lock +++ b/poetry.lock @@ -772,6 +772,47 @@ files = [ [package.dependencies] setuptools = "*" +[[package]] +name = "numpy" +version = "1.26.1" +description = "Fundamental package for array computing in Python" +optional = false +python-versions = "<3.13,>=3.9" +files = [ + {file = "numpy-1.26.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:82e871307a6331b5f09efda3c22e03c095d957f04bf6bc1804f30048d0e5e7af"}, + {file = "numpy-1.26.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:cdd9ec98f0063d93baeb01aad472a1a0840dee302842a2746a7a8e92968f9575"}, + {file = "numpy-1.26.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d78f269e0c4fd365fc2992c00353e4530d274ba68f15e968d8bc3c69ce5f5244"}, + {file = "numpy-1.26.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8ab9163ca8aeb7fd32fe93866490654d2f7dda4e61bc6297bf72ce07fdc02f67"}, + {file = "numpy-1.26.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:78ca54b2f9daffa5f323f34cdf21e1d9779a54073f0018a3094ab907938331a2"}, + {file = "numpy-1.26.1-cp310-cp310-win32.whl", hash = "sha256:d1cfc92db6af1fd37a7bb58e55c8383b4aa1ba23d012bdbba26b4bcca45ac297"}, + {file = "numpy-1.26.1-cp310-cp310-win_amd64.whl", hash = "sha256:d2984cb6caaf05294b8466966627e80bf6c7afd273279077679cb010acb0e5ab"}, + {file = "numpy-1.26.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:cd7837b2b734ca72959a1caf3309457a318c934abef7a43a14bb984e574bbb9a"}, + {file = "numpy-1.26.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:1c59c046c31a43310ad0199d6299e59f57a289e22f0f36951ced1c9eac3665b9"}, + {file = "numpy-1.26.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d58e8c51a7cf43090d124d5073bc29ab2755822181fcad978b12e144e5e5a4b3"}, + {file = "numpy-1.26.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6081aed64714a18c72b168a9276095ef9155dd7888b9e74b5987808f0dd0a974"}, + {file = "numpy-1.26.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:97e5d6a9f0702c2863aaabf19f0d1b6c2628fbe476438ce0b5ce06e83085064c"}, + {file = "numpy-1.26.1-cp311-cp311-win32.whl", hash = "sha256:b9d45d1dbb9de84894cc50efece5b09939752a2d75aab3a8b0cef6f3a35ecd6b"}, + {file = "numpy-1.26.1-cp311-cp311-win_amd64.whl", hash = "sha256:3649d566e2fc067597125428db15d60eb42a4e0897fc48d28cb75dc2e0454e53"}, + {file = "numpy-1.26.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:1d1bd82d539607951cac963388534da3b7ea0e18b149a53cf883d8f699178c0f"}, + {file = "numpy-1.26.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:afd5ced4e5a96dac6725daeb5242a35494243f2239244fad10a90ce58b071d24"}, + {file = "numpy-1.26.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a03fb25610ef560a6201ff06df4f8105292ba56e7cdd196ea350d123fc32e24e"}, + {file = "numpy-1.26.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dcfaf015b79d1f9f9c9fd0731a907407dc3e45769262d657d754c3a028586124"}, + {file = "numpy-1.26.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:e509cbc488c735b43b5ffea175235cec24bbc57b227ef1acc691725beb230d1c"}, + {file = "numpy-1.26.1-cp312-cp312-win32.whl", hash = "sha256:af22f3d8e228d84d1c0c44c1fbdeb80f97a15a0abe4f080960393a00db733b66"}, + {file = "numpy-1.26.1-cp312-cp312-win_amd64.whl", hash = "sha256:9f42284ebf91bdf32fafac29d29d4c07e5e9d1af862ea73686581773ef9e73a7"}, + {file = "numpy-1.26.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:bb894accfd16b867d8643fc2ba6c8617c78ba2828051e9a69511644ce86ce83e"}, + {file = "numpy-1.26.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:e44ccb93f30c75dfc0c3aa3ce38f33486a75ec9abadabd4e59f114994a9c4617"}, + {file = "numpy-1.26.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9696aa2e35cc41e398a6d42d147cf326f8f9d81befcb399bc1ed7ffea339b64e"}, + {file = "numpy-1.26.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a5b411040beead47a228bde3b2241100454a6abde9df139ed087bd73fc0a4908"}, + {file = "numpy-1.26.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:1e11668d6f756ca5ef534b5be8653d16c5352cbb210a5c2a79ff288e937010d5"}, + {file = "numpy-1.26.1-cp39-cp39-win32.whl", hash = "sha256:d1d2c6b7dd618c41e202c59c1413ef9b2c8e8a15f5039e344af64195459e3104"}, + {file = "numpy-1.26.1-cp39-cp39-win_amd64.whl", hash = "sha256:59227c981d43425ca5e5c01094d59eb14e8772ce6975d4b2fc1e106a833d5ae2"}, + {file = "numpy-1.26.1-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:06934e1a22c54636a059215d6da99e23286424f316fddd979f5071093b648668"}, + {file = "numpy-1.26.1-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:76ff661a867d9272cd2a99eed002470f46dbe0943a5ffd140f49be84f68ffc42"}, + {file = "numpy-1.26.1-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:6965888d65d2848e8768824ca8288db0a81263c1efccec881cb35a0d805fcd2f"}, + {file = "numpy-1.26.1.tar.gz", hash = "sha256:c8c6c72d4a9f831f328efb1312642a1cafafaa88981d9ab76368d50d07d93cbe"}, +] + [[package]] name = "packaging" version = "23.2" @@ -796,6 +837,21 @@ files = [ [package.extras] proxy = ["PySocks"] +[[package]] +name = "pandas-stubs" +version = "2.1.1.230928" +description = "Type annotations for pandas" +optional = false +python-versions = ">=3.9" +files = [ + {file = "pandas_stubs-2.1.1.230928-py3-none-any.whl", hash = "sha256:992d97159e054ca3175ebe8321ac5616cf6502dd8218b03bb0eaf3c4f6939037"}, + {file = "pandas_stubs-2.1.1.230928.tar.gz", hash = "sha256:ce1691c71c5d67b8f332da87763f7f54650f46895d99964d588c3a5d79e2cacc"}, +] + +[package.dependencies] +numpy = {version = ">=1.26.0", markers = "python_version < \"3.13\""} +types-pytz = ">=2022.1.1" + [[package]] name = "pathspec" version = "0.11.2" @@ -1628,6 +1684,39 @@ typing-extensions = {version = ">=4.7.0", markers = "python_version < \"3.12\""} doc = ["Sphinx (>=7)", "packaging", "sphinx-autodoc-typehints (>=1.2.0)"] test = ["coverage[toml] (>=7)", "mypy (>=1.2.0)", "pytest (>=7)"] +[[package]] +name = "types-paho-mqtt" +version = "1.6.0.7" +description = "Typing stubs for paho-mqtt" +optional = false +python-versions = "*" +files = [ + {file = "types-paho-mqtt-1.6.0.7.tar.gz", hash = "sha256:fe34c68abc849cd96e1482138bbdf5f465de59629dd367cb3a2423dd9ca3220b"}, + {file = "types_paho_mqtt-1.6.0.7-py3-none-any.whl", hash = "sha256:50313d93f63d777da391acaac0278d346cf9e4a2576d814989d6500bd0ca4a35"}, +] + +[[package]] +name = "types-pytz" +version = "2023.3.1.1" +description = "Typing stubs for pytz" +optional = false +python-versions = "*" +files = [ + {file = "types-pytz-2023.3.1.1.tar.gz", hash = "sha256:cc23d0192cd49c8f6bba44ee0c81e4586a8f30204970fc0894d209a6b08dab9a"}, + {file = "types_pytz-2023.3.1.1-py3-none-any.whl", hash = "sha256:1999a123a3dc0e39a2ef6d19f3f8584211de9e6a77fe7a0259f04a524e90a5cf"}, +] + +[[package]] +name = "types-pyyaml" +version = "6.0.12.12" +description = "Typing stubs for PyYAML" +optional = false +python-versions = "*" +files = [ + {file = "types-PyYAML-6.0.12.12.tar.gz", hash = "sha256:334373d392fde0fdf95af5c3f1661885fa10c52167b14593eb856289e1855062"}, + {file = "types_PyYAML-6.0.12.12-py3-none-any.whl", hash = "sha256:c05bc6c158facb0676674b7f11fe3960db4f389718e19e62bd2b84d6205cfd24"}, +] + [[package]] name = "typing-extensions" version = "4.8.0" @@ -1723,4 +1812,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "db06eb442b1bc7543792d552a222eb76bc2a8eafbda5dea863e57b2fae7e3857" +content-hash = "eb180d289a329d0441fa35ef046cc807294f7ff8bdba0080b7a6f95f02d41d93" diff --git a/pyproject.toml b/pyproject.toml index 7bb5834..ba11b84 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "heisskleber" -version = "0.1.1" +version = "0.2.0" description = "Heisskleber" authors = ["Felix Weiler "] license = "MIT" @@ -39,6 +39,12 @@ typeguard = ">=2.13.3" xdoctest = { extras = ["colors"], version = ">=0.15.10" } myst-parser = { version = ">=0.16.1" } + +[tool.poetry.group.types.dependencies] +pandas-stubs = "^2.1.1.230928" +types-pyyaml = "^6.0.12.12" +types-paho-mqtt = "^1.6.0.7" + [tool.poetry.scripts] heisskleber = "heisskleber.__main__:main" @@ -60,6 +66,7 @@ warn_unreachable = true pretty = true show_column_numbers = true show_error_context = true +exclude = ["tests/*", "^test_*\\.py"] [tool.ruff] ignore-init-module-imports = true @@ -85,9 +92,10 @@ select = [ "TRY", # tryceratops ] ignore = [ - "E501", # LineTooLong - "E731", # DoNotAssignLambda - "A001", # + "E501", # LineTooLong + "E731", # DoNotAssignLambda + "A001", # + "PGH003", # Use specific rules when ignoring type issues ] [tool.ruff.per-file-ignores] diff --git a/tests/test_import.py b/tests/test_import.py index c440a99..438fdf2 100644 --- a/tests/test_import.py +++ b/tests/test_import.py @@ -1,31 +1,26 @@ def test_import_mqtt(): - from heisskleber.network.mqtt import ( - MqttConf, - MqttPublisher, - MqttSubscriber, - ) + import heisskleber + from heisskleber.mqtt import MqttConf, MqttPublisher, MqttSubscriber + + assert heisskleber.__all__ == [ + "get_publisher", + "get_subscriber", + "Publisher", + "Subscriber", + ] def test_import_zmq(): - from heisskleber.network.zmq import ( - ZmqConf, - ZmqPublisher, - ZmqSubscriber, - ) + from heisskleber.zmq import ZmqConf, ZmqPublisher, ZmqSubscriber def test_import_serial(): - from heisskleber.network.serial import ( - SerialConf, - SerialPublisher, - SerialSubscriber, - ) + from heisskleber.serial import SerialConf, SerialPublisher, SerialSubscriber def test_import_utils(): - from heisskleber.network import get_publisher, get_subscriber - from heisskleber.network.types import Publisher, Subscriber + from heisskleber import Publisher, Subscriber, get_publisher, get_subscriber def test_import_config(): - pass + from heisskleber.config import BaseConf, load_config diff --git a/tests/test_mqtt.py b/tests/test_mqtt.py new file mode 100644 index 0000000..006ee04 --- /dev/null +++ b/tests/test_mqtt.py @@ -0,0 +1,162 @@ +import json +from queue import SimpleQueue +from unittest.mock import call, patch + +import pytest +from paho.mqtt.client import MQTTMessage + +from heisskleber.mqtt.config import MqttConf +from heisskleber.mqtt.mqtt_base import MqttBase +from heisskleber.mqtt.subscriber import MqttSubscriber + + +# Mock configuration for MQTT_Base +@pytest.fixture +def mock_mqtt_conf() -> MqttConf: + return MqttConf( + broker="localhost", + port=1883, + user="user", + password="passwd", # noqa: S106, this is a test password + ssl=False, + verbose=False, + qos=1, + ) + + +# Mock the paho mqtt client +@pytest.fixture +def mock_mqtt_client(): + with patch("heisskleber.mqtt.mqtt_base.mqtt_client", autospec=True) as mock: + yield mock + + +@pytest.fixture +def mock_queue(): + with patch("heisskleber.mqtt.subscriber.SimpleQueue", spec=SimpleQueue) as mock: + yield mock + + +def test_mqtt_base_intialization(mock_mqtt_client, mock_mqtt_conf): + """Test that the intialization of the mqtt client is as expected.""" + base = MqttBase(config=mock_mqtt_conf) + + mock_mqtt_client.assert_called_once() + mock_mqtt_client.return_value.loop_start.assert_called_once() + mock_client_instance = mock_mqtt_client.return_value + mock_client_instance.username_pw_set.assert_called_with( + mock_mqtt_conf.user, mock_mqtt_conf.password + ) + mock_client_instance.connect.assert_called_with( + mock_mqtt_conf.broker, mock_mqtt_conf.port + ) + assert base.client.on_connect == base._on_connect + assert base.client.on_disconnect == base._on_disconnect + assert base.client.on_publish == base._on_publish + assert base.client.on_message == base._on_message + + +def test_mqtt_base_on_connect(mock_mqtt_client, mock_mqtt_conf, capsys): + base = MqttBase(config=mock_mqtt_conf) + base._on_connect(None, None, {}, 0) + captured = capsys.readouterr() + assert ( + f"MQTT node connected to {mock_mqtt_conf.broker}:{mock_mqtt_conf.port}" + in captured.out + ) + + +def test_mqtt_base_on_disconnect_with_error(mock_mqtt_client, mock_mqtt_conf, capsys): + """Assert that the mqtt client shuts down when disconnect callback is received.""" + base = MqttBase(config=mock_mqtt_conf) + with pytest.raises(SystemExit): + base._on_disconnect(None, None, 1) + captured = capsys.readouterr() + assert "Killing this service" in captured.out + print(captured.out) + + +def test_mqtt_subscribes_single_topic(mock_mqtt_client, mock_mqtt_conf): + """Test that the mqtt client subscribes to a single topic.""" + _ = MqttSubscriber(topics="singleTopic", config=mock_mqtt_conf) + + actual_calls = mock_mqtt_client.return_value.subscribe.call_args_list + assert actual_calls == [call("singleTopic", mock_mqtt_conf.qos)] + + +def test_mqtt_subscribes_multiple_topics(mock_mqtt_client, mock_mqtt_conf): + """Test that the mqtt client subscribes to multiple topics passed as list. + + I would love to do this via parametrization, but the call argument is built differently for single size lists and longer lists. + """ + _ = MqttSubscriber(topics=["multiple1", "multiple2"], config=mock_mqtt_conf) + + actual_calls = mock_mqtt_client.return_value.subscribe.call_args_list + assert actual_calls == [ + call([("multiple1", mock_mqtt_conf.qos), ("multiple2", mock_mqtt_conf.qos)]), + ] + + +def test_mqtt_subscribes_multiple_topics_tuple(mock_mqtt_client, mock_mqtt_conf): + """Test that the mqtt client subscribes to multiple topics passed as tuple.""" + _ = MqttSubscriber(topics=("multiple1", "multiple2"), config=mock_mqtt_conf) + + actual_calls = mock_mqtt_client.return_value.subscribe.call_args_list + assert actual_calls == [ + call([("multiple1", mock_mqtt_conf.qos), ("multiple2", mock_mqtt_conf.qos)]), + ] + + +def create_fake_mqtt_message(topic: bytes, payload: bytes) -> MQTTMessage: + msg = MQTTMessage() + msg.topic = topic + msg.payload = payload + return msg + + +def test_receive_with_message(mock_mqtt_conf: MqttConf, mock_mqtt_client, mock_queue): + """Test the mqtt receive function with fake MQTT messages.""" + topic = b"test/topic" + payload = json.dumps({"key": "value"}).encode() + fake_message = create_fake_mqtt_message(topic, payload) + + mock_queue.return_value.get.side_effect = [fake_message] + subscriber = MqttSubscriber(topics=[topic], config=mock_mqtt_conf) + + received_topic, received_payload = subscriber.receive() + + assert received_topic == "test/topic" + assert received_payload == {"key": "value"} + + +def test_message_is_put_into_queue( + mock_mqtt_conf: MqttConf, mock_mqtt_client, mock_queue +): + """Test that values a put into a queue when on_message callback is called.""" + topic = b"test/topic" + payload = json.dumps({"key": "value"}).encode() + fake_message = create_fake_mqtt_message(topic, payload) + + mock_queue.return_value.get.side_effect = [fake_message] + subscriber = MqttSubscriber(topics=[topic], config=mock_mqtt_conf) + + subscriber._on_message(None, None, fake_message) + + mock_queue.return_value.put.assert_called_once_with(fake_message) + + +def test_message_is_put_into_queue_with_actual_queue(mock_mqtt_conf, mock_mqtt_client): + """Test that the buffering via queue works as expected.""" + topic = b"test/topic" + payload = json.dumps({"key": "value"}).encode() + fake_message = create_fake_mqtt_message(topic, payload) + + # mock_queue.return_value.get.side_effect = [fake_message] + subscriber = MqttSubscriber(topics=[topic], config=mock_mqtt_conf) + + subscriber._on_message(None, None, fake_message) + + topic, return_dict = subscriber.receive() + + assert topic == "test/topic" + assert return_dict == {"key": "value"} diff --git a/tests/test_packer.py b/tests/test_packer.py new file mode 100644 index 0000000..62d5dcc --- /dev/null +++ b/tests/test_packer.py @@ -0,0 +1,35 @@ +import json +import pickle +from typing import Any + +import pytest + +from heisskleber.core.packer import get_packer, get_unpacker, serialpacker + + +def test_get_packer() -> None: + assert get_packer("json") == json.dumps + assert get_packer("pickle") == pickle.dumps + assert get_packer("default") == json.dumps + assert get_packer("foobar") == json.dumps + assert get_packer("serial") == serialpacker + + +def test_get_unpacker() -> None: + assert get_unpacker("json") == json.loads + assert get_unpacker("pickle") == pickle.loads + assert get_unpacker("default") == json.loads + assert get_unpacker("foobar") == json.loads + + +@pytest.mark.parametrize( + "message,expected", + [ + ({"hi": 1, "da": 2, "nei": 3}, "1,2,3"), + ({"er": 1, "ma": "ga", "gerd": 3, "jo": 4}, "1,ga,3,4"), + ({"": 1, "ho": 0.0, "lee": 0.1, "shit": 1_000}, "1,0.0,0.1,1000"), + ({"be": 1e6, "li": 1_000}, "1000000.0,1000"), + ], +) +def test_serial_packer_functionality(message: dict[str, Any], expected: str) -> None: + assert serialpacker(message) == expected diff --git a/tests/test_serial.py b/tests/test_serial.py new file mode 100644 index 0000000..0e083fd --- /dev/null +++ b/tests/test_serial.py @@ -0,0 +1,117 @@ +from unittest.mock import Mock, patch + +import pytest +import serial + +from heisskleber.core.packer import serialpacker +from heisskleber.serial.config import SerialConf +from heisskleber.serial.publisher import SerialPublisher +from heisskleber.serial.subscriber import SerialSubscriber + + +@pytest.fixture +def serial_conf(): + return SerialConf(port="/dev/test", baudrate=9600, bytesize=8, verbose=False) + + +@pytest.fixture +def mock_serial_device_subscriber(): + with patch("heisskleber.serial.subscriber.serial.Serial") as mock: + yield mock + + +@pytest.fixture +def mock_serial_device_publisher(): + with patch("heisskleber.serial.publisher.serial.Serial") as mock: + yield mock + + +def test_serial_subscriber_initialization(mock_serial_device_subscriber, serial_conf): + """Test that the SerialSubscriber class initializes correctly. + Mocks the serial.Serial class to avoid opening a serial port.""" + _ = SerialSubscriber( + config=serial_conf, + topic="", + ) + mock_serial_device_subscriber.assert_called_with( + port=serial_conf.port, + baudrate=serial_conf.baudrate, + bytesize=serial_conf.bytesize, + parity=serial.PARITY_NONE, + stopbits=serial.STOPBITS_ONE, + ) + + +def test_serial_subscriber_receive(mock_serial_device_subscriber, serial_conf): + """Test that the SerialSubscriber class calls readline and unpack as expected.""" + subscriber = SerialSubscriber(config=serial_conf, topic="") + + # Set up the readline return value + mock_serial_instance = mock_serial_device_subscriber.return_value + mock_serial_instance.readline.return_value = b"test message\n" + + # Set up the unpack function to convert message to dict + unpack_func = Mock(return_value={"data": "test message"}) + subscriber.unpack = unpack_func + + # Call the receive method and assert it behaves as expected + _, payload = subscriber.receive() + + # Was readline called? + mock_serial_instance.readline.assert_called_once() + + # Was unpack called? + assert payload == {"data": "test message"} + unpack_func.assert_called_once_with("test message\n") + + +def test_serial_subscriber_converts_bytes_to_str(): + """Test that the SerialSubscriber class converts bytes to str as expected.""" + with patch("heisskleber.serial.subscriber.serial.Serial") as mock_serial: + subscriber = SerialSubscriber( + config=SerialConf(), topic="", custom_unpack=lambda x: x + ) + + # Set the readline method to raise UnicodeError + mock_serial_instance = mock_serial.return_value + mock_serial_instance.readline.side_effect = [b"test message", b"test\x86more"] + + _, payload = subscriber.receive() + assert payload == "test message" + + # Assert that none-unicode is skipped + _, payload = subscriber.receive() + assert payload == "testmore" + + +def test_serial_publisher_initialization(mock_serial_device_publisher, serial_conf): + """Test that the SerialPublisher class initializes correctly. + Mocks the serial.Serial class to avoid opening a serial port.""" + publisher = SerialPublisher(config=serial_conf) + mock_serial_device_publisher.assert_called_with( + port=serial_conf.port, + baudrate=serial_conf.baudrate, + bytesize=serial_conf.bytesize, + parity=serial.PARITY_NONE, + stopbits=serial.STOPBITS_ONE, + ) + assert publisher.serial + + +def test_serial_publisher_send(mock_serial_device_publisher, serial_conf): + """Test that the SerialPublisher class calls write and pack as expected.""" + publisher = SerialPublisher(config=serial_conf) + + # Set up the readline return value + mock_serial_instance = mock_serial_device_publisher.return_value + mock_serial_instance.readline.return_value = b"test message\n" + + # Set up the pack function to convert dict to comma separated string of values + publisher.pack = serialpacker + + # Call the receive method and assert it behaves as expected + publisher.send({"data": "test message", "more_data": "more message"}, "test") + + # Was write called with encoded payload? + mock_serial_instance.write.assert_called_once_with(b"test message,more message") + mock_serial_instance.flush.assert_called_once() diff --git a/tests/test_udp.py b/tests/test_udp.py new file mode 100644 index 0000000..09133a0 --- /dev/null +++ b/tests/test_udp.py @@ -0,0 +1,80 @@ +import json +import socket +from unittest.mock import patch + +import pytest + +from heisskleber.udp.config import UdpConf +from heisskleber.udp.publisher import UdpPublisher +from heisskleber.udp.subscriber import UdpSubscriber + + +@pytest.fixture +def mock_socket(): + with patch("heisskleber.udp.publisher.socket.socket") as mock_socket: + yield mock_socket + + +@pytest.fixture +def mock_conf(): + return UdpConf(ip="127.0.0.1", port=12345, packer="json") + + +def test_connects_to_socket(mock_socket, mock_conf) -> None: + _ = UdpPublisher(mock_conf) + + # constructor was called + mock_socket.assert_called_with(socket.AF_INET, socket.SOCK_DGRAM) + + +def test_closes_socket(mock_socket, mock_conf) -> None: + pub = UdpPublisher(mock_conf) + del pub + + # instace was closed + mock_socket.return_value.close.assert_called() + + +def test_packs_and_sends_message(mock_socket, mock_conf) -> None: + pub = UdpPublisher(mock_conf) + + # explicitly define packer to be json.dumps + assert pub.pack == json.dumps + + pub.send({"key": "val", "intkey": 1, "floatkey": 1.0}, "test") + + mock_socket.return_value.sendto.assert_called_with( + b'{"key": "val", "intkey": 1, "floatkey": 1.0, "topic": "test"}', + (str(mock_conf.ip), mock_conf.port), + ) + + +def test_subscriber_receives_message_from_queue(mock_conf) -> None: + sub = UdpSubscriber(mock_conf) + + test_topic, test_data = ("test", {"key": "val", "intkey": 1, "floatkey": 1.0}) + + sub._queue.put((test_topic, test_data)) + + topic, data = sub.receive() + assert test_topic == topic + assert test_data == data + + +@pytest.fixture +def udp_sub(mock_conf): + sub = UdpSubscriber(mock_conf) + sub.start_loop() + yield sub + + +def test_sends_message_between_pub_and_sub(udp_sub, mock_conf): + pub = UdpPublisher(mock_conf) + test_data = {"key": "val", "intkey": 1, "floatkey": 1.0} + test_topic = "test_topic" + + # Need to copy the dict, because the publisher will mutate it + pub.send(test_data.copy(), test_topic) + topic, data = udp_sub.receive() + assert test_topic == topic + assert test_data == data diff --git a/tests/test_version.py b/tests/test_version.py index 411bd58..f08109c 100644 --- a/tests/test_version.py +++ b/tests/test_version.py @@ -4,4 +4,4 @@ import heisskleber def test_heisskleber_version() -> None: """Test that the glue version is correct.""" - assert heisskleber.__version__ == "0.1.0" + assert heisskleber.__version__ == "0.2.0" diff --git a/tests/test_zmq.py b/tests/test_zmq.py new file mode 100644 index 0000000..cec2631 --- /dev/null +++ b/tests/test_zmq.py @@ -0,0 +1,14 @@ +from heisskleber.zmq.config import ZmqConf + + +def test_config_parses_correctly(): + conf = ZmqConf( + protocol="tcp", interface="localhost", publisher_port=5555, subscriber_port=5556 + ) + assert conf.protocol == "tcp" + assert conf.interface == "localhost" + assert conf.publisher_port == 5555 + assert conf.subscriber_port == 5556 + + assert conf.publisher_address == "tcp://localhost:5555" + assert conf.subscriber_address == "tcp://localhost:5556"