diff --git a/heisskleber/config/MSBConfig.py b/heisskleber/config/MSBConfig.py new file mode 100644 index 0000000..7d05478 --- /dev/null +++ b/heisskleber/config/MSBConfig.py @@ -0,0 +1,29 @@ +import socket +import warnings +from dataclasses import dataclass + + +@dataclass +class BaseConf: + """ + default configuration class for generic configuration info + """ + + verbose: bool = False + print_stdout: bool = False + + def __setitem__(self, key, value): + if hasattr(self, key): + self.__setattr__(key, value) + else: + warnings.warn(UserWarning(f"no such class member: {key}"), stacklevel=2) + + def __getitem__(self, key): + if hasattr(self, key): + return getattr(self, key) + else: + warnings.warn(UserWarning(f"no such class member: {key}"), stacklevel=2) + + @property + def serial_number(self): + return socket.gethostname().upper() diff --git a/heisskleber/config/__init__.py b/heisskleber/config/__init__.py new file mode 100644 index 0000000..c554645 --- /dev/null +++ b/heisskleber/config/__init__.py @@ -0,0 +1,4 @@ +from heisskleber.config.MSBConfig import BaseConf +from heisskleber.config.parse import load_config + +__all__ = ["load_config", "BaseConf"] diff --git a/heisskleber/config/cmdline.py b/heisskleber/config/cmdline.py new file mode 100644 index 0000000..5a39f58 --- /dev/null +++ b/heisskleber/config/cmdline.py @@ -0,0 +1,58 @@ +import argparse + + +class KeyValue(argparse.Action): + # Constructor calling + """ + def __call__( self , parser, namespace, values : list, option_string = None): + setattr(namespace, self.dest, dict()) + for value in values: + # split it into key and value + key, value = value.split('=') + # assign into dictionary + getattr(namespace, self.dest)[key] = value + """ + + def __call__(self, parser, args, values, option_string=None): + try: + params = dict(x.split("=") for x in values) + except ValueError as ex: + raise argparse.ArgumentError( + self, + f'Could not parse argument "{values}" as k1=v1 k2=v2 ... format: {ex}', + ) from ex + setattr(args, self.dest, params) + + +def get_cmdline(args=None) -> dict: + """ + get commandline arguments and return a dictionary of + the provided arguments. + + available commandline arguments are: + --verbose: flag to toggle debugging output + --print-stdout: flag to toggle all data printed to stdout + --param key1=value1 key2=value2: allows to pass service specific + parameters + """ + arp = argparse.ArgumentParser() + arp.add_argument("--verbose", action="store_true", help="debug output flag") + arp.add_argument( + "--print-stdout", + action="store_true", + help="toggles output of all data to stdout", + ) + arp.add_argument( + "--params", + nargs="*", + action=KeyValue, + ) + args = arp.parse_args(args) + config = {} + if args.verbose: + config["verbose"] = args.verbose + if args.print_stdout: + config["print_stdout"] = args.print_stdout + if args.params: + config |= args.params + return config diff --git a/heisskleber/config/parse.py b/heisskleber/config/parse.py new file mode 100644 index 0000000..f6cfd2b --- /dev/null +++ b/heisskleber/config/parse.py @@ -0,0 +1,86 @@ +import os +import sys +import warnings +from typing import TypeVar + +import yaml + +from heisskleber.config import BaseConf +from heisskleber.config.cmdline import get_cmdline + +ConfigType = TypeVar("ConfigType", bound=BaseConf) + + +def get_msb_config_filepath(config_filename: str = "heisskleber.conf") -> str: + config_subpath = os.path.join("msb/conf.d/", config_filename) + try: + config_filepath = os.path.join(os.environ["MSB_CONFIG_DIR"], config_subpath) + except Exception as e: + print(f"could no get MSB_CONFIG from PATH: {e}") + sys.exit() # TODO use 1 or the error str as exit value + if not os.path.isfile(config_filepath): + print(f"not a file: {config_filepath}!") + sys.exit() + return config_filepath + + +def read_yaml_config_file(config_fpath: str) -> dict: + with open(config_fpath) as config_filehandle: + return yaml.safe_load(config_filehandle) + + +def update_config(config: ConfigType, config_dict: dict) -> ConfigType: + for config_key, config_value in config_dict.items(): + # get expected type of element from config_object: + if not hasattr(config, config_key): + warnings.warn( + f"no such configuration parameter: {config_key}, skipping", stacklevel=2 + ) + continue + cast_func = type(config[config_key]) + try: + if config_key == "topic": + config[config_key] = config_value.encode("utf-8") + else: + config[config_key] = cast_func(config_value) + except Exception as e: + print( + f"failed to cast {config_value} to {type(config[config_key])}: {e}. skipping" + ) + continue + return config + + +ConfigType = TypeVar( + "ConfigType", bound=BaseConf +) # https://stackoverflow.com/a/46227137 , https://docs.python.org/3/library/typing.html#typing.TypeVar + + +def load_config( + config: ConfigType, config_filename: str, read_commandline: bool = True +) -> ConfigType: + """Load the config file and update the config object. + + Parameters + ---------- + config : MSBConf + The config object to fill with values. + config_filename : str + The name of the config file in $MSB_CONF/msb/conf.d/. + If the file does not have an extension the default extension .yaml is appended. + read_commandline : bool + Whether to read arguments from the command line. Optional. Defaults to True. + """ + config_filename = ( + config_filename if "." in config_filename else config_filename + ".yaml" + ) + config_filepath = get_msb_config_filepath(config_filename) + config_dict = read_yaml_config_file(config_filepath) + config = update_config(config, config_dict) + + if not read_commandline: + return config + + config_dict = get_cmdline() + config = update_config(config, config_dict) + return config diff --git a/heisskleber/config/zeromq.py b/heisskleber/config/zeromq.py new file mode 100644 index 0000000..571440a --- /dev/null +++ b/heisskleber/config/zeromq.py @@ -0,0 +1,40 @@ +import sys + +import zmq + + +def open_zmq_sub_socket(connect_to: str, topic=b""): + ctx = zmq.Context() + zmq_socket = ctx.socket(zmq.SUB) + try: + zmq_socket.connect(connect_to) + except Exception as e: + print(f"failed to bind to zeromq socket: {e}") + sys.exit(-1) + zmq_socket.setsockopt(zmq.SUBSCRIBE, topic) + return zmq_socket + + +def open_zmq_pub_socket(connect_to: str): + ctx = zmq.Context() + zmq_socket = ctx.socket(zmq.PUB) + try: + zmq_socket.connect(connect_to) + except Exception as e: + print(f"failed to bind to zeromq socket: {e}") + sys.exit(-1) + return zmq_socket + + +def get_zmq_xpub_socketstring(msb_config: dict) -> str: + zmq_config = msb_config["zeromq"] + return ( + f"{zmq_config['protocol']}://{zmq_config['address']}:{zmq_config['xpub_port']}" + ) + + +def get_zmq_xsub_socketstring(msb_config: dict) -> str: + zmq_config = msb_config["zeromq"] + return ( + f"{zmq_config['protocol']}://{zmq_config['address']}:{zmq_config['xsub_port']}" + ) diff --git a/heisskleber/network/__init__.py b/heisskleber/network/__init__.py new file mode 100644 index 0000000..31256a0 --- /dev/null +++ b/heisskleber/network/__init__.py @@ -0,0 +1 @@ +from .pubsub.factories import get_publisher, get_subscriber # noqa: F401 diff --git a/heisskleber/network/config.py b/heisskleber/network/config.py new file mode 100644 index 0000000..e69de29 diff --git a/heisskleber/network/influxdb/__init__.py b/heisskleber/network/influxdb/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/heisskleber/network/influxdb/config.py b/heisskleber/network/influxdb/config.py new file mode 100644 index 0000000..62251a1 --- /dev/null +++ b/heisskleber/network/influxdb/config.py @@ -0,0 +1,20 @@ +from dataclasses import dataclass + +from heisskleber.config import BaseConf + + +@dataclass +class InfluxDBConf(BaseConf): + host: str = "localhost" + port: int = 8086 + bucket: str = "test" + org: str = "test" + ssl: bool = False + read_token: str = "" + write_token: str = "" + all_access_token: str = "" + + @property + def url(self) -> str: + protocol = "https" if self.ssl else "http" + return f"{protocol}://{self.host}:{self.port}" diff --git a/heisskleber/network/influxdb/subscriber.py b/heisskleber/network/influxdb/subscriber.py new file mode 100644 index 0000000..a49433a --- /dev/null +++ b/heisskleber/network/influxdb/subscriber.py @@ -0,0 +1,74 @@ +import pandas as pd +from influxdb_client import InfluxDBClient + +from .config import InfluxDBConf + + +def build_query(options: dict) -> str: + query = ( + f'from(bucket:"{options["bucket"]}")' + + f'|> range(start: {options["start"].isoformat("T")}, stop: {options["end"].isoformat("T")})' + + f'|> filter(fn:(r) => r._measurement == "{options["measurement"]}")' + ) + if options["filter"]: + for attribute, value in options["filter"].items(): + if isinstance(value, list): + query += f'|> filter(fn:(r) => r.{attribute} == "{value[0]}"' + for vv in value[1:]: + query += f' or r.{attribute} == "{vv}"' + query += ")" + else: + query += f'|> filter(fn:(r) => r.{attribute} == "{value}")' + + query += ( + f'|> aggregateWindow(every: {options["resample"]}, fn: mean)' + + '|> pivot(rowKey:["_time"], columnKey: ["_field"], valueColumn: "_value")' + ) + + return query + + +class Influx_Subscriber: + def __init__(self, config: InfluxDBConf, query: str): + self.config = config + self.query = query + self.df: pd.DataFrame = None + + self.client: InfluxDBClient = InfluxDBClient( + url=self.config.url, + token=self.config.all_access_token or self.config.read_token, + org=self.config.org, + timeout=60_000, + ) + self.reader = self.client.query_api() + + self._run_query() + self.index = 0 + + def receive(self) -> dict: + row = self.df.iloc[self.index].to_dict() + self.index += 1 + return "influx", row + + def _run_query(self): + self.df = self.reader.query_data_frame(self.query, org=self.config.org) + self.df["epoch"] = pd.to_numeric(self.df["_time"]) / 1e9 + self.df.drop( + columns=[ + "result", + "table", + "_start", + "_stop", + "_measurement", + "_time", + "topic", + ], + inplace=True, + ) + + def __iter__(self): + for _, row in self.df.iterrows(): + yield "influx", row.to_dict() + + def __next__(self): + return self.__iter__().__next__() diff --git a/heisskleber/network/influxdb/writer.py b/heisskleber/network/influxdb/writer.py new file mode 100644 index 0000000..cabc5ca --- /dev/null +++ b/heisskleber/network/influxdb/writer.py @@ -0,0 +1,50 @@ +from config import InfluxDBConf +from influxdb_client import InfluxDBClient, WriteOptions + +from heisskleber.config import load_config + + +class Influx_Writer: + def __init__(self, config: InfluxDBConf): + self.config = config + # self.write_options = SYNCHRONOUS + self.write_options = WriteOptions( + batch_size=500, + flush_interval=10_000, + jitter_interval=2_000, + retry_interval=5_000, + max_retries=5, + max_retry_delay=30_000, + exponential_base=2, + ) + self.client = InfluxDBClient( + url=self.config.url, token=self.config.token, org=self.config.org + ) + self.writer = self.client.write_api( + write_options=self.write_options, + ) + + def __del__(self): + self.writer.close() + self.client.close() + + def write_line(self, line): + self.writer.write(bucket=self.config.bucket, record=line) + + def write_from_generator(self, generator): + for line in generator: + self.writer.write(bucket=self.config.bucket, record=line) + + def write_from_line_generator(self, generator): + with InfluxDBClient( + url=self.config.url, token=self.config.token, org=self.config.org + ) as client, client.write_api( + write_options=self.write_options, + ) as write_api: + for line in generator: + write_api.write(bucket=self.config.bucket, record=line) + + +def get_parsed_flux_writer(): + config = load_config(InfluxDBConf(), "flux", read_commandline=False) + return Influx_Writer(config) diff --git a/heisskleber/network/mqtt/__init__.py b/heisskleber/network/mqtt/__init__.py new file mode 100644 index 0000000..e71a954 --- /dev/null +++ b/heisskleber/network/mqtt/__init__.py @@ -0,0 +1,3 @@ +from .config import MqttConf # noqa: F401 +from .publisher import MqttPublisher # noqa: F401 +from .subscriber import MqttSubscriber # noqa: F401 diff --git a/heisskleber/network/mqtt/config.py b/heisskleber/network/mqtt/config.py new file mode 100644 index 0000000..c716cb3 --- /dev/null +++ b/heisskleber/network/mqtt/config.py @@ -0,0 +1,23 @@ +from dataclasses import dataclass, field + +from heisskleber.config import BaseConf + + +@dataclass +class MqttConf(BaseConf): + """ + MQTT configuration class. + """ + + broker: str = "localhost" + user: str = "" + password: str = "" + port: int = 1883 + ssl: bool = False + qos: int = 0 + retain: bool = False + topics: list[bytes] = field(default_factory=list) + mapping: str = "/msb/" + packstyle: str = "json" + max_saved_messages: int = 100 + timeout_s: int = 60 diff --git a/heisskleber/network/mqtt/forwarder.py b/heisskleber/network/mqtt/forwarder.py new file mode 100644 index 0000000..e3a959a --- /dev/null +++ b/heisskleber/network/mqtt/forwarder.py @@ -0,0 +1,22 @@ +from heisskleber.config import load_config +from heisskleber.network import get_publisher, get_subscriber + +from .config import MqttConf + + +def map_topic(zmq_topic, mapping): + return mapping + zmq_topic.decode() + + +def main(): + config: MqttConf = load_config(MqttConf(), "mqtt") + sub = get_subscriber("zmq", config.topics) + pub = get_publisher("mqtt") + + sub.unpack = pub.pack = lambda x: x + + while True: + (zmq_topic, data) = sub.receive() + mqtt_topic = map_topic(zmq_topic, config.mapping) + + pub.send(mqtt_topic, data) diff --git a/heisskleber/network/mqtt/mqtt_base.py b/heisskleber/network/mqtt/mqtt_base.py new file mode 100644 index 0000000..ed9fb2a --- /dev/null +++ b/heisskleber/network/mqtt/mqtt_base.py @@ -0,0 +1,89 @@ +import ssl +import sys +import threading + +from paho.mqtt.client import Client as mqtt_client + +from .config import MqttConf + + +class ThreadDiedError(RuntimeError): + pass + + +_thread_died = threading.Event() + +_default_excepthook = threading.excepthook + + +def _set_thread_died_excepthook(args, /): + _default_excepthook(args) + global _thread_died + _thread_died.set() + + +threading.excepthook = _set_thread_died_excepthook + + +class MQTT_Base: + """ + Wrapper around eclipse paho mqtt client. + Handles connection and callbacks. + Callbacks may be overwritten in subclasses. + """ + + def __init__(self, config: MqttConf): + self.config = config + self.connect() + self.client.loop_start() + + def connect(self): + self.client = mqtt_client() + self.client.username_pw_set(self.config.user, self.config.password) + + # Add callbacks + self.client.on_connect = self._on_connect + self.client.on_disconnect = self._on_disconnect + self.client.on_publish = self._on_publish + self.client.on_message = self._on_message + + if self.config.ssl: + # By default, on Python 2.7.9+ or 3.4+, + # the default certification authority of the system is used. + self.client.tls_set(tls_version=ssl.PROTOCOL_TLS_CLIENT) + + self.client.connect(self.config.broker, self.config.port) + + @staticmethod + def _raise_if_thread_died(): + global _thread_died + if _thread_died.is_set(): + raise ThreadDiedError() + + # MQTT callbacks + def _on_connect(self, client, userdata, flags, return_code): + if return_code == 0: + print(f"MQTT node connected to {self.config.broker}:{self.config.port}") + else: + print("Connection failed!") + if self.config.verbose: + print(flags) + + def _on_disconnect(self, client, userdata, return_code): + print(f"Disconnected from broker with return code {return_code}") + if return_code != 0: + print("Killing this service") + sys.exit(-1) + + def _on_publish(self, client, userdata, message_id): + if self.config.verbose: + print(f"Published message with id {message_id}, qos={self.config.qos}") + + def _on_message(self, client, userdata, message): + if self.config.verbose: + print( + f"Received message: {message.payload!s}, topic: {message.topic}, qos: {message.qos}" + ) + + def __del__(self): + self.client.loop_stop() diff --git a/heisskleber/network/mqtt/msb_mqtt.py b/heisskleber/network/mqtt/msb_mqtt.py new file mode 100644 index 0000000..344add1 --- /dev/null +++ b/heisskleber/network/mqtt/msb_mqtt.py @@ -0,0 +1,18 @@ +from heisskleber.config import load_config +from heisskleber.network import get_publisher, get_subscriber + +from .config import MqttConf +from .forwarder import ZMQ_to_MQTT_Forwarder + + +def main(): + config = load_config(MqttConf(), "mqtt") + for topic in config.topics: + print(f"Subscribing to {topic}") + + zmq_sub = get_subscriber("zmq", list(config.topics)) + mqtt_pub = get_publisher("mqtt") + forwarder = ZMQ_to_MQTT_Forwarder(config, subscriber=zmq_sub, publisher=mqtt_pub) + + # Wait for zmq messages, publish as mqtt message + forwarder.zmq_to_mqtt_loop() diff --git a/heisskleber/network/mqtt/publisher.py b/heisskleber/network/mqtt/publisher.py new file mode 100644 index 0000000..4acf4a9 --- /dev/null +++ b/heisskleber/network/mqtt/publisher.py @@ -0,0 +1,63 @@ +from __future__ import annotations + +from heisskleber.config import load_config +from heisskleber.network.packer import get_packer +from heisskleber.network.pubsub.types import Publisher + +from .config import MqttConf +from .mqtt_base import MQTT_Base + + +class MqttPublisher(MQTT_Base, Publisher): + """ + MQTT publisher class. + Can be used everywhere that a flucto style publishing connection is required. + + Network message loop is handled in a separated thread. + """ + + def __init__(self, config: MqttConf): + super().__init__(config) + self.pack = get_packer(config.packstyle) + + def send(self, topic: str | bytes, data: dict): + """ + Takes python dictionary, serializes it according to the packstyle + and sends it to the broker. + + Publishing is asynchronous + """ + self._raise_if_thread_died() + if isinstance(topic, bytes): + topic = topic.decode() + + payload = self.pack(data) + self.client.publish( + topic, payload, qos=self.config.qos, retain=self.config.retain + ) + + +def get_mqtt_publisher() -> MqttPublisher: + """ + Generate mqtt publisher with configuration from yaml file, + falls back to default values if no config is found + """ + import os + + if "MSB_CONFIG_DIR" in os.environ: + print("loading mqtt config") + config = load_config(MqttConf(), "mqtt", read_commandline=False) + else: + print("using default mqtt config") + config = MqttConf() + return MqttPublisher(config) + + +def get_default_publisher() -> MqttPublisher: + """ + Generate mqtt publisher with configuration from yaml file, + falls back to default values if no config is found + + Deprecated, use get_mqtt_publisher() instead + """ + return get_mqtt_publisher() diff --git a/heisskleber/network/mqtt/subscriber.py b/heisskleber/network/mqtt/subscriber.py new file mode 100644 index 0000000..e684bff --- /dev/null +++ b/heisskleber/network/mqtt/subscriber.py @@ -0,0 +1,104 @@ +from __future__ import annotations + +from queue import SimpleQueue + +from heisskleber.config import load_config +from heisskleber.network.packer import get_unpacker +from heisskleber.network.pubsub.types import Subscriber + +from .config import MqttConf +from .mqtt_base import MQTT_Base + + +class MqttSubscriber(MQTT_Base, Subscriber): + """ + MQTT subscriber, wraps around ecplipse's paho mqtt client. + Network message loop is handled in a separated thread. + + Incoming messages are saved as a stack when not processed via the receive() function. + """ + + def __init__(self, topics, config: MqttConf): + super().__init__(config) + self._message_queue = SimpleQueue() + self.subscribe(topics) + self.client.on_message = self._on_message + self.unpack = get_unpacker(config.packstyle) + + def _subscribe_single_topic(self, topic: bytes | str): + if isinstance(topic, bytes): + topic = topic.decode() + if self.config.verbose: + print(f"Subscribed to: {topic}") + self.client.subscribe(topic, self.config.qos) + + def _subscribe_multiple_topics(self, topics: list[bytes] | list[str]): + topics = [ + topic.decode() if isinstance(topic, bytes) else topic for topic in topics + ] + subscription_list = [(topic, self.config.qos) for topic in topics] + if self.config.verbose: + print(f"Subscribed to: {topics}") + self.client.subscribe(subscription_list) + + def subscribe(self, topics): + """ + Subscribe to one or multiple topics + """ + # if subscribing to multiple topics, use a list of tuples + if isinstance(topics, (list, tuple)): + self._subscribe_multiple_topics(topics) + else: + self.client.subscribe(topics, self.config.qos) + + def receive(self) -> tuple[bytes, dict]: + """ + Reads a message from mqtt and returns it + + Messages are saved in a stack, if no message is available, this function blocks. + + Returns: + tuple(topic: bytes, message: dict): the message received + """ + self._raise_if_thread_died() + mqtt_message = self._message_queue.get( + block=True, timeout=self.config.timeout_s + ) + + topic = mqtt_message.topic.encode("utf-8") + message_returned = self.unpack(mqtt_message.payload.decode()) + return (topic, message_returned) + + # callback to add incoming messages onto stack + def _on_message(self, client, userdata, message): + self._message_queue.put(message) + + if self.config.verbose: + print(f"Topic: {message.topic}") + print(f"MQTT message: {message.payload.decode()}") + + +def get_mqtt_subscriber(topic: bytes | str) -> MqttSubscriber: + """ + Generate mqtt subscriber with configuration from yaml file, + falls back to default values if no config is found + """ + import os + + if "MSB_CONFIG_DIR" in os.environ: + print("loading mqtt config") + config = load_config(MqttConf(), "mqtt", read_commandline=False) + else: + print("using default mqtt config") + config = MqttConf() + return MqttSubscriber(topic, config) + + +def get_default_subscriber(topic: bytes | str) -> MqttSubscriber: + """ + Generate mqtt subscriber with configuration from yaml file, + falls back to default values if no config is found + + Deprecated, use get_mqtt_subscriber(topic) instead. + """ + return get_mqtt_subscriber(topic) diff --git a/heisskleber/network/packer.py b/heisskleber/network/packer.py new file mode 100644 index 0000000..3150c2c --- /dev/null +++ b/heisskleber/network/packer.py @@ -0,0 +1,34 @@ +import json +import pickle + + +def get_packer(style): + if style in _packstyles: + return _packstyles[style] + else: + return _packstyles["default"] + + +def get_unpacker(style): + if style in _unpackstyles: + return _unpackstyles[style] + else: + return _unpackstyles["default"] + + +def serialpacker(data: dict): + return ",".join([str(v) for v in data.values()]) + + +_packstyles = { + "json": json.dumps, + "pickle": pickle.dumps, + "serial": serialpacker, + "default": json.dumps, +} + +_unpackstyles = { + "json": json.loads, + "pickle": pickle.loads, + "default": json.loads, +} diff --git a/heisskleber/network/pubsub/__init__.py b/heisskleber/network/pubsub/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/heisskleber/network/pubsub/factories.py b/heisskleber/network/pubsub/factories.py new file mode 100644 index 0000000..a7dfdf9 --- /dev/null +++ b/heisskleber/network/pubsub/factories.py @@ -0,0 +1,52 @@ +import os + +from heisskleber.config import load_config +from heisskleber.network.mqtt import MqttConf, MqttPublisher, MqttSubscriber +from heisskleber.network.serial import SerialConf, SerialPublisher, SerialSubscriber +from heisskleber.network.zmq import ZmqConf, ZmqPublisher, ZmqSubscriber + +_registered_publishers = { + "zmq": (ZmqPublisher, ZmqConf), + "mqtt": (MqttPublisher, MqttConf), + "serial": (SerialPublisher, SerialConf), +} + +_registered_subscribers = { + "zmq": (ZmqSubscriber, ZmqConf), + "mqtt": (MqttSubscriber, MqttConf), + "serial": (SerialSubscriber, SerialConf), +} + + +def get_publisher(name: str): + if name not in _registered_publishers: + error_message = f"{name} is not a registered Publisher." + raise KeyError(error_message) + + pub_cls, conf_cls = _registered_publishers[name] + + if "MSB_CONFIG_DIR" in os.environ: + print(f"loading {name} config") + config = load_config(conf_cls(), name, read_commandline=False) + else: + print(f"using default {name} config") + config = conf_cls() + + return pub_cls(config) + + +def get_subscriber(name: str, topic): + if name not in _registered_publishers: + error_message = f"{name} is not a registered Subscriber." + raise KeyError(error_message) + + sub_cls, conf_cls = _registered_subscribers[name] + + if "MSB_CONFIG_DIR" in os.environ: + print(f"loading {name} config") + config = load_config(conf_cls(), name, read_commandline=False) + else: + print(f"using default {name} config") + config = conf_cls() + + return sub_cls(topic, config) diff --git a/heisskleber/network/pubsub/types.py b/heisskleber/network/pubsub/types.py new file mode 100644 index 0000000..8fe5184 --- /dev/null +++ b/heisskleber/network/pubsub/types.py @@ -0,0 +1,31 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod + + +class Publisher(ABC): + """ + Publisher interface. + """ + + @abstractmethod + def send(self, topic: str | bytes, data: dict): + """ + Send data via the implemented output stream. + """ + pass + + +class Subscriber(ABC): + """ + Subscriber interface + """ + + @abstractmethod + def receive(self) -> tuple[bytes, dict]: + """ + Blocking function to receive data from the implemented input stream. + + Data is returned as a tuple of (topic, data). + """ + pass diff --git a/heisskleber/network/serial/__init__.py b/heisskleber/network/serial/__init__.py new file mode 100644 index 0000000..12cd737 --- /dev/null +++ b/heisskleber/network/serial/__init__.py @@ -0,0 +1,3 @@ +from .config import SerialConf # noqa: F401 +from .publisher import SerialPublisher # noqa: F401 +from .subscriber import SerialSubscriber # noqa: F401 diff --git a/heisskleber/network/serial/config.py b/heisskleber/network/serial/config.py new file mode 100644 index 0000000..a845f1e --- /dev/null +++ b/heisskleber/network/serial/config.py @@ -0,0 +1,11 @@ +from dataclasses import dataclass + +from heisskleber.config import BaseConf + + +@dataclass +class SerialConf(BaseConf): + port: str = "/dev/serial0" + baudrate: int = 9600 + bytesize: int = 8 + encoding: str = "ascii" diff --git a/heisskleber/network/serial/forwarder.py b/heisskleber/network/serial/forwarder.py new file mode 100644 index 0000000..098b661 --- /dev/null +++ b/heisskleber/network/serial/forwarder.py @@ -0,0 +1,30 @@ +from heisskleber.network.types import Subscriber + +from .publisher import SerialPublisher + + +class SerialForwarder: + def __init__(self, subscriber: Subscriber, publisher: SerialPublisher): + self.sub = subscriber + self.pub = publisher + + """ + Wait for message and forward + """ + + def forward_message(self): + # collected = {} + # for sub in self.sub: + # topic, data = sub.receive() + # collected.update(data) + _, collected = self.sub.receive() + + self.pub.send(collected) + + """ + Enter loop and continuously forward messages + """ + + def sub_pub_loop(self): + while True: + self.forward_message() diff --git a/heisskleber/network/serial/publisher.py b/heisskleber/network/serial/publisher.py new file mode 100644 index 0000000..7f67138 --- /dev/null +++ b/heisskleber/network/serial/publisher.py @@ -0,0 +1,65 @@ +from __future__ import annotations + +from types import FunctionType + +import serial + +from heisskleber.network.packer import get_packer +from heisskleber.network.pubsub.types import Publisher + +from .config import SerialConf + + +class SerialPublisher(Publisher): + """ + Publisher for serial devices. + Can be used everywhere that a flucto style publishing connection is required. + + Parameters + ---------- + config : SerialConf + Configuration for the serial connection. + pack_func : FunctionType + Function to translate from a dict to a serialized string. + """ + + def __init__(self, config: SerialConf, pack_func: FunctionType | None = None): + self.config = config + self.packer = pack_func if pack_func else get_packer("serial") + self._connect() + + def _connect(self): + self.serial: serial.Serial = serial.Serial( + port=self.config.port, + baudrate=self.config.baudrate, + bytesize=self.config.bytesize, + parity=serial.PARITY_NONE, + stopbits=serial.STOPBITS_ONE, + ) + print(f"Successfully connected to serial device at port {self.config.port}") + + def send(self, message: object): + """ + Takes python dictionary, serializes it according to the packstyle + and sends it to the broker. + + Please note that this does not adhere to the interface, as there is no topic. + + Parameters + ---------- + message : object + object to be serialized and sent via the serial connection. Usually a dict. + """ + payload = self.packer(message) + self.serial.write(payload.encode(self.config.encoding)) + self.serial.flush() + if self.config.verbose: + print(payload) + + def __del__(self): + if not hasattr(self, "serial"): + return + if not self.serial.is_open: + return + self.serial.flush() + self.serial.close() diff --git a/heisskleber/network/serial/subscriber.py b/heisskleber/network/serial/subscriber.py new file mode 100644 index 0000000..d6c203f --- /dev/null +++ b/heisskleber/network/serial/subscriber.py @@ -0,0 +1,90 @@ +from __future__ import annotations + +from typing import Callable, Optional + +import serial + +from heisskleber.network.pubsub.types import Subscriber + +from .config import SerialConf + + +class SerialSubscriber(Subscriber): + """ + Subscriber for serial devices. Connects to a serial port and reads from it. + + Parameters + ---------- + topics : + Placeholder for topic. Not used. + + config : SerialConf + Configuration class for the serial connection. + + unpack_func : FunctionType + Function to translate from a serialized string to a dict. + """ + + def __init__( + self, + topics, + config: SerialConf, + unpack_func: Optional[Callable] = None, # noqa: UP007 + ): + self.config = config + self.unpack = unpack_func if unpack_func else lambda x: x + self._connect() + + def _connect(self): + self.serial: serial.Serial = serial.Serial( + port=self.config.port, + baudrate=self.config.baudrate, + bytesize=self.config.bytesize, + parity=serial.PARITY_NONE, + stopbits=serial.STOPBITS_ONE, + ) + print(f"Successfully connected to serial device at port {self.config.port}") + + def receive(self) -> dict: + """ + Wait for data to arrive on the serial port and return it. + + Returns + ------- + :return: (topic, payload) + topic is a placeholder to adhere to the Subscriber interface + payload is a dictionary containing the data from the serial port + """ + # message is a string + message = next(self.read_serial_port()) + # payload is a dictionary + payload = self.unpack(message) + # port is a placeholder for topic + return self.config.port, payload + + def read_serial_port(self) -> str: + buffer = "" + while True: + try: + buffer = self.serial.readline().decode() + yield buffer + except UnicodeError as e: + if self.config.verbose: + print(f"Could not decode: {message}") + print(e) + continue + + def __del__(self): + if not hasattr(self, "serial"): + return + if not self.serial.is_open: + return + self.serial.flush() + self.serial.close() + + +if __name__ == "__main__": + config = SerialConf() + serial_reader = SerialSubscriber(config) + for message in serial_reader.receive(): + print(message) diff --git a/heisskleber/network/types.py b/heisskleber/network/types.py new file mode 100644 index 0000000..d701017 --- /dev/null +++ b/heisskleber/network/types.py @@ -0,0 +1 @@ +from .pubsub.types import Publisher, Subscriber # noqa: F401 diff --git a/heisskleber/network/udp/__init__.py b/heisskleber/network/udp/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/heisskleber/network/udp/config.py b/heisskleber/network/udp/config.py new file mode 100644 index 0000000..51c9ad9 --- /dev/null +++ b/heisskleber/network/udp/config.py @@ -0,0 +1,14 @@ +from dataclasses import dataclass + +from heisskleber.config import BaseConf + + +@dataclass +class UDPConf(BaseConf): + """ + UDP configuration. + """ + + port: int = 1234 + ip: str = "127.0.0.1" + packer: str = "json" diff --git a/heisskleber/network/udp/publisher.py b/heisskleber/network/udp/publisher.py new file mode 100644 index 0000000..5499a69 --- /dev/null +++ b/heisskleber/network/udp/publisher.py @@ -0,0 +1,49 @@ +import socket + +from heisskleber.network.packer import get_packer +from heisskleber.network.pubsub.types import Publisher +from heisskleber.network.udp.config import UDPConf + + +class UDP_Publisher(Publisher): + def __init__(self, config): + self.config = config + self.ip = self.config.ip + self.port = self.config.port + self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + self.packer = get_packer(self.config.packer) + + def send(self, topic, message): + payload = self.packer(message) + payload = payload.encode("utf-8") + self.socket.sendto(payload, (self.ip, self.port)) + + def __del__(self): + self.socket.close() + + +def udp_sender(): + target_ip = "127.0.0.1" # Replace this with the receiver's IP address + target_port = 12345 # Replace this with the receiver's port number + + message = "Hello, UDP Receiver!" + + # Create a UDP socket + udp_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + + try: + # Send the message to the receiver + udp_socket.sendto(message.encode("utf-8"), (target_ip, target_port)) + print("Message sent successfully!") + except Exception as e: + print("Error occurred while sending the message:", str(e)) + finally: + udp_socket.close() + + +if __name__ == "__main__": + conf = UDPConf(ip="192.168.1.122", port=12345) + pub = UDP_Publisher(conf) + + pub.send("test", {"test": "test"}) + # pub.send("test", "Hi from pub") diff --git a/heisskleber/network/udp/subscriber.py b/heisskleber/network/udp/subscriber.py new file mode 100644 index 0000000..ffaab08 --- /dev/null +++ b/heisskleber/network/udp/subscriber.py @@ -0,0 +1,34 @@ +import socket + +from heisskleber.network.packer import get_unpacker +from heisskleber.network.pubsub.types import Subscriber +from heisskleber.network.udp.config import UDPConf + + +class UDP_Subscriber(Subscriber): + def __init__(self, config, topic=None): + self.config = config + self.ip = self.config.ip + self.port = self.config.port + self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + self.socket.bind((self.ip, self.port)) + self.unpacker = get_unpacker(self.config.packer) + + def receive(self): + payload, addr = self.socket.recvfrom(1024) + return addr, self.unpacker(payload.decode("utf-8")) + + def listen_loop(self): + while True: + addr, data = self.receive() + print(type(data)) + print(data) + + def __del__(self): + self.socket.close() + + +if __name__ == "__main__": + conf = UDPConf(ip="192.168.1.122", port=12345) + sub = UDP_Subscriber(conf) + sub.listen_loop() diff --git a/heisskleber/network/zmq/__init__.py b/heisskleber/network/zmq/__init__.py new file mode 100644 index 0000000..e66b53b --- /dev/null +++ b/heisskleber/network/zmq/__init__.py @@ -0,0 +1,3 @@ +from .config import ZmqConf # noqa: F401 +from .publisher import ZmqPublisher # noqa: F401 +from .subscriber import ZmqSubscriber # noqa: F401 diff --git a/heisskleber/network/zmq/config.py b/heisskleber/network/zmq/config.py new file mode 100644 index 0000000..2243d2d --- /dev/null +++ b/heisskleber/network/zmq/config.py @@ -0,0 +1,20 @@ +from dataclasses import dataclass + +from heisskleber.config import BaseConf + + +@dataclass +class ZmqConf(BaseConf): + protocol: str = "tcp" + interface: str = "127.0.0.1" + publisher_port: int = 5555 + subscriber_port: int = 5556 + packstyle: str = "json" + + @property + def publisher_address(self): + return f"{self.protocol}://{self.interface}:{self.publisher_port}" + + @property + def subscriber_address(self): + return f"{self.protocol}://{self.interface}:{self.subscriber_port}" diff --git a/heisskleber/network/zmq/publisher.py b/heisskleber/network/zmq/publisher.py new file mode 100644 index 0000000..622506b --- /dev/null +++ b/heisskleber/network/zmq/publisher.py @@ -0,0 +1,33 @@ +import sys + +import zmq + +from heisskleber.network.packer import get_packer +from heisskleber.network.pubsub.types import Publisher + +from .config import ZmqConf + + +class ZmqPublisher(Publisher): + def __init__(self, config: ZmqConf): + self.config = config + + self.context = zmq.Context.instance() + self.socket = self.context.socket(zmq.PUB) + + self.pack = get_packer(config.packstyle) + self.connect() + + def connect(self): + try: + self.socket.connect(self.config.publisher_address) + except Exception as e: + print(f"failed to bind to zeromq socket: {e}") + sys.exit(-1) + + def send(self, topic: bytes, data: dict): + data = self.pack(data) + self.socket.send_multipart([topic, data.encode("utf-8")]) + + def __del__(self): + self.socket.close() diff --git a/heisskleber/network/zmq/subscriber.py b/heisskleber/network/zmq/subscriber.py new file mode 100644 index 0000000..245a9cb --- /dev/null +++ b/heisskleber/network/zmq/subscriber.py @@ -0,0 +1,57 @@ +from __future__ import annotations + +import sys + +import zmq + +from heisskleber.network.packer import get_unpacker +from heisskleber.network.pubsub.types import Subscriber + +from .config import ZmqConf + + +class ZmqSubscriber(Subscriber): + def __init__(self, topic: bytes | str | list[bytes] | list[str], config: ZmqConf): + self.config = config + + self.context = zmq.Context.instance() + self.socket = self.context.socket(zmq.SUB) + self.connect() + self.subscribe(topic) + + self.unpack = get_unpacker(config.packstyle) + + def connect(self): + try: + # print(f"Connecting to { self.config.consumer_connection }") + self.socket.connect(self.config.subscriber_address) + except Exception as e: + print(f"failed to bind to zeromq socket: {e}") + sys.exit(-1) + + def _subscribe_single_topic(self, topic: bytes | str): + if isinstance(topic, str): + topic = topic.encode() + self.socket.setsockopt(zmq.SUBSCRIBE, topic) + + def subscribe(self, topic: bytes | str | list[bytes] | list[str]): + # Accepts single topic or list of topics + if isinstance(topic, (list, tuple)): + for t in topic: + self._subscribe_single_topic(t) + else: + self._subscribe_single_topic(topic) + + def receive(self) -> tuple[bytes, dict]: + """ + reads a message from the zmq bus and returns it + + Returns: + tuple(topic: bytes, message: dict): the message received + """ + (topic, message) = self.socket.recv_multipart() + message = self.unpack(message.decode()) + return (topic, message) + + def __del__(self): + self.socket.close()