WIP: Test, refactor

This commit is contained in:
Felix Weiler
2023-11-06 22:52:48 +01:00
parent 16837b4324
commit 8cfc66ac45
53 changed files with 625 additions and 421 deletions

View File

@@ -47,7 +47,7 @@ docs-test: ## Test if documentation can be built without warnings or errors
.PHONY: docs
docs: ## Build and serve the documentation
@poetry run mkdocs serve
@poetry run python3 -m sphinx docs docs/_build -b html
.PHONY: help
help:

View File

@@ -3,22 +3,21 @@
## Network
```{eval-rst}
.. automodule:: heisskleber.network
.. automodule:: heisskleber
:members:
.. automodule:: heisskleber.network.mqtt
.. autoclass:: MqttPublisher
.. autoclass:: MqttSubscriber
.. automodule:: heisskleber.network.zmq
.. automodule:: heisskleber.mqtt
:members:
.. automodule:: heisskleber.zmq
:members:
.. automodule:: heisskleber.serial
:members:
.. automodule:: heisskleber.config
:members:
.. autoclass:: ZmqPublisher
.. autoclass:: ZmqSubscriber
```
### Broker
```{eval-rst}
.. automodule:: heisskleber.broker
:members:
```
## Config
@@ -34,7 +33,7 @@
Configs are extended dataclasses, which inherit from the BaseConf class.
```{eval-rst}
.. autoclass:: heisskleber.config.BaseConf
.. autoclass:: heisskleber.network.mqtt.config.MqttConf
.. autoclass:: heisskleber.network.zmq.config.ZmqConf
.. autoclass:: heisskleber.network.serial.config.SerialConf
.. autoclass:: heisskleber.mqtt.config.MqttConf
.. autoclass:: heisskleber.zmq.config.ZmqConf
.. autoclass:: heisskleber.serial.config.SerialConf
```

View File

@@ -1,5 +1,6 @@
"""Heisskleber."""
from .network import get_publisher, get_subscriber
from .core.factories import get_publisher, get_subscriber
from .core.types import Publisher, Subscriber
__all__ = ["get_publisher", "get_subscriber"]
__version__ = "0.1.0"
__all__ = ["get_publisher", "get_subscriber", "Publisher", "Subscriber"]
__version__ = "0.2.0"

View File

@@ -1,3 +1,3 @@
from .msb_broker import msb_broker as start_zmq_broker
from .msb_broker import zmq_broker as start_zmq_broker
__all__ = ["start_zmq_broker"]

View File

@@ -2,9 +2,10 @@ import signal
import sys
import zmq
from zmq import Socket
from heisskleber.config import load_config
from heisskleber.network.zmq.config import ZmqConf as BrokerConf
from heisskleber.zmq.config import ZmqConf as BrokerConf
def signal_handler(sig, frame):
@@ -16,29 +17,31 @@ class BrokerBindingError(Exception):
pass
def bind_socket(socket, address, socket_type, verbose=False):
def bind_socket(socket: Socket, address: str, socket_type: str, verbose=False) -> None:
"""Bind a ZMQ socket and handle errors."""
if verbose:
print(f"creating {socket_type} socket")
try:
socket.bind(address)
except Exception as err:
raise BrokerBindingError(f"failed to bind to {socket_type}: {err}") from err
error_message = f"failed to bind to {socket_type}: {err}"
raise BrokerBindingError(error_message) from err
if verbose:
print(f"successfully bound to {socket_type} socket: {address}")
def create_proxy(xpub, xsub, verbose=False):
def create_proxy(xpub: Socket, xsub: Socket, verbose=False) -> None:
"""Create a ZMQ proxy to connect XPUB and XSUB sockets."""
if verbose:
print("creating proxy")
try:
zmq.proxy(xpub, xsub)
except Exception as err:
raise BrokerBindingError(f"failed to create proxy: {err}") from err
error_message = f"failed to create proxy: {err}"
raise BrokerBindingError(error_message) from err
def msb_broker(config: BrokerConf) -> None:
def zmq_broker(config: BrokerConf) -> None:
"""Start a zmq broker.
Binds to a publisher and subscriber port, allowing many to many connections."""
@@ -60,4 +63,4 @@ def main() -> None:
"""Start a zmq broker, with a user specified configuration."""
signal.signal(signal.SIGINT, signal_handler)
broker_config = load_config(BrokerConf(), "zmq")
msb_broker(broker_config)
zmq_broker(broker_config)

View File

@@ -1,4 +1,4 @@
from heisskleber.config.MSBConfig import BaseConf
from heisskleber.config.parse import load_config
from .config import BaseConf
from .parse import load_config
__all__ = ["load_config", "BaseConf"]

View File

@@ -2,18 +2,7 @@ import argparse
class KeyValue(argparse.Action):
# Constructor calling
"""
def __call__( self , parser, namespace, values : list, option_string = None):
setattr(namespace, self.dest, dict())
for value in values:
# split it into key and value
key, value = value.split('=')
# assign into dictionary
getattr(namespace, self.dest)[key] = value
"""
def __call__(self, parser, args, values, option_string=None):
def __call__(self, parser, args, values, option_string=None) -> None:
try:
params = dict(x.split("=") for x in values)
except ValueError as ex:

View File

@@ -1,6 +1,7 @@
import socket
import warnings
from dataclasses import dataclass
from typing import Any
@dataclass
@@ -12,18 +13,18 @@ class BaseConf:
verbose: bool = False
print_stdout: bool = False
def __setitem__(self, key, value):
def __setitem__(self, key: str, value: Any) -> None:
if hasattr(self, key):
self.__setattr__(key, value)
else:
warnings.warn(UserWarning(f"no such class member: {key}"), stacklevel=2)
def __getitem__(self, key):
def __getitem__(self, key: str) -> Any:
if hasattr(self, key):
return getattr(self, key)
else:
warnings.warn(UserWarning(f"no such class member: {key}"), stacklevel=2)
@property
def serial_number(self):
def serial_number(self) -> str:
return socket.gethostname().upper()

