mirror of
https://github.com/OMGeeky/flucto-heisskleber.git
synced 2026-02-23 15:38:33 +01:00
WIP: Test, refactor
This commit is contained in:
2
Makefile
2
Makefile
@@ -47,7 +47,7 @@ docs-test: ## Test if documentation can be built without warnings or errors
|
||||
|
||||
.PHONY: docs
|
||||
docs: ## Build and serve the documentation
|
||||
@poetry run mkdocs serve
|
||||
@poetry run python3 -m sphinx docs docs/_build -b html
|
||||
|
||||
.PHONY: help
|
||||
help:
|
||||
|
||||
@@ -3,22 +3,21 @@
|
||||
## Network
|
||||
|
||||
```{eval-rst}
|
||||
.. automodule:: heisskleber.network
|
||||
.. automodule:: heisskleber
|
||||
:members:
|
||||
.. automodule:: heisskleber.network.mqtt
|
||||
.. autoclass:: MqttPublisher
|
||||
.. autoclass:: MqttSubscriber
|
||||
.. automodule:: heisskleber.network.zmq
|
||||
.. automodule:: heisskleber.mqtt
|
||||
:members:
|
||||
.. automodule:: heisskleber.zmq
|
||||
:members:
|
||||
.. automodule:: heisskleber.serial
|
||||
:members:
|
||||
.. automodule:: heisskleber.config
|
||||
:members:
|
||||
.. autoclass:: ZmqPublisher
|
||||
.. autoclass:: ZmqSubscriber
|
||||
```
|
||||
|
||||
### Broker
|
||||
|
||||
```{eval-rst}
|
||||
.. automodule:: heisskleber.broker
|
||||
:members:
|
||||
```
|
||||
|
||||
## Config
|
||||
@@ -34,7 +33,7 @@
|
||||
Configs are extended dataclasses, which inherit from the BaseConf class.
|
||||
```{eval-rst}
|
||||
.. autoclass:: heisskleber.config.BaseConf
|
||||
.. autoclass:: heisskleber.network.mqtt.config.MqttConf
|
||||
.. autoclass:: heisskleber.network.zmq.config.ZmqConf
|
||||
.. autoclass:: heisskleber.network.serial.config.SerialConf
|
||||
.. autoclass:: heisskleber.mqtt.config.MqttConf
|
||||
.. autoclass:: heisskleber.zmq.config.ZmqConf
|
||||
.. autoclass:: heisskleber.serial.config.SerialConf
|
||||
```
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""Heisskleber."""
|
||||
from .network import get_publisher, get_subscriber
|
||||
from .core.factories import get_publisher, get_subscriber
|
||||
from .core.types import Publisher, Subscriber
|
||||
|
||||
__all__ = ["get_publisher", "get_subscriber"]
|
||||
__version__ = "0.1.0"
|
||||
__all__ = ["get_publisher", "get_subscriber", "Publisher", "Subscriber"]
|
||||
__version__ = "0.2.0"
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
from .msb_broker import msb_broker as start_zmq_broker
|
||||
from .msb_broker import zmq_broker as start_zmq_broker
|
||||
|
||||
__all__ = ["start_zmq_broker"]
|
||||
|
||||
@@ -2,9 +2,10 @@ import signal
|
||||
import sys
|
||||
|
||||
import zmq
|
||||
from zmq import Socket
|
||||
|
||||
from heisskleber.config import load_config
|
||||
from heisskleber.network.zmq.config import ZmqConf as BrokerConf
|
||||
from heisskleber.zmq.config import ZmqConf as BrokerConf
|
||||
|
||||
|
||||
def signal_handler(sig, frame):
|
||||
@@ -16,29 +17,31 @@ class BrokerBindingError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def bind_socket(socket, address, socket_type, verbose=False):
|
||||
def bind_socket(socket: Socket, address: str, socket_type: str, verbose=False) -> None:
|
||||
"""Bind a ZMQ socket and handle errors."""
|
||||
if verbose:
|
||||
print(f"creating {socket_type} socket")
|
||||
try:
|
||||
socket.bind(address)
|
||||
except Exception as err:
|
||||
raise BrokerBindingError(f"failed to bind to {socket_type}: {err}") from err
|
||||
error_message = f"failed to bind to {socket_type}: {err}"
|
||||
raise BrokerBindingError(error_message) from err
|
||||
if verbose:
|
||||
print(f"successfully bound to {socket_type} socket: {address}")
|
||||
|
||||
|
||||
def create_proxy(xpub, xsub, verbose=False):
|
||||
def create_proxy(xpub: Socket, xsub: Socket, verbose=False) -> None:
|
||||
"""Create a ZMQ proxy to connect XPUB and XSUB sockets."""
|
||||
if verbose:
|
||||
print("creating proxy")
|
||||
try:
|
||||
zmq.proxy(xpub, xsub)
|
||||
except Exception as err:
|
||||
raise BrokerBindingError(f"failed to create proxy: {err}") from err
|
||||
error_message = f"failed to create proxy: {err}"
|
||||
raise BrokerBindingError(error_message) from err
|
||||
|
||||
|
||||
def msb_broker(config: BrokerConf) -> None:
|
||||
def zmq_broker(config: BrokerConf) -> None:
|
||||
"""Start a zmq broker.
|
||||
|
||||
Binds to a publisher and subscriber port, allowing many to many connections."""
|
||||
@@ -60,4 +63,4 @@ def main() -> None:
|
||||
"""Start a zmq broker, with a user specified configuration."""
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
broker_config = load_config(BrokerConf(), "zmq")
|
||||
msb_broker(broker_config)
|
||||
zmq_broker(broker_config)
|
||||
@@ -1,4 +1,4 @@
|
||||
from heisskleber.config.MSBConfig import BaseConf
|
||||
from heisskleber.config.parse import load_config
|
||||
from .config import BaseConf
|
||||
from .parse import load_config
|
||||
|
||||
__all__ = ["load_config", "BaseConf"]
|
||||
|
||||
@@ -2,18 +2,7 @@ import argparse
|
||||
|
||||
|
||||
class KeyValue(argparse.Action):
|
||||
# Constructor calling
|
||||
"""
|
||||
def __call__( self , parser, namespace, values : list, option_string = None):
|
||||
setattr(namespace, self.dest, dict())
|
||||
for value in values:
|
||||
# split it into key and value
|
||||
key, value = value.split('=')
|
||||
# assign into dictionary
|
||||
getattr(namespace, self.dest)[key] = value
|
||||
"""
|
||||
|
||||
def __call__(self, parser, args, values, option_string=None):
|
||||
def __call__(self, parser, args, values, option_string=None) -> None:
|
||||
try:
|
||||
params = dict(x.split("=") for x in values)
|
||||
except ValueError as ex:
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import socket
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -12,18 +13,18 @@ class BaseConf:
|
||||
verbose: bool = False
|
||||
print_stdout: bool = False
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
def __setitem__(self, key: str, value: Any) -> None:
|
||||
if hasattr(self, key):
|
||||
self.__setattr__(key, value)
|
||||
else:
|
||||
warnings.warn(UserWarning(f"no such class member: {key}"), stacklevel=2)
|
||||
|
||||
def __getitem__(self, key):
|
||||
def __getitem__(self, key: str) -> Any:
|
||||
if hasattr(self, key):
|
||||
return getattr(self, key)
|
||||
else:
|
||||
warnings.warn(UserWarning(f"no such class member: {key}"), stacklevel=2)
|
||||
|
||||
@property
|
||||
def serial_number(self):
|
||||
def serial_number(self) -> str:
|
||||
return socket.gethostname().upper()
|
||||
@@ -8,7 +8,9 @@ import yaml
|
||||
from heisskleber.config import BaseConf
|
||||
from heisskleber.config.cmdline import get_cmdline
|
||||
|
||||
ConfigType = TypeVar("ConfigType", bound=BaseConf)
|
||||
ConfigType = TypeVar(
|
||||
"ConfigType", bound=BaseConf
|
||||
) # https://stackoverflow.com/a/46227137 , https://docs.python.org/3/library/typing.html#typing.TypeVar
|
||||
|
||||
|
||||
def get_msb_config_filepath(config_filename: str = "heisskleber.conf") -> str:
|
||||
@@ -17,10 +19,10 @@ def get_msb_config_filepath(config_filename: str = "heisskleber.conf") -> str:
|
||||
config_filepath = os.path.join(os.environ["MSB_CONFIG_DIR"], config_subpath)
|
||||
except Exception as e:
|
||||
print(f"could no get MSB_CONFIG from PATH: {e}")
|
||||
sys.exit() # TODO use 1 or the error str as exit value
|
||||
sys.exit(1)
|
||||
if not os.path.isfile(config_filepath):
|
||||
print(f"not a file: {config_filepath}!")
|
||||
sys.exit()
|
||||
sys.exit(1)
|
||||
return config_filepath
|
||||
|
||||
|
||||
@@ -33,16 +35,12 @@ def update_config(config: ConfigType, config_dict: dict) -> ConfigType:
|
||||
for config_key, config_value in config_dict.items():
|
||||
# get expected type of element from config_object:
|
||||
if not hasattr(config, config_key):
|
||||
warnings.warn(
|
||||
f"no such configuration parameter: {config_key}, skipping", stacklevel=2
|
||||
)
|
||||
error_msg = f"no such configuration parameter: {config_key}, skipping"
|
||||
warnings.warn(error_msg, stacklevel=2)
|
||||
continue
|
||||
cast_func = type(config[config_key])
|
||||
try:
|
||||
if config_key == "topic":
|
||||
config[config_key] = config_value.encode("utf-8")
|
||||
else:
|
||||
config[config_key] = cast_func(config_value)
|
||||
config[config_key] = cast_func(config_value)
|
||||
except Exception as e:
|
||||
print(
|
||||
f"failed to cast {config_value} to {type(config[config_key])}: {e}. skipping"
|
||||
@@ -51,11 +49,6 @@ def update_config(config: ConfigType, config_dict: dict) -> ConfigType:
|
||||
return config
|
||||
|
||||
|
||||
ConfigType = TypeVar(
|
||||
"ConfigType", bound=BaseConf
|
||||
) # https://stackoverflow.com/a/46227137 , https://docs.python.org/3/library/typing.html#typing.TypeVar
|
||||
|
||||
|
||||
def load_config(
|
||||
config: ConfigType, config_filename: str, read_commandline: bool = True
|
||||
) -> ConfigType:
|
||||
|
||||
@@ -1,40 +0,0 @@
|
||||
import sys
|
||||
|
||||
import zmq
|
||||
|
||||
|
||||
def open_zmq_sub_socket(connect_to: str, topic=b""):
|
||||
ctx = zmq.Context()
|
||||
zmq_socket = ctx.socket(zmq.SUB)
|
||||
try:
|
||||
zmq_socket.connect(connect_to)
|
||||
except Exception as e:
|
||||
print(f"failed to bind to zeromq socket: {e}")
|
||||
sys.exit(-1)
|
||||
zmq_socket.setsockopt(zmq.SUBSCRIBE, topic)
|
||||
return zmq_socket
|
||||
|
||||
|
||||
def open_zmq_pub_socket(connect_to: str):
|
||||
ctx = zmq.Context()
|
||||
zmq_socket = ctx.socket(zmq.PUB)
|
||||
try:
|
||||
zmq_socket.connect(connect_to)
|
||||
except Exception as e:
|
||||
print(f"failed to bind to zeromq socket: {e}")
|
||||
sys.exit(-1)
|
||||
return zmq_socket
|
||||
|
||||
|
||||
def get_zmq_xpub_socketstring(msb_config: dict) -> str:
|
||||
zmq_config = msb_config["zeromq"]
|
||||
return (
|
||||
f"{zmq_config['protocol']}://{zmq_config['address']}:{zmq_config['xpub_port']}"
|
||||
)
|
||||
|
||||
|
||||
def get_zmq_xsub_socketstring(msb_config: dict) -> str:
|
||||
zmq_config = msb_config["zeromq"]
|
||||
return (
|
||||
f"{zmq_config['protocol']}://{zmq_config['address']}:{zmq_config['xsub_port']}"
|
||||
)
|
||||
@@ -1,24 +1,25 @@
|
||||
import os
|
||||
|
||||
from heisskleber.config import load_config
|
||||
from heisskleber.network.mqtt import MqttConf, MqttPublisher, MqttSubscriber
|
||||
from heisskleber.network.serial import SerialConf, SerialPublisher, SerialSubscriber
|
||||
from heisskleber.network.zmq import ZmqConf, ZmqPublisher, ZmqSubscriber
|
||||
from heisskleber.config import BaseConf, load_config
|
||||
from heisskleber.core.types import Publisher, Subscriber
|
||||
from heisskleber.mqtt import MqttConf, MqttPublisher, MqttSubscriber
|
||||
from heisskleber.serial import SerialConf, SerialPublisher, SerialSubscriber
|
||||
from heisskleber.zmq import ZmqConf, ZmqPublisher, ZmqSubscriber
|
||||
|
||||
_registered_publishers = {
|
||||
_registered_publishers: dict[str, tuple[type[Publisher], type[BaseConf]]] = {
|
||||
"zmq": (ZmqPublisher, ZmqConf),
|
||||
"mqtt": (MqttPublisher, MqttConf),
|
||||
"serial": (SerialPublisher, SerialConf),
|
||||
}
|
||||
|
||||
_registered_subscribers = {
|
||||
_registered_subscribers: dict[str, tuple[type[Subscriber], type[BaseConf]]] = {
|
||||
"zmq": (ZmqSubscriber, ZmqConf),
|
||||
"mqtt": (MqttSubscriber, MqttConf),
|
||||
"serial": (SerialSubscriber, SerialConf),
|
||||
}
|
||||
|
||||
|
||||
def get_publisher(name: str):
|
||||
def get_publisher(name: str) -> Publisher:
|
||||
if name not in _registered_publishers:
|
||||
error_message = f"{name} is not a registered Publisher."
|
||||
raise KeyError(error_message)
|
||||
@@ -35,7 +36,7 @@ def get_publisher(name: str):
|
||||
return pub_cls(config)
|
||||
|
||||
|
||||
def get_subscriber(name: str, topic):
|
||||
def get_subscriber(name: str, topic: str | list[str]) -> Subscriber:
|
||||
if name not in _registered_publishers:
|
||||
error_message = f"{name} is not a registered Subscriber."
|
||||
raise KeyError(error_message)
|
||||
@@ -1,10 +1,10 @@
|
||||
"""Packer and unpacker for network data."""
|
||||
import json
|
||||
import pickle
|
||||
from typing import Callable
|
||||
from typing import Any, Callable
|
||||
|
||||
|
||||
def get_packer(style) -> Callable[[dict], str]:
|
||||
def get_packer(style: str) -> Callable[[dict[str, Any]], Any]:
|
||||
"""Return a packer function for the given style.
|
||||
|
||||
Packer func serializes a given dict."""
|
||||
@@ -14,7 +14,7 @@ def get_packer(style) -> Callable[[dict], str]:
|
||||
return _packstyles["default"]
|
||||
|
||||
|
||||
def get_unpacker(style) -> Callable[[str], dict]:
|
||||
def get_unpacker(style: str) -> Callable[[str | bytes], dict[str, Any] | str | bytes]:
|
||||
"""Return an unpacker function for the given style.
|
||||
|
||||
Unpacker func deserializes a string."""
|
||||
@@ -24,21 +24,21 @@ def get_unpacker(style) -> Callable[[str], dict]:
|
||||
return _unpackstyles["default"]
|
||||
|
||||
|
||||
def serialpacker(data: dict) -> str:
|
||||
def serialpacker(data: dict[str, Any]) -> str:
|
||||
return ",".join([str(v) for v in data.values()])
|
||||
|
||||
|
||||
_packstyles = {
|
||||
_packstyles: dict[str, Callable[[dict[str, Any]], str | bytes]] = {
|
||||
"default": json.dumps,
|
||||
"json": json.dumps,
|
||||
"pickle": pickle.dumps,
|
||||
"serial": serialpacker,
|
||||
"raw": lambda x: x,
|
||||
"raw": lambda x: x, # type: ignore
|
||||
}
|
||||
|
||||
_unpackstyles = {
|
||||
_unpackstyles: dict[str, Callable[[str | bytes], dict[str, Any] | Any]] = {
|
||||
"default": json.loads,
|
||||
"json": json.loads,
|
||||
"pickle": pickle.loads,
|
||||
"pickle": pickle.loads, # type: ignore
|
||||
"raw": lambda x: x,
|
||||
}
|
||||
52
heisskleber/core/types.py
Normal file
52
heisskleber/core/types.py
Normal file
@@ -0,0 +1,52 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, Union
|
||||
|
||||
PayloadType = Union[str, int, float]
|
||||
|
||||
|
||||
class Publisher(ABC):
|
||||
"""
|
||||
Publisher interface.
|
||||
"""
|
||||
|
||||
pack: Callable[[dict[str, Any] | Any], str]
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, config: Any) -> None:
|
||||
"""
|
||||
Initialize the publisher with a configuration object.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def send(self, topic: str, data: dict[str, Any]) -> None:
|
||||
"""
|
||||
Send data via the implemented output stream.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class Subscriber(ABC):
|
||||
"""
|
||||
Subscriber interface
|
||||
"""
|
||||
|
||||
unpack: Callable[[bytes], dict[str, PayloadType] | Any]
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, topic: str | list[str], config: Any) -> None:
|
||||
"""
|
||||
Initialize the subscriber with a topic and a configuration object.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def receive(self) -> tuple[str, dict[str, PayloadType]]:
|
||||
"""
|
||||
Blocking function to receive data from the implemented input stream.
|
||||
|
||||
Data is returned as a tuple of (topic, data).
|
||||
"""
|
||||
pass
|
||||
@@ -1,6 +1,8 @@
|
||||
import pandas as pd
|
||||
from influxdb_client import InfluxDBClient
|
||||
|
||||
from heisskleber.core.types import Subscriber
|
||||
|
||||
from .config import InfluxDBConf
|
||||
|
||||
|
||||
@@ -28,7 +30,7 @@ def build_query(options: dict) -> str:
|
||||
return query
|
||||
|
||||
|
||||
class Influx_Subscriber:
|
||||
class Influx_Subscriber(Subscriber):
|
||||
def __init__(self, config: InfluxDBConf, query: str):
|
||||
self.config = config
|
||||
self.query = query
|
||||
@@ -45,7 +47,7 @@ class Influx_Subscriber:
|
||||
self._run_query()
|
||||
self.index = 0
|
||||
|
||||
def receive(self) -> dict:
|
||||
def receive(self) -> tuple[str, dict]:
|
||||
row = self.df.iloc[self.index].to_dict()
|
||||
self.index += 1
|
||||
return "influx", row
|
||||
5
heisskleber/mqtt/__init__.py
Normal file
5
heisskleber/mqtt/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from .config import MqttConf
|
||||
from .publisher import MqttPublisher
|
||||
from .subscriber import MqttSubscriber
|
||||
|
||||
__all__ = ["MqttConf", "MqttPublisher", "MqttSubscriber"]
|
||||
@@ -16,7 +16,7 @@ class MqttConf(BaseConf):
|
||||
ssl: bool = False
|
||||
qos: int = 0
|
||||
retain: bool = False
|
||||
topics: list[bytes] = field(default_factory=list)
|
||||
topics: list[str] = field(default_factory=list)
|
||||
mapping: str = "/msb/"
|
||||
packstyle: str = "json"
|
||||
max_saved_messages: int = 100
|
||||
@@ -1,14 +1,14 @@
|
||||
from heisskleber import get_publisher, get_subscriber
|
||||
from heisskleber.config import load_config
|
||||
from heisskleber.network import get_publisher, get_subscriber
|
||||
|
||||
from .config import MqttConf
|
||||
|
||||
|
||||
def map_topic(zmq_topic, mapping):
|
||||
return mapping + zmq_topic.decode()
|
||||
def map_topic(zmq_topic: str, mapping: str) -> str:
|
||||
return mapping + zmq_topic
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> None:
|
||||
config: MqttConf = load_config(MqttConf(), "mqtt")
|
||||
sub = get_subscriber("zmq", config.topics)
|
||||
pub = get_publisher("mqtt")
|
||||
@@ -25,19 +25,19 @@ def _set_thread_died_excepthook(args, /):
|
||||
threading.excepthook = _set_thread_died_excepthook
|
||||
|
||||
|
||||
class MQTT_Base:
|
||||
class MqttBase:
|
||||
"""
|
||||
Wrapper around eclipse paho mqtt client.
|
||||
Handles connection and callbacks.
|
||||
Callbacks may be overwritten in subclasses.
|
||||
"""
|
||||
|
||||
def __init__(self, config: MqttConf):
|
||||
def __init__(self, config: MqttConf) -> None:
|
||||
self.config = config
|
||||
self.connect()
|
||||
self.client.loop_start()
|
||||
|
||||
def connect(self):
|
||||
def connect(self) -> None:
|
||||
self.client = mqtt_client()
|
||||
self.client.username_pw_set(self.config.user, self.config.password)
|
||||
|
||||
@@ -55,13 +55,13 @@ class MQTT_Base:
|
||||
self.client.connect(self.config.broker, self.config.port)
|
||||
|
||||
@staticmethod
|
||||
def _raise_if_thread_died():
|
||||
def _raise_if_thread_died() -> None:
|
||||
global _thread_died
|
||||
if _thread_died.is_set():
|
||||
raise ThreadDiedError()
|
||||
|
||||
# MQTT callbacks
|
||||
def _on_connect(self, client, userdata, flags, return_code):
|
||||
def _on_connect(self, client, userdata, flags, return_code) -> None:
|
||||
if return_code == 0:
|
||||
print(f"MQTT node connected to {self.config.broker}:{self.config.port}")
|
||||
else:
|
||||
@@ -69,21 +69,21 @@ class MQTT_Base:
|
||||
if self.config.verbose:
|
||||
print(flags)
|
||||
|
||||
def _on_disconnect(self, client, userdata, return_code):
|
||||
def _on_disconnect(self, client, userdata, return_code) -> None:
|
||||
print(f"Disconnected from broker with return code {return_code}")
|
||||
if return_code != 0:
|
||||
print("Killing this service")
|
||||
sys.exit(-1)
|
||||
|
||||
def _on_publish(self, client, userdata, message_id):
|
||||
def _on_publish(self, client, userdata, message_id) -> None:
|
||||
if self.config.verbose:
|
||||
print(f"Published message with id {message_id}, qos={self.config.qos}")
|
||||
|
||||
def _on_message(self, client, userdata, message):
|
||||
def _on_message(self, client, userdata, message) -> None:
|
||||
if self.config.verbose:
|
||||
print(
|
||||
f"Received message: {message.payload!s}, topic: {message.topic}, qos: {message.qos}"
|
||||
)
|
||||
|
||||
def __del__(self):
|
||||
def __del__(self) -> None:
|
||||
self.client.loop_stop()
|
||||
36
heisskleber/mqtt/publisher.py
Normal file
36
heisskleber/mqtt/publisher.py
Normal file
@@ -0,0 +1,36 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from heisskleber.core.packer import get_packer
|
||||
from heisskleber.core.types import Publisher
|
||||
|
||||
from .config import MqttConf
|
||||
from .mqtt_base import MqttBase
|
||||
|
||||
|
||||
class MqttPublisher(MqttBase, Publisher):
|
||||
"""
|
||||
MQTT publisher class.
|
||||
Can be used everywhere that a flucto style publishing connection is required.
|
||||
|
||||
Network message loop is handled in a separated thread.
|
||||
"""
|
||||
|
||||
def __init__(self, config: MqttConf) -> None:
|
||||
super().__init__(config)
|
||||
self.pack = get_packer(config.packstyle)
|
||||
|
||||
def send(self, topic: str, data: dict[str, Any]) -> None:
|
||||
"""
|
||||
Takes python dictionary, serializes it according to the packstyle
|
||||
and sends it to the broker.
|
||||
|
||||
Publishing is asynchronous
|
||||
"""
|
||||
self._raise_if_thread_died()
|
||||
|
||||
payload = self.pack(data)
|
||||
self.client.publish(
|
||||
topic, payload, qos=self.config.qos, retain=self.config.retain
|
||||
)
|
||||
66
heisskleber/mqtt/subscriber.py
Normal file
66
heisskleber/mqtt/subscriber.py
Normal file
@@ -0,0 +1,66 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from queue import SimpleQueue
|
||||
from typing import Any
|
||||
|
||||
from paho.mqtt.client import MQTTMessage
|
||||
|
||||
from heisskleber.core.packer import get_unpacker
|
||||
from heisskleber.core.types import Subscriber
|
||||
|
||||
from .config import MqttConf
|
||||
from .mqtt_base import MqttBase
|
||||
|
||||
|
||||
class MqttSubscriber(MqttBase, Subscriber):
|
||||
"""
|
||||
MQTT subscriber, wraps around ecplipse's paho mqtt client.
|
||||
Network message loop is handled in a separated thread.
|
||||
|
||||
Incoming messages are saved as a stack when not processed via the receive() function.
|
||||
"""
|
||||
|
||||
def __init__(self, topics: str | list[str], config: MqttConf):
|
||||
super().__init__(config)
|
||||
self._message_queue: SimpleQueue[MQTTMessage] = SimpleQueue()
|
||||
self.subscribe(topics)
|
||||
self.client.on_message = self._on_message
|
||||
self.unpack = get_unpacker(config.packstyle)
|
||||
|
||||
def subscribe(self, topics: str | list[str] | tuple[str]) -> None:
|
||||
"""
|
||||
Subscribe to one or multiple topics
|
||||
"""
|
||||
if isinstance(topics, (list, tuple)):
|
||||
# if subscribing to multiple topics, use a list of tuples
|
||||
subscription_list = [(topic, self.config.qos) for topic in topics]
|
||||
self.client.subscribe(subscription_list)
|
||||
else:
|
||||
self.client.subscribe(topics, self.config.qos)
|
||||
if self.config.verbose:
|
||||
print(f"Subscribed to: {topics}")
|
||||
|
||||
def receive(self) -> tuple[str, dict[str, Any]]:
|
||||
"""
|
||||
Reads a message from mqtt and returns it
|
||||
|
||||
Messages are saved in a stack, if no message is available, this function blocks.
|
||||
|
||||
Returns:
|
||||
tuple(topic: bytes, message: dict): the message received
|
||||
"""
|
||||
self._raise_if_thread_died()
|
||||
mqtt_message = self._message_queue.get(
|
||||
block=True, timeout=self.config.timeout_s
|
||||
)
|
||||
|
||||
message_returned = self.unpack(mqtt_message.payload.decode())
|
||||
return (mqtt_message.topic, message_returned)
|
||||
|
||||
# callback to add incoming messages onto stack
|
||||
def _on_message(self, client, userdata, message) -> None:
|
||||
self._message_queue.put(message)
|
||||
|
||||
if self.config.verbose:
|
||||
print(f"Topic: {message.topic}")
|
||||
print(f"MQTT message: {message.payload.decode()}")
|
||||
@@ -1 +0,0 @@
|
||||
from .pubsub.factories import get_publisher, get_subscriber # noqa: F401
|
||||
@@ -1,3 +0,0 @@
|
||||
from .config import MqttConf # noqa: F401
|
||||
from .publisher import MqttPublisher # noqa: F401
|
||||
from .subscriber import MqttSubscriber # noqa: F401
|
||||
@@ -1,18 +0,0 @@
|
||||
from heisskleber.config import load_config
|
||||
from heisskleber.network import get_publisher, get_subscriber
|
||||
|
||||
from .config import MqttConf
|
||||
from .forwarder import ZMQ_to_MQTT_Forwarder
|
||||
|
||||
|
||||
def main():
|
||||
config = load_config(MqttConf(), "mqtt")
|
||||
for topic in config.topics:
|
||||
print(f"Subscribing to {topic}")
|
||||
|
||||
zmq_sub = get_subscriber("zmq", list(config.topics))
|
||||
mqtt_pub = get_publisher("mqtt")
|
||||
forwarder = ZMQ_to_MQTT_Forwarder(config, subscriber=zmq_sub, publisher=mqtt_pub)
|
||||
|
||||
# Wait for zmq messages, publish as mqtt message
|
||||
forwarder.zmq_to_mqtt_loop()
|
||||
@@ -1,63 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from heisskleber.config import load_config
|
||||
from heisskleber.network.packer import get_packer
|
||||
from heisskleber.network.pubsub.types import Publisher
|
||||
|
||||
from .config import MqttConf
|
||||
from .mqtt_base import MQTT_Base
|
||||
|
||||
|
||||
class MqttPublisher(MQTT_Base, Publisher):
|
||||
"""
|
||||
MQTT publisher class.
|
||||
Can be used everywhere that a flucto style publishing connection is required.
|
||||
|
||||
Network message loop is handled in a separated thread.
|
||||
"""
|
||||
|
||||
def __init__(self, config: MqttConf):
|
||||
super().__init__(config)
|
||||
self.pack = get_packer(config.packstyle)
|
||||
|
||||
def send(self, topic: str | bytes, data: dict):
|
||||
"""
|
||||
Takes python dictionary, serializes it according to the packstyle
|
||||
and sends it to the broker.
|
||||
|
||||
Publishing is asynchronous
|
||||
"""
|
||||
self._raise_if_thread_died()
|
||||
if isinstance(topic, bytes):
|
||||
topic = topic.decode()
|
||||
|
||||
payload = self.pack(data)
|
||||
self.client.publish(
|
||||
topic, payload, qos=self.config.qos, retain=self.config.retain
|
||||
)
|
||||
|
||||
|
||||
def get_mqtt_publisher() -> MqttPublisher:
|
||||
"""
|
||||
Generate mqtt publisher with configuration from yaml file,
|
||||
falls back to default values if no config is found
|
||||
"""
|
||||
import os
|
||||
|
||||
if "MSB_CONFIG_DIR" in os.environ:
|
||||
print("loading mqtt config")
|
||||
config = load_config(MqttConf(), "mqtt", read_commandline=False)
|
||||
else:
|
||||
print("using default mqtt config")
|
||||
config = MqttConf()
|
||||
return MqttPublisher(config)
|
||||
|
||||
|
||||
def get_default_publisher() -> MqttPublisher:
|
||||
"""
|
||||
Generate mqtt publisher with configuration from yaml file,
|
||||
falls back to default values if no config is found
|
||||
|
||||
Deprecated, use get_mqtt_publisher() instead
|
||||
"""
|
||||
return get_mqtt_publisher()
|
||||
@@ -1,104 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from queue import SimpleQueue
|
||||
|
||||
from heisskleber.config import load_config
|
||||
from heisskleber.network.packer import get_unpacker
|
||||
from heisskleber.network.pubsub.types import Subscriber
|
||||
|
||||
from .config import MqttConf
|
||||
from .mqtt_base import MQTT_Base
|
||||
|
||||
|
||||
class MqttSubscriber(MQTT_Base, Subscriber):
|
||||
"""
|
||||
MQTT subscriber, wraps around ecplipse's paho mqtt client.
|
||||
Network message loop is handled in a separated thread.
|
||||
|
||||
Incoming messages are saved as a stack when not processed via the receive() function.
|
||||
"""
|
||||
|
||||
def __init__(self, topics, config: MqttConf):
|
||||
super().__init__(config)
|
||||
self._message_queue = SimpleQueue()
|
||||
self.subscribe(topics)
|
||||
self.client.on_message = self._on_message
|
||||
self.unpack = get_unpacker(config.packstyle)
|
||||
|
||||
def _subscribe_single_topic(self, topic: bytes | str):
|
||||
if isinstance(topic, bytes):
|
||||
topic = topic.decode()
|
||||
if self.config.verbose:
|
||||
print(f"Subscribed to: {topic}")
|
||||
self.client.subscribe(topic, self.config.qos)
|
||||
|
||||
def _subscribe_multiple_topics(self, topics: list[bytes] | list[str]):
|
||||
topics = [
|
||||
topic.decode() if isinstance(topic, bytes) else topic for topic in topics
|
||||
]
|
||||
subscription_list = [(topic, self.config.qos) for topic in topics]
|
||||
if self.config.verbose:
|
||||
print(f"Subscribed to: {topics}")
|
||||
self.client.subscribe(subscription_list)
|
||||
|
||||
def subscribe(self, topics):
|
||||
"""
|
||||
Subscribe to one or multiple topics
|
||||
"""
|
||||
# if subscribing to multiple topics, use a list of tuples
|
||||
if isinstance(topics, (list, tuple)):
|
||||
self._subscribe_multiple_topics(topics)
|
||||
else:
|
||||
self.client.subscribe(topics, self.config.qos)
|
||||
|
||||
def receive(self) -> tuple[bytes, dict]:
|
||||
"""
|
||||
Reads a message from mqtt and returns it
|
||||
|
||||
Messages are saved in a stack, if no message is available, this function blocks.
|
||||
|
||||
Returns:
|
||||
tuple(topic: bytes, message: dict): the message received
|
||||
"""
|
||||
self._raise_if_thread_died()
|
||||
mqtt_message = self._message_queue.get(
|
||||
block=True, timeout=self.config.timeout_s
|
||||
)
|
||||
|
||||
topic = mqtt_message.topic.encode("utf-8")
|
||||
message_returned = self.unpack(mqtt_message.payload.decode())
|
||||
return (topic, message_returned)
|
||||
|
||||
# callback to add incoming messages onto stack
|
||||
def _on_message(self, client, userdata, message):
|
||||
self._message_queue.put(message)
|
||||
|
||||
if self.config.verbose:
|
||||
print(f"Topic: {message.topic}")
|
||||
print(f"MQTT message: {message.payload.decode()}")
|
||||
|
||||
|
||||
def get_mqtt_subscriber(topic: bytes | str) -> MqttSubscriber:
|
||||
"""
|
||||
Generate mqtt subscriber with configuration from yaml file,
|
||||
falls back to default values if no config is found
|
||||
"""
|
||||
import os
|
||||
|
||||
if "MSB_CONFIG_DIR" in os.environ:
|
||||
print("loading mqtt config")
|
||||
config = load_config(MqttConf(), "mqtt", read_commandline=False)
|
||||
else:
|
||||
print("using default mqtt config")
|
||||
config = MqttConf()
|
||||
return MqttSubscriber(topic, config)
|
||||
|
||||
|
||||
def get_default_subscriber(topic: bytes | str) -> MqttSubscriber:
|
||||
"""
|
||||
Generate mqtt subscriber with configuration from yaml file,
|
||||
falls back to default values if no config is found
|
||||
|
||||
Deprecated, use get_mqtt_subscriber(topic) instead.
|
||||
"""
|
||||
return get_mqtt_subscriber(topic)
|
||||
@@ -1,31 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class Publisher(ABC):
|
||||
"""
|
||||
Publisher interface.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def send(self, topic: str | bytes, data: dict):
|
||||
"""
|
||||
Send data via the implemented output stream.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class Subscriber(ABC):
|
||||
"""
|
||||
Subscriber interface
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def receive(self) -> tuple[bytes, dict]:
|
||||
"""
|
||||
Blocking function to receive data from the implemented input stream.
|
||||
|
||||
Data is returned as a tuple of (topic, data).
|
||||
"""
|
||||
pass
|
||||
@@ -1,3 +0,0 @@
|
||||
from .config import SerialConf # noqa: F401
|
||||
from .publisher import SerialPublisher # noqa: F401
|
||||
from .subscriber import SerialSubscriber # noqa: F401
|
||||
@@ -1 +0,0 @@
|
||||
from .pubsub.types import Publisher, Subscriber # noqa: F401
|
||||
@@ -1,3 +0,0 @@
|
||||
from .config import ZmqConf # noqa: F401
|
||||
from .publisher import ZmqPublisher # noqa: F401
|
||||
from .subscriber import ZmqSubscriber # noqa: F401
|
||||
5
heisskleber/serial/__init__.py
Normal file
5
heisskleber/serial/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from .config import SerialConf
|
||||
from .publisher import SerialPublisher
|
||||
from .subscriber import SerialSubscriber
|
||||
|
||||
__all__ = ["SerialConf", "SerialPublisher", "SerialSubscriber"]
|
||||
@@ -1,10 +1,10 @@
|
||||
from heisskleber.network.types import Subscriber
|
||||
from heisskleber.core.types import Subscriber
|
||||
|
||||
from .publisher import SerialPublisher
|
||||
|
||||
|
||||
class SerialForwarder:
|
||||
def __init__(self, subscriber: Subscriber, publisher: SerialPublisher):
|
||||
def __init__(self, subscriber: Subscriber, publisher: SerialPublisher) -> None:
|
||||
self.sub = subscriber
|
||||
self.pub = publisher
|
||||
|
||||
@@ -12,19 +12,19 @@ class SerialForwarder:
|
||||
Wait for message and forward
|
||||
"""
|
||||
|
||||
def forward_message(self):
|
||||
def forward_message(self) -> None:
|
||||
# collected = {}
|
||||
# for sub in self.sub:
|
||||
# topic, data = sub.receive()
|
||||
# collected.update(data)
|
||||
_, collected = self.sub.receive()
|
||||
|
||||
self.pub.send(collected)
|
||||
self.pub.send("", collected)
|
||||
|
||||
"""
|
||||
Enter loop and continuously forward messages
|
||||
"""
|
||||
|
||||
def sub_pub_loop(self):
|
||||
def sub_pub_loop(self) -> None:
|
||||
while True:
|
||||
self.forward_message()
|
||||
@@ -1,11 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from types import FunctionType
|
||||
from typing import Callable, Optional
|
||||
|
||||
import serial
|
||||
|
||||
from heisskleber.network.packer import get_packer
|
||||
from heisskleber.network.pubsub.types import Publisher
|
||||
from heisskleber.core.packer import get_packer
|
||||
from heisskleber.core.types import Publisher
|
||||
|
||||
from .config import SerialConf
|
||||
|
||||
@@ -23,12 +23,16 @@ class SerialPublisher(Publisher):
|
||||
Function to translate from a dict to a serialized string.
|
||||
"""
|
||||
|
||||
def __init__(self, config: SerialConf, pack_func: FunctionType | None = None):
|
||||
def __init__(
|
||||
self,
|
||||
config: SerialConf,
|
||||
pack_func: Optional[Callable] = None, # noqa: UP007
|
||||
):
|
||||
self.config = config
|
||||
self.packer = pack_func if pack_func else get_packer("serial")
|
||||
self.pack = pack_func if pack_func else get_packer("serial")
|
||||
self._connect()
|
||||
|
||||
def _connect(self):
|
||||
def _connect(self) -> None:
|
||||
self.serial: serial.Serial = serial.Serial(
|
||||
port=self.config.port,
|
||||
baudrate=self.config.baudrate,
|
||||
@@ -38,25 +42,23 @@ class SerialPublisher(Publisher):
|
||||
)
|
||||
print(f"Successfully connected to serial device at port {self.config.port}")
|
||||
|
||||
def send(self, message: object):
|
||||
def send(self, topic, message: dict) -> None:
|
||||
"""
|
||||
Takes python dictionary, serializes it according to the packstyle
|
||||
and sends it to the broker.
|
||||
|
||||
Please note that this does not adhere to the interface, as there is no topic.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
message : object
|
||||
message : dict
|
||||
object to be serialized and sent via the serial connection. Usually a dict.
|
||||
"""
|
||||
payload = self.packer(message)
|
||||
payload = self.pack(message)
|
||||
self.serial.write(payload.encode(self.config.encoding))
|
||||
self.serial.flush()
|
||||
if self.config.verbose:
|
||||
print(payload)
|
||||
print(f"{topic}: {payload}")
|
||||
|
||||
def __del__(self):
|
||||
def __del__(self) -> None:
|
||||
if not hasattr(self, "serial"):
|
||||
return
|
||||
if not self.serial.is_open:
|
||||
@@ -1,10 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator
|
||||
from typing import Callable, Optional
|
||||
|
||||
import serial
|
||||
|
||||
from heisskleber.network.pubsub.types import Subscriber
|
||||
from heisskleber.core.types import Subscriber
|
||||
|
||||
from .config import SerialConf
|
||||
|
||||
@@ -45,7 +46,7 @@ class SerialSubscriber(Subscriber):
|
||||
)
|
||||
print(f"Successfully connected to serial device at port {self.config.port}")
|
||||
|
||||
def receive(self) -> dict:
|
||||
def receive(self) -> tuple[str, dict]:
|
||||
"""
|
||||
Wait for data to arrive on the serial port and return it.
|
||||
|
||||
@@ -62,29 +63,30 @@ class SerialSubscriber(Subscriber):
|
||||
# port is a placeholder for topic
|
||||
return self.config.port, payload
|
||||
|
||||
def read_serial_port(self) -> str:
|
||||
def read_serial_port(self) -> Generator[str, None, None]:
|
||||
"""
|
||||
Generator function reading from the serial port.
|
||||
|
||||
Returns
|
||||
-------
|
||||
:return: Generator[str, None, None]
|
||||
Generator yielding strings read from the serial port
|
||||
"""
|
||||
buffer = ""
|
||||
while True:
|
||||
try:
|
||||
buffer = self.serial.readline().decode()
|
||||
buffer = self.serial.readline().decode(self.config.encoding, "ignore")
|
||||
yield buffer
|
||||
except UnicodeError as e:
|
||||
if self.config.verbose:
|
||||
print(f"Could not decode: {message}")
|
||||
print(f"Could not decode: {buffer!r}")
|
||||
print(e)
|
||||
continue
|
||||
|
||||
def __del__(self):
|
||||
def __del__(self) -> None:
|
||||
if not hasattr(self, "serial"):
|
||||
return
|
||||
if not self.serial.is_open:
|
||||
return
|
||||
self.serial.flush()
|
||||
self.serial.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
config = SerialConf()
|
||||
serial_reader = SerialSubscriber(config)
|
||||
for message in serial_reader.receive():
|
||||
print(message)
|
||||
@@ -1,8 +1,8 @@
|
||||
import socket
|
||||
|
||||
from heisskleber.network.packer import get_packer
|
||||
from heisskleber.network.pubsub.types import Publisher
|
||||
from heisskleber.network.udp.config import UDPConf
|
||||
from heisskleber.core.packer import get_packer
|
||||
from heisskleber.core.types import Publisher
|
||||
from heisskleber.udp.config import UDPConf
|
||||
|
||||
|
||||
class UDP_Publisher(Publisher):
|
||||
@@ -1,8 +1,8 @@
|
||||
import socket
|
||||
|
||||
from heisskleber.network.packer import get_unpacker
|
||||
from heisskleber.network.pubsub.types import Subscriber
|
||||
from heisskleber.network.udp.config import UDPConf
|
||||
from heisskleber.core.packer import get_unpacker
|
||||
from heisskleber.core.types import Subscriber
|
||||
from heisskleber.udp.config import UDPConf
|
||||
|
||||
|
||||
class UDP_Subscriber(Subscriber):
|
||||
5
heisskleber/zmq/__init__.py
Normal file
5
heisskleber/zmq/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from .config import ZmqConf
|
||||
from .publisher import ZmqPublisher
|
||||
from .subscriber import ZmqSubscriber
|
||||
|
||||
__all__ = ["ZmqConf", "ZmqPublisher", "ZmqSubscriber"]
|
||||
@@ -1,9 +1,10 @@
|
||||
import sys
|
||||
from typing import Any
|
||||
|
||||
import zmq
|
||||
|
||||
from heisskleber.network.packer import get_packer
|
||||
from heisskleber.network.pubsub.types import Publisher
|
||||
from heisskleber.core.packer import get_packer
|
||||
from heisskleber.core.types import Publisher
|
||||
|
||||
from .config import ZmqConf
|
||||
|
||||
@@ -25,9 +26,9 @@ class ZmqPublisher(Publisher):
|
||||
print(f"failed to bind to zeromq socket: {e}")
|
||||
sys.exit(-1)
|
||||
|
||||
def send(self, topic: str, data: dict) -> None:
|
||||
data = self.pack(data)
|
||||
self.socket.send_multipart([topic, data.encode("utf-8")])
|
||||
def send(self, topic: str, data: dict[str, Any]) -> None:
|
||||
payload = self.pack(data)
|
||||
self.socket.send_multipart([topic, payload.encode("utf-8")])
|
||||
|
||||
def __del__(self):
|
||||
self.socket.close()
|
||||
@@ -4,8 +4,8 @@ import sys
|
||||
|
||||
import zmq
|
||||
|
||||
from heisskleber.network.packer import get_unpacker
|
||||
from heisskleber.network.pubsub.types import Subscriber
|
||||
from heisskleber.core.packer import get_unpacker
|
||||
from heisskleber.core.types import Subscriber
|
||||
|
||||
from .config import ZmqConf
|
||||
|
||||
@@ -85,9 +85,10 @@ select = [
|
||||
"TRY", # tryceratops
|
||||
]
|
||||
ignore = [
|
||||
"E501", # LineTooLong
|
||||
"E731", # DoNotAssignLambda
|
||||
"A001", #
|
||||
"E501", # LineTooLong
|
||||
"E731", # DoNotAssignLambda
|
||||
"A001", #
|
||||
"PGH003", # Use specific rules when ignoring type issues
|
||||
]
|
||||
|
||||
[tool.ruff.per-file-ignores]
|
||||
|
||||
@@ -1,31 +1,26 @@
|
||||
def test_import_mqtt():
|
||||
from heisskleber.network.mqtt import (
|
||||
MqttConf,
|
||||
MqttPublisher,
|
||||
MqttSubscriber,
|
||||
)
|
||||
import heisskleber
|
||||
from heisskleber.mqtt import MqttConf, MqttPublisher, MqttSubscriber
|
||||
|
||||
assert heisskleber.__all__ == [
|
||||
"get_publisher",
|
||||
"get_subscriber",
|
||||
"Publisher",
|
||||
"Subscriber",
|
||||
]
|
||||
|
||||
|
||||
def test_import_zmq():
|
||||
from heisskleber.network.zmq import (
|
||||
ZmqConf,
|
||||
ZmqPublisher,
|
||||
ZmqSubscriber,
|
||||
)
|
||||
from heisskleber.zmq import ZmqConf, ZmqPublisher, ZmqSubscriber
|
||||
|
||||
|
||||
def test_import_serial():
|
||||
from heisskleber.network.serial import (
|
||||
SerialConf,
|
||||
SerialPublisher,
|
||||
SerialSubscriber,
|
||||
)
|
||||
from heisskleber.serial import SerialConf, SerialPublisher, SerialSubscriber
|
||||
|
||||
|
||||
def test_import_utils():
|
||||
from heisskleber.network import get_publisher, get_subscriber
|
||||
from heisskleber.network.types import Publisher, Subscriber
|
||||
from heisskleber import Publisher, Subscriber, get_publisher, get_subscriber
|
||||
|
||||
|
||||
def test_import_config():
|
||||
pass
|
||||
from heisskleber.config import BaseConf, load_config
|
||||
|
||||
162
tests/test_mqtt.py
Normal file
162
tests/test_mqtt.py
Normal file
@@ -0,0 +1,162 @@
|
||||
import json
|
||||
from queue import SimpleQueue
|
||||
from unittest.mock import call, patch
|
||||
|
||||
import pytest
|
||||
from paho.mqtt.client import MQTTMessage
|
||||
|
||||
from heisskleber.mqtt.config import MqttConf
|
||||
from heisskleber.mqtt.mqtt_base import MqttBase
|
||||
from heisskleber.mqtt.subscriber import MqttSubscriber
|
||||
|
||||
|
||||
# Mock configuration for MQTT_Base
|
||||
@pytest.fixture
|
||||
def mock_mqtt_conf() -> MqttConf:
|
||||
return MqttConf(
|
||||
broker="localhost",
|
||||
port=1883,
|
||||
user="user",
|
||||
password="passwd", # noqa: S106
|
||||
ssl=False,
|
||||
verbose=False,
|
||||
qos=1,
|
||||
)
|
||||
|
||||
|
||||
# Mock the paho mqtt client
|
||||
@pytest.fixture
|
||||
def mock_mqtt_client():
|
||||
with patch("heisskleber.mqtt.mqtt_base.mqtt_client", autospec=True) as mock:
|
||||
yield mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_queue():
|
||||
with patch("heisskleber.mqtt.subscriber.SimpleQueue", spec=SimpleQueue) as mock:
|
||||
yield mock
|
||||
|
||||
|
||||
def test_mqtt_base_intialization(mock_mqtt_client, mock_mqtt_conf):
|
||||
"""Test that the intialization of the mqtt client is as expected."""
|
||||
base = MqttBase(config=mock_mqtt_conf)
|
||||
|
||||
mock_mqtt_client.assert_called_once()
|
||||
mock_mqtt_client.return_value.loop_start.assert_called_once()
|
||||
mock_client_instance = mock_mqtt_client.return_value
|
||||
mock_client_instance.username_pw_set.assert_called_with(
|
||||
mock_mqtt_conf.user, mock_mqtt_conf.password
|
||||
)
|
||||
mock_client_instance.connect.assert_called_with(
|
||||
mock_mqtt_conf.broker, mock_mqtt_conf.port
|
||||
)
|
||||
assert base.client.on_connect == base._on_connect
|
||||
assert base.client.on_disconnect == base._on_disconnect
|
||||
assert base.client.on_publish == base._on_publish
|
||||
assert base.client.on_message == base._on_message
|
||||
|
||||
|
||||
def test_mqtt_base_on_connect(mock_mqtt_client, mock_mqtt_conf, capsys):
|
||||
base = MqttBase(config=mock_mqtt_conf)
|
||||
base._on_connect(None, None, {}, 0)
|
||||
captured = capsys.readouterr()
|
||||
assert (
|
||||
f"MQTT node connected to {mock_mqtt_conf.broker}:{mock_mqtt_conf.port}"
|
||||
in captured.out
|
||||
)
|
||||
|
||||
|
||||
def test_mqtt_base_on_disconnect_with_error(mock_mqtt_client, mock_mqtt_conf, capsys):
|
||||
"""Assert that the mqtt client shuts down when disconnect callback is received."""
|
||||
base = MqttBase(config=mock_mqtt_conf)
|
||||
with pytest.raises(SystemExit):
|
||||
base._on_disconnect(None, None, 1)
|
||||
captured = capsys.readouterr()
|
||||
assert "Killing this service" in captured.out
|
||||
print(captured.out)
|
||||
|
||||
|
||||
def test_mqtt_subscribes_single_topic(mock_mqtt_client, mock_mqtt_conf):
|
||||
"""Test that the mqtt client subscribes to a single topic."""
|
||||
_ = MqttSubscriber(topics="singleTopic", config=mock_mqtt_conf)
|
||||
|
||||
actual_calls = mock_mqtt_client.return_value.subscribe.call_args_list
|
||||
assert actual_calls == [call("singleTopic", mock_mqtt_conf.qos)]
|
||||
|
||||
|
||||
def test_mqtt_subscribes_multiple_topics(mock_mqtt_client, mock_mqtt_conf):
|
||||
"""Test that the mqtt client subscribes to multiple topics passed as list.
|
||||
|
||||
I would love to do this via parametrization, but the call argument is built differently for single size lists and longer lists.
|
||||
"""
|
||||
_ = MqttSubscriber(topics=["multiple1", "multiple2"], config=mock_mqtt_conf)
|
||||
|
||||
actual_calls = mock_mqtt_client.return_value.subscribe.call_args_list
|
||||
assert actual_calls == [
|
||||
call([("multiple1", mock_mqtt_conf.qos), ("multiple2", mock_mqtt_conf.qos)]),
|
||||
]
|
||||
|
||||
|
||||
def test_mqtt_subscribes_multiple_topics_tuple(mock_mqtt_client, mock_mqtt_conf):
|
||||
"""Test that the mqtt client subscribes to multiple topics passed as tuple."""
|
||||
_ = MqttSubscriber(topics=("multiple1", "multiple2"), config=mock_mqtt_conf)
|
||||
|
||||
actual_calls = mock_mqtt_client.return_value.subscribe.call_args_list
|
||||
assert actual_calls == [
|
||||
call([("multiple1", mock_mqtt_conf.qos), ("multiple2", mock_mqtt_conf.qos)]),
|
||||
]
|
||||
|
||||
|
||||
def create_fake_mqtt_message(topic: bytes, payload: bytes) -> MQTTMessage:
|
||||
msg = MQTTMessage()
|
||||
msg.topic = topic
|
||||
msg.payload = payload
|
||||
return msg
|
||||
|
||||
|
||||
def test_receive_with_message(mock_mqtt_conf: MqttConf, mock_mqtt_client, mock_queue):
|
||||
"""Test the mqtt receive function with fake MQTT messages."""
|
||||
topic = b"test/topic"
|
||||
payload = json.dumps({"key": "value"}).encode()
|
||||
fake_message = create_fake_mqtt_message(topic, payload)
|
||||
|
||||
mock_queue.return_value.get.side_effect = [fake_message]
|
||||
subscriber = MqttSubscriber(topics=[topic], config=mock_mqtt_conf)
|
||||
|
||||
received_topic, received_payload = subscriber.receive()
|
||||
|
||||
assert received_topic == "test/topic"
|
||||
assert received_payload == {"key": "value"}
|
||||
|
||||
|
||||
def test_message_is_put_into_queue(
|
||||
mock_mqtt_conf: MqttConf, mock_mqtt_client, mock_queue
|
||||
):
|
||||
"""Test that values a put into a queue when on_message callback is called."""
|
||||
topic = b"test/topic"
|
||||
payload = json.dumps({"key": "value"}).encode()
|
||||
fake_message = create_fake_mqtt_message(topic, payload)
|
||||
|
||||
mock_queue.return_value.get.side_effect = [fake_message]
|
||||
subscriber = MqttSubscriber(topics=[topic], config=mock_mqtt_conf)
|
||||
|
||||
subscriber._on_message(None, None, fake_message)
|
||||
|
||||
mock_queue.return_value.put.assert_called_once_with(fake_message)
|
||||
|
||||
|
||||
def test_message_is_put_into_queue_with_actual_queue(mock_mqtt_conf, mock_mqtt_client):
|
||||
"""Test that the buffering via queue works as expected."""
|
||||
topic = b"test/topic"
|
||||
payload = json.dumps({"key": "value"}).encode()
|
||||
fake_message = create_fake_mqtt_message(topic, payload)
|
||||
|
||||
# mock_queue.return_value.get.side_effect = [fake_message]
|
||||
subscriber = MqttSubscriber(topics=[topic], config=mock_mqtt_conf)
|
||||
|
||||
subscriber._on_message(None, None, fake_message)
|
||||
|
||||
topic, return_dict = subscriber.receive()
|
||||
|
||||
assert topic == "test/topic"
|
||||
assert return_dict == {"key": "value"}
|
||||
35
tests/test_packer.py
Normal file
35
tests/test_packer.py
Normal file
@@ -0,0 +1,35 @@
|
||||
import json
|
||||
import pickle
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from heisskleber.core.packer import get_packer, get_unpacker, serialpacker
|
||||
|
||||
|
||||
def test_get_packer() -> None:
|
||||
assert get_packer("json") == json.dumps
|
||||
assert get_packer("pickle") == pickle.dumps
|
||||
assert get_packer("default") == json.dumps
|
||||
assert get_packer("foobar") == json.dumps
|
||||
assert get_packer("serial") == serialpacker
|
||||
|
||||
|
||||
def test_get_unpacker() -> None:
|
||||
assert get_unpacker("json") == json.loads
|
||||
assert get_unpacker("pickle") == pickle.loads
|
||||
assert get_unpacker("default") == json.loads
|
||||
assert get_unpacker("foobar") == json.loads
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"message,expected",
|
||||
[
|
||||
({"hi": 1, "da": 2, "nei": 3}, "1,2,3"),
|
||||
({"er": 1, "ma": "ga", "gerd": 3, "jo": 4}, "1,ga,3,4"),
|
||||
({"": 1, "ho": 0.0, "lee": 0.1, "shit": 1_000}, "1,0.0,0.1,1000"),
|
||||
({"be": 1e6, "li": 1_000}, "1000000.0,1000"),
|
||||
],
|
||||
)
|
||||
def test_serial_packer_functionality(message: dict[str, Any], expected: str) -> None:
|
||||
assert serialpacker(message) == expected
|
||||
115
tests/test_serial.py
Normal file
115
tests/test_serial.py
Normal file
@@ -0,0 +1,115 @@
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
import serial
|
||||
|
||||
from heisskleber.core.packer import serialpacker
|
||||
from heisskleber.serial.config import SerialConf
|
||||
from heisskleber.serial.publisher import SerialPublisher
|
||||
from heisskleber.serial.subscriber import SerialSubscriber
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def serial_conf():
|
||||
return SerialConf(port="/dev/test", baudrate=9600, bytesize=8, verbose=False)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_serial_device_subscriber():
|
||||
with patch("heisskleber.serial.subscriber.serial.Serial") as mock:
|
||||
yield mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_serial_device_publisher():
|
||||
with patch("heisskleber.serial.publisher.serial.Serial") as mock:
|
||||
yield mock
|
||||
|
||||
|
||||
def test_serial_subscriber_initialization(mock_serial_device_subscriber, serial_conf):
|
||||
"""Test that the SerialSubscriber class initializes correctly.
|
||||
Mocks the serial.Serial class to avoid opening a serial port."""
|
||||
_ = SerialSubscriber(topics="", config=serial_conf)
|
||||
mock_serial_device_subscriber.assert_called_with(
|
||||
port=serial_conf.port,
|
||||
baudrate=serial_conf.baudrate,
|
||||
bytesize=serial_conf.bytesize,
|
||||
parity=serial.PARITY_NONE,
|
||||
stopbits=serial.STOPBITS_ONE,
|
||||
)
|
||||
|
||||
|
||||
def test_serial_subscriber_receive(mock_serial_device_subscriber, serial_conf):
|
||||
"""Test that the SerialSubscriber class calls readline and unpack as expected."""
|
||||
subscriber = SerialSubscriber(topics="", config=serial_conf)
|
||||
|
||||
# Set up the readline return value
|
||||
mock_serial_instance = mock_serial_device_subscriber.return_value
|
||||
mock_serial_instance.readline.return_value = b"test message\n"
|
||||
|
||||
# Set up the unpack function to convert message to dict
|
||||
unpack_func = Mock(return_value={"data": "test message"})
|
||||
subscriber.unpack = unpack_func
|
||||
|
||||
# Call the receive method and assert it behaves as expected
|
||||
_, payload = subscriber.receive()
|
||||
|
||||
# Was readline called?
|
||||
mock_serial_instance.readline.assert_called_once()
|
||||
|
||||
# Was unpack called?
|
||||
assert payload == {"data": "test message"}
|
||||
unpack_func.assert_called_once_with("test message\n")
|
||||
|
||||
|
||||
def test_serial_subscriber_converts_bytes_to_str():
|
||||
"""Test that the SerialSubscriber class converts bytes to str as expected."""
|
||||
with patch("heisskleber.serial.subscriber.serial.Serial") as mock_serial:
|
||||
subscriber = SerialSubscriber(
|
||||
topics="", config=SerialConf(), unpack_func=lambda x: x
|
||||
)
|
||||
|
||||
# Set the readline method to raise UnicodeError
|
||||
mock_serial_instance = mock_serial.return_value
|
||||
mock_serial_instance.readline.side_effect = [b"test message", b"test\x86more"]
|
||||
|
||||
_, payload = subscriber.receive()
|
||||
assert isinstance(payload, str)
|
||||
assert payload == "test message"
|
||||
|
||||
# Assert that none-unicode is skipped
|
||||
_, payload = subscriber.receive()
|
||||
assert payload == "testmore"
|
||||
|
||||
|
||||
def test_serial_publisher_initialization(mock_serial_device_publisher, serial_conf):
|
||||
"""Test that the SerialPublisher class initializes correctly.
|
||||
Mocks the serial.Serial class to avoid opening a serial port."""
|
||||
publisher = SerialPublisher(config=serial_conf)
|
||||
mock_serial_device_publisher.assert_called_with(
|
||||
port=serial_conf.port,
|
||||
baudrate=serial_conf.baudrate,
|
||||
bytesize=serial_conf.bytesize,
|
||||
parity=serial.PARITY_NONE,
|
||||
stopbits=serial.STOPBITS_ONE,
|
||||
)
|
||||
assert publisher.serial
|
||||
|
||||
|
||||
def test_serial_publisher_send(mock_serial_device_publisher, serial_conf):
|
||||
"""Test that the SerialPublisher class calls write and pack as expected."""
|
||||
publisher = SerialPublisher(config=serial_conf)
|
||||
|
||||
# Set up the readline return value
|
||||
mock_serial_instance = mock_serial_device_publisher.return_value
|
||||
mock_serial_instance.readline.return_value = b"test message\n"
|
||||
|
||||
# Set up the pack function to convert dict to comma separated string of values
|
||||
publisher.pack = serialpacker
|
||||
|
||||
# Call the receive method and assert it behaves as expected
|
||||
publisher.send("test", {"data": "test message", "more_data": "more message"})
|
||||
|
||||
# Was write called with encoded payload?
|
||||
mock_serial_instance.write.assert_called_once_with(b"test message,more message")
|
||||
mock_serial_instance.flush.assert_called_once()
|
||||
@@ -4,4 +4,4 @@ import heisskleber
|
||||
|
||||
def test_heisskleber_version() -> None:
|
||||
"""Test that the glue version is correct."""
|
||||
assert heisskleber.__version__ == "0.1.0"
|
||||
assert heisskleber.__version__ == "0.2.0"
|
||||
|
||||
Reference in New Issue
Block a user