Refactor heisskleber core, remove synchronous implementations (#156)

* #129 AsyncTcpSource enhancements
- retry connection on startup (behavior is configurable)
- reconnect if data receiving fails (EOF received)
- add Python logging
- add unit tests

* remove syncronous implementations.

* WIP: Refactor packer/unpacker

* Refactor type hints and topic handling in console sink.

* Remove comma from tcp config enum definitions

* Remove references to deleted synchronous classes.

* Hopefully stable interface for Packer and Unpacker.

* WIP: Working with protocols and generics

* Finalized Sink, Source definition.

* Rename mqtt source and sink files

* Rename mqtt publisher and subscriber.

* Fix start function to async.

* Update documentation.

* Remove recursion from udp source.

* rename unpack to unpacker, stay consistent.

* Renaming in tests.

* Make MqttSource generic.

* Configure pyproject.toml to move to uv

* Add nox support.

* Update documentation with myst-parser and sphinx.

* Mess with autogeneration of __call__ signatures.

* Add dynamic versioning to hatch

* Asyncio wrapper for pyserial.

* Add docstrings for serial sink and source.

* Refactor config handling (#171)

* Removes deprecated "verbose" and "print_std" parameters

* Adds class methods for config generation from dictionary or file (yaml or json at this point)

* Run-time type checking via __post_init__() function

* Add serial dependency.

* WIP

* Move broker to bin/

* Update docs.

* WIP: Need to update docstrings to make ruff happy.

* Move source files to src/

* Fix tests for TcpSource.

* WIP: Remove old tests.

* Fix docstrings in mqtt classes.

* Make default tcp unpacker json_unpacker.

* No failed tests if there are no tests

* Update test pipeline

* Update ruff pre-commit

* Updated ruff formatting

* Format bin/

* Fix type hints

* No type checking

* Make stop() async

* Only test on ubuntu for now

* Don't be so strict about sphinx warnings.

* Rename TestConf for pytest naming compability.

* Install package in editable mode for ci tests.

* Update dependencies for docs generation.

* Add keepalive and will to mqtt, fixes #112.

* Update readme to reflect changes in usage.

* Requested fixes for console adapters.

* Raise correct errors in unpacker and packer.

* Correct logger name for mqtt sink.

* Add config options for stopbits and parity to Serial.

* Remove exception logging call from yaml parser.

* Add comments to clear up very implicit test.

* Rename Sink -> Sender, Source -> Receiver.

* Rename sink and source in tests.

* Fix tests.

---------

Co-authored-by: Adrian Weiler <a.weiler@aldea.de>
This commit is contained in:
Felix Weiler
2024-12-09 19:32:34 +01:00
committed by GitHub
parent 1f6e17bd42
commit 98099f5b00
145 changed files with 4394 additions and 6590 deletions

3
.git_archival.txt Normal file
View File

@@ -0,0 +1,3 @@
node: $Format:%H$
node-date: $Format:%cI$
describe-name: $Format:%(describe:tags=true,match=*[0-9]*)$

1
.gitattributes vendored
View File

@@ -1 +1,2 @@
* text=auto eol=lf
.git_archival.txt export-subst

View File

@@ -1,51 +1,82 @@
name: Tests
name: CI
on:
workflow_dispatch:
pull_request:
push:
branches:
- main
pull_request:
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
env:
FORCE_COLOR: 3
jobs:
tests:
pre-commit:
name: Format
runs-on: ubuntu-latest
steps:
- name: Check out the repository
uses: actions/checkout@v4
- name: Set up Python 3.10
uses: actions/setup-python@v5.0.0
- uses: actions/checkout@v4
with:
python-version: "3.10"
- name: Cache poetry install
uses: actions/cache@v4
fetch-depth: 0
- uses: actions/setup-python@v5
with:
path: ~/.local
key: poetry-1.7.1-0
- name: Install poetry
uses: snok/install-poetry@v1
python-version: "3.x"
- uses: pre-commit/action@v3.0.1
with:
version: 1.7.1
virtualenvs-create: true
virtualenvs-in-project: true
extra_args: --hook-stage manual --all-files
- name: Run Ruff
run: |
pip install ruff
ruff check .
- name: cache deps
id: cache-deps
uses: actions/cache@v4
tests:
name: Python ${{ matrix.python-version }} on ${{ matrix.runs-on }}
runs-on: ${{ matrix.runs-on }}
needs: [pre-commit]
strategy:
fail-fast: false
matrix:
python-version: ["3.11"]
# python-version: ["3.11", "3.12", "3.13"]
runs-on: [ubuntu-latest]
steps:
- uses: actions/checkout@v4
with:
path: .venv
key: pydeps-${{ hashFiles('**/poetry.lock') }}
fetch-depth: 0
- run: poetry install --no-interaction # --no-root
if: steps.cache-deps.outputs.cache-hit != 'true'
- uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
allow-prereleases: true
# - run: poetry install --no-interaction
- name: Install dependencies
run: python -m pip install -e ".[test,docs]"
- run: poetry run pytest --cov --cov-report xml
- name: Run tests
run: python -m pytest -ra --cov --cov-report=xml --cov-report=term --durations=20
- name: Upload coverage roports to Codecov
- name: Upload coverage
uses: codecov/codecov-action@v4
env:
CODECOV_TOKEN: ${{secrets.CODECOV_TOKEN}}
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
with:
files: ./coverage.xml
fail_ci_if_error: true
docs:
runs-on: ubuntu-latest
needs: [pre-commit]
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: "3.11"
- name: Build docs
run: |
pip install ".[docs]"
sphinx-build -b html docs docs/_build/html

View File

@@ -1,6 +1,6 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: "v4.4.0"
rev: "v5.0.0"
hooks:
- id: check-case-conflict
- id: check-merge-conflict
@@ -10,10 +10,8 @@ repos:
- id: trailing-whitespace
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.1.5
rev: "v0.7.4"
hooks:
# Run the linter.
- id: ruff
# Run the formatter.
args: ["--fix", "--show-fixes"]
- id: ruff-format

View File

@@ -1,12 +1,17 @@
# Read the Docs configuration file
# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
version: 2
build:
os: ubuntu-20.04
os: ubuntu-22.04
tools:
python: "3.10"
sphinx:
configuration: docs/conf.py
formats: all
python:
install:
- requirements: docs/requirements.txt
- path: .
python: "3.12"
commands:
- asdf plugin add uv
- asdf install uv latest
- asdf global uv latest
- uv venv
- uv pip install .[docs]
- .venv/bin/python -m sphinx -T -b html -d docs/_build/doctrees -D
language=en docs $READTHEDOCS_OUTPUT/html

View File

@@ -23,9 +23,8 @@ Heisskleber is a versatile library designed to seamlessly "glue" together variou
## Features
- Multiple Protocol Support: Easy integration with zmq, mqtt, udp, serial, influxdb, and cmdline. Future plans include REST API and file operations.
- Custom Data Handling: Customizable "unpack" and "pack" functions allow for the translation of any data format (e.g., ascii encoded, comma-separated messages from a serial bus) into dictionaries for easy manipulation and transmission.
- Synchronous & Asynchronous Versions: Cater to different programming needs and scenarios with both sync and async interfaces.
- Multiple Protocol Support: Easy integration with zmq, mqtt, udp, serial, and cmdline. Future plans include REST API and file operations.
- Custom Data Handling: Customizable "unpacker" and "packer" functions allow for the translation of any data format (e.g., ascii encoded, comma-separated messages from a serial bus) into dictionaries for easy manipulation and transmission.
- Extensible: Designed for easy extension with additional protocols and data handling functions.
## Installation
@@ -36,56 +35,35 @@ You can install _Heisskleber_ via [pip] from [PyPI]:
$ pip install heisskleber
```
Configuration files for zmq, mqtt and other heisskleber related settings should be placed in the user's config directory, usually `$HOME/.config`. Config file templates can be found in the `config`
directory of the package.
## Quick Start
Here's a simple example to demonstrate how Heisskleber can be used to connect a zmq source to an mqtt sink:
```python
"""
A simple forwarder that takes messages from
A simple forwarder that takes messages from a serial device and publishes them via MQTT.
"""
import asyncio
from heisskleber.serial import SerialSubscriber, SerialConf
from heisskleber.mqtt import MqttPublisher, MqttConf
source = SerialSubscriber(config=SerialConf(port="/dev/ACM0"))
sink = MqttPublisher(config=MqttConf(host="127.0.0.1", port=1883, user="", password=""))
while True:
topic, data = source.receive()
sink.send(data, topic="/hostname/" + topic)
async def main():
source = SerialSubscriber(config=SerialConf(port="/dev/ACM0", baudrate=9600))
sink = MqttPublisher(config=MqttConf(host="mqtt.example.com", port=1883, user="", password=""))
while True:
data, metadata = await source.receive()
await sink.send(data, topic="/hotglue/" + metadata.get("topic", "serial"))
asyncio.run(main())
```
All sources and sinks come with customizable "unpack" and "pack" functions, making it simple to work with various data formats.
It is also possible to do configuration via yaml files, placed at `$HOME/.config/heisskleber` and named according to the protocol in question.
All sources and sinks come with customizable "unpacker" and "packer" functions, making it simple to work with various data formats.
See the [documentation][read the docs] for detailed usage.
## Development
1. Install poetry
```
curl -sSL https://install.python-poetry.org | python3 -
```
2. clone repository
```
git clone https://github.com/flucto-gmbh/heisskleber.git
cd heisskleber
```
3. setup
```
make install
```
## License
Distributed under the terms of the [MIT license][license],

52
bin/zmq_broker.py Normal file
View File

@@ -0,0 +1,52 @@
# /// script
# dependencies = [
# "pyzmq",
# "heisskleber"
# ]
# ///
import argparse
import logging
import sys
import zmq
from heisskleber.zmq import ZmqConf
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO, format="%(asctime)s - ZmqBroker - %(levelname)s - %(message)s")
def main() -> None:
"""Run ZMQ broker."""
parser = argparse.ArgumentParser()
parser.add_argument("-c", "--config", type=str, required=True, help="ZMQ configuration file (yaml or json)")
args = parser.parse_args()
config = ZmqConf.from_file(args.config)
try:
ctx = zmq.Context()
logger.info("Creating XPUB socket")
xpub = ctx.socket(zmq.XPUB)
logger.info("Creating XSUB socket")
xsub = ctx.socket(zmq.XSUB)
logger.info("Connecting XPUB socket to %(addr)s", {"addr": config.subscriber_address})
xpub.bind(config.subscriber_address)
logger.info("Connecting XSUB socket to %(addr)s", {"addr": config.publisher_address})
xsub.bind(config.publisher_address)
logger.info("Starting proxy...")
zmq.proxy(xpub, xsub)
except Exception:
logger.exception("Oh no! ZMQ broker failed!")
sys.exit(-1)
if __name__ == "__main__":
main()

View File

@@ -1,12 +1,74 @@
"""Sphinx configuration."""
from __future__ import annotations
import importlib.metadata
from typing import Any
project = "Heisskleber"
author = "Felix Weiler"
copyright = "2023, Felix Weiler"
author = "Felix Weiler-Detjen"
copyright = "2023, Flucto GmbH"
version = release = importlib.metadata.version("heisskleber")
extensions = [
"sphinx.ext.autodoc",
"sphinx.ext.napoleon",
"myst_parser",
] # , "autodoc2"
# autodoc2_packages = ["../heisskleber"]
autodoc_typehints = "description"
"sphinx.ext.autodoc",
"sphinx.ext.intersphinx",
"sphinx.ext.mathjax",
"sphinx.ext.napoleon",
"sphinx_autodoc_typehints",
"sphinx_copybutton",
]
autodoc_typehints = "description" # or 'signature' or 'both'
autodoc_type_aliases = {
"T": "heisskleber.core.T",
"T_co": "heisskleber.core.T_co",
"T_contra": "heisskleber.core.T_contra",
}
# If you're using typing.TypeVar in your code:
nitpicky = True
nitpick_ignore = [
("py:class", "T"),
("py:class", "T_co"),
("py:class", "T_contra"),
("py:data", "typing.Any"),
("py:class", "_io.StringIO"),
("py:class", "_io.BytesIO"),
]
source_suffix = [".rst", ".md"]
exclude_patterns = [
"_build",
"**.ipynb_checkpoints",
"Thumbs.db",
".DS_Store",
".env",
".venv",
]
html_theme = "furo"
html_theme_options: dict[str, Any] = {
"footer_icons": [
{
"name": "GitHub",
"url": "https://github.com/flucto-gmbh/heisskleber",
"html": """
<svg stroke="currentColor" fill="currentColor" stroke-width="0" viewBox="0 0 16 16">
<path fill-rule="evenodd" d="M8 0C3.58 0 0 3.58 0 8c0 3.54 2.29 6.53 5.47 7.59.4.07.55-.17.55-.38 0-.19-.01-.82-.01-1.49-2.01.37-2.53-.49-2.69-.94-.09-.23-.48-.94-.82-1.13-.28-.15-.68-.52-.01-.53.63-.01 1.08.58 1.23.82.72 1.21 1.87.87 2.33.66.07-.52.28-.87.51-1.07-1.78-.2-3.64-.89-3.64-3.95 0-.87.31-1.59.82-2.15-.08-.2-.36-1.02.08-2.12 0 0 .67-.21 2.2.82.64-.18 1.32-.27 2-.27.68 0 1.36.09 2 .27 1.53-1.04 2.2-.82 2.2-.82.44 1.1.16 1.92.08 2.12.51.56.82 1.27.82 2.15 0 3.07-1.87 3.75-3.65 3.95.29.25.54.73.54 1.48 0 1.07-.01 1.93-.01 2.2 0 .21.15.46.55.38A8.013 8.013 0 0 0 16 8c0-4.42-3.58-8-8-8z"></path>
</svg>
""",
"class": "",
},
],
"source_repository": "https://github.com/flucto-gmbh/heisskleber",
"source_branch": "main",
"source_directory": "docs/",
}
always_document_param_types = True

7
docs/development.md Normal file
View File

@@ -0,0 +1,7 @@
# How to contribute to development
1. Fork repository
2. Set up development environment
- Install uv, if you don't have it already: `curl -LsSf https://astral.sh/uv/install.sh | sh` (Or install from package manager, if applicable)
- Set python version (`uv venv --python 3.10`) (or whatever version you would like, as long as it's 3.10+. Careful if you use pyenv.)

View File

@@ -5,6 +5,7 @@ end-before: <!-- github-only -->
```
[license]: license
[serializing]: packer_and_unpacker
```{toctree}
---
@@ -14,6 +15,8 @@ maxdepth: 1
yaml-config
reference
serialization
development
License <license>
Changelog <https://github.com/flucto-gmbh/heisskleber/releases>
```

View File

@@ -1,45 +1,115 @@
# Reference
## Network
```{eval-rst}
.. automodule:: heisskleber.mqtt
:members:
.. automodule:: heisskleber.zmq
:members:
.. automodule:: heisskleber.serial
:members:
.. automodule:: heisskleber.udp
:members:
```
## Baseclasses
```{eval-rst}
.. automodule:: heisskleber.core.types
.. autoclass:: heisskleber.core.AsyncSink
:members:
.. autoclass:: heisskleber.core.AsyncSource
:members:
```
## Stream
## Serialization
Work on streaming data.
See <project:serialization.md> for a tutorial on how to implement custom packer and unpacker for (de-)serialization.
```{eval-rst}
.. automodule:: heisskleber.stream.filter
:members: __aiter__
.. autoclass:: heisskleber.core::Packer
.. automodule:: heisskleber.stream.butter
:members:
.. autoclass:: heisskleber.core::Unpacker
.. automodule:: heisskleber.stream.gh-filter
:members:
.. autoclass:: heisskleber.core.unpacker::JSONUnpacker
.. autoclass:: heisskleber.core.packer::JSONPacker
```
## Config
### Loading configs
### Errors
```{eval-rst}
.. automodule:: heisskleber.config
:members:
.. autoclass:: heisskleber.core::UnpackError
.. autoclass:: heisskleber.core::PackerError
```
## Implementations (Adapters)
### MQTT
```{eval-rst}
.. automodule:: heisskleber.mqtt
:no-members:
.. autoclass:: heisskleber.mqtt.MqttSink
:members: send
.. autoclass:: heisskleber.mqtt.MqttSource
:members: receive, subscribe
.. autoclass:: heisskleber.mqtt.MqttConf
:members:
```
### ZMQ
```{eval-rst}
.. autoclass:: heisskleber.zmq::ZmqConf
```
```{eval-rst}
.. autoclass:: heisskleber.zmq::ZmqSink
:members: send
```
```{eval-rst}
.. autoclass:: heisskleber.zmq::ZmqSource
:members: receive
```
### Serial
```{eval-rst}
.. autoclass:: heisskleber.serial::SerialConf
```
```{eval-rst}
.. autoclass:: heisskleber.serial::SerialSink
:members: send
```
```{eval-rst}
.. autoclass:: heisskleber.serial::SerialSource
:members: receive
```
### TCP
```{eval-rst}
.. autoclass:: heisskleber.tcp::TcpConf
```
```{eval-rst}
.. autoclass:: heisskleber.tcp::TcpSink
:members: send
```
```{eval-rst}
.. autoclass:: heisskleber.tcp::TcpSource
:members: receive
```
### UDP
```{eval-rst}
.. autoclass:: heisskleber.udp::UdpConf
```
```{eval-rst}
.. autoclass:: heisskleber.udp::UdpSink
:members: send
```
```{eval-rst}
.. autoclass:: heisskleber.udp::UdpSource
:members: receive
```

View File

@@ -1,3 +1,3 @@
furo==2024.1.29
sphinx==7.2.6
myst_parser==2.0.0
furo==2024.8.6
sphinx==8.1.3
myst_parser==4.0.0

72
docs/serialization.md Normal file
View File

@@ -0,0 +1,72 @@
# Serialization
## Implementing a custom Packer
The packer class is defined in heisskleber.core.packer.py as a Protocol [see PEP 544](https://peps.python.org/pep-0544/).
```python
T = TypeVar("T", contravariant=True)
class Packer(Protocol[T]):
def __call__(self, data: T) -> bytes:
pass
```
Users can create custom Packer classes with variable input data, either as callable classes, subclasses of the packer class or functions.
Please note, that to satisfy type checking engines, the argument must be named `data`, but being Python, it's obviously not enforced at runtime.
The AsyncSink's type is defined by the concrete packer implementation. So if your Packer packs strings to bytes, the AsyncSink will be of type `AsyncSink[str]`,
indicating that the send function takes strings only, see example below:
```python
from heisskleber import MqttSink, MqttConf
def string_packer(data: str) -> bytes:
return data.encode("ascii")
async def main():
sink = MqttSink(MqttConf(), packer = string_packer)
await sink.send("Hi there!") # This is fine
await sink.send({"data": 3.14}) # Type checker will complain
```
Heisskleber comes with default packers, such as the JSON_Packer, which can be importet as json_packer from heisskleber.core and is the default value for most Sinks.
## Implementing a custom Unpacker
The unpacker's responsibility is creating usable data from serialized byte strings.
This may be a serialized json string which is unpacked into a dictionary, but could be anything the user defines.
In heisskleber.core.unpacker.py the Unpacker Protocol is defined.
```python
class Unpacker(Protocol[T]):
def __call__(self, payload: bytes) -> tuple[T, dict[str, Any]]:
pass
```
Here, the payload is fixed to be of type bytes and the return type is a combination of a user-defined data type and a dictionary of meta-data.
```{eval-rst}
.. note::
Please Note: The extra dictionary may be updated by the Source, e.g. the MqttSource will add a "topic" field, received from the mqtt node.
```
The receive function of an AsyncSource object will have its return type informed by the signature of the unpacker.
```python
from heisskleber import MqttSource, MqttConf
import time
def csv_unpacker(payload: bytes) -> tuple[list[str], dict[str, Any]]:
# Unpack a utf-8 encoded csv string, such as b'1,42,3.14,100.0' to [1.0, 42.0, 3.14, 100.0]
# Adds some exemplary meta data
return [float(chunk) for chunk in payload.decode().split(",")], {"processed_at": time.time()}
async def main():
sub = MqttSource(MqttConf, unpacker = csv_unpacker)
data, extra = await sub.receive()
assert isinstance(data, list[str]) # passes
```
## Error handling
To be implemented...

View File

@@ -11,10 +11,6 @@ The configuration parameters are host, port, ssl, user and password to establish
- 2: "At least once", where messages are assured to arrive but duplicates can occur.
- 3: "Exactly once", where messages are assured to arrive exactly once.
- **max_saved_messages**: maximum number of messages that will be saved in the buffer until connection is available.
- **packstyle**: key of the serialization technique to use. Currently only JSON is supported.
- **source_id**: id of the device that will be used to identify the MQTT messages to be used by clients to format the topic.
Suggested topic format is in the form of `f"/{measurement_type}/{source_id}"`, eg. "/temperature/box-01".
- **topics**: the topics that the mqtt forwarder will subscribe to.
```yaml
# Heisskleber config file for MqttConf
@@ -27,10 +23,4 @@ qos: 0 # quality of service, 0=at most once, 1=at least once, 2=exactly once
timeout_s: 60
retain: false # save last message
max_saved_messages: 100 # buffer messages in until connection available
packstyle: json
# configs only valid for mqtt forwarder
mapping: /deprecated/
source_id: box-01
topics: ["topic1", "topic2"]
```

View File

@@ -1,19 +0,0 @@
"""Heisskleber."""
from .core.async_factories import get_async_sink, get_async_source
from .core.factories import get_publisher, get_sink, get_source, get_subscriber
from .core.types import AsyncSink, AsyncSource, Sink, Source
__all__ = [
"get_source",
"get_sink",
"get_publisher",
"get_subscriber",
"get_async_source",
"get_async_sink",
"Sink",
"Source",
"AsyncSink",
"AsyncSource",
]
__version__ = "0.5.7"

View File

@@ -1,3 +0,0 @@
from .zmq_broker import zmq_broker as start_zmq_broker
__all__ = ["start_zmq_broker"]

View File

@@ -1,60 +0,0 @@
import sys
import zmq
from zmq import Socket
from heisskleber.config import load_config
from heisskleber.zmq.config import ZmqConf as BrokerConf
class BrokerBindingError(Exception):
pass
def bind_socket(socket: Socket, address: str, socket_type: str, verbose=False) -> None:
"""Bind a ZMQ socket and handle errors."""
if verbose:
print(f"creating {socket_type} socket")
try:
socket.bind(address)
except Exception as err:
error_message = f"failed to bind to {socket_type}: {err}"
raise BrokerBindingError(error_message) from err
if verbose:
print(f"successfully bound to {socket_type} socket: {address}")
def create_proxy(xpub: Socket, xsub: Socket, verbose=False) -> None:
"""Create a ZMQ proxy to connect XPUB and XSUB sockets."""
if verbose:
print("creating proxy")
try:
zmq.proxy(xpub, xsub)
except Exception as err:
error_message = f"failed to create proxy: {err}"
raise BrokerBindingError(error_message) from err
# TODO reimplement as object?
def zmq_broker(config: BrokerConf) -> None:
"""Start a zmq broker.
Binds to a publisher and subscriber port, allowing many to many connections."""
ctx = zmq.Context()
xpub = ctx.socket(zmq.XPUB)
xsub = ctx.socket(zmq.XSUB)
try:
bind_socket(xpub, config.subscriber_address, "publisher", config.verbose)
bind_socket(xsub, config.publisher_address, "subscriber", config.verbose)
create_proxy(xpub, xsub, config.verbose)
except BrokerBindingError as e:
print(e)
sys.exit(-1)
def main() -> None:
"""Start a zmq broker, with a user specified configuration."""
broker_config = load_config(BrokerConf(), "zmq")
zmq_broker(broker_config)

View File

@@ -1,4 +0,0 @@
from .config import BaseConf, Config
from .parse import load_config
__all__ = ["load_config", "BaseConf", "Config"]

View File

@@ -1,47 +0,0 @@
import argparse
class KeyValue(argparse.Action):
def __call__(self, parser, args, values, option_string=None) -> None:
try:
params = dict(x.split("=") for x in values)
except ValueError as ex:
raise argparse.ArgumentError(
self,
f'Could not parse argument "{values}" as k1=v1 k2=v2 ... format: {ex}',
) from ex
setattr(args, self.dest, params)
def get_cmdline(args=None) -> dict:
"""
get commandline arguments and return a dictionary of
the provided arguments.
available commandline arguments are:
--verbose: flag to toggle debugging output
--print-stdout: flag to toggle all data printed to stdout
--param key1=value1 key2=value2: allows to pass service specific
parameters
"""
arp = argparse.ArgumentParser()
arp.add_argument("--verbose", action="store_true", help="debug output flag")
arp.add_argument(
"--print-stdout",
action="store_true",
help="toggles output of all data to stdout",
)
arp.add_argument(
"--params",
nargs="*",
action=KeyValue,
)
args = arp.parse_args(args)
config = {}
if args.verbose:
config["verbose"] = args.verbose
if args.print_stdout:
config["print_stdout"] = args.print_stdout
if args.params:
config |= args.params
return config

View File

@@ -1,35 +0,0 @@
import socket
import warnings
from dataclasses import dataclass
from typing import Any, TypeVar
@dataclass
class BaseConf:
"""
default configuration class for generic configuration info
"""
verbose: bool = False
print_stdout: bool = False
def __setitem__(self, key: str, value: Any) -> None:
if hasattr(self, key):
self.__setattr__(key, value)
else:
warnings.warn(UserWarning(f"no such class member: {key}"), stacklevel=2)
def __getitem__(self, key: str) -> Any:
if hasattr(self, key):
return getattr(self, key)
else:
warnings.warn(UserWarning(f"no such class member: {key}"), stacklevel=2)
@property
def serial_number(self) -> str:
return socket.gethostname().upper()
Config = TypeVar(
"Config", bound=BaseConf
) # https://stackoverflow.com/a/46227137 , https://docs.python.org/3/library/typing.html#typing.TypeVar

View File

@@ -1,75 +0,0 @@
import logging
from pathlib import Path
from typing import Any
import yaml
from heisskleber.config.cmdline import get_cmdline
from heisskleber.config.config import Config
log = logging.getLogger(__name__)
def get_config_dir() -> Path:
config_dir = Path.home() / ".config" / "heisskleber"
if not config_dir.is_dir():
log.error(f"no such directory: {config_dir}", stacklevel=2)
raise FileNotFoundError
return config_dir
def get_config_filepath(filename: str) -> Path:
config_filepath = get_config_dir() / filename
if not config_filepath.is_file():
log.error(f"no such file: {config_filepath}", stacklevel=2)
raise FileNotFoundError
return config_filepath
def read_yaml_config_file(config_fpath: Path) -> dict[str, Any]:
with config_fpath.open() as config_filehandle:
return yaml.safe_load(config_filehandle) # type: ignore [no-any-return]
def update_config(config: Config, config_dict: dict[str, Any]) -> Config:
for config_key, config_value in config_dict.items():
if not hasattr(config, config_key):
error_msg = f"no such configuration parameter: {config_key}, skipping"
log.info(error_msg, stacklevel=2)
continue
cast_func = type(config[config_key])
try:
config[config_key] = cast_func(config_value)
except Exception as e:
log.warning(
f"failed to cast {config_value} to {type(config[config_key])}: {e}. skipping",
stacklevel=2,
)
continue
return config
def load_config(config: Config, config_filename: str, read_commandline: bool = True) -> Config:
"""Load the config file and update the config object.
Parameters
----------
config : BaseConf
The config object to fill with values.
config_filename : str
The name of the config file in $HOME/.config
If the file does not have an extension the default extension .yaml is appended.
read_commandline : bool
Whether to read arguments from the command line. Optional. Defaults to True.
"""
config_filename = config_filename if "." in config_filename else config_filename + ".yaml"
config_filepath = get_config_filepath(config_filename)
config_dict = read_yaml_config_file(config_filepath)
config = update_config(config, config_dict)
if not read_commandline:
return config
config_dict = get_cmdline()
config = update_config(config, config_dict)
return config

View File

@@ -1,55 +0,0 @@
import json
import time
from heisskleber.core.types import AsyncSink, Serializable, Sink
class ConsoleSink(Sink):
def __init__(self, pretty: bool = False, verbose: bool = False) -> None:
self.verbose = verbose
self.pretty = pretty
def send(self, data: dict[str, Serializable], topic: str) -> None:
verbose_topic = topic + ":\t" if self.verbose else ""
if self.pretty:
print(verbose_topic + json.dumps(data, indent=4))
else:
print(verbose_topic + str(data))
def __repr__(self) -> str:
return f"{self.__class__.__name__}(pretty={self.pretty}, verbose={self.verbose})"
def start(self) -> None:
pass
def stop(self) -> None:
pass
class AsyncConsoleSink(AsyncSink):
def __init__(self, pretty: bool = False, verbose: bool = False) -> None:
self.verbose = verbose
self.pretty = pretty
async def send(self, data: dict[str, Serializable], topic: str) -> None:
verbose_topic = topic + ":\t" if self.verbose else ""
if self.pretty:
print(verbose_topic + json.dumps(data, indent=4))
else:
print(verbose_topic + str(data))
def __repr__(self) -> str:
return f"{self.__class__.__name__}(pretty={self.pretty}, verbose={self.verbose})"
def start(self) -> None:
pass
def stop(self) -> None:
pass
if __name__ == "__main__":
sink = ConsoleSink()
while True:
sink.send({"test": "test"}, "test")
time.sleep(1)

View File

@@ -1,98 +0,0 @@
import asyncio
import json
import sys
import time
from queue import SimpleQueue
from threading import Thread
from heisskleber.core.types import AsyncSource, Serializable, Source
class ConsoleSource(Source):
def __init__(self, topic: str = "console") -> None:
self.topic = topic
self.queue = SimpleQueue()
self.pack = json.loads
self.thread: Thread | None = None
def listener_task(self):
while True:
try:
data = sys.stdin.readline()
payload = self.pack(data)
self.queue.put(payload)
except json.decoder.JSONDecodeError:
print("Invalid JSON")
continue
except ValueError:
break
print("listener task finished")
def receive(self) -> tuple[str, dict[str, Serializable]]:
if not self.thread:
self.start()
data = self.queue.get()
return self.topic, data
def __repr__(self) -> str:
return f"{self.__class__.__name__}(topic={self.topic})"
def start(self) -> None:
self.thread = Thread(target=self.listener_task, daemon=True)
self.thread.start()
def stop(self) -> None:
if self.thread:
sys.stdin.close()
self.thread.join()
class AsyncConsoleSource(AsyncSource):
def __init__(self, topic: str = "console") -> None:
self.topic = topic
self.queue: asyncio.Queue[dict[str, Serializable]] = asyncio.Queue(maxsize=10)
self.pack = json.loads
self.task: asyncio.Task[None] | None = None
async def listener_task(self):
while True:
data = sys.stdin.readline()
payload = self.pack(data)
await self.queue.put(payload)
async def receive(self) -> tuple[str, dict[str, Serializable]]:
if not self.task:
self.start()
data = await self.queue.get()
return self.topic, data
def __repr__(self) -> str:
return f"{self.__class__.__name__}(topic={self.topic})"
def start(self) -> None:
self.task = asyncio.create_task(self.listener_task())
def stop(self) -> None:
if self.task:
self.task.cancel()
if __name__ == "__main__":
console_source = ConsoleSource()
console_source.start()
print("Listening to console input.")
count = 0
try:
while True:
print(console_source.receive())
time.sleep(1)
count += 1
print(count)
except KeyboardInterrupt:
print("Stopped")
sys.exit(0)

View File

@@ -1,59 +0,0 @@
from heisskleber.config import BaseConf, load_config
from heisskleber.mqtt import AsyncMqttPublisher, AsyncMqttSubscriber, MqttConf
from heisskleber.udp import AsyncUdpSink, AsyncUdpSource, UdpConf
from heisskleber.zmq import ZmqAsyncPublisher, ZmqAsyncSubscriber, ZmqConf
from .types import AsyncSink, AsyncSource
_registered_async_sinks: dict[str, tuple[type[AsyncSink], type[BaseConf]]] = {
"mqtt": (AsyncMqttPublisher, MqttConf),
"zmq": (ZmqAsyncPublisher, ZmqConf),
"udp": (AsyncUdpSink, UdpConf),
}
_registered_async_sources: dict[str, tuple] = {
"mqtt": (AsyncMqttSubscriber, MqttConf),
"zmq": (ZmqAsyncSubscriber, ZmqConf),
"udp": (AsyncUdpSource, UdpConf),
}
def get_async_sink(name: str) -> AsyncSink:
"""
Factory function to create a sink object.
Parameters:
name: Name of the sink to create.
config: Configuration object to use for the sink.
"""
if name not in _registered_async_sinks:
error_message = f"{name} is not a registered Sink."
raise KeyError(error_message)
pub_cls, conf_cls = _registered_async_sinks[name]
config = load_config(conf_cls(), name, read_commandline=False)
return pub_cls(config)
def get_async_source(name: str, topic: str | list[str] | tuple[str]) -> AsyncSource:
"""
Factory function to create a source object.
Parameters:
name: Name of the source to create.
config: Configuration object to use for the source.
topic: Topic to subscribe to.
"""
if name not in _registered_async_sources:
error_message = f"{name} is not a registered Source."
raise KeyError(error_message)
sub_cls, conf_cls = _registered_async_sources[name]
config = load_config(conf_cls(), name, read_commandline=False)
return sub_cls(config, topic)

View File

@@ -1,90 +0,0 @@
from heisskleber.config import BaseConf, load_config
from heisskleber.core.types import Sink, Source
from heisskleber.mqtt import MqttConf, MqttPublisher, MqttSubscriber
from heisskleber.serial import SerialConf, SerialPublisher, SerialSubscriber
from heisskleber.udp import UdpConf, UdpPublisher, UdpSubscriber
from heisskleber.zmq import ZmqConf, ZmqPublisher, ZmqSubscriber
_registered_sinks: dict[str, tuple[type[Sink], type[BaseConf]]] = {
"zmq": (ZmqPublisher, ZmqConf),
"mqtt": (MqttPublisher, MqttConf),
"serial": (SerialPublisher, SerialConf),
"udp": (UdpPublisher, UdpConf),
}
_registered_sources: dict[str, tuple[type[Source], type[BaseConf]]] = {
"zmq": (ZmqSubscriber, ZmqConf),
"mqtt": (MqttSubscriber, MqttConf),
"serial": (SerialSubscriber, SerialConf),
"udp": (UdpSubscriber, UdpConf),
}
def get_sink(name: str) -> Sink:
"""
Factory function to create a sink object.
Parameters:
name: Name of the sink to create.
config: Configuration object to use for the sink.
"""
if name not in _registered_sinks:
error_message = f"{name} is not a registered Sink."
raise KeyError(error_message)
pub_cls, conf_cls = _registered_sinks[name]
print(f"loading {name} config")
config = load_config(conf_cls(), name, read_commandline=False)
return pub_cls(config)
def get_source(name: str, topic: str | list[str]) -> Source:
"""
Factory function to create a source object.
Parameters:
name: Name of the source to create.
config: Configuration object to use for the source.
topic: Topic to subscribe to.
"""
if name not in _registered_sinks:
error_message = f"{name} is not a registered Source."
raise KeyError(error_message)
sub_cls, conf_cls = _registered_sources[name]
print(f"loading {name} config")
config = load_config(conf_cls(), name, read_commandline=False)
return sub_cls(config, topic)
def get_subscriber(name: str, topic: str | list[str]) -> Source:
"""
Deprecated: Factory function to create a source object (formerly known as subscriber).
Parameters:
name: Name of the source to create.
config: Configuration object to use for the source.
topic: Topic to subscribe to.
"""
print("Deprecated: use get_source instead.")
return get_source(name, topic)
def get_publisher(name: str) -> Sink:
"""
Deprecated: Factory function to create a sink object (formerly known as publisher).
Parameters:
name: Name of the sink to create.
config: Configuration object to use for the sink.
"""
print("Deprecated: use get_sink instead.")
return get_sink(name)

View File

@@ -1,44 +0,0 @@
"""Packer and unpacker for network data."""
import json
import pickle
from typing import Any, Callable
from .types import Serializable
def get_packer(style: str) -> Callable[[dict[str, Serializable]], str]:
"""Return a packer function for the given style.
Packer func serializes a given dict."""
if style in _packstyles:
return _packstyles[style]
else:
return _packstyles["default"]
def get_unpacker(style: str) -> Callable[[str], dict[str, Serializable]]:
"""Return an unpacker function for the given style.
Unpacker func deserializes a string."""
if style in _unpackstyles:
return _unpackstyles[style]
else:
return _unpackstyles["default"]
def serialpacker(data: dict[str, Any]) -> str:
return ",".join([str(v) for v in data.values()])
_packstyles: dict[str, Callable[[dict[str, Serializable]], str]] = {
"default": json.dumps,
"json": json.dumps,
"pickle": pickle.dumps, # type: ignore
"serial": serialpacker,
}
_unpackstyles: dict[str, Callable[[str], dict[str, Serializable]]] = {
"default": json.loads,
"json": json.loads,
"pickle": pickle.loads, # type: ignore
}

View File

@@ -1,180 +0,0 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from collections.abc import AsyncGenerator, Generator
from typing import Any, Callable, Union
from heisskleber.config import BaseConf
Serializable = Union[int, float]
class Sink(ABC):
"""
Sink interface to send() data to.
"""
pack: Callable[[dict[str, Serializable]], str]
@abstractmethod
def __init__(self, config: BaseConf) -> None:
"""
Initialize the publisher with a configuration object.
"""
pass
@abstractmethod
def send(self, data: dict[str, Serializable], topic: str) -> None:
"""
Send data via the implemented output stream.
"""
pass
@abstractmethod
def __repr__(self) -> str:
pass
@abstractmethod
def start(self) -> None:
"""
Start any background processes and tasks.
"""
pass
@abstractmethod
def stop(self) -> None:
"""
Stop any background processes and tasks.
"""
pass
class Source(ABC):
"""
Source interface that emits data via the receive() method.
"""
unpack: Callable[[str], dict[str, Serializable]]
def __iter__(self) -> Generator[tuple[str, dict[str, Serializable]], None, None]:
topic, data = self.receive()
yield topic, data
@abstractmethod
def __init__(self, config: BaseConf, topic: str | list[str]) -> None:
"""
Initialize the subscriber with a topic and a configuration object.
"""
pass
@abstractmethod
def receive(self) -> tuple[str, dict[str, Serializable]]:
"""
Blocking function to receive data from the implemented input stream.
Data is returned as a tuple of (topic, data).
"""
pass
@abstractmethod
def __repr__(self) -> str:
pass
@abstractmethod
def start(self) -> None:
"""
Start any background processes and tasks.
"""
pass
@abstractmethod
def stop(self) -> None:
"""
Stop any background processes and tasks.
"""
pass
class AsyncSource(ABC):
"""
AsyncSubscriber interface
"""
async def __aiter__(self) -> AsyncGenerator[tuple[str, dict[str, Serializable]], None]:
while True:
topic, data = await self.receive()
yield topic, data
@abstractmethod
def __init__(self, config: Any, topic: str | list[str]) -> None:
"""
Initialize the subscriber with a topic and a configuration object.
"""
pass
@abstractmethod
async def receive(self) -> tuple[str, dict[str, Serializable]]:
"""
Blocking function to receive data from the implemented input stream.
Data is returned as a tuple of (topic, data).
"""
pass
@abstractmethod
def __repr__(self) -> str:
pass
@abstractmethod
def start(self) -> None:
"""
Start any background processes and tasks.
"""
pass
@abstractmethod
def stop(self) -> None:
"""
Stop any background processes and tasks.
"""
pass
class AsyncSink(ABC):
"""
Sink interface to send() data to.
"""
pack: Callable[[dict[str, Serializable]], str]
@abstractmethod
def __init__(self, config: BaseConf) -> None:
"""
Initialize the publisher with a configuration object.
"""
pass
@abstractmethod
async def send(self, data: dict[str, Any], topic: str) -> None:
"""
Send data via the implemented output stream.
"""
pass
@abstractmethod
def __repr__(self) -> str:
pass
@abstractmethod
def start(self) -> None:
"""
Start any background processes and tasks.
"""
pass
@abstractmethod
def stop(self) -> None:
"""
Stop any background processes and tasks.
"""
pass

View File

@@ -1,20 +0,0 @@
from dataclasses import dataclass
from heisskleber.config import BaseConf
@dataclass
class InfluxDBConf(BaseConf):
host: str = "localhost"
port: int = 8086
bucket: str = "test"
org: str = "test"
ssl: bool = False
read_token: str = ""
write_token: str = ""
all_access_token: str = ""
@property
def url(self) -> str:
protocol = "https" if self.ssl else "http"
return f"{protocol}://{self.host}:{self.port}"

View File

@@ -1,75 +0,0 @@
import pandas as pd
from influxdb_client import InfluxDBClient
from heisskleber.core.types import Source
from .config import InfluxDBConf
def build_query(options: dict) -> str:
query = (
f'from(bucket:"{options["bucket"]}")'
+ f'|> range(start: {options["start"].isoformat("T")}, stop: {options["end"].isoformat("T")})'
+ f'|> filter(fn:(r) => r._measurement == "{options["measurement"]}")'
)
if options["filter"]:
for attribute, value in options["filter"].items():
if isinstance(value, list):
query += f'|> filter(fn:(r) => r.{attribute} == "{value[0]}"'
for vv in value[1:]:
query += f' or r.{attribute} == "{vv}"'
query += ")"
else:
query += f'|> filter(fn:(r) => r.{attribute} == "{value}")'
query += (
f'|> aggregateWindow(every: {options["resample"]}, fn: mean)'
+ '|> pivot(rowKey:["_time"], columnKey: ["_field"], valueColumn: "_value")'
)
return query
class Influx_Subscriber(Source):
def __init__(self, config: InfluxDBConf, query: str):
self.config = config
self.query = query
self.client: InfluxDBClient = InfluxDBClient(
url=self.config.url,
token=self.config.all_access_token or self.config.read_token,
org=self.config.org,
timeout=60_000,
)
self.reader = self.client.query_api()
self._run_query()
self.index = 0
def receive(self) -> tuple[str, dict]:
row = self.df.iloc[self.index].to_dict()
self.index += 1
return "influx", row
def _run_query(self):
self.df: pd.DataFrame = self.reader.query_data_frame(self.query, org=self.config.org)
self.df["epoch"] = pd.to_numeric(self.df["_time"]) / 1e9
self.df.drop(
columns=[
"result",
"table",
"_start",
"_stop",
"_measurement",
"_time",
"topic",
],
inplace=True,
)
def __iter__(self):
for _, row in self.df.iterrows():
yield "influx", row.to_dict()
def __next__(self):
return self.__iter__().__next__()

View File

@@ -1,48 +0,0 @@
from influxdb_client import InfluxDBClient, WriteOptions
from config import InfluxDBConf
from heisskleber.config import load_config
class Influx_Writer:
def __init__(self, config: InfluxDBConf):
self.config = config
# self.write_options = SYNCHRONOUS
self.write_options = WriteOptions(
batch_size=500,
flush_interval=10_000,
jitter_interval=2_000,
retry_interval=5_000,
max_retries=5,
max_retry_delay=30_000,
exponential_base=2,
)
self.client = InfluxDBClient(url=self.config.url, token=self.config.token, org=self.config.org)
self.writer = self.client.write_api(
write_options=self.write_options,
)
def __del__(self):
self.writer.close()
self.client.close()
def write_line(self, line):
self.writer.write(bucket=self.config.bucket, record=line)
def write_from_generator(self, generator):
for line in generator:
self.writer.write(bucket=self.config.bucket, record=line)
def write_from_line_generator(self, generator):
with InfluxDBClient(
url=self.config.url, token=self.config.token, org=self.config.org
) as client, client.write_api(
write_options=self.write_options,
) as write_api:
for line in generator:
write_api.write(bucket=self.config.bucket, record=line)
def get_parsed_flux_writer():
config = load_config(InfluxDBConf(), "flux", read_commandline=False)
return Influx_Writer(config)

View File

@@ -1,7 +0,0 @@
from .config import MqttConf
from .publisher import MqttPublisher
from .publisher_async import AsyncMqttPublisher
from .subscriber import MqttSubscriber
from .subscriber_async import AsyncMqttSubscriber
__all__ = ["MqttConf", "MqttPublisher", "MqttSubscriber", "AsyncMqttSubscriber", "AsyncMqttPublisher"]

View File

@@ -1,24 +0,0 @@
from dataclasses import dataclass, field
from heisskleber.config import BaseConf
@dataclass
class MqttConf(BaseConf):
"""
MQTT configuration class.
"""
host: str = "localhost"
user: str = ""
password: str = ""
port: int = 1883
ssl: bool = False
qos: int = 0
retain: bool = False
topics: list[str] = field(default_factory=list)
mapping: str = "/deprecated/" # deprecated
packstyle: str = "json"
max_saved_messages: int = 100
timeout_s: int = 60
source_id: str = "box-01"

View File

@@ -1,23 +0,0 @@
from heisskleber import get_publisher, get_subscriber
from heisskleber.config import load_config
from .config import MqttConf
def map_topic(zmq_topic: str, mapping: str) -> str:
return mapping + zmq_topic
def main() -> None:
config: MqttConf = load_config(MqttConf(), "mqtt")
sub = get_subscriber("zmq", config.topics)
pub = get_publisher("mqtt")
pub.pack = lambda x: x # type: ignore
sub.unpack = lambda x: x # type: ignore
while True:
(zmq_topic, data) = sub.receive()
mqtt_topic = map_topic(zmq_topic, config.mapping)
pub.send(data, mqtt_topic)

View File

@@ -1,97 +0,0 @@
import ssl
import sys
import threading
from paho.mqtt.client import Client as mqtt_client
from .config import MqttConf
class ThreadDiedError(RuntimeError):
pass
_thread_died = threading.Event()
_default_excepthook = threading.excepthook
def _set_thread_died_excepthook(args, /):
_default_excepthook(args)
global _thread_died
_thread_died.set()
threading.excepthook = _set_thread_died_excepthook
class MqttBase:
"""
Wrapper around eclipse paho mqtt client.
Handles connection and callbacks.
Callbacks may be overwritten in subclasses.
"""
def __init__(self, config: MqttConf) -> None:
self.config = config
self.client = mqtt_client()
self.is_connected = False
def start(self) -> None:
if not self.is_connected:
self.connect()
def stop(self) -> None:
if self.client:
self.client.loop_stop()
self.is_connected = False
def connect(self) -> None:
self.client.username_pw_set(self.config.user, self.config.password)
# Add callbacks
self.client.on_connect = self._on_connect
self.client.on_disconnect = self._on_disconnect
self.client.on_publish = self._on_publish
self.client.on_message = self._on_message
if self.config.ssl:
# By default, on Python 2.7.9+ or 3.4+,
# the default certification authority of the system is used.
self.client.tls_set(tls_version=ssl.PROTOCOL_TLS_CLIENT)
self.client.connect(self.config.host, self.config.port)
self.client.loop_start()
self.is_connected = True
@staticmethod
def _raise_if_thread_died() -> None:
global _thread_died
if _thread_died.is_set():
raise ThreadDiedError()
# MQTT callbacks
def _on_connect(self, client, userdata, flags, return_code) -> None:
if return_code == 0:
print(f"MQTT node connected to {self.config.host}:{self.config.port}")
else:
print("Connection failed!")
if self.config.verbose:
print(flags)
def _on_disconnect(self, client, userdata, return_code) -> None:
print(f"Disconnected from broker with return code {return_code}")
if return_code != 0:
print("Killing this service")
sys.exit(-1)
def _on_publish(self, client, userdata, message_id) -> None:
if self.config.verbose:
print(f"Published message with id {message_id}, qos={self.config.qos}")
def _on_message(self, client, userdata, message) -> None:
if self.config.verbose:
print(f"Received message: {message.payload!s}, topic: {message.topic}, qos: {message.qos}")
def __del__(self) -> None:
self.stop()

View File

@@ -1,44 +0,0 @@
from __future__ import annotations
from heisskleber.core.packer import get_packer
from heisskleber.core.types import Serializable, Sink
from .config import MqttConf
from .mqtt_base import MqttBase
class MqttPublisher(MqttBase, Sink):
"""
MQTT publisher class.
Can be used everywhere that a flucto style publishing connection is required.
Network message loop is handled in a separated thread.
"""
def __init__(self, config: MqttConf) -> None:
super().__init__(config)
self.pack = get_packer(config.packstyle)
def send(self, data: dict[str, Serializable], topic: str) -> None:
"""
Takes python dictionary, serializes it according to the packstyle
and sends it to the broker.
Publishing is asynchronous
"""
if not self.is_connected:
self.start()
self._raise_if_thread_died()
payload = self.pack(data)
self.client.publish(topic, payload, qos=self.config.qos, retain=self.config.retain)
def __repr__(self) -> str:
return f"{self.__class__.__name__}(host={self.config.host}, port={self.config.port})"
def start(self) -> None:
super().start()
def stop(self) -> None:
super().stop()

View File

@@ -1,70 +0,0 @@
from asyncio import Queue, Task, create_task, sleep
import aiomqtt
from heisskleber.core.packer import get_packer
from heisskleber.core.types import AsyncSink, Serializable
from .config import MqttConf
class AsyncMqttPublisher(AsyncSink):
"""
MQTT publisher class.
Can be used everywhere that a flucto style publishing connection is required.
Network message loop is handled in a separated thread.
"""
def __init__(self, config: MqttConf) -> None:
self.config = config
self.pack = get_packer(config.packstyle)
self._send_queue: Queue[tuple[dict[str, Serializable], str]] = Queue()
self._sender_task: Task[None] | None = None
async def send(self, data: dict[str, Serializable], topic: str) -> None:
"""
Takes python dictionary, serializes it according to the packstyle
and sends it to the broker.
Publishing is asynchronous
"""
if not self._sender_task:
self.start()
await self._send_queue.put((data, topic))
async def send_work(self) -> None:
"""
Takes python dictionary, serializes it according to the packstyle
and sends it to the broker.
Publishing is asynchronous
"""
while True:
try:
async with aiomqtt.Client(
hostname=self.config.host,
port=self.config.port,
username=self.config.user,
password=self.config.password,
timeout=float(self.config.timeout_s),
) as client:
while True:
data, topic = await self._send_queue.get()
payload = self.pack(data)
await client.publish(topic, payload)
except aiomqtt.MqttError:
print("Connection to MQTT broker failed. Retrying in 5 seconds")
await sleep(5)
def __repr__(self) -> str:
return f"{self.__class__.__name__}(broker={self.config.host}, port={self.config.port})"
def start(self) -> None:
self._sender_task = create_task(self.send_work())
def stop(self) -> None:
if self._sender_task:
self._sender_task.cancel()
self._sender_task = None

View File

@@ -1,81 +0,0 @@
from __future__ import annotations
from queue import SimpleQueue
from typing import Any
from paho.mqtt.client import MQTTMessage
from heisskleber.core.packer import get_unpacker
from heisskleber.core.types import Source
from .config import MqttConf
from .mqtt_base import MqttBase
class MqttSubscriber(MqttBase, Source):
"""
MQTT subscriber, wraps around ecplipse's paho mqtt client.
Network message loop is handled in a separated thread.
Incoming messages are saved as a stack when not processed via the receive() function.
"""
def __init__(self, config: MqttConf, topics: str | list[str]) -> None:
super().__init__(config)
self.topics = topics
self._message_queue: SimpleQueue[MQTTMessage] = SimpleQueue()
self.unpack = get_unpacker(config.packstyle)
def subscribe(self, topics: str | list[str] | tuple[str]) -> None:
"""
Subscribe to one or multiple topics
"""
if not self.is_connected:
super().start()
self.client.on_message = self._on_message
if isinstance(topics, (list, tuple)):
# if subscribing to multiple topics, use a list of tuples
subscription_list = [(topic, self.config.qos) for topic in topics]
self.client.subscribe(subscription_list)
else:
self.client.subscribe(topics, self.config.qos)
if self.config.verbose:
print(f"Subscribed to: {topics}")
def receive(self) -> tuple[str, dict[str, Any]]:
"""
Reads a message from mqtt and returns it
Messages are saved in a stack, if no message is available, this function blocks.
Returns:
tuple(topic: str, message: dict): the message received
"""
if not self.client:
self.start()
self._raise_if_thread_died()
mqtt_message = self._message_queue.get(block=True, timeout=self.config.timeout_s)
message_returned = self.unpack(mqtt_message.payload.decode())
return (mqtt_message.topic, message_returned)
def __repr__(self) -> str:
return f"{self.__class__.__name__}(host={self.config.host}, port={self.config.port})"
def start(self) -> None:
super().start()
self.subscribe(self.topics)
self.client.on_message = self._on_message
def stop(self) -> None:
super().stop()
# callback to add incoming messages onto stack
def _on_message(self, client, userdata, message) -> None:
self._message_queue.put(message)
if self.config.verbose:
print(f"Topic: {message.topic}")
print(f"MQTT message: {message.payload.decode()}")

View File

@@ -1,86 +0,0 @@
from asyncio import Queue, Task, create_task, sleep
from aiomqtt import Client, Message, MqttError
from heisskleber.core.packer import get_unpacker
from heisskleber.core.types import AsyncSource, Serializable
from heisskleber.mqtt import MqttConf
class AsyncMqttSubscriber(AsyncSource):
"""Asynchronous MQTT susbsciber based on aiomqtt.
Data is received by the `receive` method returns the newest message in the queue.
"""
def __init__(self, config: MqttConf, topic: str | list[str]) -> None:
self.config: MqttConf = config
self.client = Client(
hostname=self.config.host,
port=self.config.port,
username=self.config.user,
password=self.config.password,
)
self.topics = topic
self.unpack = get_unpacker(self.config.packstyle)
self.message_queue: Queue[Message] = Queue(self.config.max_saved_messages)
self._listener_task: Task[None] | None = None
def __repr__(self) -> str:
return f"{self.__class__.__name__}(broker={self.config.host}, port={self.config.port})"
def start(self) -> None:
self._listener_task = create_task(self.run())
def stop(self) -> None:
if self._listener_task:
self._listener_task.cancel()
self._listener_task = None
async def receive(self) -> tuple[str, dict[str, Serializable]]:
"""
Await the newest message in the queue and return Tuple
"""
if not self._listener_task:
self.start()
mqtt_message = await self.message_queue.get()
return self._handle_message(mqtt_message)
async def run(self):
"""
Handle the connection to MQTT broker and run the message loop.
"""
while True:
try:
async with self.client:
await self._subscribe_topics()
await self._listen_mqtt_loop()
except MqttError as e:
print(f"MqttError: {e}")
print("Connection to MQTT failed. Retrying...")
await sleep(1)
async def _listen_mqtt_loop(self) -> None:
"""
Listen to incoming messages asynchronously and put them into a queue
"""
async with self.client.messages() as messages:
# async with self.client.filtered_messages(self.topics) as messages:
async for message in messages:
await self.message_queue.put(message)
def _handle_message(self, message: Message) -> tuple[str, dict[str, Serializable]]:
if not isinstance(message.payload, bytes):
error_msg = "Payload is not of type bytes."
raise TypeError(error_msg)
topic = str(message.topic)
message_returned = self.unpack(message.payload.decode())
return (topic, message_returned)
async def _subscribe_topics(self) -> None:
print(f"subscribing to {self.topics}")
if isinstance(self.topics, list):
await self.client.subscribe([(topic, self.config.qos) for topic in self.topics])
else:
await self.client.subscribe(self.topics, self.config.qos)

View File

@@ -1,95 +0,0 @@
import argparse
import sys
from typing import Callable, Union
from heisskleber.config import load_config
from heisskleber.console.sink import ConsoleSink
from heisskleber.core.factories import _registered_sources
from heisskleber.mqtt import MqttSubscriber
from heisskleber.udp import UdpSubscriber
from heisskleber.zmq import ZmqSubscriber
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
prog="hkcli",
description="Heisskleber command line interface",
usage="%(prog)s [options]",
)
parser.add_argument(
"-t",
"--type",
type=str,
choices=["zmq", "mqtt", "serial", "udp"],
default="zmq",
)
parser.add_argument(
"-T",
"--topic",
type=str,
default="#",
help="Topic to subscribe to, valid for zmq and mqtt only.",
)
parser.add_argument(
"-H",
"--host",
type=str,
help="Host or broker for MQTT, zmq and UDP.",
)
parser.add_argument(
"-P",
"--port",
type=int,
help="Port or serial interface for MQTT, zmq and UDP.",
)
parser.add_argument("-v", "--verbose", action="store_true")
parser.add_argument("-p", "--pretty", action="store_true", help="Pretty print JSON data.")
return parser.parse_args()
def keyboardexit(func) -> Callable:
def wrapper(*args, **kwargs) -> Union[None, int]:
try:
return func(*args, **kwargs)
except KeyboardInterrupt:
print("Exiting...")
sys.exit(0)
return wrapper
@keyboardexit
def main() -> None:
args = parse_args()
sink = ConsoleSink(pretty=args.pretty, verbose=args.verbose)
sub_cls, conf_cls = _registered_sources[args.type]
try:
config = load_config(conf_cls(), args.type, read_commandline=False)
except FileNotFoundError:
print(f"No config file found for {args.type}, using default values and user input.")
config = conf_cls()
source = sub_cls(config, args.topic)
if isinstance(source, (MqttSubscriber, UdpSubscriber)):
source.config.host = args.host or source.config.host
source.config.port = args.port or source.config.port
elif isinstance(source, ZmqSubscriber):
source.config.host = args.host or source.config.host
source.config.subscriber_port = args.port or source.config.subscriber_port
source.topic = "" if args.topic == "#" else args.topic
elif isinstance(source, UdpSubscriber):
source.config.port = args.port or source.config.port
source.start()
sink.start()
while True:
topic, data = source.receive()
sink.send(data, topic)
if __name__ == "__main__":
main()

View File

@@ -1,12 +0,0 @@
from heisskleber.broker import start_zmq_broker
from heisskleber.config import load_config
from heisskleber.zmq.config import ZmqConf as BrokerConf
def main():
broker_config = load_config(BrokerConf(), "zmq")
start_zmq_broker(config=broker_config)
if __name__ == "__main__":
main()

View File

@@ -1,5 +0,0 @@
from .config import SerialConf
from .publisher import SerialPublisher
from .subscriber import SerialSubscriber
__all__ = ["SerialConf", "SerialPublisher", "SerialSubscriber"]

View File

@@ -1,11 +0,0 @@
from dataclasses import dataclass
from heisskleber.config import BaseConf
@dataclass
class SerialConf(BaseConf):
port: str = "/dev/serial0"
baudrate: int = 9600
bytesize: int = 8
encoding: str = "ascii"

View File

@@ -1,31 +0,0 @@
from heisskleber.core.types import Source
from .publisher import SerialPublisher
class SerialForwarder:
def __init__(self, subscriber: Source, publisher: SerialPublisher) -> None:
self.sub = subscriber
self.pub = publisher
"""
Wait for message and forward
"""
def forward_message(self) -> None:
# collected = {}
# for sub in self.sub:
# topic, data = sub.receive()
# collected.update(data)
topic, data = self.sub.receive()
# We send the topic and let the publisher decide what to do with it
self.pub.send(data, topic)
"""
Enter loop and continuously forward messages
"""
def sub_pub_loop(self) -> None:
while True:
self.forward_message()

View File

@@ -1,87 +0,0 @@
from __future__ import annotations
import sys
from typing import Callable, Optional
import serial
from heisskleber.core.packer import get_packer
from heisskleber.core.types import Serializable, Sink
from .config import SerialConf
class SerialPublisher(Sink):
serial_connection: serial.Serial
"""
Publisher for serial devices.
Can be used everywhere that a flucto style publishing connection is required.
Parameters
----------
config : SerialConf
Configuration for the serial connection.
pack_func : FunctionType
Function to translate from a dict to a serialized string.
"""
def __init__(
self,
config: SerialConf,
pack_func: Optional[Callable] = None, # noqa: UP007
):
self.config = config
self.pack = pack_func if pack_func else get_packer("serial")
self.is_connected = False
def start(self) -> None:
"""
Start the serial connection.
"""
try:
self.serial_connection = serial.Serial(
port=self.config.port,
baudrate=self.config.baudrate,
bytesize=self.config.bytesize,
parity=serial.PARITY_NONE,
stopbits=serial.STOPBITS_ONE,
)
except serial.SerialException:
print(f"Failed to connect to serial device at port {self.config.port}")
sys.exit(1)
print(f"Successfully connected to serial device at port {self.config.port}")
self.is_connected = True
def stop(self) -> None:
"""
Stop the serial connection.
"""
if hasattr(self, "serial_connection") and self.serial_connection.is_open:
self.serial_connection.flush()
self.serial_connection.close()
def send(self, data: dict[str, Serializable], topic: str) -> None:
"""
Takes python dictionary, serializes it according to the packstyle
and sends it to the broker.
Parameters
----------
message : dict
object to be serialized and sent via the serial connection. Usually a dict.
"""
if not self.is_connected:
self.start()
payload = self.pack(data)
self.serial_connection.write(payload.encode(self.config.encoding))
self.serial_connection.flush()
if self.config.verbose:
print(f"{topic}: {payload}")
def __repr__(self) -> str:
return f"SerialPublisher(port={self.config.port}, baudrate={self.config.baudrate}, bytezize={self.config.bytesize}, encoding={self.config.encoding})"
def __del__(self) -> None:
self.stop()

View File

@@ -1,110 +0,0 @@
import sys
from collections.abc import Generator
from typing import Callable
import serial
from heisskleber.core.types import Source
from .config import SerialConf
class SerialSubscriber(Source):
serial_connection: serial.Serial
"""
Subscriber for serial devices. Connects to a serial port and reads from it.
Parameters
----------
topics :
Placeholder for topic. Not used.
config : SerialConf
Configuration class for the serial connection.
unpack_func : FunctionType
Function to translate from a serialized string to a dict.
"""
def __init__(
self,
config: SerialConf,
topic: str | None = None,
custom_unpack: Callable | None = None,
):
self.config = config
self.topic = topic
self.unpack = custom_unpack if custom_unpack else lambda x: x # types: ignore
self.is_connected = False
def start(self) -> None:
"""
Start the serial connection.
"""
try:
self.serial_connection = serial.Serial(
port=self.config.port,
baudrate=self.config.baudrate,
bytesize=self.config.bytesize,
parity=serial.PARITY_NONE,
stopbits=serial.STOPBITS_ONE,
)
except serial.SerialException:
print(f"Failed to connect to serial device at port {self.config.port}")
sys.exit(1)
print(f"Successfully connected to serial device at port {self.config.port}")
self.is_connected = True
def stop(self) -> None:
"""
Stop the serial connection.
"""
if hasattr(self, "serial_connection") and self.serial_connection.is_open:
self.serial_connection.flush()
self.serial_connection.close()
def receive(self) -> tuple[str, dict]:
"""
Wait for data to arrive on the serial port and return it.
Returns
-------
:return: (topic, payload)
topic is a placeholder to adhere to the Subscriber interface
payload is a dictionary containing the data from the serial port
"""
if not self.is_connected:
self.start()
# message is a string
message = next(self.read_serial_port())
# payload is a dictionary
payload = self.unpack(message)
# port is a placeholder for topic
return self.config.port, payload
def read_serial_port(self) -> Generator[str, None, None]:
"""
Generator function reading from the serial port.
Returns
-------
:return: Generator[str, None, None]
Generator yielding strings read from the serial port
"""
buffer = ""
while True:
try:
buffer = self.serial_connection.readline().decode(self.config.encoding, "ignore")
yield buffer
except UnicodeError as e:
if self.config.verbose:
print(f"Could not decode: {buffer!r}")
print(e)
continue
def __repr__(self) -> str:
return f"SerialPublisher(port={self.config.port}, baudrate={self.config.baudrate}, bytezize={self.config.bytesize}, encoding={self.config.encoding})"
def __del__(self) -> None:
self.stop()

View File

@@ -1,5 +0,0 @@
from .config import ResamplerConf
from .joint import Joint
from .resampler import Resampler
__all__ = ["Resampler", "ResamplerConf", "Joint"]

View File

@@ -1,77 +0,0 @@
from collections import deque
import numpy as np
import scipy.signal # type: ignore [import-untyped]
from numpy.typing import NDArray
from heisskleber.core.types import AsyncSource, Serializable
from heisskleber.stream.filter import Filter
class LiveLFilter:
"""
Filter using standard difference equations.
Kudos to Sam Proell https://www.samproell.io/posts/yarppg/yarppg-live-digital-filter/
"""
def __init__(self, b: NDArray[np.float64], a: NDArray[np.float64], init_val: float = 0.0) -> None:
"""Initialize live filter based on difference equation.
Args:
b (array-like): numerator coefficients obtained from scipy.
a (array-like): denominator coefficients obtained from scipy.
"""
self.b = b
self.a = a
self._xs = deque([init_val] * len(b), maxlen=len(b))
self._ys = deque([init_val] * (len(a) - 1), maxlen=len(a) - 1)
def __call__(self, x: float) -> float:
"""Filter incoming data with standard difference equations."""
self._xs.appendleft(x)
y = np.dot(self.b, self._xs) - np.dot(self.a[1:], self._ys)
y = y / self.a[0]
self._ys.appendleft(y)
return y # type: ignore [no-any-return]
def __repr__(self) -> str:
return f"{self.__class__.__name__}(b={self.b}, a={self.a})"
class ButterFilter(Filter):
"""
Butterworth filter based on scipy.
Args:
source (AsyncSource): Source of data.
cutoff_freq (float): Cutoff frequency.
sampling_rate (float): Sampling rate of the input signal, i.e update frequency
btype (str): Type of filter, "high" or "low"
order (int): order of the filter to be applied
Example:
>>> source = get_async_source()
>>> filter = LowPassFilter(source, 0.1, 100, 3)
>>> async for topic, data in filter:
>>> print(topic, data)
"""
def __init__(
self, source: AsyncSource, cutoff_freq: float, sampling_rate: float, btype: str = "low", order: int = 3
) -> None:
self.source = source
nyquist_fq = sampling_rate / 2
Wn = cutoff_freq / nyquist_fq
self.b, self.a = scipy.signal.iirfilter(order, Wn=Wn, fs=sampling_rate, btype=btype, ftype="butter")
self.filters: dict[str, LiveLFilter] = {}
def _filter(self, data: dict[str, Serializable]) -> dict[str, Serializable]:
if not self.filters:
for key in data:
self.filters[key] = LiveLFilter(a=self.a, b=self.b)
for key, value in data.items():
data[key] = self.filters[key](value)
return data

View File

@@ -1,6 +0,0 @@
from dataclasses import dataclass
@dataclass
class ResamplerConf:
resample_rate: int = 1000

View File

@@ -1,19 +0,0 @@
from abc import ABC, abstractmethod
from collections.abc import AsyncGenerator
from typing import Any
from heisskleber.core.types import AsyncSource, Serializable
class Filter(ABC):
def __init__(self, source: AsyncSource):
self.source = source
async def __aiter__(self) -> AsyncGenerator[Any, None]:
async for topic, data in self.source:
data = self._filter(data)
yield topic, data
@abstractmethod
def _filter(self, data: dict[str, Serializable]) -> dict[str, Serializable]:
pass

View File

@@ -1,56 +0,0 @@
from heisskleber.core.types import AsyncSource
from heisskleber.stream.filter import Filter
class GhFilter(Filter):
"""
G-H filter (also called alpha-beta, f-g filter), simplified observer for estimation and data smoothing.
Args:
source (AsyncSource): Source of data.
g (float): Correction gain for value
h (float): Correction gain for derivative
Example:
>>> source = get_async_source()
>>> filter = GhFilter(source, 0.008, 0.001)
>>> async for topic, data in filter:
>>> print(topic, data)
"""
def __init__(self, source: AsyncSource, g: float, h: float):
self.source = source
if not 0 < g < 1.0 or not 0 < h < 1.0:
msg = "g and h must be between 0 and 1.0"
raise ValueError(msg)
self.g = g
self.h = h
self.x: dict[str, float] = {}
self.dx: dict[str, float] = {}
def _filter(self, data: dict[str, float]) -> dict[str, float]:
if not self.x:
self.x = data
self.dx = {key: 0.0 for key in data}
return data
invalid_keys = []
ts = data.pop("epoch")
dt = ts - self.x["epoch"]
if abs(dt) <= 1e-4:
data["epoch"] = ts
return data
for key in data:
if not isinstance(data[key], float):
invalid_keys.append(key)
continue
x_pred = self.x[key] + dt * self.dx[key]
residual = data[key] - x_pred
self.dx[key] = self.dx[key] + self.h * residual / dt
self.x[key] = x_pred + self.g * residual
for key in invalid_keys:
self.x[key] = data[key]
self.x["epoch"] = ts
return self.x

View File

@@ -1,114 +0,0 @@
import asyncio
from typing import Any
from heisskleber.core.types import Serializable
from heisskleber.stream.resampler import Resampler, ResamplerConf
class Joint:
"""Joint that takes multiple async streams and synchronizes them based on their timestamps.
Note that you need to run the setup() function first to initialize the
Parameters:
----------
conf : ResamplerConf
Configuration for the joint.
subscribers : list[AsyncSubscriber]
List of asynchronous subscribers.
"""
def __init__(self, conf: ResamplerConf, resamplers: list[Resampler]):
self.conf = conf
self.resamplers = resamplers
self.output_queue: asyncio.Queue[dict[str, Serializable]] = asyncio.Queue()
self.initialized = asyncio.Event()
self.initalize_task = asyncio.create_task(self.sync())
self.combined_dict: dict[str, Serializable] = {}
self.task: asyncio.Task[None] | None = None
def __repr__(self) -> str:
return f"""Joint(resample_rate={self.conf.resample_rate},
sources={len(self.resamplers)} of type(s): {{r.__class__.__name__ for r in self.resamplers}})"""
def start(self) -> None:
self.task = asyncio.create_task(self.output_work())
def stop(self) -> None:
if self.task:
self.task.cancel()
async def receive(self) -> dict[str, Any]:
"""
Main interaction coroutine: Get next value out of the queue.
"""
if not self.task:
self.start()
output = await self.output_queue.get()
return output
async def sync(self) -> None:
"""Synchronize the resamplers by pulling data from each until the timestamp is aligned. Retains first matching data."""
print("Starting sync")
datas = await asyncio.gather(*[source.receive() for source in self.resamplers])
print("Got data")
output_data = {}
data = {}
latest_timestamp: float = 0.0
timestamps = []
print("Syncing...")
for data in datas:
if not isinstance(data["epoch"], float):
error = "Timestamps must be floats"
raise TypeError(error)
ts = float(data["epoch"])
print(f"Syncing..., got {ts}")
timestamps.append(ts)
if ts > latest_timestamp:
latest_timestamp = ts
# only take the piece of the latest data
output_data = data
for resampler, ts in zip(self.resamplers, timestamps):
while ts < latest_timestamp:
data = await resampler.receive()
ts = float(data["epoch"])
output_data.update(data)
await self.output_queue.put(output_data)
print("Finished initalization")
self.initialized.set()
"""
Coroutine that waits for new queue data and updates dict.
"""
async def update_dict(self, resampler: Resampler) -> None:
data = await resampler.receive()
if self.combined_dict and self.combined_dict["epoch"] != data["epoch"]:
print("Oh shit, this is bad!")
self.combined_dict.update(data)
"""
Output worker: iterate through queues, read data and join into output queue.
"""
async def output_work(self) -> None:
print("Output worker waiting for intitialization")
await self.initialized.wait()
print("Output worker resuming")
while True:
self.combined_dict = {}
tasks = [asyncio.create_task(self.update_dict(res)) for res in self.resamplers]
await asyncio.gather(*tasks)
await self.output_queue.put(self.combined_dict)

View File

@@ -1,195 +0,0 @@
import math
from asyncio import Queue, Task, create_task
from collections.abc import Generator
from datetime import datetime, timedelta
import numpy as np
from heisskleber.core.types import AsyncSource, Serializable
from .config import ResamplerConf
def floor_dt(dt: datetime, delta: timedelta) -> datetime:
"""Round a datetime object based on a delta timedelta."""
return datetime.min + math.floor((dt - datetime.min) / delta) * delta
def timestamp_generator(start_epoch: float, timedelta_in_ms: int) -> Generator[float, None, None]:
"""Generate increasing timestamps based on a start epoch and a delta in ms.
The timestamps are meant to be used with the resampler and generator half delta offsets of the returned timetsamps.
"""
timestamp_start = datetime.fromtimestamp(start_epoch)
delta = timedelta(milliseconds=timedelta_in_ms)
delta_half = timedelta(milliseconds=timedelta_in_ms // 2)
next_timestamp = floor_dt(timestamp_start, delta) + delta_half
while True:
yield datetime.timestamp(next_timestamp)
next_timestamp += delta
def interpolate(t1: float, y1: list[float], t2: float, y2: list[float], t_target: float) -> list[float]:
"""Perform linear interpolation between two data points."""
y1_array, y2_array = np.array(y1), np.array(y2)
fraction = (t_target - t1) / (t2 - t1)
interpolated_values = y1_array + fraction * (y2_array - y1_array)
return interpolated_values.tolist()
def check_dict(data: dict[str, Serializable]) -> None:
"""Check that only numeric types are in input data."""
for key, value in data.items():
if not isinstance(value, (int, float)):
error_msg = f"Value {value} for key {key} is not of type int or float"
raise TypeError(error_msg)
class Resampler:
"""
Async resample data based on a fixed rate. Can handle upsampling and downsampling.
Methods:
--------
start()
Start the resampler task.
stop()
Stop the resampler task.
receive()
Get next resampled dictonary from the resampler.
"""
def __init__(self, config: ResamplerConf, subscriber: AsyncSource) -> None:
"""
Parameters:
----------
config : namedtuple
Configuration for the resampler.
subscriber : AsyncMQTTSubscriber
Asynchronous Subscriber
"""
self.config = config
self.subscriber = subscriber
self.resample_rate = self.config.resample_rate
self.delta_t = round(self.resample_rate / 1_000, 3)
self.message_queue: Queue[dict[str, float]] = Queue(maxsize=50)
self.resample_task: None | Task[None] = None
def start(self) -> None:
"""
Start the resampler task.
"""
self.resample_task = create_task(self.resample())
def stop(self) -> None:
"""
Stop the resampler task
"""
if self.resample_task:
self.resample_task.cancel()
async def receive(self) -> dict[str, float]:
"""
Get next resampled dictonary from the resampler.
Implicitly starts the resampler if not already running.
"""
if not self.resample_task:
self.start()
return await self.message_queue.get()
async def resample(self) -> None:
"""
Resample data based on a fixed rate.
Can handle upsampling and downsampling.
Data will always be centered around the output resample timestamp.
(i.e. for data returned for t = 1.0s, the data will be resampled for [0.5, 1.5]s)
"""
print("Starting resampler")
aggregated_data = []
aggregated_timestamps = []
# Get first element to determine timestamp
topic, data = await self.subscriber.receive()
check_dict(data)
timestamp, message = self._pack_data(data) # type: ignore [arg-type]
timestamps = timestamp_generator(timestamp, self.resample_rate)
print(f"Got first element {topic}: {data}")
# Set data keys to reconstruct dict later
self.data_keys = data.keys()
self.topic = topic
# step through interpolation timestamps
for next_timestamp in timestamps:
# await new data and append to buffer until the most recent data
# is newer than the next interplation timestamp
while timestamp < next_timestamp:
aggregated_timestamps.append(timestamp)
aggregated_data.append(message)
topic, data = await self.subscriber.receive()
timestamp, message = self._pack_data(data) # type: ignore [arg-type]
return_timestamp = round(next_timestamp - self.delta_t / 2, 3)
# Only one new data point was received
if len(aggregated_data) == 1:
self._is_upsampling = False
# print("Only one data point")
last_timestamp, last_message = (
aggregated_timestamps[0],
aggregated_data[0],
)
# Case 2 Upsampling:
while timestamp - next_timestamp > self.delta_t:
self._is_upsampling = True
# print("Upsampling")
last_message = interpolate(
last_timestamp,
last_message,
timestamp,
message,
return_timestamp,
)
last_timestamp = return_timestamp
return_timestamp += self.delta_t
next_timestamp = next(timestamps)
await self.message_queue.put(self._unpack_data(last_timestamp, last_message))
if self._is_upsampling:
last_message = interpolate(
last_timestamp,
last_message,
timestamp,
message,
return_timestamp,
)
last_timestamp = return_timestamp
await self.message_queue.put(self._unpack_data(last_timestamp, last_message))
if len(aggregated_data) > 1:
# Case 4 - downsampling: Multiple data points were during the resampling timeframe
mean_message = np.mean(np.array(aggregated_data), axis=0)
await self.message_queue.put(self._unpack_data(return_timestamp, mean_message))
# reset the aggregator
aggregated_data.clear()
aggregated_timestamps.clear()
def _pack_data(self, data: dict[str, float]) -> tuple[float, list[float]]:
# pack data from dict to tuple list
ts = data.pop("epoch")
return (ts, list(data.values()))
def _unpack_data(self, ts: float, values: list[float]) -> dict[str, float]:
# from tuple
return {"epoch": round(ts, 3), **dict(zip(self.data_keys, values))}

View File

@@ -1,142 +0,0 @@
import math
from datetime import datetime, timedelta
from queue import Queue
import numpy as np
from heisskleber.mqtt import MqttSubscriber
def round_dt(dt, delta):
"""Round a datetime object based on a delta timedelta."""
return datetime.min + math.floor((dt - datetime.min) / delta) * delta
def timestamp_generator(start_epoch, timedelta_in_ms):
"""Generate increasing timestamps based on a start epoch and a delta in ms."""
timestamp_start = datetime.fromtimestamp(start_epoch)
delta = timedelta(milliseconds=timedelta_in_ms)
delta_half = timedelta(milliseconds=timedelta_in_ms // 2)
next_timestamp = round_dt(timestamp_start, delta) + delta_half
while True:
yield datetime.timestamp(next_timestamp)
next_timestamp += delta
def interpolate(t1, y1, t2, y2, t_target):
"""Perform linear interpolation between two data points."""
y1, y2 = np.array(y1), np.array(y2)
fraction = (t_target - t1) / (t2 - t1)
interpolated_values = y1 + fraction * (y2 - y1)
return interpolated_values.tolist()
class Resampler:
"""
Synchronously resample data based on a fixed rate. Can handle upsampling and downsampling.
Parameters:
----------
config : namedtuple
Configuration for the resampler.
subscriber : MqttSubscriber
Synchronous Subscriber
"""
def __init__(self, config, subscriber: MqttSubscriber):
self.config = config
self.subscriber = subscriber
self.buffer = Queue()
self.resample_rate = self.config.resample_rate
self.delta_t = round(self.resample_rate / 1_000, 3)
def run(self):
topic, message = self.subscriber.receive()
self.buffer.put(self._pack_data(message))
self.data_keys = message.keys()
while True:
topic, message = self.subscriber.receive()
self.buffer.put(self._pack_data(message))
def resample(self):
aggregated_data = []
aggregated_timestamps = []
# Get first element to determine timestamp
timestamp, message = self.buffer.get()
timestamps = timestamp_generator(timestamp, self.resample_rate)
# step through interpolation timestamps
for next_timestamp in timestamps:
# last_timestamp, last_message = timestamp, message
# append new data to buffer until the most recent data
# is newer than the next interplation timestamp
while timestamp < next_timestamp:
aggregated_timestamps.append(timestamp)
aggregated_data.append(message)
timestamp, message = self.buffer.get()
return_timestamp = round(next_timestamp - self.delta_t / 2, 3)
# Case 1: Only one new data point was received
if len(aggregated_data) == 1:
last_timestamp, last_message = (
aggregated_timestamps[0],
aggregated_data[0],
)
# Case 1a Upsampling:
# The data point is not within our time interval
# We step through time intervals, yielding interpolated data points
while timestamp - next_timestamp > self.delta_t:
last_message = interpolate(
last_timestamp,
last_message,
timestamp,
message,
return_timestamp,
)
last_timestamp = return_timestamp
return_timestamp += self.delta_t
next_timestamp = next(timestamps)
yield self._unpack_data(last_timestamp, last_message)
# Case 1b: The data point is within our time interval
# We simply yield the data point
# Note, this will also be the case once we have advanced the time interval by upsampling
last_message = interpolate(
last_timestamp,
last_message,
timestamp,
message,
return_timestamp,
)
last_timestamp = return_timestamp
return_timestamp += self.delta_t
yield self._unpack_data(last_timestamp, last_message)
# Case 2 - downsampling: Multiple data points were during the resampling timeframe
# We simply yield the mean of the data points, which is more robust and performant than interpolation
if len(aggregated_data) > 1:
# yield self._handle_downsampling(return_timestamp, aggregated_data)
mean_message = np.mean(np.array(aggregated_data), axis=0)
yield self._unpack_data(return_timestamp, mean_message)
# reset the aggregator
aggregated_data.clear()
aggregated_timestamps.clear()
def _handle_downsampling(self, return_timestamp, aggregated_data) -> dict:
"""Handle the downsampling case."""
mean_message = np.mean(np.array(aggregated_data), axis=0)
return self._unpack_data(return_timestamp, mean_message)
def _pack_data(self, data) -> tuple[int, list]:
# pack data from dict to tuple list
ts = data.pop("epoch")
return (ts, list(data.values()))
def _unpack_data(self, ts, values) -> dict:
# from tuple
return {"epoch": round(ts, 3), **dict(zip(self.data_keys, values))}

View File

@@ -1,5 +0,0 @@
from heisskleber.tcp.config import TcpConf
from heisskleber.tcp.sink import AsyncTcpSink
from heisskleber.tcp.source import AsyncTcpSource
__all__ = ["AsyncTcpSource", "AsyncTcpSink", "TcpConf"]

View File

@@ -1,10 +0,0 @@
from dataclasses import dataclass
from heisskleber.config import BaseConf
@dataclass
class TcpConf(BaseConf):
host: str = "localhost"
port: int = 6000
timeout: int = 60

View File

@@ -1,60 +0,0 @@
import asyncio
from typing import Callable
from heisskleber.core.types import AsyncSource, Serializable
from heisskleber.tcp.config import TcpConf
def bytes_csv_unpacker(data: bytes) -> tuple[str, dict[str, str]]:
vals = data.decode().rstrip().split(",")
keys = [f"key{i}" for i in range(len(vals))]
return ("tcp", dict(zip(keys, vals)))
class AsyncTcpSource(AsyncSource):
"""
Async TCP connection, connects to host:port and reads byte encoded strings.
Pass an unpack function like so:
Example
-------
def unpack(data: bytes) -> tuple[str, dict[str, float | int | str]]:
return dict(zip(["key1", "key2"], data.decode().split(","))
"""
def __init__(self, config: TcpConf, unpack: Callable[[bytes], tuple[str, dict[str, Serializable]]] | None) -> None:
self.config = config
self.is_connected = asyncio.Event()
self.unpack = unpack or bytes_csv_unpacker
self.timeout = config.timeout
self.start_task: asyncio.Task[None] | None = None
async def receive(self) -> tuple[str, dict[str, Serializable]]:
await self._check_connection()
data = await self.reader.readline()
topic, payload = self.unpack(data)
return (topic, payload) # type: ignore
def start(self) -> None:
self.start_task = asyncio.create_task(self._connect())
def stop(self) -> None:
if self.is_connected:
print("stopping")
async def _check_connection(self) -> None:
if not self.start_task:
self.start()
await self.is_connected.wait()
async def _connect(self) -> None:
print(f"{self} waiting for connection.")
(self.reader, self.writer) = await asyncio.open_connection(self.config.host, self.config.port)
print(f"{self} connected successfully!")
self.is_connected.set()
def __repr__(self) -> str:
return f"{self.__class__.__name__}(host={self.config.host}, port={self.config.port})"

View File

@@ -1,5 +0,0 @@
from .config import UdpConf
from .publisher import AsyncUdpSink, UdpPublisher
from .subscriber import AsyncUdpSource, UdpSubscriber
__all__ = ["AsyncUdpSource", "UdpSubscriber", "AsyncUdpSink", "UdpPublisher", "UdpConf"]

View File

@@ -1,84 +0,0 @@
import asyncio
import socket
import sys
from heisskleber.core.packer import get_packer
from heisskleber.core.types import AsyncSink, Serializable, Sink
from heisskleber.udp.config import UdpConf
class UdpPublisher(Sink):
def __init__(self, config: UdpConf) -> None:
self.config = config
self.pack = get_packer(self.config.packer)
self.is_connected = False
def start(self) -> None:
try:
self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
except OSError as e:
print(f"failed to create socket: {e}")
sys.exit(-1)
else:
self.is_connected = True
def stop(self) -> None:
self.socket.close()
self.is_connected = True
def send(self, data: dict[str, Serializable], topic: str | None = None) -> None:
if not self.is_connected:
self.start()
if topic:
data["topic"] = topic
payload = self.pack(data).encode("utf-8")
self.socket.sendto(payload, (self.config.host, self.config.port))
def __repr__(self) -> str:
return f"{self.__class__.__name__}(host={self.config.host}, port={self.config.port})"
class UdpProtocol(asyncio.DatagramProtocol):
def __init__(self, is_connected: bool) -> None:
super().__init__()
self.is_connected = is_connected
def connection_lost(self, exc: Exception | None) -> None:
print("Connection lost")
self.is_connected = False
class AsyncUdpSink(AsyncSink):
def __init__(self, config: UdpConf) -> None:
self.config = config
self.pack = get_packer(self.config.packer)
self.socket: asyncio.DatagramTransport | None = None
self.is_connected = False
def start(self) -> None:
# No background loop required
pass
def stop(self) -> None:
if self.socket is not None:
self.socket.close()
self.is_connected = False
async def _ensure_connection(self) -> None:
if not self.is_connected:
loop = asyncio.get_running_loop()
self.socket, _ = await loop.create_datagram_endpoint(
lambda: UdpProtocol(self.is_connected),
remote_addr=(self.config.host, self.config.port),
)
self.is_connected = True
async def send(self, data: dict[str, Serializable], topic: str | None = None) -> None:
await self._ensure_connection()
if topic:
data["topic"] = topic
payload = self.pack(data).encode(self.config.encoding)
self.socket.sendto(payload) # type: ignore
def __repr__(self) -> str:
return f"{self.__class__.__name__}(host={self.config.host}, port={self.config.port})"

View File

@@ -1,117 +0,0 @@
import asyncio
import socket
import sys
import threading
from queue import Queue
from typing import Any
from heisskleber.core.packer import get_unpacker
from heisskleber.core.types import AsyncSource, Serializable, Source
from heisskleber.udp.config import UdpConf
class UdpSubscriber(Source):
def __init__(self, config: UdpConf, topic: str | None = None):
self.config = config
self.topic = topic
self.unpacker = get_unpacker(self.config.packer)
self._queue: Queue[tuple[str, dict[str, Serializable]]] = Queue(maxsize=self.config.max_queue_size)
self._running = threading.Event()
def start(self) -> None:
try:
self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
except OSError as e:
print(f"failed to create socket: {e}")
sys.exit(-1)
self.socket.bind((self.config.host, self.config.port))
self._running.set()
self._thread = threading.Thread(target=self._loop, daemon=True)
self._thread.start()
def stop(self) -> None:
self._running.clear()
# if self._thread is not None:
# self._thread.join()
self.socket.close()
def receive(self) -> tuple[str, dict[str, Serializable]]:
if not self._running.is_set():
self.start()
return self._queue.get()
def _loop(self) -> None:
while self._running.is_set():
try:
payload, _ = self.socket.recvfrom(1024)
data = self.unpacker(payload.decode("utf-8"))
topic: str = str(data.pop("topic")) if "topic" in data else ""
self._queue.put((topic, data))
except Exception as e:
error_message = f"Error in UDP listener loop: {e}"
print(error_message)
def __repr__(self) -> str:
return f"{self.__class__.__name__}(host={self.config.host}, port={self.config.port})"
class UdpProtocol(asyncio.DatagramProtocol):
def __init__(self, queue: asyncio.Queue[bytes]) -> None:
super().__init__()
self.queue = queue
def datagram_received(self, data: bytes, addr: tuple[str | Any, int]) -> None:
self.queue.put_nowait(data)
def connection_made(self, transport: asyncio.DatagramTransport) -> None:
print("Connection made")
class AsyncUdpSource(AsyncSource):
"""
An asynchronous UDP subscriber based on asyncio.protocols.DatagramProtocol
"""
def __init__(self, config: UdpConf, topic: str = "udp"):
self.config = config
self.topic = topic
self.EOF = self.config.delimiter.encode(self.config.encoding)
self.unpacker = get_unpacker(self.config.packer)
self.queue: asyncio.Queue[bytes] = asyncio.Queue(maxsize=self.config.max_queue_size)
self.task: asyncio.Task[None] | None = None
self.is_connected = False
async def setup(self) -> None:
loop = asyncio.get_event_loop()
self.transport, self.protocol = await loop.create_datagram_endpoint(
lambda: UdpProtocol(self.queue),
local_addr=(self.config.host, self.config.port),
)
self.is_connected = True
print("Udp connection established.")
def start(self) -> None:
# Background loop not required, handled by Protocol
pass
def stop(self) -> None:
self.transport.close()
async def receive(self) -> tuple[str, dict[str, Serializable]]:
if not self.is_connected:
await self.setup()
data = await self.queue.get()
try:
payload = self.unpacker(data.decode(self.config.encoding, errors="ignore"))
# except UnicodeDecodeError: # this won't be thrown anymore, as the error flag is set to ignore!
# print(f"Could not decode data, is not {self.config.encoding}")
except Exception:
if self.config.verbose:
print(f"Could not deserialize data: {data!r}")
else:
return (self.topic, payload)
return await self.receive() # Try again
def __repr__(self) -> str:
return f"{self.__class__.__name__}(host={self.config.host}, port={self.config.port})"

View File

@@ -1,5 +0,0 @@
from .config import ZmqConf
from .publisher import ZmqAsyncPublisher, ZmqPublisher
from .subscriber import ZmqAsyncSubscriber, ZmqSubscriber
__all__ = ["ZmqConf", "ZmqPublisher", "ZmqSubscriber", "ZmqAsyncPublisher", "ZmqAsyncSubscriber"]

View File

@@ -1,129 +0,0 @@
import sys
from typing import Callable
import zmq
import zmq.asyncio
from heisskleber.core.packer import get_packer
from heisskleber.core.types import AsyncSink, Serializable, Sink
from .config import ZmqConf
class ZmqPublisher(Sink):
"""
Publisher that sends messages to a ZMQ PUB socket.
Attributes:
-----------
pack : Callable
The packer function to use for serializing the data.
Methods:
--------
send(data : dict, topic : str):
Send the data with the given topic.
start():
Connect to the socket.
stop():
Close the socket.
"""
def __init__(self, config: ZmqConf):
self.config = config
self.context = zmq.Context.instance()
self.socket = self.context.socket(zmq.PUB)
self.pack = get_packer(config.packstyle)
self.is_connected = False
def send(self, data: dict[str, Serializable], topic: str) -> None:
"""
Take the data as a dict, serialize it with the given packer and send it to the zmq socket.
"""
if not self.is_connected:
self.start()
payload = self.pack(data)
if self.config.verbose:
print(f"sending message {payload} to topic {topic}")
self.socket.send_multipart([topic.encode(), payload.encode()])
def start(self) -> None:
"""Connect to the zmq socket."""
try:
if self.config.verbose:
print(f"connecting to {self.config.publisher_address}")
self.socket.connect(self.config.publisher_address)
except Exception as e:
print(f"failed to bind to zeromq socket: {e}")
sys.exit(-1)
else:
self.is_connected = True
def stop(self) -> None:
self.socket.close()
self.is_connected = False
def __repr__(self) -> str:
return f"{self.__class__.__name__}(host={self.config.publisher_address}, port={self.config.publisher_port})"
class ZmqAsyncPublisher(AsyncSink):
"""
Async publisher that sends messages to a ZMQ PUB socket.
Attributes:
-----------
pack : Callable
The packer function to use for serializing the data.
Methods:
--------
send(data : dict, topic : str):
Send the data with the given topic.
start():
Connect to the socket.
stop():
Close the socket.
"""
def __init__(self, config: ZmqConf):
self.config = config
self.context = zmq.asyncio.Context.instance()
self.socket: zmq.asyncio.Socket = self.context.socket(zmq.PUB)
self.pack: Callable = get_packer(config.packstyle)
self.is_connected = False
async def send(self, data: dict[str, Serializable], topic: str) -> None:
"""
Take the data as a dict, serialize it with the given packer and send it to the zmq socket.
"""
if not self.is_connected:
self.start()
payload = self.pack(data)
if self.config.verbose:
print(f"sending message {payload} to topic {topic}")
await self.socket.send_multipart([topic.encode(), payload.encode()])
def start(self) -> None:
"""Connect to the zmq socket."""
try:
if self.config.verbose:
print(f"connecting to {self.config.publisher_address}")
self.socket.connect(self.config.publisher_address)
except Exception as e:
print(f"failed to bind to zeromq socket: {e}")
sys.exit(-1)
else:
self.is_connected = True
def stop(self) -> None:
"""Close the zmq socket."""
self.socket.close()
self.is_connected = False
def __repr__(self) -> str:
return f"{self.__class__.__name__}(host={self.config.publisher_address}, port={self.config.publisher_port})"

View File

@@ -1,181 +0,0 @@
from __future__ import annotations
import sys
import zmq
import zmq.asyncio
from heisskleber.core.packer import get_unpacker
from heisskleber.core.types import AsyncSource, Source
from .config import ZmqConf
class ZmqSubscriber(Source):
"""
Source that subscribes to one or many topics from a zmq broker and receives messages via the receive() function.
Attributes:
-----------
unpack : Callable
The unpacker function to use for deserializing the data.
Methods:
--------
receive() -> tuple[str, dict]:
Send the data with the given topic.
start():
Connect to the socket.
stop():
Close the socket.
"""
def __init__(self, config: ZmqConf, topic: str | list[str]):
"""
Constructs new ZmqAsyncSubscriber instance.
Parameters:
-----------
config : ZmqConf
The configuration dataclass object for the zmq connection.
topic : str
The topic or list of topics to subscribe to.
"""
self.config = config
self.topic = topic
self.context = zmq.Context.instance()
self.socket = self.context.socket(zmq.SUB)
self.unpack = get_unpacker(config.packstyle)
self.is_connected = False
def receive(self) -> tuple[str, dict]:
"""
reads a message from the zmq bus and returns it
Returns:
tuple(topic: str, message: dict): the message received
"""
if not self.is_connected:
self.start()
(topic, payload) = self.socket.recv_multipart()
message = self.unpack(payload.decode())
topic = topic.decode()
return (topic, message)
def start(self):
try:
self.socket.connect(self.config.subscriber_address)
self.subscribe(self.topic)
except Exception as e:
print(f"failed to bind to zeromq socket: {e}")
sys.exit(-1)
else:
self.is_connected = True
def stop(self):
self.socket.close()
self.is_connected = False
def subscribe(self, topic: str | list[str] | tuple[str]):
# Accepts single topic or list of topics
if isinstance(topic, (list, tuple)):
for t in topic:
self._subscribe_single_topic(t)
else:
self._subscribe_single_topic(topic)
def _subscribe_single_topic(self, topic: str):
self.socket.setsockopt(zmq.SUBSCRIBE, topic.encode())
def __repr__(self) -> str:
return f"{self.__class__.__name__}(host={self.config.subscriber_address}, port={self.config.subscriber_port})"
class ZmqAsyncSubscriber(AsyncSource):
"""
Async source that subscribes to one or many topics from a zmq broker and receives messages via the receive() function.
Attributes:
-----------
unpack : Callable
The unpacker function to use for deserializing the data.
Methods:
--------
receive() -> tuple[str, dict]:
Send the data with the given topic.
start():
Connect to the socket.
stop():
Close the socket.
"""
def __init__(self, config: ZmqConf, topic: str | list[str]):
"""
Constructs new ZmqAsyncSubscriber instance.
Parameters:
-----------
config : ZmqConf
The configuration dataclass object for the zmq connection.
topic : str
The topic or list of topics to subscribe to.
"""
self.config = config
self.topic = topic
self.context = zmq.asyncio.Context.instance()
self.socket: zmq.asyncio.Socket = self.context.socket(zmq.SUB)
self.unpack = get_unpacker(config.packstyle)
self.is_connected = False
async def receive(self) -> tuple[str, dict]:
"""
reads a message from the zmq bus and returns it
Returns:
tuple(topic: str, message: dict): the message received
"""
if not self.is_connected:
self.start()
(topic, payload) = await self.socket.recv_multipart()
message = self.unpack(payload.decode())
topic = topic.decode()
return (topic, message)
def start(self):
"""Connect to the zmq socket."""
try:
self.socket.connect(self.config.subscriber_address)
except Exception as e:
print(f"failed to bind to zeromq socket: {e}")
sys.exit(-1)
else:
self.is_connected = True
self.subscribe(self.topic)
def stop(self):
"""Close the zmq socket."""
self.socket.close()
self.is_connected = False
def subscribe(self, topic: str | list[str] | tuple[str]):
"""
Subscribes to the given topic(s) on the zmq socket.
Accepts single topic or list of topics.
"""
if isinstance(topic, (list, tuple)):
for t in topic:
self._subscribe_single_topic(t)
else:
self._subscribe_single_topic(topic)
def _subscribe_single_topic(self, topic: str):
self.socket.setsockopt(zmq.SUBSCRIBE, topic.encode())
def __repr__(self) -> str:
return f"{self.__class__.__name__}(host={self.config.subscriber_address}, port={self.config.subscriber_port})"

52
noxfile.py Normal file
View File

@@ -0,0 +1,52 @@
from __future__ import annotations
import argparse
from pathlib import Path
import nox
DIR = Path(__file__).parent.resolve()
nox.needs_version = ">=2024.3.2"
nox.options.sessions = ["lint", "tests", "check"]
nox.options.default_venv_backend = "uv|virtualenv"
@nox.session
def lint(session: nox.Session) -> None:
"""Run the linter."""
session.install("pre-commit")
session.run("pre-commit", "run", "--all-files", "--show-diff-on-failure", *session.posargs)
@nox.session
def tests(session: nox.Session) -> None:
"""Run the unit and regular tests."""
session.install(".[test]")
session.run("pytest", *session.posargs)
@nox.session(reuse_venv=True)
def docs(session: nox.Session) -> None:
"""Build the docs. Pass --non-interactive to avoid serving. First positional argument is the target directory."""
parser = argparse.ArgumentParser()
parser.add_argument("-b", dest="builder", default="html", help="Build target (default: html)")
parser.add_argument("output", nargs="?", help="Output directory")
args, posargs = parser.parse_known_args(session.posargs)
serve = args.builder == "html" and session.interactive
session.install("-e.[docs]", "sphinx-autobuild")
shared_args = (
"-n", # nitpicky mode
"-T", # full tracebacks
f"-b={args.builder}",
"docs",
args.output or f"docs/_build/{args.builder}",
*posargs,
)
if serve:
session.run("sphinx-autobuild", "--open-browser", *shared_args)
else:
session.run("sphinx-build", "--keep-going", *shared_args)

2083
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -1,57 +1,66 @@
[tool.poetry]
[build-system]
requires = ["hatchling", "hatch-vcs"]
build-backend = "hatchling.build"
[project]
name = "heisskleber"
version = "0.5.7"
description = "Heisskleber"
authors = ["Felix Weiler <felix@flucto.tech>"]
license = "MIT"
authors = [
{ name = "Felix Weiler-Detjen", email = "felix@flucto.tech" },
]
license = {file = "LICENSE"}
readme = "README.md"
homepage = "https://github.com/flucto-gmbh/heisskleber"
repository = "https://github.com/flucto-gmbh/heisskleber"
documentation = "https://heisskleber.readthedocs.io"
[tool.poetry.urls]
Changelog = "https://github.com/flucto-gmbh/heisskleber/releases"
requires-python = ">=3.10"
dynamic = ["version"]
dependencies= [
"aiomqtt>=2.3.0",
"pyserial>=3.5",
"pyyaml>=6.0.2",
"pyzmq>=26.2.0",
]
[tool.poetry.dependencies]
python = "^3.9"
paho-mqtt = "^1.6.1"
pyserial = "^3.5"
pyyaml = "^6.0.1"
pyzmq = "^25.1.1"
aiomqtt = "^1.2.1"
[tool.poetry.group.dev.dependencies]
black = ">=21.10b0"
coverage = { extras = ["toml"], version = ">=6.2" }
furo = ">=2021.11.12"
mypy = ">=0.930"
pre-commit = ">=2.16.0"
pre-commit-hooks = ">=4.1.0"
pytest = ">=6.2.5"
pytest-cov = "^4.1.0"
pytest-mock = "^3.11.1"
ruff = "^0.0.292"
pyupgrade = ">=2.29.1"
safety = ">=1.10.3"
sphinx = ">=4.3.2"
sphinx-autobuild = ">=2021.3.14"
sphinx-autodoc-typehints = "^1.24.0"
sphinx-rtd-theme = "^1.3.0"
typeguard = ">=2.13.3"
xdoctest = { extras = ["colors"], version = ">=0.15.10" }
myst-parser = { version = ">=0.16.1" }
pytest-asyncio = "^0.21.1"
termcolor = "^2.4.0"
codecov = "^2.1.13"
[project.urls]
Homepage = "https://github.com/flucto-gmbh/heisskleber"
Repository = "https://github.com/flucto-gmbh/heisskleber"
Documentation = "https://heisskleber.readthedocs.io"
[tool.poetry.group.types.dependencies]
pandas-stubs = "^2.1.1.230928"
types-pyyaml = "^6.0.12.12"
types-paho-mqtt = "^1.6.0.7"
[tool.poetry.scripts]
hkcli = "heisskleber.run.cli:main"
zmqbroker = "heisskleber.run.zmqbroker:main"
[project.optional-dependencies]
test = [
"pytest>=8.3.3",
"pytest-cov>=5.0.0",
"coverage[toml]>=7.6.1",
"xdoctest>=1.2.0",
"pytest-asyncio>=0.24.0",
]
docs = [
"furo>=2024.8.6",
"myst-parser>=4.0.0",
"sphinx>=8.0.2",
"sphinx-autobuild>=2024.9.19",
"sphinx-rtd-theme>=0.5.1",
"sphinx_copybutton",
"sphinx_autodoc_typehints",
]
filter = [
"numpy>=2.1.1",
"scipy>=1.14.1",
]
[tool.uv]
dev-dependencies = [
"deptry>=0.20.0",
"mypy>=1.11.2",
"ruff>=0.6.8",
"xdoctest>=1.2.0",
"nox>=2024.4.15",
"pytest>=8.3.3",
"pytest-cov>=5.0.0",
"coverage[toml]>=7.6.1",
"pytest-asyncio>=0.24.0",
]
package = true
[tool.coverage.paths]
source = ["heisskleber", "*/site-packages"]
@@ -59,7 +68,7 @@ tests = ["tests", "*/tests"]
[tool.coverage.run]
branch = true
source = ["heisskleber"]
source = ["src/heisskleber"]
omit = ["tests/*"]
[tool.coverage.report]
@@ -75,39 +84,48 @@ show_error_context = true
exclude = ["tests/*", "^test_*\\.py"]
[tool.ruff]
ignore-init-module-imports = true
target-version = "py39"
line-length = 120
fix = true
select = [
"YTT", # flake8-2020
"S", # flake8-bandit
"B", # flake8-bugbear
"A", # flake8-builtins
"C4", # flake8-comprehensions
"T10", # flake8-debugger
"SIM", # flake8-simplify
"I", # isort
"C90", # mccabe
"E",
"W", # pycodestyle
"F", # pyflakes
"PGH", # pygrep-hooks
"UP", # pyupgrade
"RUF", # ruff
"TRY", # tryceratops
]
[tool.ruff.lint]
select = ["ALL"]
ignore = [
"E501", # LineTooLong
"E731", # DoNotAssignLambda
"A001", #
"PGH003", # Use specific rules when ignoring type issues
"D100", # Missing module docstring
"D104", # Missing package docstring
"D107", # Missing __init__ docstring
"ANN101", # Deprecated and stupid self annotation
"ANN401", # Dynamically typed annotation
"FA102", # Missing from __future__ import annotations
"FBT001", # boolean style argument in function definition
"FBT002", # boolean style argument in function definition
"ARG002", # Unused kwargs
"TD002",
"TD003",
"FIX002",
"COM812",
"ISC001",
"ARG001",
"INP001"
]
[tool.ruff.per-file-ignores]
"tests/*" = ["S101", "D"]
"tests/test_import.py" = ["F401"]
[tool.ruff.lint.pydocstyle]
convention = "google"
[build-system]
requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"
[tool.ruff.lint.per-file-ignores]
"tests/*" = ["S101", "D", "T201", "PLR2", "SLF001", "ANN"]
"bin/*" = [
"ERA001", # Found commented-out code
]
[tool.hatch]
version.source = "vcs"
version.path = "src/heisskleber/__init__.py"
[tool.hatch.envs.default]
features = ["test"]
scripts.test = "pytest {args}"

View File

@@ -1,17 +0,0 @@
from dataclasses import asdict
import yaml
from heisskleber.mqtt.config import MqttConf
from heisskleber.serial.config import SerialConf
from heisskleber.tcp.config import TcpConf
from heisskleber.udp.config import UdpConf
from heisskleber.zmq.config import ZmqConf
configs = {"mqtt": MqttConf(), "zmq": ZmqConf(), "udp": UdpConf(), "tcp": TcpConf(), "serial": SerialConf()}
for name, config in configs.items():
with open(f"./config/heisskleber/{name}.yaml", "w") as file:
file.write(f"# Heisskleber config file for {config.__class__.__name__}\n")
file.write(yaml.dump(asdict(config)))

View File

@@ -1,14 +0,0 @@
from heisskleber.udp import UdpConf, UdpSubscriber
def main() -> None:
conf = UdpConf(host="192.168.137.1", port=6600)
subscriber = UdpSubscriber(conf)
while True:
topic, data = subscriber.receive()
print(topic, data)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,38 @@
"""Heisskleber."""
from heisskleber.console import ConsoleReceiver, ConsoleSender
from heisskleber.core import Receiver, Sender
from heisskleber.mqtt import MqttConf, MqttReceiver, MqttSender
from heisskleber.serial import SerialConf, SerialReceiver, SerialSender
from heisskleber.tcp import TcpConf, TcpReceiver, TcpSender
from heisskleber.udp import UdpConf, UdpReceiver, UdpSender
from heisskleber.zmq import ZmqConf, ZmqReceiver, ZmqSender
__all__ = [
"Sender",
"Receiver",
# mqtt
"MqttConf",
"MqttSender",
"MqttReceiver",
# zmq
"ZmqConf",
"ZmqSender",
"ZmqReceiver",
# udp
"UdpConf",
"UdpSender",
"UdpReceiver",
# tcp
"TcpConf",
"TcpSender",
"TcpReceiver",
# serial
"SerialConf",
"SerialSender",
"SerialReceiver",
# console
"ConsoleSender",
"ConsoleReceiver",
]
__version__ = "1.0.0"

View File

@@ -0,0 +1,4 @@
from heisskleber.console.receiver import ConsoleReceiver
from heisskleber.console.sender import ConsoleSender
__all__ = ["ConsoleReceiver", "ConsoleSender"]

View File

@@ -0,0 +1,46 @@
import asyncio
import sys
from typing import Any, TypeVar
from heisskleber.core import Receiver, Unpacker, json_unpacker
T = TypeVar("T")
class ConsoleReceiver(Receiver[T]):
"""Read stdin from console and create data of type T."""
def __init__(
self,
unpacker: Unpacker[T] = json_unpacker, # type: ignore[assignment]
) -> None:
self.queue: asyncio.Queue[tuple[T, dict[str, Any]]] = asyncio.Queue(maxsize=10)
self.unpack = unpacker
self.task: asyncio.Task[None] | None = None
async def _listener_task(self) -> None:
while True:
payload = sys.stdin.readline().encode() # I know this is stupid, but I adhere to the interface for now
data, extra = self.unpack(payload)
await self.queue.put((data, extra))
async def receive(self) -> tuple[T, dict[str, Any]]:
"""Receive the next message from the console input."""
if not self.task:
self.task = asyncio.create_task(self._listener_task())
data, extra = await self.queue.get()
return data, extra
def __repr__(self) -> str:
"""Return string representation of ConsoleSource."""
return f"{self.__class__.__name__}"
async def start(self) -> None:
"""Start ConsoleSource."""
self.task = asyncio.create_task(self._listener_task())
async def stop(self) -> None:
"""Stop ConsoleSource."""
if self.task:
self.task.cancel()

View File

@@ -0,0 +1,35 @@
from typing import Any, TypeVar
from heisskleber.core import Packer, Sender, json_packer
T = TypeVar("T")
class ConsoleSender(Sender[T]):
"""Send data to console out."""
def __init__(
self,
pretty: bool = False,
verbose: bool = False,
packer: Packer[T] = json_packer, # type: ignore[assignment]
) -> None:
self.verbose = verbose
self.pretty = pretty
self.packer = packer
async def send(self, data: T, topic: str | None = None, **kwargs: dict[str, Any]) -> None:
"""Serialize data and write to console output."""
serialized = self.packer(data)
output = f"{topic}:\t{serialized.decode()}" if topic else serialized.decode()
print(output) # noqa: T201
def __repr__(self) -> str:
"""Return string reprensentation of ConsoleSink."""
return f"{self.__class__.__name__}(pretty={self.pretty}, verbose={self.verbose})"
async def start(self) -> None:
"""Not implemented."""
async def stop(self) -> None:
"""Not implemented."""

View File

@@ -0,0 +1,23 @@
"""Core classes of the heisskleber library."""
from .config import BaseConf, ConfigType
from .packer import JSONPacker, Packer, PackerError
from .receiver import Receiver
from .sender import Sender
from .unpacker import JSONUnpacker, Unpacker, UnpackError
json_packer = JSONPacker()
json_unpacker = JSONUnpacker()
__all__ = [
"Packer",
"Unpacker",
"Sender",
"Receiver",
"json_packer",
"json_unpacker",
"BaseConf",
"ConfigType",
"PackerError",
"UnpackError",
]

View File

@@ -0,0 +1,127 @@
"""Configuration baseclass."""
import logging
from dataclasses import dataclass, fields
from pathlib import Path
from typing import Any, TextIO, TypeVar, Union
import yaml # type: ignore[import-untyped]
logger = logging.getLogger("heisskleber")
ConfigType = TypeVar(
"ConfigType",
bound="BaseConf",
) # https://stackoverflow.com/a/46227137 , https://docs.python.org/3/library/typing.html#typing.TypeVar
def _parse_yaml(file: TextIO) -> dict[str, Any]:
try:
return dict(yaml.safe_load(file))
except yaml.YAMLError as e:
msg = "Failed to parse config file!"
logger.exception(msg)
raise ValueError(msg) from e
def _parse_json(file: TextIO) -> dict[str, Any]:
import json
try:
return dict(json.load(file))
except json.JSONDecodeError as e:
msg = "Failed to parse config file!"
logger.exception(msg)
raise ValueError(msg) from e
def _parser(path: Path) -> dict[str, Any]:
suffix = path.suffix.lower()
with path.open() as f:
if suffix in [".yaml", ".yml"]:
return _parse_yaml(f)
if suffix == ".json":
return _parse_json(f)
msg = f"Unsupported file format {suffix}."
logger.exception(msg)
raise ValueError
@dataclass
class BaseConf:
"""Default configuration class for generic configuration info."""
def __post_init__(self) -> None:
"""Check if all attributes are the same type as the original defition of the dataclass."""
for field in fields(self):
value = getattr(self, field.name)
if value is None: # Allow optional fields
continue
if not isinstance(value, field.type): # Failed field comparison
raise TypeError
if ( # Failed Union comparison
hasattr(field.type, "__origin__")
and field.type.__origin__ is Union
and not any(isinstance(value, t) for t in field.type.__args__)
):
raise TypeError
@classmethod
def from_dict(cls: type[ConfigType], config_dict: dict[str, Any]) -> ConfigType:
"""Create a config instance from a dictionary, including only fields defined in the dataclass.
Arguments:
config_dict: Dictionary containing configuration values.
Keys should match dataclass field names.
Returns:
An instance of the configuration class with values from the dictionary.
Raises:
TypeError: If provided values don't match field types.
Example:
>>> from dataclasses import dataclass
>>> @dataclass
... class ServerConfig(BaseConf):
... host: str
... port: int
... debug: bool = False
>>>
>>> config = ServerConfig.from_dict({
... "host": "localhost",
... "port": 8080,
... "debug": True,
... "invalid_key": "ignored" # Will be filtered out
... })
>>> config.host
'localhost'
>>> config.port
8080
>>> config.debug
True
>>> hasattr(config, "invalid_key") # Extra keys are ignored
False
>>>
>>> # Type validation
>>> try:
... ServerConfig.from_dict({"host": "localhost", "port": "8080"}) # Wrong type
... except TypeError as e:
... print("TypeError raised as expected")
TypeError raised as expected
"""
valid_fields = {f.name for f in fields(cls)}
filtered_dict = {k: v for k, v in config_dict.items() if k in valid_fields}
return cls(**filtered_dict)
@classmethod
def from_file(cls: type[ConfigType], file_path: str | Path) -> ConfigType:
"""Create a config instance from a file - accepts yaml or json."""
path = Path(file_path)
if not path.exists():
logger.exception("Config file not found: %(path)s", {"path": path})
raise FileNotFoundError
return cls.from_dict(_parser(path))

View File

@@ -0,0 +1,80 @@
"""Packer and unpacker for network data."""
import json
from abc import abstractmethod
from typing import Any, Protocol, TypeVar
T_contra = TypeVar("T_contra", contravariant=True)
class PackerError(Exception):
"""Raised when unpacking operations fail.
This exception wraps underlying errors that may occur during unpacking,
providing a consistent interface for error handling.
Arguments:
data: The data object that caused the PackerError
"""
PREVIEW_LENGTH = 100
def __init__(self, data: Any) -> None:
"""Initialize the error with the failed payload and cause."""
message = "Failed to pack data."
super().__init__(message)
class Packer(Protocol[T_contra]):
"""Packer Interface.
This class defines a protocol for packing data.
It takes data and converts it into a bytes payload.
Attributes:
None
"""
@abstractmethod
def __call__(self, data: T_contra) -> bytes:
"""Packs the data dictionary into a bytes payload.
Arguments:
data (T_contra): The input data dictionary to be packed.
Returns:
bytes: The packed payload.
Raises:
PackerError: The data dictionary could not be packed.
"""
class JSONPacker(Packer[dict[str, Any]]):
"""Converts a dictionary into JSON-formatted bytes.
Arguments:
data: A dictionary with string keys and arbitrary values to be serialized into JSON format.
Returns:
bytes: The JSON-encoded data as a bytes object.
Raises:
PackerError: If the data cannot be serialized to JSON.
Example:
>>> packer = JSONPacker()
>>> result = packer({"key": "value"})
b'{"key": "value"}'
"""
def __call__(self, data: dict[str, Any]) -> bytes:
"""Pack the data."""
try:
return json.dumps(data).encode()
except (UnicodeEncodeError, TypeError) as err:
raise PackerError(data) from err

View File

@@ -0,0 +1,101 @@
"""Asynchronous data source interface."""
from abc import ABC, abstractmethod
from collections.abc import AsyncGenerator
from types import TracebackType
from typing import Any, Generic
from .unpacker import T_co, Unpacker
class Receiver(ABC, Generic[T_co]):
"""Abstract interface for asynchronous data sources.
This class defines a protocol for receiving data from various input streams
asynchronously. It supports both async iteration and context manager patterns,
and ensures proper resource management.
The source is covariant in its type parameter, allowing for type-safe subtyping
relationships.
Attributes:
unpacker: Component responsible for deserializing incoming data into type T_co.
Example:
>>> async with CustomSource(unpacker) as source:
... async for data, metadata in source:
... print(f"Received: {data}, metadata: {metadata}")
"""
unpacker: Unpacker[T_co]
@abstractmethod
async def receive(self) -> tuple[T_co, dict[str, Any]]:
"""Receive data from the implemented input stream.
Returns:
tuple[T_co, dict[str, Any]]: A tuple containing:
- The received and unpacked data of type T_co
- A dictionary of metadata associated with the received data
Raises:
Any implementation-specific exceptions that might occur during receiving.
"""
@abstractmethod
async def start(self) -> None:
"""Initialize and start any background processes and tasks of the source."""
@abstractmethod
async def stop(self) -> None:
"""Stop any background processes and tasks."""
@abstractmethod
def __repr__(self) -> str:
"""A string reprensatiion of the source."""
async def __aiter__(self) -> AsyncGenerator[tuple[T_co, dict[str, Any]], None]:
"""Implement async iteration over the source's data stream.
Yields:
tuple[T_co, dict[str, Any]]: Each data item and its associated metadata
as returned by receive().
Raises:
Any exceptions that might occur during receive().
"""
while True:
data, meta = await self.receive()
yield data, meta
async def __aenter__(self) -> "Receiver[T_co]":
"""Initialize the source for use in an async context manager.
Returns:
AsyncSource[T_co]: The initialized source instance.
Raises:
Any exceptions that might occur during start().
"""
await self.start()
return self
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: TracebackType | None,
) -> None:
"""Cleanup the source when exiting an async context manager.
Arguments:
exc_type: The type of the exception that was raised, if any.
exc_value: The instance of the exception that was raised, if any.
traceback: The traceback of the exception that was raised, if any.
"""
await self.stop()

View File

@@ -0,0 +1,79 @@
"""Asyncronous data sink interface."""
from abc import ABC, abstractmethod
from types import TracebackType
from typing import Any, Generic, TypeVar
from .packer import Packer
T = TypeVar("T")
class Sender(ABC, Generic[T]):
"""Abstract interface for asynchronous data sinks.
This class defines a protocol for sending data to various output streams
asynchronously. It supports context manager usage and ensures proper
resource management.
Attributes:
packer: Component responsible for serializing type T data before sending.
"""
packer: Packer[T]
@abstractmethod
async def send(self, data: T, **kwargs: Any) -> None:
"""Send data through the implemented output stream.
Arguments:
data: The data to be sent, of type T.
**kwargs: Additional implementation-specific arguments.
"""
@abstractmethod
async def start(self) -> None:
"""Initialize and start the sink's background processes and tasks."""
@abstractmethod
async def stop(self) -> None:
"""Stop and cleanup the sink's background processes and tasks.
This method should be called when the sink is no longer needed.
It should handle cleanup of any resources initialized in start().
"""
@abstractmethod
def __repr__(self) -> str:
"""A string representation of the sink."""
async def __aenter__(self) -> "Sender[T]":
"""Initialize the sink for use in an async context manager.
Returns:
AsyncSink[T]: The initialized sink instance.
Raises:
Any exceptions that might occur during start().
"""
await self.start()
return self
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: TracebackType | None,
) -> None:
"""Cleanup the sink when exiting an async context manager.
Arguments:
exc_type: The type of the exception that was raised, if any.
exc_value: The instance of the exception that was raised, if any.
traceback: The traceback of the exception that was raised, if any.
"""
await self.stop()

View File

@@ -0,0 +1,84 @@
"""Unpacker protocol definition and example implemetation."""
import json
from abc import abstractmethod
from typing import Any, Protocol, TypeVar
T_co = TypeVar("T_co", covariant=True)
class UnpackError(Exception):
"""Raised when unpacking operations fail.
This exception wraps underlying errors that may occur during unpacking,
providing a consistent interface for error handling.
Arguments:
payload: The bytes payload that failed to unpack.
"""
PREVIEW_LENGTH = 100
def __init__(self, payload: bytes) -> None:
"""Initialize the error with the failed payload and cause."""
self.payload = payload
preview = payload[: self.PREVIEW_LENGTH] + b"..." if len(payload) > self.PREVIEW_LENGTH else payload
message = f"Failed to unpack payload: {preview!r}. "
super().__init__(message)
class Unpacker(Protocol[T_co]):
"""Unpacker Interface.
This abstract base class defines an interface for unpacking payloads.
It takes a payload of bytes, creates a data dictionary and an optional topic,
and returns a tuple containing the topic and data.
"""
@abstractmethod
def __call__(self, payload: bytes) -> tuple[T_co, dict[str, Any]]:
"""Unpacks the payload into a data object and optional meta-data dictionary.
Args:
payload (bytes): The input payload to be unpacked.
Returns:
tuple[T, Optional[dict[str, Any]]]: A tuple containing:
- T: The data object generated from the input data, e.g. dict or dataclass
- dict[str, Any]: The meta data associated with the unpack operation, such as topic, timestamp or errors
Raises:
UnpackError: The payload could not be unpacked.
"""
class JSONUnpacker(Unpacker[dict[str, Any]]):
"""Deserializes JSON-formatted bytes into dictionaries.
Arguments:
payload: JSON-formatted bytes to deserialize.
Returns:
tuple[dict[str, Any], dict[str, Any]]: A tuple containing:
- The deserialized JSON data as a dictionary
- An empty dictionary for metadata (not used in JSON unpacking)
Raises:
UnpackError: If the payload cannot be decoded as valid JSON.
Example:
>>> unpacker = JSONUnpacker()
>>> data, metadata = unpacker(b'{"hotglue": "very_nais"}')
>>> print(data)
{'hotglue': 'very_nais'}
"""
def __call__(self, payload: bytes) -> tuple[dict[str, Any], dict[str, Any]]:
"""Unpack the payload."""
try:
return json.loads(payload), {}
except json.JSONDecodeError as e:
raise UnpackError(payload) from e

View File

@@ -0,0 +1,44 @@
import asyncio
from collections.abc import Coroutine
from functools import wraps
from typing import Any, Callable, ParamSpec, TypeVar
P = ParamSpec("P")
T = TypeVar("T")
def retry(
every: int = 5,
strategy: str = "always",
catch: type[Exception] | tuple[type[Exception], ...] = Exception,
logger_fn: Callable[[str, dict[str, Any]], None] | None = None,
) -> Callable[[Callable[P, Coroutine[Any, Any, T]]], Callable[P, Coroutine[Any, Any, T]]]:
"""Retry a coroutine."""
def decorator(func: Callable[P, Coroutine[Any, Any, T]]) -> Callable[P, Coroutine[Any, Any, T]]:
@wraps(func)
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
retries = 0
while True:
try:
result = await func(*args, **kwargs)
break
except catch as e:
if logger_fn:
logger_fn(
"Error occurred: %(err). Retrying in %(seconds) seconds",
{"err": e, "seconds": every},
)
retries += 1
await asyncio.sleep(every)
except asyncio.CancelledError:
raise
if strategy != "always":
raise NotImplementedError
return result
return wrapper
return decorator

View File

@@ -0,0 +1,13 @@
"""Async wrappers for mqtt functionality.
MQTT implementation is achieved via the `aiomqtt`_ package, which is an async wrapper around the `paho-mqtt`_ package.
.. _aiomqtt: https://github.com/mossblaser/aiomqtt
.. _paho-mqtt: https://github.com/eclipse/paho.mqtt.python
"""
from .config import MqttConf
from .receiver import MqttReceiver
from .sender import MqttSender
__all__ = ["MqttConf", "MqttReceiver", "MqttSender"]

View File

@@ -0,0 +1,50 @@
"""Mqtt config."""
from dataclasses import dataclass
from typing import Any
from aiomqtt import Will
from heisskleber.core import BaseConf
@dataclass
class WillConf(BaseConf):
"""MQTT Last Will and Testament message configuration."""
topic: str
payload: str | None = None
qos: int = 0
retain: bool = False
def to_aiomqtt_will(self) -> Will:
"""Create an aiomqtt style will."""
return Will(topic=self.topic, payload=self.payload, qos=self.qos, retain=self.retain, properties=None)
@dataclass
class MqttConf(BaseConf):
"""MQTT configuration class."""
# transport
host: str = "localhost"
port: int = 1883
ssl: bool = False
# mqtt
user: str = ""
password: str = ""
qos: int = 0
retain: bool = False
max_saved_messages: int = 100
timeout: int = 60
keep_alive: int = 60
will: Will | None = None
@classmethod
def from_dict(cls, config_dict: dict[str, Any]) -> "MqttConf":
"""Create a MqttConf object from a dictionary."""
if "will" in config_dict:
config_dict = config_dict.copy()
config_dict["will"] = WillConf.from_dict(config_dict["will"]).to_aiomqtt_will()
return super().from_dict(config_dict)

View File

@@ -0,0 +1,133 @@
import asyncio
import logging
from asyncio import Queue, Task, create_task
from typing import Any, TypeVar
from aiomqtt import Client, Message, MqttError
from heisskleber.core import Receiver, Unpacker, json_unpacker
from heisskleber.core.utils import retry
from heisskleber.mqtt import MqttConf
T = TypeVar("T")
logger = logging.getLogger("heisskleber.mqtt")
class MqttReceiver(Receiver[T]):
"""Asynchronous MQTT subscriber based on aiomqtt.
This class implements an asynchronous MQTT subscriber that handles connection, subscription, and message reception from an MQTT broker. It uses aiomqtt as the underlying MQTT client implementation.
The subscriber maintains a queue of received messages which can be accessed through the `receive` method.
Attributes:
config (MqttConf): Stored configuration for MQTT connection.
topics (Union[str, List[str]]): Topics to subscribe to.
"""
def __init__(
self,
config: MqttConf,
topic: str | list[str],
unpacker: Unpacker[T] = json_unpacker, # type: ignore[assignment]
) -> None:
"""Initialize the MQTT source.
Args:
config: Configuration object containing:
- host (str): MQTT broker hostname
- port (int): MQTT broker port
- user (str): Username for authentication
- password (str): Password for authentication
- qos (int): Default Quality of Service level
- max_saved_messages (int): Maximum queue size
topic: Single topic string or list of topics to subscribe to
unpacker: Function to deserialize received messages, defaults to json_unpacker
"""
self.config = config
self.topics = topic if isinstance(topic, list) else [topic]
self.unpacker = unpacker
self._message_queue: Queue[Message] = Queue(self.config.max_saved_messages)
self._listener_task: Task[None] | None = None
async def receive(self) -> tuple[T, dict[str, Any]]:
"""Receive and process the next message from the queue.
Returns:
tuple[T, dict[str, Any]]
- The unpacked message data
- A dictionary with metadata including the message topic
Raises:
TypeError: If the message payload is not of type bytes.
UnpackError: If the message could not be unpacked with the unpacker protocol.
"""
if not self._listener_task:
await self.start()
message = await self._message_queue.get()
if not isinstance(message.payload, bytes):
error_msg = "Payload is not of type bytes."
raise TypeError(error_msg)
data, extra = self.unpacker(message.payload)
extra["topic"] = message.topic.value
return (data, extra)
def __repr__(self) -> str:
"""Return string representation of Mqtt Source class."""
return f"{self.__class__.__name__}(broker={self.config.host}, port={self.config.port})"
async def start(self) -> None:
"""Start the MQTT listener task."""
self._listener_task = create_task(self._run())
async def stop(self) -> None:
"""Stop the MQTT listener task."""
if not self._listener_task:
return
self._listener_task.cancel()
try:
await self._listener_task
except asyncio.CancelledError:
# Raise if the stop task was cancelled
# kudos:https://superfastpython.com/asyncio-cancel-task-and-wait/
task = asyncio.current_task()
if task and task.cancelled():
raise
self._listener_task = None
async def subscribe(self, topic: str, qos: int | None = None) -> None:
"""Subscribe to an additional MQTT topic.
Args:
topic: The topic to subscribe to
qos: Quality of Service level, uses config.qos if None
"""
qos = qos or self.config.qos
self.topics.append(topic)
await self._client.subscribe(topic, qos)
@retry(every=1, catch=MqttError, logger_fn=logger.exception)
async def _run(self) -> None:
"""Background task for MQTT connection."""
async with Client(
hostname=self.config.host,
port=self.config.port,
username=self.config.user,
password=self.config.password,
timeout=self.config.timeout,
keepalive=self.config.keep_alive,
will=self.config.will,
) as client:
self._client = client
logger.info("subscribing to %(topics)s", {"topics": self.topics})
await client.subscribe([(topic, self.config.qos) for topic in self.topics])
async for message in client.messages:
await self._message_queue.put(message)

View File

@@ -0,0 +1,99 @@
"""Async mqtt sink implementation."""
import asyncio
import logging
from asyncio import CancelledError, create_task
from typing import Any, TypeVar
import aiomqtt
from heisskleber.core import Packer, Sender, json_packer
from heisskleber.core.utils import retry
from .config import MqttConf
T = TypeVar("T")
logger = logging.getLogger("heisskleber.mqtt")
class MqttSender(Sender[T]):
"""MQTT publisher with queued message handling.
This sink implementation provides asynchronous MQTT publishing capabilities with automatic connection management and message queueing.
Network operations are handled in a separate task.
Attributes:
config: MQTT configuration in a dataclass.
packer: Callable to pack data from type T to bytes for transport.
"""
def __init__(self, config: MqttConf, packer: Packer[T] = json_packer) -> None: # type: ignore[assignment]
self.config = config
self.packer = packer
self._send_queue: asyncio.Queue[tuple[T, str]] = asyncio.Queue()
self._sender_task: asyncio.Task[None] | None = None
async def send(self, data: T, topic: str = "mqtt", qos: int = 0, retain: bool = False, **kwargs: Any) -> None:
"""Queue data for asynchronous publication to the mqtt broker.
Arguments:
data: The data to be published.
topic: The mqtt topic to publish to.
qos: MQTT QOS level (0, 1, or 2). Defaults to 0.o
retain: Whether to set the MQTT retain flag. Defaults to False.
**kwargs: Not implemented.
"""
if not self._sender_task:
await self.start()
await self._send_queue.put((data, topic))
@retry(every=5, catch=aiomqtt.MqttError, logger_fn=logger.exception)
async def _send_work(self) -> None:
async with aiomqtt.Client(
hostname=self.config.host,
port=self.config.port,
username=self.config.user,
password=self.config.password,
timeout=float(self.config.timeout),
keepalive=self.config.keep_alive,
will=self.config.will,
) as client:
try:
while True:
data, topic = await self._send_queue.get()
payload = self.packer(data)
await client.publish(topic=topic, payload=payload)
except CancelledError:
logger.info("MqttSink background loop cancelled. Emptying queue...")
while not self._send_queue.empty():
_ = self._send_queue.get_nowait()
raise
def __repr__(self) -> str:
"""Return string representation of the MQTT sink object."""
return f"{self.__class__.__name__}(broker={self.config.host}, port={self.config.port})"
async def start(self) -> None:
"""Start the send queue in a separate task.
The task will retry connections every 5 seconds on failure.
"""
self._sender_task = create_task(self._send_work())
async def stop(self) -> None:
"""Stop the background task."""
if not self._sender_task:
return
self._sender_task.cancel()
try:
await self._sender_task
except asyncio.CancelledError:
# If the stop task was cancelled, we raise.
task = asyncio.current_task()
if task and task.cancelled():
raise
self._sender_task = None

View File

@@ -0,0 +1,7 @@
"""Asyncronous implementations to read and write to a serial interface."""
from .config import SerialConf
from .receiver import SerialReceiver
from .sender import SerialSender
__all__ = ["SerialConf", "SerialSender", "SerialReceiver"]

View File

@@ -0,0 +1,29 @@
from dataclasses import dataclass
from typing import Literal
from heisskleber.core.config import BaseConf
@dataclass
class SerialConf(BaseConf):
"""Serial Config class.
Attributes:
port: The port to connect to. Defaults to /dev/serial0.
baudrate: The baudrate of the serial connection. Defaults to 9600.
bytesize: The bytesize of the messages. Defaults to 8.
encoding: The string encoding of the messages. Defaults to ascii.
parity: The parity checking value. One of "N" for none, "E" for even, "O" for odd. Defaults to None.
stopbits: Stopbits. One of 1, 2 or 1.5. Defaults to 1.
Note:
stopbits 1.5 is not yet implemented.
"""
port: str = "/dev/serial0"
baudrate: int = 9600
bytesize: int = 8
encoding: str = "ascii"
parity: Literal["N", "O", "E"] = "N" # definitions from serial.PARTITY_'N'ONE / 'O'DD / 'E'VEN
stopbits: Literal[1, 2] = 1 # 1.5 not yet implemented

View File

@@ -0,0 +1,98 @@
import asyncio
import logging
from concurrent.futures import ThreadPoolExecutor
from typing import Any, TypeVar
import serial # type: ignore[import-untyped]
from heisskleber.core import Receiver, Unpacker
from .config import SerialConf
T = TypeVar("T")
logger = logging.getLogger("heisskleber.serial")
class SerialReceiver(Receiver[T]):
"""An asynchronous source for reading data from a serial port.
This class implements the AsyncSource interface for reading data from a serial port.
It uses a thread pool executor to perform blocking I/O operations asynchronously.
Attributes:
config: Configuration for the serial port.
unpacker: Function to unpack received data.
"""
def __init__(self, config: SerialConf, unpack: Unpacker[T]) -> None:
self.config = config
self.unpacker = unpack
self._loop = asyncio.get_running_loop()
self._executor = ThreadPoolExecutor(max_workers=2)
self._lock = asyncio.Lock()
self._is_connected = False
self._cancel_read_timeout = 1
async def receive(self) -> tuple[T, dict[str, Any]]:
"""Receive data from the serial port.
This method reads a line from the serial port, unpacks it, and returns the data.
If the serial port is not connected, it will attempt to connect first.
Returns:
tuple[T, dict[str, Any]]: A tuple containing the unpacked data and any extra information.
Raises:
UnpackError: If the data could not be unpacked with the provided unpacker.
"""
if not self._is_connected:
await self.start()
try:
payload = await asyncio.get_running_loop().run_in_executor(self._executor, self._ser.readline, -1)
except asyncio.CancelledError:
await asyncio.shield(self._cancel_read())
raise
data, extra = self.unpacker(payload=payload)
logger.debug(
"SerialSource(%(port)s): Unpacked: %(data)s, extra information: %(extra)s",
{"port": self.config.port, "data": data, "extra": extra},
)
return (data, extra)
async def _cancel_read(self) -> None:
if not hasattr(self, "_ser"):
return
logger.warning(
"SerialSource(%(port)s).read() cancelled, waiting for %(timeout)s",
{"port": self.config.port, "timeout": self._cancel_read_timeout},
)
await asyncio.wait_for(
asyncio.get_running_loop().run_in_executor(self._executor, self._ser.cancel_read),
self._cancel_read_timeout,
)
async def start(self) -> None:
"""Open serial device."""
if hasattr(self, "_ser"):
return
self._ser = serial.Serial(
port=self.config.port,
baudrate=self.config.baudrate,
bytesize=self.config.bytesize,
parity=self.config.parity,
stopbits=self.config.stopbits,
)
async def stop(self) -> None:
"""Close serial connection."""
await self._cancel_read()
self._ser.close()
def __repr__(self) -> str:
"""Return string representation of Serial Source."""
return f"SerialSource({self.config.port}, baudrate={self.config.baudrate})"

View File

@@ -0,0 +1,91 @@
"""Asynchronous sink implementation for sending data via serial port."""
import asyncio
from concurrent.futures import ThreadPoolExecutor
from typing import Any, TypeVar
import serial # type: ignore[import-untyped]
from heisskleber.core import Packer, Sender
from .config import SerialConf
T = TypeVar("T")
class SerialSender(Sender[T]):
"""An asynchronous sink for writing data to a serial port.
This class implements the AsyncSink interface for writing data to a serial port.
It uses a thread pool executor to perform blocking I/O operations asynchronously.
Attributes:
config: Configuration for the serial port.
packer: Function to pack data for sending.
"""
def __init__(self, config: SerialConf, pack: Packer[T]) -> None:
"""SerialSink constructor."""
self.config = config
self.packer = pack
self._loop = asyncio.get_running_loop()
self._executor = ThreadPoolExecutor(max_workers=2)
self._lock = asyncio.Lock()
self._is_connected = False
self._cancel_write_timeout = 1
async def send(self, data: T, **kwargs: dict[str, Any]) -> None:
"""Send data to the serial port.
This method packs the data, writes it to the serial port, and then flushes the port.
Arguments:
data: The data to be sent.
**kwargs: Not implemented.
Raises:
PackerError: If data could not be packed to bytes with the provided packer.
Note:
If the serial port is not connected, it will implicitly attempt to connect first.
"""
if not self._is_connected:
await self.start()
payload = self.packer(data)
try:
await asyncio.get_running_loop().run_in_executor(self._executor, self._ser.write, payload)
await asyncio.get_running_loop().run_in_executor(self._executor, self._ser.flush)
except asyncio.CancelledError:
await asyncio.shield(self._cancel_write())
raise
async def _cancel_write(self) -> None:
if not hasattr(self, "_ser"):
return
await asyncio.wait_for(
asyncio.get_running_loop().run_in_executor(self._executor, self._ser.cancel_write),
self._cancel_write_timeout,
)
async def start(self) -> None:
"""Open serial connection."""
if hasattr(self, "_ser"):
return
self._ser = serial.Serial(
port=self.config.port,
baudrate=self.config.baudrate,
bytesize=self.config.bytesize,
parity=self.config.parity,
stopbits=self.config.stopbits,
)
async def stop(self) -> None:
"""Close serial connection."""
self._ser.close()
def __repr__(self) -> str:
"""Return string representation of SerialSink."""
return f"SerialSink({self.config.port}, baudrate={self.config.baudrate})"

View File

@@ -0,0 +1,5 @@
from .config import TcpConf
from .receiver import TcpReceiver
from .sender import TcpSender
__all__ = ["TcpReceiver", "TcpSender", "TcpConf"]

View File

@@ -0,0 +1,22 @@
from dataclasses import dataclass
from enum import Enum
from heisskleber.core import BaseConf
@dataclass
class TcpConf(BaseConf):
"""Configuration dataclass for TCP connections."""
class RestartBehavior(Enum):
"""The three types of restart behaviour."""
NEVER = 0 # Never restart on failure
ONCE = 1 # Restart once
ALWAYS = 2 # Restart until the connection succeeds
host: str = "localhost"
port: int = 6000
timeout: int = 60
retry_delay: float = 0.5
restart_behavior: RestartBehavior = RestartBehavior.ALWAYS

View File

@@ -0,0 +1,107 @@
"""Async TCP Source - get data from arbitrary TCP server."""
import asyncio
import logging
from typing import Any, TypeVar
from heisskleber.core import Receiver, Unpacker, json_unpacker
from heisskleber.tcp.config import TcpConf
T = TypeVar("T")
logger = logging.getLogger("heisskleber.tcp")
class TcpReceiver(Receiver[T]):
"""Async TCP connection, connects to host:port and reads byte encoded strings."""
def __init__(self, config: TcpConf, unpacker: Unpacker[T] = json_unpacker) -> None: # type: ignore [assignment]
self.config = config
self.unpack = unpacker
self.is_connected = False
self.timeout = config.timeout
self._start_task: asyncio.Task[None] | None = None
self.reader: asyncio.StreamReader | None = None
self.writer: asyncio.StreamWriter | None = None
async def receive(self) -> tuple[T, dict[str, Any]]:
"""Receive data from a connection.
Attempt to read data from the connection and handle the process of re-establishing the connection if necessary.
Returns:
tuple[T, dict[str, Any]]
- The unpacked message data
- A dictionary with metadata including the message topic
Raises:
TypeError: If the message payload is not of type bytes.
UnpackError: If the message could not be unpacked with the unpacker protocol.
"""
data = b""
retry_delay = self.config.retry_delay
while not data:
await self._ensure_connected()
data = await self.reader.readline() # type: ignore [union-attr]
if not data:
self.is_connected = False
logger.warning(
"%(self)s nothing received, retrying connect in %(seconds)s",
{"self": self, "seconds": retry_delay},
)
await asyncio.sleep(retry_delay)
retry_delay = min(self.config.timeout, retry_delay * 2)
payload, extra = self.unpack(data)
return payload, extra
async def start(self) -> None:
"""Start TcpSource."""
await self._connect()
async def stop(self) -> None:
"""Stop TcpSource."""
if self.is_connected:
logger.info("%(self)s stopping", {"self": self})
async def _ensure_connected(self) -> None:
if self.is_connected:
return
# Not connected, try to (re-)connect
if not self._start_task:
# Possibly multiple reconnects, so can't just await once
self._start_task = asyncio.create_task(self._connect())
try:
await self._start_task
finally:
self._start_task = None
async def _connect(self) -> None:
logger.info("%(self)s waiting for connection.", {"self": self})
num_attempts = 0
while True:
try:
self.reader, self.writer = await asyncio.wait_for(
asyncio.open_connection(self.config.host, self.config.port),
timeout=self.timeout,
)
logger.info("%(self)s connected successfully!", {"self": self})
break
except ConnectionRefusedError as e:
logger.exception("%(self)s: %(error_type)s", {"self": self, "error_type": type(e).__name__})
if self.config.restart_behavior == TcpConf.RestartBehavior.NEVER:
raise
num_attempts += 1
if self.config.restart_behavior == TcpConf.RestartBehavior.ONCE and num_attempts > 1:
raise
# otherwise retry indefinitely
self.is_connected = True
def __repr__(self) -> str:
"""Return string representation of TcpSource."""
return f"{self.__class__.__name__}(host={self.config.host}, port={self.config.port})"

View File

@@ -0,0 +1,41 @@
from typing import Any, TypeVar
from heisskleber.core import Sender
from heisskleber.core.packer import Packer
from .config import TcpConf
T = TypeVar("T")
class TcpSender(Sender[T]):
"""Async TCP Sink.
Attributes:
config: The TcpConf configuration object.
packer: The packer protocol to serialize data before sending.
"""
def __init__(self, config: TcpConf, packer: Packer[T]) -> None:
self.config = config
self.packer = packer
async def send(self, data: T, **kwargs: dict[str, Any]) -> None:
"""Send data via tcp connection.
Arguments:
data: The data to be sent.
kwargs: Not implemented.
"""
def __repr__(self) -> str:
"""Return string representation of TcpSink."""
return f"TcpSink({self.config.host}:{self.config.port})"
async def start(self) -> None:
"""Start TcpSink."""
async def stop(self) -> None:
"""Stop TcpSink."""

View File

@@ -0,0 +1,5 @@
from .config import UdpConf
from .receiver import UdpReceiver
from .sender import UdpSender
__all__ = ["UdpReceiver", "UdpSender", "UdpConf"]

View File

@@ -1,17 +1,14 @@
from dataclasses import dataclass
from heisskleber.config import BaseConf
from heisskleber.core import BaseConf
@dataclass
class UdpConf(BaseConf):
"""
UDP configuration.
"""
"""UDP configuration."""
port: int = 1234
host: str = "127.0.0.1"
packer: str = "json"
max_queue_size: int = 1000
encoding: str = "utf-8"
delimiter: str = "\r\n"

View File

@@ -0,0 +1,86 @@
import asyncio
import logging
from typing import Any, TypeVar
from heisskleber.core import Receiver, Unpacker, json_unpacker
from heisskleber.udp.config import UdpConf
logger = logging.getLogger("heisskleber.udp")
T = TypeVar("T")
class UdpProtocol(asyncio.DatagramProtocol):
"""Protocol for udp connection.
Arguments:
queue: The asyncioQueue to put messages into.
"""
def __init__(self, queue: asyncio.Queue[bytes]) -> None:
super().__init__()
self.queue = queue
def datagram_received(self, data: bytes, addr: tuple[str | Any, int]) -> None:
"""Handle received udp message."""
self.queue.put_nowait(data)
def connection_made(self, transport: asyncio.DatagramTransport) -> None: # type: ignore[override]
"""Log successful connection."""
logger.info("UdpSource: Connection made")
class UdpReceiver(Receiver[T]):
"""An asynchronous UDP subscriber based on asyncio.protocols.DatagramProtocol."""
def __init__(self, config: UdpConf, unpacker: Unpacker[T] = json_unpacker) -> None: # type: ignore[assignment]
self.config = config
self.EOF = self.config.delimiter.encode(self.config.encoding)
self.unpacker = unpacker
self._queue: asyncio.Queue[bytes] = asyncio.Queue(maxsize=self.config.max_queue_size)
self._task: asyncio.Task[None] | None = None
self._is_connected = False
self._transport: asyncio.DatagramTransport | None = None
self._protocol: asyncio.DatagramProtocol | None = None
async def start(self) -> None:
"""Start udp connection."""
loop = asyncio.get_event_loop()
self._transport, self._protocol = await loop.create_datagram_endpoint(
lambda: UdpProtocol(self._queue),
local_addr=(self.config.host, self.config.port),
)
self._is_connected = True
logger.info("Udp connection established.")
async def stop(self) -> None:
"""Stop the udp connection."""
if self._transport is not None:
self._transport.close()
self._transport = None
self._is_connected = False
async def receive(self) -> tuple[T, dict[str, Any]]:
"""Get the next message from the udp connection.
Returns:
tuple[T, dict[str, Any]]
- The data as returned by the unpacker.
- A dictionary containing extra information.
Raises:
UnpackError: If the received message could not be unpacked.
"""
if not self._is_connected:
await self.start()
while True:
data = None
data = await self._queue.get()
payload, extra = self.unpacker(data)
return (payload, extra)
def __repr__(self) -> str:
"""Return string representation of UdpSource."""
return f"{self.__class__.__name__}(host={self.config.host}, port={self.config.port})"

View File

@@ -0,0 +1,88 @@
import asyncio
import logging
from typing import Any, TypeVar
from heisskleber.core import Packer, Sender, json_packer
from heisskleber.udp.config import UdpConf
logger = logging.getLogger("heisskleber.udp")
T = TypeVar("T")
class UdpProtocol(asyncio.DatagramProtocol):
"""UDP protocol handler that tracks connection state.
Arguments:
is_connected: Flag tracking if protocol is connected
"""
def __init__(self, is_connected: bool) -> None:
super().__init__()
self.is_connected = is_connected
def connection_lost(self, exc: Exception | None) -> None:
"""Update state and log a lost connection."""
logger.info("UDP Connection lost")
self.is_connected = False
class UdpSender(Sender[T]):
"""UDP sink for sending data via UDP protocol.
Arguments:
config: UDP configuration parameters
packer: Function to serialize data, defaults to JSON packing
"""
def __init__(self, config: UdpConf, packer: Packer[T] = json_packer) -> None: # type: ignore[assignment]
self.config = config
self.pack = packer
self.is_connected = False
self._transport: asyncio.DatagramTransport | None = None
self._protocol: UdpProtocol | None = None
async def start(self) -> None:
"""Connect the UdpSink."""
await self._ensure_connection()
async def stop(self) -> None:
"""Disconnect the UdpSink connection."""
if self._transport is not None:
self._transport.close()
self.is_connected = False
self._transport = None
self._protocol = None
async def _ensure_connection(self) -> None:
"""Create UDP endpoint if not connected.
Creates datagram endpoint using protocol handler if no connection exists.
Updates connected state on successful connection.
"""
if not self.is_connected or self._transport is None:
loop = asyncio.get_running_loop()
self._transport, _ = await loop.create_datagram_endpoint(
lambda: UdpProtocol(self.is_connected),
remote_addr=(self.config.host, self.config.port),
)
self.is_connected = True
async def send(self, data: T, **kwargs: dict[str, Any]) -> None:
"""Send data over UDP connection.
Arguments:
data: Data to send
**kwargs: Additional arguments passed to send
"""
await self._ensure_connection() # we know that self._transport is intialized
payload = self.pack(data)
self._transport.sendto(payload) # type: ignore [union-attr]
def __repr__(self) -> str:
"""Return string representation of UdpSink."""
return f"{self.__class__.__name__}(host={self.config.host}, port={self.config.port})"

View File

@@ -0,0 +1,5 @@
from .config import ZmqConf
from .receiver import ZmqReceiver
from .sender import ZmqSender
__all__ = ["ZmqConf", "ZmqSender", "ZmqReceiver"]

View File

@@ -1,10 +1,12 @@
from dataclasses import dataclass
from heisskleber.config import BaseConf
from heisskleber.core import BaseConf
@dataclass
class ZmqConf(BaseConf):
"""ZMQ Configuration file."""
protocol: str = "tcp"
host: str = "127.0.0.1"
publisher_port: int = 5555
@@ -13,8 +15,10 @@ class ZmqConf(BaseConf):
@property
def publisher_address(self) -> str:
"""Return the full url to connect to the publisher port."""
return f"{self.protocol}://{self.host}:{self.publisher_port}"
@property
def subscriber_address(self) -> str:
"""Return the full url to connect to the subscriber port."""
return f"{self.protocol}://{self.host}:{self.subscriber_port}"

View File

@@ -0,0 +1,85 @@
import logging
from typing import Any, TypeVar
import zmq
import zmq.asyncio
from heisskleber.core import Receiver, Unpacker, json_unpacker
from heisskleber.zmq.config import ZmqConf
logger = logging.getLogger("heisskleber.zmq")
T = TypeVar("T")
class ZmqReceiver(Receiver[T]):
"""Async source that subscribes to one or many topics from a zmq broker and receives messages via the receive() function.
Attributes:
config: The ZmqConf configuration object for the connection.
unpacker : The unpacker function to use for deserializing the data.
"""
def __init__(self, config: ZmqConf, topic: str | list[str], unpacker: Unpacker[T] = json_unpacker) -> None: # type: ignore [assignment]
self.config = config
self.topic = topic
self.context = zmq.asyncio.Context.instance()
self.socket: zmq.asyncio.Socket = self.context.socket(zmq.SUB)
self.unpack = unpacker
self.is_connected = False
async def receive(self) -> tuple[T, dict[str, Any]]:
"""Read a message from the zmq bus and return it.
Returns:
tuple(topic: str, message: dict): the message received
Raises:
UnpackError: If payload could not be unpacked with provided unpacker.
"""
if not self.is_connected:
await self.start()
(topic, payload) = await self.socket.recv_multipart()
data, extra = self.unpack(payload)
extra["topic"] = topic.decode()
return data, extra
async def start(self) -> None:
"""Connect to the zmq socket."""
try:
self.socket.connect(self.config.subscriber_address)
except Exception:
logger.exception("Failed to bind to zeromq socket")
else:
self.is_connected = True
self.subscribe(self.topic)
async def stop(self) -> None:
"""Close the zmq socket."""
self.socket.close()
self.is_connected = False
def subscribe(self, topic: str | list[str] | tuple[str]) -> None:
"""Subscribe to the given topic(s) on the zmq socket.
Arguments:
---------
topic: The topic or list of topics to subscribe to.
"""
if isinstance(topic, (list, tuple)):
for t in topic:
self._subscribe_single_topic(t)
else:
self._subscribe_single_topic(topic)
def _subscribe_single_topic(self, topic: str) -> None:
self.socket.setsockopt(zmq.SUBSCRIBE, topic.encode())
def __repr__(self) -> str:
"""Return string representation of ZmqSource."""
return f"{self.__class__.__name__}(host={self.config.subscriber_address}, port={self.config.subscriber_port})"

View File

@@ -0,0 +1,54 @@
import logging
from typing import Any, TypeVar
import zmq
import zmq.asyncio
from heisskleber.core import Packer, Sender, json_packer
from .config import ZmqConf
logger = logging.getLogger("heisskleber.zmq")
T = TypeVar("T")
class ZmqSender(Sender[T]):
"""Async publisher that sends messages to a ZMQ PUB socket.
Attributes:
config: The ZmqConf configuration object for the connection.
packer : The packer strategy to use for serializing the data.
Defaults to json packer with utf-8 encoding.
"""
def __init__(self, config: ZmqConf, packer: Packer[T] = json_packer) -> None: # type: ignore[assignment]
self.config = config
self.context = zmq.asyncio.Context.instance()
self.socket: zmq.asyncio.Socket = self.context.socket(zmq.PUB)
self.packer = packer
self.is_connected = False
async def send(self, data: T, topic: str = "zmq", **kwargs: Any) -> None:
"""Take the data as a dict, serialize it with the given packer and send it to the zmq socket."""
if not self.is_connected:
await self.start()
payload = self.packer(data)
logger.debug("sending payload %(payload)b to topic %(topic)s", {"payload": payload, "topic": topic})
await self.socket.send_multipart([topic.encode(), payload])
async def start(self) -> None:
"""Connect to the zmq socket."""
logger.info("Connecting to %(addr)s", {"addr": self.config.publisher_address})
self.socket.connect(self.config.publisher_address)
self.is_connected = True
async def stop(self) -> None:
"""Close the zmq socket."""
self.socket.close()
self.is_connected = False
def __repr__(self) -> str:
"""Return string representation of ZmqSink."""
return f"{self.__class__.__name__}(host={self.config.publisher_address}, port={self.config.publisher_port})"

View File

@@ -1,2 +0,0 @@
verbose: True
print_stdout: False

Some files were not shown because too many files have changed in this diff Show More