From 8cfc66ac45cdde06859ac04adffdb0e77079a47a Mon Sep 17 00:00:00 2001 From: Felix Weiler Date: Mon, 6 Nov 2023 22:52:48 +0100 Subject: [PATCH] WIP: Test, refactor --- Makefile | 2 +- docs/reference.md | 23 ++- heisskleber/__init__.py | 7 +- heisskleber/broker/__init__.py | 2 +- .../broker/{msb_broker.py => zmq_broker.py} | 17 +- heisskleber/{network => }/config.py | 0 heisskleber/config/__init__.py | 4 +- heisskleber/config/cmdline.py | 13 +- .../config/{MSBConfig.py => config.py} | 7 +- heisskleber/config/parse.py | 23 +-- heisskleber/config/zeromq.py | 40 ----- .../{network/influxdb => core}/__init__.py | 0 .../{network/pubsub => core}/factories.py | 17 +- heisskleber/{network => core}/packer.py | 16 +- heisskleber/core/types.py | 52 ++++++ .../{network/pubsub => influxdb}/__init__.py | 0 heisskleber/{network => }/influxdb/config.py | 0 .../{network => }/influxdb/subscriber.py | 6 +- heisskleber/{network => }/influxdb/writer.py | 0 heisskleber/mqtt/__init__.py | 5 + heisskleber/{network => }/mqtt/config.py | 2 +- heisskleber/{network => }/mqtt/forwarder.py | 8 +- heisskleber/{network => }/mqtt/mqtt_base.py | 18 +- heisskleber/mqtt/publisher.py | 36 ++++ heisskleber/mqtt/subscriber.py | 66 +++++++ heisskleber/network/__init__.py | 1 - heisskleber/network/mqtt/__init__.py | 3 - heisskleber/network/mqtt/msb_mqtt.py | 18 -- heisskleber/network/mqtt/publisher.py | 63 ------- heisskleber/network/mqtt/subscriber.py | 104 ----------- heisskleber/network/pubsub/types.py | 31 ---- heisskleber/network/serial/__init__.py | 3 - heisskleber/network/types.py | 1 - heisskleber/network/zmq/__init__.py | 3 - heisskleber/serial/__init__.py | 5 + heisskleber/{network => }/serial/config.py | 0 heisskleber/{network => }/serial/forwarder.py | 10 +- heisskleber/{network => }/serial/publisher.py | 28 +-- .../{network => }/serial/subscriber.py | 28 +-- heisskleber/{network => }/udp/__init__.py | 0 heisskleber/{network => }/udp/config.py | 0 heisskleber/{network => }/udp/publisher.py | 6 +- heisskleber/{network => }/udp/subscriber.py | 6 +- heisskleber/zmq/__init__.py | 5 + heisskleber/{network => }/zmq/config.py | 0 heisskleber/{network => }/zmq/publisher.py | 11 +- heisskleber/{network => }/zmq/subscriber.py | 4 +- pyproject.toml | 7 +- tests/test_import.py | 31 ++-- tests/test_mqtt.py | 162 ++++++++++++++++++ tests/test_packer.py | 35 ++++ tests/test_serial.py | 115 +++++++++++++ tests/test_version.py | 2 +- 53 files changed, 625 insertions(+), 421 deletions(-) rename heisskleber/broker/{msb_broker.py => zmq_broker.py} (71%) rename heisskleber/{network => }/config.py (100%) rename heisskleber/config/{MSBConfig.py => config.py} (79%) delete mode 100644 heisskleber/config/zeromq.py rename heisskleber/{network/influxdb => core}/__init__.py (100%) rename heisskleber/{network/pubsub => core}/factories.py (66%) rename heisskleber/{network => core}/packer.py (61%) create mode 100644 heisskleber/core/types.py rename heisskleber/{network/pubsub => influxdb}/__init__.py (100%) rename heisskleber/{network => }/influxdb/config.py (100%) rename heisskleber/{network => }/influxdb/subscriber.py (94%) rename heisskleber/{network => }/influxdb/writer.py (100%) create mode 100644 heisskleber/mqtt/__init__.py rename heisskleber/{network => }/mqtt/config.py (89%) rename heisskleber/{network => }/mqtt/forwarder.py (72%) rename heisskleber/{network => }/mqtt/mqtt_base.py (83%) create mode 100644 heisskleber/mqtt/publisher.py create mode 100644 heisskleber/mqtt/subscriber.py delete mode 100644 heisskleber/network/__init__.py delete mode 100644 heisskleber/network/mqtt/__init__.py delete mode 100644 heisskleber/network/mqtt/msb_mqtt.py delete mode 100644 heisskleber/network/mqtt/publisher.py delete mode 100644 heisskleber/network/mqtt/subscriber.py delete mode 100644 heisskleber/network/pubsub/types.py delete mode 100644 heisskleber/network/serial/__init__.py delete mode 100644 heisskleber/network/types.py delete mode 100644 heisskleber/network/zmq/__init__.py create mode 100644 heisskleber/serial/__init__.py rename heisskleber/{network => }/serial/config.py (100%) rename heisskleber/{network => }/serial/forwarder.py (75%) rename heisskleber/{network => }/serial/publisher.py (71%) rename heisskleber/{network => }/serial/subscriber.py (78%) rename heisskleber/{network => }/udp/__init__.py (100%) rename heisskleber/{network => }/udp/config.py (100%) rename heisskleber/{network => }/udp/publisher.py (89%) rename heisskleber/{network => }/udp/subscriber.py (83%) create mode 100644 heisskleber/zmq/__init__.py rename heisskleber/{network => }/zmq/config.py (100%) rename heisskleber/{network => }/zmq/publisher.py (68%) rename heisskleber/{network => }/zmq/subscriber.py (93%) create mode 100644 tests/test_mqtt.py create mode 100644 tests/test_packer.py create mode 100644 tests/test_serial.py diff --git a/Makefile b/Makefile index 78a4553..32b7eb7 100644 --- a/Makefile +++ b/Makefile @@ -47,7 +47,7 @@ docs-test: ## Test if documentation can be built without warnings or errors .PHONY: docs docs: ## Build and serve the documentation - @poetry run mkdocs serve + @poetry run python3 -m sphinx docs docs/_build -b html .PHONY: help help: diff --git a/docs/reference.md b/docs/reference.md index 20ff891..16622bd 100644 --- a/docs/reference.md +++ b/docs/reference.md @@ -3,22 +3,21 @@ ## Network ```{eval-rst} -.. automodule:: heisskleber.network +.. automodule:: heisskleber :members: -.. automodule:: heisskleber.network.mqtt -.. autoclass:: MqttPublisher -.. autoclass:: MqttSubscriber -.. automodule:: heisskleber.network.zmq +.. automodule:: heisskleber.mqtt + :members: +.. automodule:: heisskleber.zmq + :members: +.. automodule:: heisskleber.serial + :members: +.. automodule:: heisskleber.config :members: -.. autoclass:: ZmqPublisher -.. autoclass:: ZmqSubscriber ``` ### Broker ```{eval-rst} -.. automodule:: heisskleber.broker - :members: ``` ## Config @@ -34,7 +33,7 @@ Configs are extended dataclasses, which inherit from the BaseConf class. ```{eval-rst} .. autoclass:: heisskleber.config.BaseConf -.. autoclass:: heisskleber.network.mqtt.config.MqttConf -.. autoclass:: heisskleber.network.zmq.config.ZmqConf -.. autoclass:: heisskleber.network.serial.config.SerialConf +.. autoclass:: heisskleber.mqtt.config.MqttConf +.. autoclass:: heisskleber.zmq.config.ZmqConf +.. autoclass:: heisskleber.serial.config.SerialConf ``` diff --git a/heisskleber/__init__.py b/heisskleber/__init__.py index 2cb7032..0379734 100644 --- a/heisskleber/__init__.py +++ b/heisskleber/__init__.py @@ -1,5 +1,6 @@ """Heisskleber.""" -from .network import get_publisher, get_subscriber +from .core.factories import get_publisher, get_subscriber +from .core.types import Publisher, Subscriber -__all__ = ["get_publisher", "get_subscriber"] -__version__ = "0.1.0" +__all__ = ["get_publisher", "get_subscriber", "Publisher", "Subscriber"] +__version__ = "0.2.0" diff --git a/heisskleber/broker/__init__.py b/heisskleber/broker/__init__.py index 961abcd..17f19ef 100644 --- a/heisskleber/broker/__init__.py +++ b/heisskleber/broker/__init__.py @@ -1,3 +1,3 @@ -from .msb_broker import msb_broker as start_zmq_broker +from .msb_broker import zmq_broker as start_zmq_broker __all__ = ["start_zmq_broker"] diff --git a/heisskleber/broker/msb_broker.py b/heisskleber/broker/zmq_broker.py similarity index 71% rename from heisskleber/broker/msb_broker.py rename to heisskleber/broker/zmq_broker.py index db4394e..ff271b3 100644 --- a/heisskleber/broker/msb_broker.py +++ b/heisskleber/broker/zmq_broker.py @@ -2,9 +2,10 @@ import signal import sys import zmq +from zmq import Socket from heisskleber.config import load_config -from heisskleber.network.zmq.config import ZmqConf as BrokerConf +from heisskleber.zmq.config import ZmqConf as BrokerConf def signal_handler(sig, frame): @@ -16,29 +17,31 @@ class BrokerBindingError(Exception): pass -def bind_socket(socket, address, socket_type, verbose=False): +def bind_socket(socket: Socket, address: str, socket_type: str, verbose=False) -> None: """Bind a ZMQ socket and handle errors.""" if verbose: print(f"creating {socket_type} socket") try: socket.bind(address) except Exception as err: - raise BrokerBindingError(f"failed to bind to {socket_type}: {err}") from err + error_message = f"failed to bind to {socket_type}: {err}" + raise BrokerBindingError(error_message) from err if verbose: print(f"successfully bound to {socket_type} socket: {address}") -def create_proxy(xpub, xsub, verbose=False): +def create_proxy(xpub: Socket, xsub: Socket, verbose=False) -> None: """Create a ZMQ proxy to connect XPUB and XSUB sockets.""" if verbose: print("creating proxy") try: zmq.proxy(xpub, xsub) except Exception as err: - raise BrokerBindingError(f"failed to create proxy: {err}") from err + error_message = f"failed to create proxy: {err}" + raise BrokerBindingError(error_message) from err -def msb_broker(config: BrokerConf) -> None: +def zmq_broker(config: BrokerConf) -> None: """Start a zmq broker. Binds to a publisher and subscriber port, allowing many to many connections.""" @@ -60,4 +63,4 @@ def main() -> None: """Start a zmq broker, with a user specified configuration.""" signal.signal(signal.SIGINT, signal_handler) broker_config = load_config(BrokerConf(), "zmq") - msb_broker(broker_config) + zmq_broker(broker_config) diff --git a/heisskleber/network/config.py b/heisskleber/config.py similarity index 100% rename from heisskleber/network/config.py rename to heisskleber/config.py diff --git a/heisskleber/config/__init__.py b/heisskleber/config/__init__.py index c554645..8710d76 100644 --- a/heisskleber/config/__init__.py +++ b/heisskleber/config/__init__.py @@ -1,4 +1,4 @@ -from heisskleber.config.MSBConfig import BaseConf -from heisskleber.config.parse import load_config +from .config import BaseConf +from .parse import load_config __all__ = ["load_config", "BaseConf"] diff --git a/heisskleber/config/cmdline.py b/heisskleber/config/cmdline.py index 5a39f58..b71588c 100644 --- a/heisskleber/config/cmdline.py +++ b/heisskleber/config/cmdline.py @@ -2,18 +2,7 @@ import argparse class KeyValue(argparse.Action): - # Constructor calling - """ - def __call__( self , parser, namespace, values : list, option_string = None): - setattr(namespace, self.dest, dict()) - for value in values: - # split it into key and value - key, value = value.split('=') - # assign into dictionary - getattr(namespace, self.dest)[key] = value - """ - - def __call__(self, parser, args, values, option_string=None): + def __call__(self, parser, args, values, option_string=None) -> None: try: params = dict(x.split("=") for x in values) except ValueError as ex: diff --git a/heisskleber/config/MSBConfig.py b/heisskleber/config/config.py similarity index 79% rename from heisskleber/config/MSBConfig.py rename to heisskleber/config/config.py index 7d05478..69cc2d3 100644 --- a/heisskleber/config/MSBConfig.py +++ b/heisskleber/config/config.py @@ -1,6 +1,7 @@ import socket import warnings from dataclasses import dataclass +from typing import Any @dataclass @@ -12,18 +13,18 @@ class BaseConf: verbose: bool = False print_stdout: bool = False - def __setitem__(self, key, value): + def __setitem__(self, key: str, value: Any) -> None: if hasattr(self, key): self.__setattr__(key, value) else: warnings.warn(UserWarning(f"no such class member: {key}"), stacklevel=2) - def __getitem__(self, key): + def __getitem__(self, key: str) -> Any: if hasattr(self, key): return getattr(self, key) else: warnings.warn(UserWarning(f"no such class member: {key}"), stacklevel=2) @property - def serial_number(self): + def serial_number(self) -> str: return socket.gethostname().upper() diff --git a/heisskleber/config/parse.py b/heisskleber/config/parse.py index f6cfd2b..0d776df 100644 --- a/heisskleber/config/parse.py +++ b/heisskleber/config/parse.py @@ -8,7 +8,9 @@ import yaml from heisskleber.config import BaseConf from heisskleber.config.cmdline import get_cmdline -ConfigType = TypeVar("ConfigType", bound=BaseConf) +ConfigType = TypeVar( + "ConfigType", bound=BaseConf +) # https://stackoverflow.com/a/46227137 , https://docs.python.org/3/library/typing.html#typing.TypeVar def get_msb_config_filepath(config_filename: str = "heisskleber.conf") -> str: @@ -17,10 +19,10 @@ def get_msb_config_filepath(config_filename: str = "heisskleber.conf") -> str: config_filepath = os.path.join(os.environ["MSB_CONFIG_DIR"], config_subpath) except Exception as e: print(f"could no get MSB_CONFIG from PATH: {e}") - sys.exit() # TODO use 1 or the error str as exit value + sys.exit(1) if not os.path.isfile(config_filepath): print(f"not a file: {config_filepath}!") - sys.exit() + sys.exit(1) return config_filepath @@ -33,16 +35,12 @@ def update_config(config: ConfigType, config_dict: dict) -> ConfigType: for config_key, config_value in config_dict.items(): # get expected type of element from config_object: if not hasattr(config, config_key): - warnings.warn( - f"no such configuration parameter: {config_key}, skipping", stacklevel=2 - ) + error_msg = f"no such configuration parameter: {config_key}, skipping" + warnings.warn(error_msg, stacklevel=2) continue cast_func = type(config[config_key]) try: - if config_key == "topic": - config[config_key] = config_value.encode("utf-8") - else: - config[config_key] = cast_func(config_value) + config[config_key] = cast_func(config_value) except Exception as e: print( f"failed to cast {config_value} to {type(config[config_key])}: {e}. skipping" @@ -51,11 +49,6 @@ def update_config(config: ConfigType, config_dict: dict) -> ConfigType: return config -ConfigType = TypeVar( - "ConfigType", bound=BaseConf -) # https://stackoverflow.com/a/46227137 , https://docs.python.org/3/library/typing.html#typing.TypeVar - - def load_config( config: ConfigType, config_filename: str, read_commandline: bool = True ) -> ConfigType: diff --git a/heisskleber/config/zeromq.py b/heisskleber/config/zeromq.py deleted file mode 100644 index 571440a..0000000 --- a/heisskleber/config/zeromq.py +++ /dev/null @@ -1,40 +0,0 @@ -import sys - -import zmq - - -def open_zmq_sub_socket(connect_to: str, topic=b""): - ctx = zmq.Context() - zmq_socket = ctx.socket(zmq.SUB) - try: - zmq_socket.connect(connect_to) - except Exception as e: - print(f"failed to bind to zeromq socket: {e}") - sys.exit(-1) - zmq_socket.setsockopt(zmq.SUBSCRIBE, topic) - return zmq_socket - - -def open_zmq_pub_socket(connect_to: str): - ctx = zmq.Context() - zmq_socket = ctx.socket(zmq.PUB) - try: - zmq_socket.connect(connect_to) - except Exception as e: - print(f"failed to bind to zeromq socket: {e}") - sys.exit(-1) - return zmq_socket - - -def get_zmq_xpub_socketstring(msb_config: dict) -> str: - zmq_config = msb_config["zeromq"] - return ( - f"{zmq_config['protocol']}://{zmq_config['address']}:{zmq_config['xpub_port']}" - ) - - -def get_zmq_xsub_socketstring(msb_config: dict) -> str: - zmq_config = msb_config["zeromq"] - return ( - f"{zmq_config['protocol']}://{zmq_config['address']}:{zmq_config['xsub_port']}" - ) diff --git a/heisskleber/network/influxdb/__init__.py b/heisskleber/core/__init__.py similarity index 100% rename from heisskleber/network/influxdb/__init__.py rename to heisskleber/core/__init__.py diff --git a/heisskleber/network/pubsub/factories.py b/heisskleber/core/factories.py similarity index 66% rename from heisskleber/network/pubsub/factories.py rename to heisskleber/core/factories.py index a7dfdf9..158040b 100644 --- a/heisskleber/network/pubsub/factories.py +++ b/heisskleber/core/factories.py @@ -1,24 +1,25 @@ import os -from heisskleber.config import load_config -from heisskleber.network.mqtt import MqttConf, MqttPublisher, MqttSubscriber -from heisskleber.network.serial import SerialConf, SerialPublisher, SerialSubscriber -from heisskleber.network.zmq import ZmqConf, ZmqPublisher, ZmqSubscriber +from heisskleber.config import BaseConf, load_config +from heisskleber.core.types import Publisher, Subscriber +from heisskleber.mqtt import MqttConf, MqttPublisher, MqttSubscriber +from heisskleber.serial import SerialConf, SerialPublisher, SerialSubscriber +from heisskleber.zmq import ZmqConf, ZmqPublisher, ZmqSubscriber -_registered_publishers = { +_registered_publishers: dict[str, tuple[type[Publisher], type[BaseConf]]] = { "zmq": (ZmqPublisher, ZmqConf), "mqtt": (MqttPublisher, MqttConf), "serial": (SerialPublisher, SerialConf), } -_registered_subscribers = { +_registered_subscribers: dict[str, tuple[type[Subscriber], type[BaseConf]]] = { "zmq": (ZmqSubscriber, ZmqConf), "mqtt": (MqttSubscriber, MqttConf), "serial": (SerialSubscriber, SerialConf), } -def get_publisher(name: str): +def get_publisher(name: str) -> Publisher: if name not in _registered_publishers: error_message = f"{name} is not a registered Publisher." raise KeyError(error_message) @@ -35,7 +36,7 @@ def get_publisher(name: str): return pub_cls(config) -def get_subscriber(name: str, topic): +def get_subscriber(name: str, topic: str | list[str]) -> Subscriber: if name not in _registered_publishers: error_message = f"{name} is not a registered Subscriber." raise KeyError(error_message) diff --git a/heisskleber/network/packer.py b/heisskleber/core/packer.py similarity index 61% rename from heisskleber/network/packer.py rename to heisskleber/core/packer.py index 1924ba7..e0643b4 100644 --- a/heisskleber/network/packer.py +++ b/heisskleber/core/packer.py @@ -1,10 +1,10 @@ """Packer and unpacker for network data.""" import json import pickle -from typing import Callable +from typing import Any, Callable -def get_packer(style) -> Callable[[dict], str]: +def get_packer(style: str) -> Callable[[dict[str, Any]], Any]: """Return a packer function for the given style. Packer func serializes a given dict.""" @@ -14,7 +14,7 @@ def get_packer(style) -> Callable[[dict], str]: return _packstyles["default"] -def get_unpacker(style) -> Callable[[str], dict]: +def get_unpacker(style: str) -> Callable[[str | bytes], dict[str, Any] | str | bytes]: """Return an unpacker function for the given style. Unpacker func deserializes a string.""" @@ -24,21 +24,21 @@ def get_unpacker(style) -> Callable[[str], dict]: return _unpackstyles["default"] -def serialpacker(data: dict) -> str: +def serialpacker(data: dict[str, Any]) -> str: return ",".join([str(v) for v in data.values()]) -_packstyles = { +_packstyles: dict[str, Callable[[dict[str, Any]], str | bytes]] = { "default": json.dumps, "json": json.dumps, "pickle": pickle.dumps, "serial": serialpacker, - "raw": lambda x: x, + "raw": lambda x: x, # type: ignore } -_unpackstyles = { +_unpackstyles: dict[str, Callable[[str | bytes], dict[str, Any] | Any]] = { "default": json.loads, "json": json.loads, - "pickle": pickle.loads, + "pickle": pickle.loads, # type: ignore "raw": lambda x: x, } diff --git a/heisskleber/core/types.py b/heisskleber/core/types.py new file mode 100644 index 0000000..31bbba4 --- /dev/null +++ b/heisskleber/core/types.py @@ -0,0 +1,52 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Any, Callable, Union + +PayloadType = Union[str, int, float] + + +class Publisher(ABC): + """ + Publisher interface. + """ + + pack: Callable[[dict[str, Any] | Any], str] + + @abstractmethod + def __init__(self, config: Any) -> None: + """ + Initialize the publisher with a configuration object. + """ + pass + + @abstractmethod + def send(self, topic: str, data: dict[str, Any]) -> None: + """ + Send data via the implemented output stream. + """ + pass + + +class Subscriber(ABC): + """ + Subscriber interface + """ + + unpack: Callable[[bytes], dict[str, PayloadType] | Any] + + @abstractmethod + def __init__(self, topic: str | list[str], config: Any) -> None: + """ + Initialize the subscriber with a topic and a configuration object. + """ + pass + + @abstractmethod + def receive(self) -> tuple[str, dict[str, PayloadType]]: + """ + Blocking function to receive data from the implemented input stream. + + Data is returned as a tuple of (topic, data). + """ + pass diff --git a/heisskleber/network/pubsub/__init__.py b/heisskleber/influxdb/__init__.py similarity index 100% rename from heisskleber/network/pubsub/__init__.py rename to heisskleber/influxdb/__init__.py diff --git a/heisskleber/network/influxdb/config.py b/heisskleber/influxdb/config.py similarity index 100% rename from heisskleber/network/influxdb/config.py rename to heisskleber/influxdb/config.py diff --git a/heisskleber/network/influxdb/subscriber.py b/heisskleber/influxdb/subscriber.py similarity index 94% rename from heisskleber/network/influxdb/subscriber.py rename to heisskleber/influxdb/subscriber.py index a49433a..a51e927 100644 --- a/heisskleber/network/influxdb/subscriber.py +++ b/heisskleber/influxdb/subscriber.py @@ -1,6 +1,8 @@ import pandas as pd from influxdb_client import InfluxDBClient +from heisskleber.core.types import Subscriber + from .config import InfluxDBConf @@ -28,7 +30,7 @@ def build_query(options: dict) -> str: return query -class Influx_Subscriber: +class Influx_Subscriber(Subscriber): def __init__(self, config: InfluxDBConf, query: str): self.config = config self.query = query @@ -45,7 +47,7 @@ class Influx_Subscriber: self._run_query() self.index = 0 - def receive(self) -> dict: + def receive(self) -> tuple[str, dict]: row = self.df.iloc[self.index].to_dict() self.index += 1 return "influx", row diff --git a/heisskleber/network/influxdb/writer.py b/heisskleber/influxdb/writer.py similarity index 100% rename from heisskleber/network/influxdb/writer.py rename to heisskleber/influxdb/writer.py diff --git a/heisskleber/mqtt/__init__.py b/heisskleber/mqtt/__init__.py new file mode 100644 index 0000000..9e08600 --- /dev/null +++ b/heisskleber/mqtt/__init__.py @@ -0,0 +1,5 @@ +from .config import MqttConf +from .publisher import MqttPublisher +from .subscriber import MqttSubscriber + +__all__ = ["MqttConf", "MqttPublisher", "MqttSubscriber"] diff --git a/heisskleber/network/mqtt/config.py b/heisskleber/mqtt/config.py similarity index 89% rename from heisskleber/network/mqtt/config.py rename to heisskleber/mqtt/config.py index c716cb3..004d955 100644 --- a/heisskleber/network/mqtt/config.py +++ b/heisskleber/mqtt/config.py @@ -16,7 +16,7 @@ class MqttConf(BaseConf): ssl: bool = False qos: int = 0 retain: bool = False - topics: list[bytes] = field(default_factory=list) + topics: list[str] = field(default_factory=list) mapping: str = "/msb/" packstyle: str = "json" max_saved_messages: int = 100 diff --git a/heisskleber/network/mqtt/forwarder.py b/heisskleber/mqtt/forwarder.py similarity index 72% rename from heisskleber/network/mqtt/forwarder.py rename to heisskleber/mqtt/forwarder.py index e3a959a..c6125de 100644 --- a/heisskleber/network/mqtt/forwarder.py +++ b/heisskleber/mqtt/forwarder.py @@ -1,14 +1,14 @@ +from heisskleber import get_publisher, get_subscriber from heisskleber.config import load_config -from heisskleber.network import get_publisher, get_subscriber from .config import MqttConf -def map_topic(zmq_topic, mapping): - return mapping + zmq_topic.decode() +def map_topic(zmq_topic: str, mapping: str) -> str: + return mapping + zmq_topic -def main(): +def main() -> None: config: MqttConf = load_config(MqttConf(), "mqtt") sub = get_subscriber("zmq", config.topics) pub = get_publisher("mqtt") diff --git a/heisskleber/network/mqtt/mqtt_base.py b/heisskleber/mqtt/mqtt_base.py similarity index 83% rename from heisskleber/network/mqtt/mqtt_base.py rename to heisskleber/mqtt/mqtt_base.py index ed9fb2a..c5d5dcc 100644 --- a/heisskleber/network/mqtt/mqtt_base.py +++ b/heisskleber/mqtt/mqtt_base.py @@ -25,19 +25,19 @@ def _set_thread_died_excepthook(args, /): threading.excepthook = _set_thread_died_excepthook -class MQTT_Base: +class MqttBase: """ Wrapper around eclipse paho mqtt client. Handles connection and callbacks. Callbacks may be overwritten in subclasses. """ - def __init__(self, config: MqttConf): + def __init__(self, config: MqttConf) -> None: self.config = config self.connect() self.client.loop_start() - def connect(self): + def connect(self) -> None: self.client = mqtt_client() self.client.username_pw_set(self.config.user, self.config.password) @@ -55,13 +55,13 @@ class MQTT_Base: self.client.connect(self.config.broker, self.config.port) @staticmethod - def _raise_if_thread_died(): + def _raise_if_thread_died() -> None: global _thread_died if _thread_died.is_set(): raise ThreadDiedError() # MQTT callbacks - def _on_connect(self, client, userdata, flags, return_code): + def _on_connect(self, client, userdata, flags, return_code) -> None: if return_code == 0: print(f"MQTT node connected to {self.config.broker}:{self.config.port}") else: @@ -69,21 +69,21 @@ class MQTT_Base: if self.config.verbose: print(flags) - def _on_disconnect(self, client, userdata, return_code): + def _on_disconnect(self, client, userdata, return_code) -> None: print(f"Disconnected from broker with return code {return_code}") if return_code != 0: print("Killing this service") sys.exit(-1) - def _on_publish(self, client, userdata, message_id): + def _on_publish(self, client, userdata, message_id) -> None: if self.config.verbose: print(f"Published message with id {message_id}, qos={self.config.qos}") - def _on_message(self, client, userdata, message): + def _on_message(self, client, userdata, message) -> None: if self.config.verbose: print( f"Received message: {message.payload!s}, topic: {message.topic}, qos: {message.qos}" ) - def __del__(self): + def __del__(self) -> None: self.client.loop_stop() diff --git a/heisskleber/mqtt/publisher.py b/heisskleber/mqtt/publisher.py new file mode 100644 index 0000000..d58e89a --- /dev/null +++ b/heisskleber/mqtt/publisher.py @@ -0,0 +1,36 @@ +from __future__ import annotations + +from typing import Any + +from heisskleber.core.packer import get_packer +from heisskleber.core.types import Publisher + +from .config import MqttConf +from .mqtt_base import MqttBase + + +class MqttPublisher(MqttBase, Publisher): + """ + MQTT publisher class. + Can be used everywhere that a flucto style publishing connection is required. + + Network message loop is handled in a separated thread. + """ + + def __init__(self, config: MqttConf) -> None: + super().__init__(config) + self.pack = get_packer(config.packstyle) + + def send(self, topic: str, data: dict[str, Any]) -> None: + """ + Takes python dictionary, serializes it according to the packstyle + and sends it to the broker. + + Publishing is asynchronous + """ + self._raise_if_thread_died() + + payload = self.pack(data) + self.client.publish( + topic, payload, qos=self.config.qos, retain=self.config.retain + ) diff --git a/heisskleber/mqtt/subscriber.py b/heisskleber/mqtt/subscriber.py new file mode 100644 index 0000000..4bd8860 --- /dev/null +++ b/heisskleber/mqtt/subscriber.py @@ -0,0 +1,66 @@ +from __future__ import annotations + +from queue import SimpleQueue +from typing import Any + +from paho.mqtt.client import MQTTMessage + +from heisskleber.core.packer import get_unpacker +from heisskleber.core.types import Subscriber + +from .config import MqttConf +from .mqtt_base import MqttBase + + +class MqttSubscriber(MqttBase, Subscriber): + """ + MQTT subscriber, wraps around ecplipse's paho mqtt client. + Network message loop is handled in a separated thread. + + Incoming messages are saved as a stack when not processed via the receive() function. + """ + + def __init__(self, topics: str | list[str], config: MqttConf): + super().__init__(config) + self._message_queue: SimpleQueue[MQTTMessage] = SimpleQueue() + self.subscribe(topics) + self.client.on_message = self._on_message + self.unpack = get_unpacker(config.packstyle) + + def subscribe(self, topics: str | list[str] | tuple[str]) -> None: + """ + Subscribe to one or multiple topics + """ + if isinstance(topics, (list, tuple)): + # if subscribing to multiple topics, use a list of tuples + subscription_list = [(topic, self.config.qos) for topic in topics] + self.client.subscribe(subscription_list) + else: + self.client.subscribe(topics, self.config.qos) + if self.config.verbose: + print(f"Subscribed to: {topics}") + + def receive(self) -> tuple[str, dict[str, Any]]: + """ + Reads a message from mqtt and returns it + + Messages are saved in a stack, if no message is available, this function blocks. + + Returns: + tuple(topic: bytes, message: dict): the message received + """ + self._raise_if_thread_died() + mqtt_message = self._message_queue.get( + block=True, timeout=self.config.timeout_s + ) + + message_returned = self.unpack(mqtt_message.payload.decode()) + return (mqtt_message.topic, message_returned) + + # callback to add incoming messages onto stack + def _on_message(self, client, userdata, message) -> None: + self._message_queue.put(message) + + if self.config.verbose: + print(f"Topic: {message.topic}") + print(f"MQTT message: {message.payload.decode()}") diff --git a/heisskleber/network/__init__.py b/heisskleber/network/__init__.py deleted file mode 100644 index 31256a0..0000000 --- a/heisskleber/network/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .pubsub.factories import get_publisher, get_subscriber # noqa: F401 diff --git a/heisskleber/network/mqtt/__init__.py b/heisskleber/network/mqtt/__init__.py deleted file mode 100644 index e71a954..0000000 --- a/heisskleber/network/mqtt/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .config import MqttConf # noqa: F401 -from .publisher import MqttPublisher # noqa: F401 -from .subscriber import MqttSubscriber # noqa: F401 diff --git a/heisskleber/network/mqtt/msb_mqtt.py b/heisskleber/network/mqtt/msb_mqtt.py deleted file mode 100644 index 344add1..0000000 --- a/heisskleber/network/mqtt/msb_mqtt.py +++ /dev/null @@ -1,18 +0,0 @@ -from heisskleber.config import load_config -from heisskleber.network import get_publisher, get_subscriber - -from .config import MqttConf -from .forwarder import ZMQ_to_MQTT_Forwarder - - -def main(): - config = load_config(MqttConf(), "mqtt") - for topic in config.topics: - print(f"Subscribing to {topic}") - - zmq_sub = get_subscriber("zmq", list(config.topics)) - mqtt_pub = get_publisher("mqtt") - forwarder = ZMQ_to_MQTT_Forwarder(config, subscriber=zmq_sub, publisher=mqtt_pub) - - # Wait for zmq messages, publish as mqtt message - forwarder.zmq_to_mqtt_loop() diff --git a/heisskleber/network/mqtt/publisher.py b/heisskleber/network/mqtt/publisher.py deleted file mode 100644 index 4acf4a9..0000000 --- a/heisskleber/network/mqtt/publisher.py +++ /dev/null @@ -1,63 +0,0 @@ -from __future__ import annotations - -from heisskleber.config import load_config -from heisskleber.network.packer import get_packer -from heisskleber.network.pubsub.types import Publisher - -from .config import MqttConf -from .mqtt_base import MQTT_Base - - -class MqttPublisher(MQTT_Base, Publisher): - """ - MQTT publisher class. - Can be used everywhere that a flucto style publishing connection is required. - - Network message loop is handled in a separated thread. - """ - - def __init__(self, config: MqttConf): - super().__init__(config) - self.pack = get_packer(config.packstyle) - - def send(self, topic: str | bytes, data: dict): - """ - Takes python dictionary, serializes it according to the packstyle - and sends it to the broker. - - Publishing is asynchronous - """ - self._raise_if_thread_died() - if isinstance(topic, bytes): - topic = topic.decode() - - payload = self.pack(data) - self.client.publish( - topic, payload, qos=self.config.qos, retain=self.config.retain - ) - - -def get_mqtt_publisher() -> MqttPublisher: - """ - Generate mqtt publisher with configuration from yaml file, - falls back to default values if no config is found - """ - import os - - if "MSB_CONFIG_DIR" in os.environ: - print("loading mqtt config") - config = load_config(MqttConf(), "mqtt", read_commandline=False) - else: - print("using default mqtt config") - config = MqttConf() - return MqttPublisher(config) - - -def get_default_publisher() -> MqttPublisher: - """ - Generate mqtt publisher with configuration from yaml file, - falls back to default values if no config is found - - Deprecated, use get_mqtt_publisher() instead - """ - return get_mqtt_publisher() diff --git a/heisskleber/network/mqtt/subscriber.py b/heisskleber/network/mqtt/subscriber.py deleted file mode 100644 index e684bff..0000000 --- a/heisskleber/network/mqtt/subscriber.py +++ /dev/null @@ -1,104 +0,0 @@ -from __future__ import annotations - -from queue import SimpleQueue - -from heisskleber.config import load_config -from heisskleber.network.packer import get_unpacker -from heisskleber.network.pubsub.types import Subscriber - -from .config import MqttConf -from .mqtt_base import MQTT_Base - - -class MqttSubscriber(MQTT_Base, Subscriber): - """ - MQTT subscriber, wraps around ecplipse's paho mqtt client. - Network message loop is handled in a separated thread. - - Incoming messages are saved as a stack when not processed via the receive() function. - """ - - def __init__(self, topics, config: MqttConf): - super().__init__(config) - self._message_queue = SimpleQueue() - self.subscribe(topics) - self.client.on_message = self._on_message - self.unpack = get_unpacker(config.packstyle) - - def _subscribe_single_topic(self, topic: bytes | str): - if isinstance(topic, bytes): - topic = topic.decode() - if self.config.verbose: - print(f"Subscribed to: {topic}") - self.client.subscribe(topic, self.config.qos) - - def _subscribe_multiple_topics(self, topics: list[bytes] | list[str]): - topics = [ - topic.decode() if isinstance(topic, bytes) else topic for topic in topics - ] - subscription_list = [(topic, self.config.qos) for topic in topics] - if self.config.verbose: - print(f"Subscribed to: {topics}") - self.client.subscribe(subscription_list) - - def subscribe(self, topics): - """ - Subscribe to one or multiple topics - """ - # if subscribing to multiple topics, use a list of tuples - if isinstance(topics, (list, tuple)): - self._subscribe_multiple_topics(topics) - else: - self.client.subscribe(topics, self.config.qos) - - def receive(self) -> tuple[bytes, dict]: - """ - Reads a message from mqtt and returns it - - Messages are saved in a stack, if no message is available, this function blocks. - - Returns: - tuple(topic: bytes, message: dict): the message received - """ - self._raise_if_thread_died() - mqtt_message = self._message_queue.get( - block=True, timeout=self.config.timeout_s - ) - - topic = mqtt_message.topic.encode("utf-8") - message_returned = self.unpack(mqtt_message.payload.decode()) - return (topic, message_returned) - - # callback to add incoming messages onto stack - def _on_message(self, client, userdata, message): - self._message_queue.put(message) - - if self.config.verbose: - print(f"Topic: {message.topic}") - print(f"MQTT message: {message.payload.decode()}") - - -def get_mqtt_subscriber(topic: bytes | str) -> MqttSubscriber: - """ - Generate mqtt subscriber with configuration from yaml file, - falls back to default values if no config is found - """ - import os - - if "MSB_CONFIG_DIR" in os.environ: - print("loading mqtt config") - config = load_config(MqttConf(), "mqtt", read_commandline=False) - else: - print("using default mqtt config") - config = MqttConf() - return MqttSubscriber(topic, config) - - -def get_default_subscriber(topic: bytes | str) -> MqttSubscriber: - """ - Generate mqtt subscriber with configuration from yaml file, - falls back to default values if no config is found - - Deprecated, use get_mqtt_subscriber(topic) instead. - """ - return get_mqtt_subscriber(topic) diff --git a/heisskleber/network/pubsub/types.py b/heisskleber/network/pubsub/types.py deleted file mode 100644 index 8fe5184..0000000 --- a/heisskleber/network/pubsub/types.py +++ /dev/null @@ -1,31 +0,0 @@ -from __future__ import annotations - -from abc import ABC, abstractmethod - - -class Publisher(ABC): - """ - Publisher interface. - """ - - @abstractmethod - def send(self, topic: str | bytes, data: dict): - """ - Send data via the implemented output stream. - """ - pass - - -class Subscriber(ABC): - """ - Subscriber interface - """ - - @abstractmethod - def receive(self) -> tuple[bytes, dict]: - """ - Blocking function to receive data from the implemented input stream. - - Data is returned as a tuple of (topic, data). - """ - pass diff --git a/heisskleber/network/serial/__init__.py b/heisskleber/network/serial/__init__.py deleted file mode 100644 index 12cd737..0000000 --- a/heisskleber/network/serial/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .config import SerialConf # noqa: F401 -from .publisher import SerialPublisher # noqa: F401 -from .subscriber import SerialSubscriber # noqa: F401 diff --git a/heisskleber/network/types.py b/heisskleber/network/types.py deleted file mode 100644 index d701017..0000000 --- a/heisskleber/network/types.py +++ /dev/null @@ -1 +0,0 @@ -from .pubsub.types import Publisher, Subscriber # noqa: F401 diff --git a/heisskleber/network/zmq/__init__.py b/heisskleber/network/zmq/__init__.py deleted file mode 100644 index e66b53b..0000000 --- a/heisskleber/network/zmq/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .config import ZmqConf # noqa: F401 -from .publisher import ZmqPublisher # noqa: F401 -from .subscriber import ZmqSubscriber # noqa: F401 diff --git a/heisskleber/serial/__init__.py b/heisskleber/serial/__init__.py new file mode 100644 index 0000000..7921657 --- /dev/null +++ b/heisskleber/serial/__init__.py @@ -0,0 +1,5 @@ +from .config import SerialConf +from .publisher import SerialPublisher +from .subscriber import SerialSubscriber + +__all__ = ["SerialConf", "SerialPublisher", "SerialSubscriber"] diff --git a/heisskleber/network/serial/config.py b/heisskleber/serial/config.py similarity index 100% rename from heisskleber/network/serial/config.py rename to heisskleber/serial/config.py diff --git a/heisskleber/network/serial/forwarder.py b/heisskleber/serial/forwarder.py similarity index 75% rename from heisskleber/network/serial/forwarder.py rename to heisskleber/serial/forwarder.py index 098b661..9b2e538 100644 --- a/heisskleber/network/serial/forwarder.py +++ b/heisskleber/serial/forwarder.py @@ -1,10 +1,10 @@ -from heisskleber.network.types import Subscriber +from heisskleber.core.types import Subscriber from .publisher import SerialPublisher class SerialForwarder: - def __init__(self, subscriber: Subscriber, publisher: SerialPublisher): + def __init__(self, subscriber: Subscriber, publisher: SerialPublisher) -> None: self.sub = subscriber self.pub = publisher @@ -12,19 +12,19 @@ class SerialForwarder: Wait for message and forward """ - def forward_message(self): + def forward_message(self) -> None: # collected = {} # for sub in self.sub: # topic, data = sub.receive() # collected.update(data) _, collected = self.sub.receive() - self.pub.send(collected) + self.pub.send("", collected) """ Enter loop and continuously forward messages """ - def sub_pub_loop(self): + def sub_pub_loop(self) -> None: while True: self.forward_message() diff --git a/heisskleber/network/serial/publisher.py b/heisskleber/serial/publisher.py similarity index 71% rename from heisskleber/network/serial/publisher.py rename to heisskleber/serial/publisher.py index 7f67138..646f829 100644 --- a/heisskleber/network/serial/publisher.py +++ b/heisskleber/serial/publisher.py @@ -1,11 +1,11 @@ from __future__ import annotations -from types import FunctionType +from typing import Callable, Optional import serial -from heisskleber.network.packer import get_packer -from heisskleber.network.pubsub.types import Publisher +from heisskleber.core.packer import get_packer +from heisskleber.core.types import Publisher from .config import SerialConf @@ -23,12 +23,16 @@ class SerialPublisher(Publisher): Function to translate from a dict to a serialized string. """ - def __init__(self, config: SerialConf, pack_func: FunctionType | None = None): + def __init__( + self, + config: SerialConf, + pack_func: Optional[Callable] = None, # noqa: UP007 + ): self.config = config - self.packer = pack_func if pack_func else get_packer("serial") + self.pack = pack_func if pack_func else get_packer("serial") self._connect() - def _connect(self): + def _connect(self) -> None: self.serial: serial.Serial = serial.Serial( port=self.config.port, baudrate=self.config.baudrate, @@ -38,25 +42,23 @@ class SerialPublisher(Publisher): ) print(f"Successfully connected to serial device at port {self.config.port}") - def send(self, message: object): + def send(self, topic, message: dict) -> None: """ Takes python dictionary, serializes it according to the packstyle and sends it to the broker. - Please note that this does not adhere to the interface, as there is no topic. - Parameters ---------- - message : object + message : dict object to be serialized and sent via the serial connection. Usually a dict. """ - payload = self.packer(message) + payload = self.pack(message) self.serial.write(payload.encode(self.config.encoding)) self.serial.flush() if self.config.verbose: - print(payload) + print(f"{topic}: {payload}") - def __del__(self): + def __del__(self) -> None: if not hasattr(self, "serial"): return if not self.serial.is_open: diff --git a/heisskleber/network/serial/subscriber.py b/heisskleber/serial/subscriber.py similarity index 78% rename from heisskleber/network/serial/subscriber.py rename to heisskleber/serial/subscriber.py index d6c203f..6c7821b 100644 --- a/heisskleber/network/serial/subscriber.py +++ b/heisskleber/serial/subscriber.py @@ -1,10 +1,11 @@ from __future__ import annotations +from collections.abc import Generator from typing import Callable, Optional import serial -from heisskleber.network.pubsub.types import Subscriber +from heisskleber.core.types import Subscriber from .config import SerialConf @@ -45,7 +46,7 @@ class SerialSubscriber(Subscriber): ) print(f"Successfully connected to serial device at port {self.config.port}") - def receive(self) -> dict: + def receive(self) -> tuple[str, dict]: """ Wait for data to arrive on the serial port and return it. @@ -62,29 +63,30 @@ class SerialSubscriber(Subscriber): # port is a placeholder for topic return self.config.port, payload - def read_serial_port(self) -> str: + def read_serial_port(self) -> Generator[str, None, None]: + """ + Generator function reading from the serial port. + + Returns + ------- + :return: Generator[str, None, None] + Generator yielding strings read from the serial port + """ buffer = "" while True: try: - buffer = self.serial.readline().decode() + buffer = self.serial.readline().decode(self.config.encoding, "ignore") yield buffer except UnicodeError as e: if self.config.verbose: - print(f"Could not decode: {message}") + print(f"Could not decode: {buffer!r}") print(e) continue - def __del__(self): + def __del__(self) -> None: if not hasattr(self, "serial"): return if not self.serial.is_open: return self.serial.flush() self.serial.close() - - -if __name__ == "__main__": - config = SerialConf() - serial_reader = SerialSubscriber(config) - for message in serial_reader.receive(): - print(message) diff --git a/heisskleber/network/udp/__init__.py b/heisskleber/udp/__init__.py similarity index 100% rename from heisskleber/network/udp/__init__.py rename to heisskleber/udp/__init__.py diff --git a/heisskleber/network/udp/config.py b/heisskleber/udp/config.py similarity index 100% rename from heisskleber/network/udp/config.py rename to heisskleber/udp/config.py diff --git a/heisskleber/network/udp/publisher.py b/heisskleber/udp/publisher.py similarity index 89% rename from heisskleber/network/udp/publisher.py rename to heisskleber/udp/publisher.py index 5499a69..0929d0c 100644 --- a/heisskleber/network/udp/publisher.py +++ b/heisskleber/udp/publisher.py @@ -1,8 +1,8 @@ import socket -from heisskleber.network.packer import get_packer -from heisskleber.network.pubsub.types import Publisher -from heisskleber.network.udp.config import UDPConf +from heisskleber.core.packer import get_packer +from heisskleber.core.types import Publisher +from heisskleber.udp.config import UDPConf class UDP_Publisher(Publisher): diff --git a/heisskleber/network/udp/subscriber.py b/heisskleber/udp/subscriber.py similarity index 83% rename from heisskleber/network/udp/subscriber.py rename to heisskleber/udp/subscriber.py index ffaab08..566110d 100644 --- a/heisskleber/network/udp/subscriber.py +++ b/heisskleber/udp/subscriber.py @@ -1,8 +1,8 @@ import socket -from heisskleber.network.packer import get_unpacker -from heisskleber.network.pubsub.types import Subscriber -from heisskleber.network.udp.config import UDPConf +from heisskleber.core.packer import get_unpacker +from heisskleber.core.types import Subscriber +from heisskleber.udp.config import UDPConf class UDP_Subscriber(Subscriber): diff --git a/heisskleber/zmq/__init__.py b/heisskleber/zmq/__init__.py new file mode 100644 index 0000000..bbdfd1e --- /dev/null +++ b/heisskleber/zmq/__init__.py @@ -0,0 +1,5 @@ +from .config import ZmqConf +from .publisher import ZmqPublisher +from .subscriber import ZmqSubscriber + +__all__ = ["ZmqConf", "ZmqPublisher", "ZmqSubscriber"] diff --git a/heisskleber/network/zmq/config.py b/heisskleber/zmq/config.py similarity index 100% rename from heisskleber/network/zmq/config.py rename to heisskleber/zmq/config.py diff --git a/heisskleber/network/zmq/publisher.py b/heisskleber/zmq/publisher.py similarity index 68% rename from heisskleber/network/zmq/publisher.py rename to heisskleber/zmq/publisher.py index 3d50841..6f29130 100644 --- a/heisskleber/network/zmq/publisher.py +++ b/heisskleber/zmq/publisher.py @@ -1,9 +1,10 @@ import sys +from typing import Any import zmq -from heisskleber.network.packer import get_packer -from heisskleber.network.pubsub.types import Publisher +from heisskleber.core.packer import get_packer +from heisskleber.core.types import Publisher from .config import ZmqConf @@ -25,9 +26,9 @@ class ZmqPublisher(Publisher): print(f"failed to bind to zeromq socket: {e}") sys.exit(-1) - def send(self, topic: str, data: dict) -> None: - data = self.pack(data) - self.socket.send_multipart([topic, data.encode("utf-8")]) + def send(self, topic: str, data: dict[str, Any]) -> None: + payload = self.pack(data) + self.socket.send_multipart([topic, payload.encode("utf-8")]) def __del__(self): self.socket.close() diff --git a/heisskleber/network/zmq/subscriber.py b/heisskleber/zmq/subscriber.py similarity index 93% rename from heisskleber/network/zmq/subscriber.py rename to heisskleber/zmq/subscriber.py index 245a9cb..124c2b3 100644 --- a/heisskleber/network/zmq/subscriber.py +++ b/heisskleber/zmq/subscriber.py @@ -4,8 +4,8 @@ import sys import zmq -from heisskleber.network.packer import get_unpacker -from heisskleber.network.pubsub.types import Subscriber +from heisskleber.core.packer import get_unpacker +from heisskleber.core.types import Subscriber from .config import ZmqConf diff --git a/pyproject.toml b/pyproject.toml index 7bb5834..5cde37d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -85,9 +85,10 @@ select = [ "TRY", # tryceratops ] ignore = [ - "E501", # LineTooLong - "E731", # DoNotAssignLambda - "A001", # + "E501", # LineTooLong + "E731", # DoNotAssignLambda + "A001", # + "PGH003", # Use specific rules when ignoring type issues ] [tool.ruff.per-file-ignores] diff --git a/tests/test_import.py b/tests/test_import.py index c440a99..438fdf2 100644 --- a/tests/test_import.py +++ b/tests/test_import.py @@ -1,31 +1,26 @@ def test_import_mqtt(): - from heisskleber.network.mqtt import ( - MqttConf, - MqttPublisher, - MqttSubscriber, - ) + import heisskleber + from heisskleber.mqtt import MqttConf, MqttPublisher, MqttSubscriber + + assert heisskleber.__all__ == [ + "get_publisher", + "get_subscriber", + "Publisher", + "Subscriber", + ] def test_import_zmq(): - from heisskleber.network.zmq import ( - ZmqConf, - ZmqPublisher, - ZmqSubscriber, - ) + from heisskleber.zmq import ZmqConf, ZmqPublisher, ZmqSubscriber def test_import_serial(): - from heisskleber.network.serial import ( - SerialConf, - SerialPublisher, - SerialSubscriber, - ) + from heisskleber.serial import SerialConf, SerialPublisher, SerialSubscriber def test_import_utils(): - from heisskleber.network import get_publisher, get_subscriber - from heisskleber.network.types import Publisher, Subscriber + from heisskleber import Publisher, Subscriber, get_publisher, get_subscriber def test_import_config(): - pass + from heisskleber.config import BaseConf, load_config diff --git a/tests/test_mqtt.py b/tests/test_mqtt.py new file mode 100644 index 0000000..03ef3d7 --- /dev/null +++ b/tests/test_mqtt.py @@ -0,0 +1,162 @@ +import json +from queue import SimpleQueue +from unittest.mock import call, patch + +import pytest +from paho.mqtt.client import MQTTMessage + +from heisskleber.mqtt.config import MqttConf +from heisskleber.mqtt.mqtt_base import MqttBase +from heisskleber.mqtt.subscriber import MqttSubscriber + + +# Mock configuration for MQTT_Base +@pytest.fixture +def mock_mqtt_conf() -> MqttConf: + return MqttConf( + broker="localhost", + port=1883, + user="user", + password="passwd", # noqa: S106 + ssl=False, + verbose=False, + qos=1, + ) + + +# Mock the paho mqtt client +@pytest.fixture +def mock_mqtt_client(): + with patch("heisskleber.mqtt.mqtt_base.mqtt_client", autospec=True) as mock: + yield mock + + +@pytest.fixture +def mock_queue(): + with patch("heisskleber.mqtt.subscriber.SimpleQueue", spec=SimpleQueue) as mock: + yield mock + + +def test_mqtt_base_intialization(mock_mqtt_client, mock_mqtt_conf): + """Test that the intialization of the mqtt client is as expected.""" + base = MqttBase(config=mock_mqtt_conf) + + mock_mqtt_client.assert_called_once() + mock_mqtt_client.return_value.loop_start.assert_called_once() + mock_client_instance = mock_mqtt_client.return_value + mock_client_instance.username_pw_set.assert_called_with( + mock_mqtt_conf.user, mock_mqtt_conf.password + ) + mock_client_instance.connect.assert_called_with( + mock_mqtt_conf.broker, mock_mqtt_conf.port + ) + assert base.client.on_connect == base._on_connect + assert base.client.on_disconnect == base._on_disconnect + assert base.client.on_publish == base._on_publish + assert base.client.on_message == base._on_message + + +def test_mqtt_base_on_connect(mock_mqtt_client, mock_mqtt_conf, capsys): + base = MqttBase(config=mock_mqtt_conf) + base._on_connect(None, None, {}, 0) + captured = capsys.readouterr() + assert ( + f"MQTT node connected to {mock_mqtt_conf.broker}:{mock_mqtt_conf.port}" + in captured.out + ) + + +def test_mqtt_base_on_disconnect_with_error(mock_mqtt_client, mock_mqtt_conf, capsys): + """Assert that the mqtt client shuts down when disconnect callback is received.""" + base = MqttBase(config=mock_mqtt_conf) + with pytest.raises(SystemExit): + base._on_disconnect(None, None, 1) + captured = capsys.readouterr() + assert "Killing this service" in captured.out + print(captured.out) + + +def test_mqtt_subscribes_single_topic(mock_mqtt_client, mock_mqtt_conf): + """Test that the mqtt client subscribes to a single topic.""" + _ = MqttSubscriber(topics="singleTopic", config=mock_mqtt_conf) + + actual_calls = mock_mqtt_client.return_value.subscribe.call_args_list + assert actual_calls == [call("singleTopic", mock_mqtt_conf.qos)] + + +def test_mqtt_subscribes_multiple_topics(mock_mqtt_client, mock_mqtt_conf): + """Test that the mqtt client subscribes to multiple topics passed as list. + + I would love to do this via parametrization, but the call argument is built differently for single size lists and longer lists. + """ + _ = MqttSubscriber(topics=["multiple1", "multiple2"], config=mock_mqtt_conf) + + actual_calls = mock_mqtt_client.return_value.subscribe.call_args_list + assert actual_calls == [ + call([("multiple1", mock_mqtt_conf.qos), ("multiple2", mock_mqtt_conf.qos)]), + ] + + +def test_mqtt_subscribes_multiple_topics_tuple(mock_mqtt_client, mock_mqtt_conf): + """Test that the mqtt client subscribes to multiple topics passed as tuple.""" + _ = MqttSubscriber(topics=("multiple1", "multiple2"), config=mock_mqtt_conf) + + actual_calls = mock_mqtt_client.return_value.subscribe.call_args_list + assert actual_calls == [ + call([("multiple1", mock_mqtt_conf.qos), ("multiple2", mock_mqtt_conf.qos)]), + ] + + +def create_fake_mqtt_message(topic: bytes, payload: bytes) -> MQTTMessage: + msg = MQTTMessage() + msg.topic = topic + msg.payload = payload + return msg + + +def test_receive_with_message(mock_mqtt_conf: MqttConf, mock_mqtt_client, mock_queue): + """Test the mqtt receive function with fake MQTT messages.""" + topic = b"test/topic" + payload = json.dumps({"key": "value"}).encode() + fake_message = create_fake_mqtt_message(topic, payload) + + mock_queue.return_value.get.side_effect = [fake_message] + subscriber = MqttSubscriber(topics=[topic], config=mock_mqtt_conf) + + received_topic, received_payload = subscriber.receive() + + assert received_topic == "test/topic" + assert received_payload == {"key": "value"} + + +def test_message_is_put_into_queue( + mock_mqtt_conf: MqttConf, mock_mqtt_client, mock_queue +): + """Test that values a put into a queue when on_message callback is called.""" + topic = b"test/topic" + payload = json.dumps({"key": "value"}).encode() + fake_message = create_fake_mqtt_message(topic, payload) + + mock_queue.return_value.get.side_effect = [fake_message] + subscriber = MqttSubscriber(topics=[topic], config=mock_mqtt_conf) + + subscriber._on_message(None, None, fake_message) + + mock_queue.return_value.put.assert_called_once_with(fake_message) + + +def test_message_is_put_into_queue_with_actual_queue(mock_mqtt_conf, mock_mqtt_client): + """Test that the buffering via queue works as expected.""" + topic = b"test/topic" + payload = json.dumps({"key": "value"}).encode() + fake_message = create_fake_mqtt_message(topic, payload) + + # mock_queue.return_value.get.side_effect = [fake_message] + subscriber = MqttSubscriber(topics=[topic], config=mock_mqtt_conf) + + subscriber._on_message(None, None, fake_message) + + topic, return_dict = subscriber.receive() + + assert topic == "test/topic" + assert return_dict == {"key": "value"} diff --git a/tests/test_packer.py b/tests/test_packer.py new file mode 100644 index 0000000..62d5dcc --- /dev/null +++ b/tests/test_packer.py @@ -0,0 +1,35 @@ +import json +import pickle +from typing import Any + +import pytest + +from heisskleber.core.packer import get_packer, get_unpacker, serialpacker + + +def test_get_packer() -> None: + assert get_packer("json") == json.dumps + assert get_packer("pickle") == pickle.dumps + assert get_packer("default") == json.dumps + assert get_packer("foobar") == json.dumps + assert get_packer("serial") == serialpacker + + +def test_get_unpacker() -> None: + assert get_unpacker("json") == json.loads + assert get_unpacker("pickle") == pickle.loads + assert get_unpacker("default") == json.loads + assert get_unpacker("foobar") == json.loads + + +@pytest.mark.parametrize( + "message,expected", + [ + ({"hi": 1, "da": 2, "nei": 3}, "1,2,3"), + ({"er": 1, "ma": "ga", "gerd": 3, "jo": 4}, "1,ga,3,4"), + ({"": 1, "ho": 0.0, "lee": 0.1, "shit": 1_000}, "1,0.0,0.1,1000"), + ({"be": 1e6, "li": 1_000}, "1000000.0,1000"), + ], +) +def test_serial_packer_functionality(message: dict[str, Any], expected: str) -> None: + assert serialpacker(message) == expected diff --git a/tests/test_serial.py b/tests/test_serial.py new file mode 100644 index 0000000..812e4c0 --- /dev/null +++ b/tests/test_serial.py @@ -0,0 +1,115 @@ +from unittest.mock import Mock, patch + +import pytest +import serial + +from heisskleber.core.packer import serialpacker +from heisskleber.serial.config import SerialConf +from heisskleber.serial.publisher import SerialPublisher +from heisskleber.serial.subscriber import SerialSubscriber + + +@pytest.fixture +def serial_conf(): + return SerialConf(port="/dev/test", baudrate=9600, bytesize=8, verbose=False) + + +@pytest.fixture +def mock_serial_device_subscriber(): + with patch("heisskleber.serial.subscriber.serial.Serial") as mock: + yield mock + + +@pytest.fixture +def mock_serial_device_publisher(): + with patch("heisskleber.serial.publisher.serial.Serial") as mock: + yield mock + + +def test_serial_subscriber_initialization(mock_serial_device_subscriber, serial_conf): + """Test that the SerialSubscriber class initializes correctly. + Mocks the serial.Serial class to avoid opening a serial port.""" + _ = SerialSubscriber(topics="", config=serial_conf) + mock_serial_device_subscriber.assert_called_with( + port=serial_conf.port, + baudrate=serial_conf.baudrate, + bytesize=serial_conf.bytesize, + parity=serial.PARITY_NONE, + stopbits=serial.STOPBITS_ONE, + ) + + +def test_serial_subscriber_receive(mock_serial_device_subscriber, serial_conf): + """Test that the SerialSubscriber class calls readline and unpack as expected.""" + subscriber = SerialSubscriber(topics="", config=serial_conf) + + # Set up the readline return value + mock_serial_instance = mock_serial_device_subscriber.return_value + mock_serial_instance.readline.return_value = b"test message\n" + + # Set up the unpack function to convert message to dict + unpack_func = Mock(return_value={"data": "test message"}) + subscriber.unpack = unpack_func + + # Call the receive method and assert it behaves as expected + _, payload = subscriber.receive() + + # Was readline called? + mock_serial_instance.readline.assert_called_once() + + # Was unpack called? + assert payload == {"data": "test message"} + unpack_func.assert_called_once_with("test message\n") + + +def test_serial_subscriber_converts_bytes_to_str(): + """Test that the SerialSubscriber class converts bytes to str as expected.""" + with patch("heisskleber.serial.subscriber.serial.Serial") as mock_serial: + subscriber = SerialSubscriber( + topics="", config=SerialConf(), unpack_func=lambda x: x + ) + + # Set the readline method to raise UnicodeError + mock_serial_instance = mock_serial.return_value + mock_serial_instance.readline.side_effect = [b"test message", b"test\x86more"] + + _, payload = subscriber.receive() + assert isinstance(payload, str) + assert payload == "test message" + + # Assert that none-unicode is skipped + _, payload = subscriber.receive() + assert payload == "testmore" + + +def test_serial_publisher_initialization(mock_serial_device_publisher, serial_conf): + """Test that the SerialPublisher class initializes correctly. + Mocks the serial.Serial class to avoid opening a serial port.""" + publisher = SerialPublisher(config=serial_conf) + mock_serial_device_publisher.assert_called_with( + port=serial_conf.port, + baudrate=serial_conf.baudrate, + bytesize=serial_conf.bytesize, + parity=serial.PARITY_NONE, + stopbits=serial.STOPBITS_ONE, + ) + assert publisher.serial + + +def test_serial_publisher_send(mock_serial_device_publisher, serial_conf): + """Test that the SerialPublisher class calls write and pack as expected.""" + publisher = SerialPublisher(config=serial_conf) + + # Set up the readline return value + mock_serial_instance = mock_serial_device_publisher.return_value + mock_serial_instance.readline.return_value = b"test message\n" + + # Set up the pack function to convert dict to comma separated string of values + publisher.pack = serialpacker + + # Call the receive method and assert it behaves as expected + publisher.send("test", {"data": "test message", "more_data": "more message"}) + + # Was write called with encoded payload? + mock_serial_instance.write.assert_called_once_with(b"test message,more message") + mock_serial_instance.flush.assert_called_once() diff --git a/tests/test_version.py b/tests/test_version.py index 411bd58..f08109c 100644 --- a/tests/test_version.py +++ b/tests/test_version.py @@ -4,4 +4,4 @@ import heisskleber def test_heisskleber_version() -> None: """Test that the glue version is correct.""" - assert heisskleber.__version__ == "0.1.0" + assert heisskleber.__version__ == "0.2.0"