View File

@@ -8,7 +8,9 @@ import yaml
from heisskleber.config import BaseConf
from heisskleber.config.cmdline import get_cmdline
ConfigType = TypeVar("ConfigType", bound=BaseConf)
ConfigType = TypeVar(
"ConfigType", bound=BaseConf
) # https://stackoverflow.com/a/46227137 , https://docs.python.org/3/library/typing.html#typing.TypeVar
def get_msb_config_filepath(config_filename: str = "heisskleber.conf") -> str:
@@ -17,10 +19,10 @@ def get_msb_config_filepath(config_filename: str = "heisskleber.conf") -> str:
config_filepath = os.path.join(os.environ["MSB_CONFIG_DIR"], config_subpath)
except Exception as e:
print(f"could no get MSB_CONFIG from PATH: {e}")
sys.exit() # TODO use 1 or the error str as exit value
sys.exit(1)
if not os.path.isfile(config_filepath):
print(f"not a file: {config_filepath}!")
sys.exit()
sys.exit(1)
return config_filepath
@@ -33,16 +35,12 @@ def update_config(config: ConfigType, config_dict: dict) -> ConfigType:
for config_key, config_value in config_dict.items():
# get expected type of element from config_object:
if not hasattr(config, config_key):
warnings.warn(
f"no such configuration parameter: {config_key}, skipping", stacklevel=2
)
error_msg = f"no such configuration parameter: {config_key}, skipping"
warnings.warn(error_msg, stacklevel=2)
continue
cast_func = type(config[config_key])
try:
if config_key == "topic":
config[config_key] = config_value.encode("utf-8")
else:
config[config_key] = cast_func(config_value)
config[config_key] = cast_func(config_value)
except Exception as e:
print(
f"failed to cast {config_value} to {type(config[config_key])}: {e}. skipping"
@@ -51,11 +49,6 @@ def update_config(config: ConfigType, config_dict: dict) -> ConfigType:
return config
ConfigType = TypeVar(
"ConfigType", bound=BaseConf
) # https://stackoverflow.com/a/46227137 , https://docs.python.org/3/library/typing.html#typing.TypeVar
def load_config(
config: ConfigType, config_filename: str, read_commandline: bool = True
) -> ConfigType:

View File

@@ -1,40 +0,0 @@
import sys
import zmq
def open_zmq_sub_socket(connect_to: str, topic=b""):
ctx = zmq.Context()
zmq_socket = ctx.socket(zmq.SUB)
try:
zmq_socket.connect(connect_to)
except Exception as e:
print(f"failed to bind to zeromq socket: {e}")
sys.exit(-1)
zmq_socket.setsockopt(zmq.SUBSCRIBE, topic)
return zmq_socket
def open_zmq_pub_socket(connect_to: str):
ctx = zmq.Context()
zmq_socket = ctx.socket(zmq.PUB)
try:
zmq_socket.connect(connect_to)
except Exception as e:
print(f"failed to bind to zeromq socket: {e}")
sys.exit(-1)
return zmq_socket
def get_zmq_xpub_socketstring(msb_config: dict) -> str:
zmq_config = msb_config["zeromq"]
return (
f"{zmq_config['protocol']}://{zmq_config['address']}:{zmq_config['xpub_port']}"
)
def get_zmq_xsub_socketstring(msb_config: dict) -> str:
zmq_config = msb_config["zeromq"]
return (
f"{zmq_config['protocol']}://{zmq_config['address']}:{zmq_config['xsub_port']}"
)

View File

@@ -1,24 +1,25 @@
import os
from heisskleber.config import load_config
from heisskleber.network.mqtt import MqttConf, MqttPublisher, MqttSubscriber
from heisskleber.network.serial import SerialConf, SerialPublisher, SerialSubscriber
from heisskleber.network.zmq import ZmqConf, ZmqPublisher, ZmqSubscriber
from heisskleber.config import BaseConf, load_config
from heisskleber.core.types import Publisher, Subscriber
from heisskleber.mqtt import MqttConf, MqttPublisher, MqttSubscriber
from heisskleber.serial import SerialConf, SerialPublisher, SerialSubscriber
from heisskleber.zmq import ZmqConf, ZmqPublisher, ZmqSubscriber
_registered_publishers = {
_registered_publishers: dict[str, tuple[type[Publisher], type[BaseConf]]] = {
"zmq": (ZmqPublisher, ZmqConf),
"mqtt": (MqttPublisher, MqttConf),
"serial": (SerialPublisher, SerialConf),
}
_registered_subscribers = {
_registered_subscribers: dict[str, tuple[type[Subscriber], type[BaseConf]]] = {
"zmq": (ZmqSubscriber, ZmqConf),
"mqtt": (MqttSubscriber, MqttConf),
"serial": (SerialSubscriber, SerialConf),
}
def get_publisher(name: str):
def get_publisher(name: str) -> Publisher:
if name not in _registered_publishers:
error_message = f"{name} is not a registered Publisher."
raise KeyError(error_message)
@@ -35,7 +36,7 @@ def get_publisher(name: str):
return pub_cls(config)
def get_subscriber(name: str, topic):
def get_subscriber(name: str, topic: str | list[str]) -> Subscriber:
if name not in _registered_publishers:
error_message = f"{name} is not a registered Subscriber."
raise KeyError(error_message)

View File

@@ -1,10 +1,10 @@
"""Packer and unpacker for network data."""
import json
import pickle
from typing import Callable
from typing import Any, Callable
def get_packer(style) -> Callable[[dict], str]:
def get_packer(style: str) -> Callable[[dict[str, Any]], Any]:
"""Return a packer function for the given style.
Packer func serializes a given dict."""
@@ -14,7 +14,7 @@ def get_packer(style) -> Callable[[dict], str]:
return _packstyles["default"]
def get_unpacker(style) -> Callable[[str], dict]:
def get_unpacker(style: str) -> Callable[[str | bytes], dict[str, Any] | str | bytes]:
"""Return an unpacker function for the given style.
Unpacker func deserializes a string."""
@@ -24,21 +24,21 @@ def get_unpacker(style) -> Callable[[str], dict]:
return _unpackstyles["default"]
def serialpacker(data: dict) -> str:
def serialpacker(data: dict[str, Any]) -> str:
return ",".join([str(v) for v in data.values()])
_packstyles = {
_packstyles: dict[str, Callable[[dict[str, Any]], str | bytes]] = {
"default": json.dumps,
"json": json.dumps,
"pickle": pickle.dumps,
"serial": serialpacker,
"raw": lambda x: x,
"raw": lambda x: x, # type: ignore
}
_unpackstyles = {
_unpackstyles: dict[str, Callable[[str | bytes], dict[str, Any] | Any]] = {
"default": json.loads,
"json": json.loads,
"pickle": pickle.loads,
"pickle": pickle.loads, # type: ignore
"raw": lambda x: x,
}

52
heisskleber/core/types.py Normal file
View File

@@ -0,0 +1,52 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any, Callable, Union
PayloadType = Union[str, int, float]
class Publisher(ABC):
"""
Publisher interface.
"""
pack: Callable[[dict[str, Any] | Any], str]
@abstractmethod
def __init__(self, config: Any) -> None:
"""
Initialize the publisher with a configuration object.
"""
pass
@abstractmethod
def send(self, topic: str, data: dict[str, Any]) -> None:
"""
Send data via the implemented output stream.
"""
pass
class Subscriber(ABC):
"""
Subscriber interface
"""
unpack: Callable[[bytes], dict[str, PayloadType] | Any]
@abstractmethod
def __init__(self, topic: str | list[str], config: Any) -> None:
"""
Initialize the subscriber with a topic and a configuration object.
"""
pass
@abstractmethod
def receive(self) -> tuple[str, dict[str, PayloadType]]:
"""
Blocking function to receive data from the implemented input stream.
Data is returned as a tuple of (topic, data).
"""
pass

View File

@@ -1,6 +1,8 @@
import pandas as pd
from influxdb_client import InfluxDBClient
from heisskleber.core.types import Subscriber
from .config import InfluxDBConf
@@ -28,7 +30,7 @@ def build_query(options: dict) -> str:
return query
class Influx_Subscriber:
class Influx_Subscriber(Subscriber):
def __init__(self, config: InfluxDBConf, query: str):
self.config = config
self.query = query
@@ -45,7 +47,7 @@ class Influx_Subscriber:
self._run_query()
self.index = 0
def receive(self) -> dict:
def receive(self) -> tuple[str, dict]:
row = self.df.iloc[self.index].to_dict()
self.index += 1
return "influx", row

View File

@@ -0,0 +1,5 @@
from .config import MqttConf
from .publisher import MqttPublisher
from .subscriber import MqttSubscriber
__all__ = ["MqttConf", "MqttPublisher", "MqttSubscriber"]

View File

@@ -16,7 +16,7 @@ class MqttConf(BaseConf):
ssl: bool = False
qos: int = 0
retain: bool = False
topics: list[bytes] = field(default_factory=list)
topics: list[str] = field(default_factory=list)
mapping: str = "/msb/"
packstyle: str = "json"
max_saved_messages: int = 100

View File

@@ -1,14 +1,14 @@
from heisskleber import get_publisher, get_subscriber
from heisskleber.config import load_config
from heisskleber.network import get_publisher, get_subscriber
from .config import MqttConf
def map_topic(zmq_topic, mapping):
return mapping + zmq_topic.decode()
def map_topic(zmq_topic: str, mapping: str) -> str:
return mapping + zmq_topic
def main():
def main() -> None:
config: MqttConf = load_config(MqttConf(), "mqtt")
sub = get_subscriber("zmq", config.topics)
pub = get_publisher("mqtt")

View File

@@ -25,19 +25,19 @@ def _set_thread_died_excepthook(args, /):
threading.excepthook = _set_thread_died_excepthook
class MQTT_Base:
class MqttBase:
"""
Wrapper around eclipse paho mqtt client.
Handles connection and callbacks.
Callbacks may be overwritten in subclasses.
"""
def __init__(self, config: MqttConf):
def __init__(self, config: MqttConf) -> None:
self.config = config
self.connect()
self.client.loop_start()
def connect(self):
def connect(self) -> None:
self.client = mqtt_client()
self.client.username_pw_set(self.config.user, self.config.password)
@@ -55,13 +55,13 @@ class MQTT_Base:
self.client.connect(self.config.broker, self.config.port)
@staticmethod
def _raise_if_thread_died():
def _raise_if_thread_died() -> None:
global _thread_died
if _thread_died.is_set():
raise ThreadDiedError()
# MQTT callbacks
def _on_connect(self, client, userdata, flags, return_code):
def _on_connect(self, client, userdata, flags, return_code) -> None:
if return_code == 0:
print(f"MQTT node connected to {self.config.broker}:{self.config.port}")
else:
@@ -69,21 +69,21 @@ class MQTT_Base:
if self.config.verbose:
print(flags)
def _on_disconnect(self, client, userdata, return_code):
def _on_disconnect(self, client, userdata, return_code) -> None:
print(f"Disconnected from broker with return code {return_code}")
if return_code != 0:
print("Killing this service")
sys.exit(-1)
def _on_publish(self, client, userdata, message_id):
def _on_publish(self, client, userdata, message_id) -> None:
if self.config.verbose:
print(f"Published message with id {message_id}, qos={self.config.qos}")
def _on_message(self, client, userdata, message):
def _on_message(self, client, userdata, message) -> None:
if self.config.verbose:
print(
f"Received message: {message.payload!s}, topic: {message.topic}, qos: {message.qos}"
)
def __del__(self):
def __del__(self) -> None:
self.client.loop_stop()

View File

@@ -0,0 +1,36 @@
from __future__ import annotations
from typing import Any
from heisskleber.core.packer import get_packer
from heisskleber.core.types import Publisher
from .config import MqttConf
from .mqtt_base import MqttBase
class MqttPublisher(MqttBase, Publisher):
"""
MQTT publisher class.
Can be used everywhere that a flucto style publishing connection is required.
Network message loop is handled in a separated thread.
"""
def __init__(self, config: MqttConf) -> None:
super().__init__(config)
self.pack = get_packer(config.packstyle)
def send(self, topic: str, data: dict[str, Any]) -> None:
"""
Takes python dictionary, serializes it according to the packstyle
and sends it to the broker.
Publishing is asynchronous
"""
self._raise_if_thread_died()
payload = self.pack(data)
self.client.publish(
topic, payload, qos=self.config.qos, retain=self.config.retain
)

View File

@@ -0,0 +1,66 @@
from __future__ import annotations
from queue import SimpleQueue
from typing import Any
from paho.mqtt.client import MQTTMessage
from heisskleber.core.packer import get_unpacker
from heisskleber.core.types import Subscriber
from .config import MqttConf
from .mqtt_base import MqttBase
class MqttSubscriber(MqttBase, Subscriber):
"""
MQTT subscriber, wraps around ecplipse's paho mqtt client.
Network message loop is handled in a separated thread.
Incoming messages are saved as a stack when not processed via the receive() function.
"""
def __init__(self, topics: str | list[str], config: MqttConf):
super().__init__(config)
self._message_queue: SimpleQueue[MQTTMessage] = SimpleQueue()
self.subscribe(topics)
self.client.on_message = self._on_message
self.unpack = get_unpacker(config.packstyle)
def subscribe(self, topics: str | list[str] | tuple[str]) -> None:
"""
Subscribe to one or multiple topics
"""
if isinstance(topics, (list, tuple)):
# if subscribing to multiple topics, use a list of tuples
subscription_list = [(topic, self.config.qos) for topic in topics]
self.client.subscribe(subscription_list)
else:
self.client.subscribe(topics, self.config.qos)
if self.config.verbose:
print(f"Subscribed to: {topics}")
def receive(self) -> tuple[str, dict[str, Any]]:
"""
Reads a message from mqtt and returns it
Messages are saved in a stack, if no message is available, this function blocks.
Returns:
tuple(topic: bytes, message: dict): the message received
"""
self._raise_if_thread_died()
mqtt_message = self._message_queue.get(
block=True, timeout=self.config.timeout_s
)
message_returned = self.unpack(mqtt_message.payload.decode())
return (mqtt_message.topic, message_returned)
# callback to add incoming messages onto stack
def _on_message(self, client, userdata, message) -> None:
self._message_queue.put(message)
if self.config.verbose:
print(f"Topic: {message.topic}")
print(f"MQTT message: {message.payload.decode()}")

View File

@@ -1 +0,0 @@
from .pubsub.factories import get_publisher, get_subscriber # noqa: F401

View File

@@ -1,3 +0,0 @@
from .config import MqttConf # noqa: F401
from .publisher import MqttPublisher # noqa: F401
from .subscriber import MqttSubscriber # noqa: F401

View File

@@ -1,18 +0,0 @@
from heisskleber.config import load_config
from heisskleber.network import get_publisher, get_subscriber
from .config import MqttConf
from .forwarder import ZMQ_to_MQTT_Forwarder
def main():
config = load_config(MqttConf(), "mqtt")
for topic in config.topics:
print(f"Subscribing to {topic}")
zmq_sub = get_subscriber("zmq", list(config.topics))
mqtt_pub = get_publisher("mqtt")
forwarder = ZMQ_to_MQTT_Forwarder(config, subscriber=zmq_sub, publisher=mqtt_pub)
# Wait for zmq messages, publish as mqtt message
forwarder.zmq_to_mqtt_loop()

View File

@@ -1,63 +0,0 @@
from __future__ import annotations
from heisskleber.config import load_config
from heisskleber.network.packer import get_packer
from heisskleber.network.pubsub.types import Publisher
from .config import MqttConf
from .mqtt_base import MQTT_Base
class MqttPublisher(MQTT_Base, Publisher):
"""
MQTT publisher class.
Can be used everywhere that a flucto style publishing connection is required.
Network message loop is handled in a separated thread.
"""
def __init__(self, config: MqttConf):
super().__init__(config)
self.pack = get_packer(config.packstyle)
def send(self, topic: str | bytes, data: dict):
"""
Takes python dictionary, serializes it according to the packstyle
and sends it to the broker.
Publishing is asynchronous
"""
self._raise_if_thread_died()
if isinstance(topic, bytes):
topic = topic.decode()
payload = self.pack(data)
self.client.publish(
topic, payload, qos=self.config.qos, retain=self.config.retain
)
def get_mqtt_publisher() -> MqttPublisher:
"""
Generate mqtt publisher with configuration from yaml file,
falls back to default values if no config is found
"""
import os
if "MSB_CONFIG_DIR" in os.environ:
print("loading mqtt config")
config = load_config(MqttConf(), "mqtt", read_commandline=False)
else:
print("using default mqtt config")
config = MqttConf()
return MqttPublisher(config)
def get_default_publisher() -> MqttPublisher:
"""
Generate mqtt publisher with configuration from yaml file,
falls back to default values if no config is found
Deprecated, use get_mqtt_publisher() instead
"""
return get_mqtt_publisher()

View File

@@ -1,104 +0,0 @@
from __future__ import annotations
from queue import SimpleQueue
from heisskleber.config import load_config
from heisskleber.network.packer import get_unpacker
from heisskleber.network.pubsub.types import Subscriber
from .config import MqttConf
from .mqtt_base import MQTT_Base
class MqttSubscriber(MQTT_Base, Subscriber):
"""
MQTT subscriber, wraps around ecplipse's paho mqtt client.
Network message loop is handled in a separated thread.
Incoming messages are saved as a stack when not processed via the receive() function.
"""
def __init__(self, topics, config: MqttConf):
super().__init__(config)
self._message_queue = SimpleQueue()
self.subscribe(topics)
self.client.on_message = self._on_message
self.unpack = get_unpacker(config.packstyle)
def _subscribe_single_topic(self, topic: bytes | str):
if isinstance(topic, bytes):
topic = topic.decode()
if self.config.verbose:
print(f"Subscribed to: {topic}")
self.client.subscribe(topic, self.config.qos)
def _subscribe_multiple_topics(self, topics: list[bytes] | list[str]):
topics = [
topic.decode() if isinstance(topic, bytes) else topic for topic in topics
]
subscription_list = [(topic, self.config.qos) for topic in topics]
if self.config.verbose:
print(f"Subscribed to: {topics}")
self.client.subscribe(subscription_list)
def subscribe(self, topics):
"""
Subscribe to one or multiple topics
"""
# if subscribing to multiple topics, use a list of tuples
if isinstance(topics, (list, tuple)):
self._subscribe_multiple_topics(topics)
else:
self.client.subscribe(topics, self.config.qos)
def receive(self) -> tuple[bytes, dict]:
"""
Reads a message from mqtt and returns it
Messages are saved in a stack, if no message is available, this function blocks.
Returns:
tuple(topic: bytes, message: dict): the message received
"""
self._raise_if_thread_died()
mqtt_message = self._message_queue.get(
block=True, timeout=self.config.timeout_s
)
topic = mqtt_message.topic.encode("utf-8")
message_returned = self.unpack(mqtt_message.payload.decode())
return (topic, message_returned)
# callback to add incoming messages onto stack
def _on_message(self, client, userdata, message):
self._message_queue.put(message)
if self.config.verbose:
print(f"Topic: {message.topic}")
print(f"MQTT message: {message.payload.decode()}")
def get_mqtt_subscriber(topic: bytes | str) -> MqttSubscriber:
"""
Generate mqtt subscriber with configuration from yaml file,
falls back to default values if no config is found
"""
import os
if "MSB_CONFIG_DIR" in os.environ:
print("loading mqtt config")
config = load_config(MqttConf(), "mqtt", read_commandline=False)
else:
print("using default mqtt config")
config = MqttConf()
return MqttSubscriber(topic, config)
def get_default_subscriber(topic: bytes | str) -> MqttSubscriber:
"""
Generate mqtt subscriber with configuration from yaml file,
falls back to default values if no config is found
Deprecated, use get_mqtt_subscriber(topic) instead.
"""
return get_mqtt_subscriber(topic)

View File

@@ -1,31 +0,0 @@
from __future__ import annotations
from abc import ABC, abstractmethod
class Publisher(ABC):
"""
Publisher interface.
"""
@abstractmethod
def send(self, topic: str | bytes, data: dict):
"""
Send data via the implemented output stream.
"""
pass
class Subscriber(ABC):
"""
Subscriber interface
"""
@abstractmethod
def receive(self) -> tuple[bytes, dict]:
"""
Blocking function to receive data from the implemented input stream.
Data is returned as a tuple of (topic, data).
"""
pass

View File

@@ -1,3 +0,0 @@
from .config import SerialConf # noqa: F401
from .publisher import SerialPublisher # noqa: F401
from .subscriber import SerialSubscriber # noqa: F401

View File

@@ -1 +0,0 @@
from .pubsub.types import Publisher, Subscriber # noqa: F401

View File

@@ -1,3 +0,0 @@
from .config import ZmqConf # noqa: F401
from .publisher import ZmqPublisher # noqa: F401
from .subscriber import ZmqSubscriber # noqa: F401

View File

@@ -0,0 +1,5 @@
from .config import SerialConf
from .publisher import SerialPublisher
from .subscriber import SerialSubscriber
__all__ = ["SerialConf", "SerialPublisher", "SerialSubscriber"]

View File

@@ -1,10 +1,10 @@
from heisskleber.network.types import Subscriber
from heisskleber.core.types import Subscriber
from .publisher import SerialPublisher
class SerialForwarder:
def __init__(self, subscriber: Subscriber, publisher: SerialPublisher):
def __init__(self, subscriber: Subscriber, publisher: SerialPublisher) -> None:
self.sub = subscriber
self.pub = publisher
@@ -12,19 +12,19 @@ class SerialForwarder:
Wait for message and forward
"""
def forward_message(self):
def forward_message(self) -> None:
# collected = {}
# for sub in self.sub:
# topic, data = sub.receive()
# collected.update(data)
_, collected = self.sub.receive()
self.pub.send(collected)
self.pub.send("", collected)
"""
Enter loop and continuously forward messages
"""
def sub_pub_loop(self):
def sub_pub_loop(self) -> None:
while True:
self.forward_message()

View File

@@ -1,11 +1,11 @@
from __future__ import annotations
from types import FunctionType
from typing import Callable, Optional
import serial
from heisskleber.network.packer import get_packer
from heisskleber.network.pubsub.types import Publisher
from heisskleber.core.packer import get_packer
from heisskleber.core.types import Publisher
from .config import SerialConf
@@ -23,12 +23,16 @@ class SerialPublisher(Publisher):
Function to translate from a dict to a serialized string.
"""
def __init__(self, config: SerialConf, pack_func: FunctionType | None = None):
def __init__(
self,
config: SerialConf,
pack_func: Optional[Callable] = None, # noqa: UP007
):
self.config = config
self.packer = pack_func if pack_func else get_packer("serial")
self.pack = pack_func if pack_func else get_packer("serial")
self._connect()
def _connect(self):
def _connect(self) -> None:
self.serial: serial.Serial = serial.Serial(
port=self.config.port,
baudrate=self.config.baudrate,
@@ -38,25 +42,23 @@ class SerialPublisher(Publisher):
)
print(f"Successfully connected to serial device at port {self.config.port}")
def send(self, message: object):
def send(self, topic, message: dict) -> None:
"""
Takes python dictionary, serializes it according to the packstyle
and sends it to the broker.
Please note that this does not adhere to the interface, as there is no topic.
Parameters
----------
message : object
message : dict
object to be serialized and sent via the serial connection. Usually a dict.
"""
payload = self.packer(message)
payload = self.pack(message)
self.serial.write(payload.encode(self.config.encoding))
self.serial.flush()
if self.config.verbose:
print(payload)
print(f"{topic}: {payload}")
def __del__(self):
def __del__(self) -> None:
if not hasattr(self, "serial"):
return
if not self.serial.is_open:

View File

@@ -1,10 +1,11 @@
from __future__ import annotations
from collections.abc import Generator
from typing import Callable, Optional
import serial
from heisskleber.network.pubsub.types import Subscriber
from heisskleber.core.types import Subscriber
from .config import SerialConf
@@ -45,7 +46,7 @@ class SerialSubscriber(Subscriber):
)
print(f"Successfully connected to serial device at port {self.config.port}")
def receive(self) -> dict:
def receive(self) -> tuple[str, dict]:
"""
Wait for data to arrive on the serial port and return it.
@@ -62,29 +63,30 @@ class SerialSubscriber(Subscriber):
# port is a placeholder for topic
return self.config.port, payload
def read_serial_port(self) -> str:
def read_serial_port(self) -> Generator[str, None, None]:
"""
Generator function reading from the serial port.
Returns
-------
:return: Generator[str, None, None]
Generator yielding strings read from the serial port
"""
buffer = ""
while True:
try:
buffer = self.serial.readline().decode()
buffer = self.serial.readline().decode(self.config.encoding, "ignore")
yield buffer
except UnicodeError as e:
if self.config.verbose:
print(f"Could not decode: {message}")
print(f"Could not decode: {buffer!r}")
print(e)
continue
def __del__(self):
def __del__(self) -> None:
if not hasattr(self, "serial"):
return
if not self.serial.is_open:
return
self.serial.flush()
self.serial.close()
if __name__ == "__main__":
config = SerialConf()
serial_reader = SerialSubscriber(config)
for message in serial_reader.receive():
print(message)

View File

@@ -1,8 +1,8 @@
import socket
from heisskleber.network.packer import get_packer
from heisskleber.network.pubsub.types import Publisher
from heisskleber.network.udp.config import UDPConf
from heisskleber.core.packer import get_packer
from heisskleber.core.types import Publisher
from heisskleber.udp.config import UDPConf
class UDP_Publisher(Publisher):

View File

@@ -1,8 +1,8 @@
import socket
from heisskleber.network.packer import get_unpacker
from heisskleber.network.pubsub.types import Subscriber
from heisskleber.network.udp.config import UDPConf
from heisskleber.core.packer import get_unpacker
from heisskleber.core.types import Subscriber
from heisskleber.udp.config import UDPConf
class UDP_Subscriber(Subscriber):

View File

@@ -0,0 +1,5 @@
from .config import ZmqConf
from .publisher import ZmqPublisher
from .subscriber import ZmqSubscriber
__all__ = ["ZmqConf", "ZmqPublisher", "ZmqSubscriber"]

View File

@@ -1,9 +1,10 @@
import sys
from typing import Any
import zmq
from heisskleber.network.packer import get_packer
from heisskleber.network.pubsub.types import Publisher
from heisskleber.core.packer import get_packer
from heisskleber.core.types import Publisher
from .config import ZmqConf
@@ -25,9 +26,9 @@ class ZmqPublisher(Publisher):
print(f"failed to bind to zeromq socket: {e}")
sys.exit(-1)
def send(self, topic: str, data: dict) -> None:
data = self.pack(data)
self.socket.send_multipart([topic, data.encode("utf-8")])
def send(self, topic: str, data: dict[str, Any]) -> None:
payload = self.pack(data)
self.socket.send_multipart([topic, payload.encode("utf-8")])
def __del__(self):
self.socket.close()

View File

@@ -4,8 +4,8 @@ import sys
import zmq
from heisskleber.network.packer import get_unpacker
from heisskleber.network.pubsub.types import Subscriber
from heisskleber.core.packer import get_unpacker
from heisskleber.core.types import Subscriber
from .config import ZmqConf

View File

@@ -85,9 +85,10 @@ select = [
"TRY", # tryceratops
]
ignore = [
"E501", # LineTooLong
"E731", # DoNotAssignLambda
"A001", #
"E501", # LineTooLong
"E731", # DoNotAssignLambda
"A001", #
"PGH003", # Use specific rules when ignoring type issues
]
[tool.ruff.per-file-ignores]

View File

@@ -1,31 +1,26 @@
def test_import_mqtt():
from heisskleber.network.mqtt import (
MqttConf,
MqttPublisher,
MqttSubscriber,
)
import heisskleber
from heisskleber.mqtt import MqttConf, MqttPublisher, MqttSubscriber
assert heisskleber.__all__ == [
"get_publisher",
"get_subscriber",
"Publisher",
"Subscriber",
]
def test_import_zmq():
from heisskleber.network.zmq import (
ZmqConf,
ZmqPublisher,
ZmqSubscriber,
)
from heisskleber.zmq import ZmqConf, ZmqPublisher, ZmqSubscriber
def test_import_serial():
from heisskleber.network.serial import (
SerialConf,
SerialPublisher,
SerialSubscriber,
)
from heisskleber.serial import SerialConf, SerialPublisher, SerialSubscriber
def test_import_utils():
from heisskleber.network import get_publisher, get_subscriber
from heisskleber.network.types import Publisher, Subscriber
from heisskleber import Publisher, Subscriber, get_publisher, get_subscriber
def test_import_config():
pass
from heisskleber.config import BaseConf, load_config

162
tests/test_mqtt.py Normal file
View File

@@ -0,0 +1,162 @@
import json
from queue import SimpleQueue
from unittest.mock import call, patch
import pytest
from paho.mqtt.client import MQTTMessage
from heisskleber.mqtt.config import MqttConf
from heisskleber.mqtt.mqtt_base import MqttBase
from heisskleber.mqtt.subscriber import MqttSubscriber
# Mock configuration for MQTT_Base
@pytest.fixture
def mock_mqtt_conf() -> MqttConf:
return MqttConf(
broker="localhost",
port=1883,
user="user",
password="passwd", # noqa: S106
ssl=False,
verbose=False,
qos=1,
)
# Mock the paho mqtt client
@pytest.fixture
def mock_mqtt_client():
with patch("heisskleber.mqtt.mqtt_base.mqtt_client", autospec=True) as mock:
yield mock
@pytest.fixture
def mock_queue():
with patch("heisskleber.mqtt.subscriber.SimpleQueue", spec=SimpleQueue) as mock:
yield mock
def test_mqtt_base_intialization(mock_mqtt_client, mock_mqtt_conf):
"""Test that the intialization of the mqtt client is as expected."""
base = MqttBase(config=mock_mqtt_conf)
mock_mqtt_client.assert_called_once()
mock_mqtt_client.return_value.loop_start.assert_called_once()
mock_client_instance = mock_mqtt_client.return_value
mock_client_instance.username_pw_set.assert_called_with(
mock_mqtt_conf.user, mock_mqtt_conf.password
)
mock_client_instance.connect.assert_called_with(
mock_mqtt_conf.broker, mock_mqtt_conf.port
)
assert base.client.on_connect == base._on_connect
assert base.client.on_disconnect == base._on_disconnect
assert base.client.on_publish == base._on_publish
assert base.client.on_message == base._on_message
def test_mqtt_base_on_connect(mock_mqtt_client, mock_mqtt_conf, capsys):
base = MqttBase(config=mock_mqtt_conf)
base._on_connect(None, None, {}, 0)
captured = capsys.readouterr()
assert (
f"MQTT node connected to {mock_mqtt_conf.broker}:{mock_mqtt_conf.port}"
in captured.out
)
def test_mqtt_base_on_disconnect_with_error(mock_mqtt_client, mock_mqtt_conf, capsys):
"""Assert that the mqtt client shuts down when disconnect callback is received."""
base = MqttBase(config=mock_mqtt_conf)
with pytest.raises(SystemExit):
base._on_disconnect(None, None, 1)
captured = capsys.readouterr()
assert "Killing this service" in captured.out
print(captured.out)
def test_mqtt_subscribes_single_topic(mock_mqtt_client, mock_mqtt_conf):
"""Test that the mqtt client subscribes to a single topic."""
_ = MqttSubscriber(topics="singleTopic", config=mock_mqtt_conf)
actual_calls = mock_mqtt_client.return_value.subscribe.call_args_list
assert actual_calls == [call("singleTopic", mock_mqtt_conf.qos)]
def test_mqtt_subscribes_multiple_topics(mock_mqtt_client, mock_mqtt_conf):
"""Test that the mqtt client subscribes to multiple topics passed as list.
I would love to do this via parametrization, but the call argument is built differently for single size lists and longer lists.
"""
_ = MqttSubscriber(topics=["multiple1", "multiple2"], config=mock_mqtt_conf)
actual_calls = mock_mqtt_client.return_value.subscribe.call_args_list
assert actual_calls == [
call([("multiple1", mock_mqtt_conf.qos), ("multiple2", mock_mqtt_conf.qos)]),
]
def test_mqtt_subscribes_multiple_topics_tuple(mock_mqtt_client, mock_mqtt_conf):
"""Test that the mqtt client subscribes to multiple topics passed as tuple."""
_ = MqttSubscriber(topics=("multiple1", "multiple2"), config=mock_mqtt_conf)
actual_calls = mock_mqtt_client.return_value.subscribe.call_args_list
assert actual_calls == [
call([("multiple1", mock_mqtt_conf.qos), ("multiple2", mock_mqtt_conf.qos)]),
]
def create_fake_mqtt_message(topic: bytes, payload: bytes) -> MQTTMessage:
msg = MQTTMessage()
msg.topic = topic
msg.payload = payload
return msg
def test_receive_with_message(mock_mqtt_conf: MqttConf, mock_mqtt_client, mock_queue):
"""Test the mqtt receive function with fake MQTT messages."""
topic = b"test/topic"
payload = json.dumps({"key": "value"}).encode()
fake_message = create_fake_mqtt_message(topic, payload)
mock_queue.return_value.get.side_effect = [fake_message]
subscriber = MqttSubscriber(topics=[topic], config=mock_mqtt_conf)
received_topic, received_payload = subscriber.receive()
assert received_topic == "test/topic"
assert received_payload == {"key": "value"}
def test_message_is_put_into_queue(
mock_mqtt_conf: MqttConf, mock_mqtt_client, mock_queue
):
"""Test that values a put into a queue when on_message callback is called."""
topic = b"test/topic"
payload = json.dumps({"key": "value"}).encode()
fake_message = create_fake_mqtt_message(topic, payload)
mock_queue.return_value.get.side_effect = [fake_message]
subscriber = MqttSubscriber(topics=[topic], config=mock_mqtt_conf)
subscriber._on_message(None, None, fake_message)
mock_queue.return_value.put.assert_called_once_with(fake_message)
def test_message_is_put_into_queue_with_actual_queue(mock_mqtt_conf, mock_mqtt_client):
"""Test that the buffering via queue works as expected."""
topic = b"test/topic"
payload = json.dumps({"key": "value"}).encode()
fake_message = create_fake_mqtt_message(topic, payload)
# mock_queue.return_value.get.side_effect = [fake_message]
subscriber = MqttSubscriber(topics=[topic], config=mock_mqtt_conf)
subscriber._on_message(None, None, fake_message)
topic, return_dict = subscriber.receive()
assert topic == "test/topic"
assert return_dict == {"key": "value"}

35
tests/test_packer.py Normal file
View File

@@ -0,0 +1,35 @@
import json
import pickle
from typing import Any
import pytest
from heisskleber.core.packer import get_packer, get_unpacker, serialpacker
def test_get_packer() -> None:
assert get_packer("json") == json.dumps
assert get_packer("pickle") == pickle.dumps
assert get_packer("default") == json.dumps
assert get_packer("foobar") == json.dumps
assert get_packer("serial") == serialpacker
def test_get_unpacker() -> None:
assert get_unpacker("json") == json.loads
assert get_unpacker("pickle") == pickle.loads
assert get_unpacker("default") == json.loads
assert get_unpacker("foobar") == json.loads
@pytest.mark.parametrize(
"message,expected",
[
({"hi": 1, "da": 2, "nei": 3}, "1,2,3"),
({"er": 1, "ma": "ga", "gerd": 3, "jo": 4}, "1,ga,3,4"),
({"": 1, "ho": 0.0, "lee": 0.1, "shit": 1_000}, "1,0.0,0.1,1000"),
({"be": 1e6, "li": 1_000}, "1000000.0,1000"),
],
)
def test_serial_packer_functionality(message: dict[str, Any], expected: str) -> None:
assert serialpacker(message) == expected

115
tests/test_serial.py Normal file
View File

@@ -0,0 +1,115 @@
from unittest.mock import Mock, patch
import pytest
import serial
from heisskleber.core.packer import serialpacker
from heisskleber.serial.config import SerialConf
from heisskleber.serial.publisher import SerialPublisher
from heisskleber.serial.subscriber import SerialSubscriber
@pytest.fixture
def serial_conf():
return SerialConf(port="/dev/test", baudrate=9600, bytesize=8, verbose=False)
@pytest.fixture
def mock_serial_device_subscriber():
with patch("heisskleber.serial.subscriber.serial.Serial") as mock:
yield mock
@pytest.fixture
def mock_serial_device_publisher():
with patch("heisskleber.serial.publisher.serial.Serial") as mock:
yield mock
def test_serial_subscriber_initialization(mock_serial_device_subscriber, serial_conf):
"""Test that the SerialSubscriber class initializes correctly.
Mocks the serial.Serial class to avoid opening a serial port."""
_ = SerialSubscriber(topics="", config=serial_conf)
mock_serial_device_subscriber.assert_called_with(
port=serial_conf.port,
baudrate=serial_conf.baudrate,
bytesize=serial_conf.bytesize,
parity=serial.PARITY_NONE,
stopbits=serial.STOPBITS_ONE,
)
def test_serial_subscriber_receive(mock_serial_device_subscriber, serial_conf):
"""Test that the SerialSubscriber class calls readline and unpack as expected."""
subscriber = SerialSubscriber(topics="", config=serial_conf)
# Set up the readline return value
mock_serial_instance = mock_serial_device_subscriber.return_value
mock_serial_instance.readline.return_value = b"test message\n"
# Set up the unpack function to convert message to dict
unpack_func = Mock(return_value={"data": "test message"})
subscriber.unpack = unpack_func
# Call the receive method and assert it behaves as expected
_, payload = subscriber.receive()
# Was readline called?
mock_serial_instance.readline.assert_called_once()
# Was unpack called?
assert payload == {"data": "test message"}
unpack_func.assert_called_once_with("test message\n")
def test_serial_subscriber_converts_bytes_to_str():
"""Test that the SerialSubscriber class converts bytes to str as expected."""
with patch("heisskleber.serial.subscriber.serial.Serial") as mock_serial:
subscriber = SerialSubscriber(
topics="", config=SerialConf(), unpack_func=lambda x: x
)
# Set the readline method to raise UnicodeError
mock_serial_instance = mock_serial.return_value
mock_serial_instance.readline.side_effect = [b"test message", b"test\x86more"]
_, payload = subscriber.receive()
assert isinstance(payload, str)
assert payload == "test message"
# Assert that none-unicode is skipped
_, payload = subscriber.receive()
assert payload == "testmore"
def test_serial_publisher_initialization(mock_serial_device_publisher, serial_conf):
"""Test that the SerialPublisher class initializes correctly.
Mocks the serial.Serial class to avoid opening a serial port."""
publisher = SerialPublisher(config=serial_conf)
mock_serial_device_publisher.assert_called_with(
port=serial_conf.port,
baudrate=serial_conf.baudrate,
bytesize=serial_conf.bytesize,
parity=serial.PARITY_NONE,
stopbits=serial.STOPBITS_ONE,
)
assert publisher.serial
def test_serial_publisher_send(mock_serial_device_publisher, serial_conf):
"""Test that the SerialPublisher class calls write and pack as expected."""
publisher = SerialPublisher(config=serial_conf)
# Set up the readline return value
mock_serial_instance = mock_serial_device_publisher.return_value
mock_serial_instance.readline.return_value = b"test message\n"
# Set up the pack function to convert dict to comma separated string of values
publisher.pack = serialpacker
# Call the receive method and assert it behaves as expected
publisher.send("test", {"data": "test message", "more_data": "more message"})
# Was write called with encoded payload?
mock_serial_instance.write.assert_called_once_with(b"test message,more message")
mock_serial_instance.flush.assert_called_once()

View File

@@ -4,4 +4,4 @@ import heisskleber
def test_heisskleber_version() -> None:
"""Test that the glue version is correct."""
assert heisskleber.__version__ == "0.1.0"
assert heisskleber.__version__ == "0.2.0"