mirror of
https://github.com/OMGeeky/flucto-heisskleber.git
synced 2025-12-26 16:07:50 +01:00
Refactor heisskleber core, remove synchronous implementations (#156)
* #129 AsyncTcpSource enhancements - retry connection on startup (behavior is configurable) - reconnect if data receiving fails (EOF received) - add Python logging - add unit tests * remove syncronous implementations. * WIP: Refactor packer/unpacker * Refactor type hints and topic handling in console sink. * Remove comma from tcp config enum definitions * Remove references to deleted synchronous classes. * Hopefully stable interface for Packer and Unpacker. * WIP: Working with protocols and generics * Finalized Sink, Source definition. * Rename mqtt source and sink files * Rename mqtt publisher and subscriber. * Fix start function to async. * Update documentation. * Remove recursion from udp source. * rename unpack to unpacker, stay consistent. * Renaming in tests. * Make MqttSource generic. * Configure pyproject.toml to move to uv * Add nox support. * Update documentation with myst-parser and sphinx. * Mess with autogeneration of __call__ signatures. * Add dynamic versioning to hatch * Asyncio wrapper for pyserial. * Add docstrings for serial sink and source. * Refactor config handling (#171) * Removes deprecated "verbose" and "print_std" parameters * Adds class methods for config generation from dictionary or file (yaml or json at this point) * Run-time type checking via __post_init__() function * Add serial dependency. * WIP * Move broker to bin/ * Update docs. * WIP: Need to update docstrings to make ruff happy. * Move source files to src/ * Fix tests for TcpSource. * WIP: Remove old tests. * Fix docstrings in mqtt classes. * Make default tcp unpacker json_unpacker. * No failed tests if there are no tests * Update test pipeline * Update ruff pre-commit * Updated ruff formatting * Format bin/ * Fix type hints * No type checking * Make stop() async * Only test on ubuntu for now * Don't be so strict about sphinx warnings. * Rename TestConf for pytest naming compability. * Install package in editable mode for ci tests. * Update dependencies for docs generation. * Add keepalive and will to mqtt, fixes #112. * Update readme to reflect changes in usage. * Requested fixes for console adapters. * Raise correct errors in unpacker and packer. * Correct logger name for mqtt sink. * Add config options for stopbits and parity to Serial. * Remove exception logging call from yaml parser. * Add comments to clear up very implicit test. * Rename Sink -> Sender, Source -> Receiver. * Rename sink and source in tests. * Fix tests. --------- Co-authored-by: Adrian Weiler <a.weiler@aldea.de>
This commit is contained in:
3
.git_archival.txt
Normal file
3
.git_archival.txt
Normal file
@@ -0,0 +1,3 @@
|
||||
node: $Format:%H$
|
||||
node-date: $Format:%cI$
|
||||
describe-name: $Format:%(describe:tags=true,match=*[0-9]*)$
|
||||
1
.gitattributes
vendored
1
.gitattributes
vendored
@@ -1 +1,2 @@
|
||||
* text=auto eol=lf
|
||||
.git_archival.txt export-subst
|
||||
|
||||
93
.github/workflows/tests.yml
vendored
93
.github/workflows/tests.yml
vendored
@@ -1,51 +1,82 @@
|
||||
name: Tests
|
||||
name: CI
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
pull_request:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
env:
|
||||
FORCE_COLOR: 3
|
||||
|
||||
jobs:
|
||||
tests:
|
||||
pre-commit:
|
||||
name: Format
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Check out the repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python 3.10
|
||||
uses: actions/setup-python@v5.0.0
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
python-version: "3.10"
|
||||
|
||||
- name: Cache poetry install
|
||||
uses: actions/cache@v4
|
||||
fetch-depth: 0
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
path: ~/.local
|
||||
key: poetry-1.7.1-0
|
||||
|
||||
- name: Install poetry
|
||||
uses: snok/install-poetry@v1
|
||||
python-version: "3.x"
|
||||
- uses: pre-commit/action@v3.0.1
|
||||
with:
|
||||
version: 1.7.1
|
||||
virtualenvs-create: true
|
||||
virtualenvs-in-project: true
|
||||
extra_args: --hook-stage manual --all-files
|
||||
- name: Run Ruff
|
||||
run: |
|
||||
pip install ruff
|
||||
ruff check .
|
||||
|
||||
- name: cache deps
|
||||
id: cache-deps
|
||||
uses: actions/cache@v4
|
||||
tests:
|
||||
name: Python ${{ matrix.python-version }} on ${{ matrix.runs-on }}
|
||||
runs-on: ${{ matrix.runs-on }}
|
||||
needs: [pre-commit]
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: ["3.11"]
|
||||
# python-version: ["3.11", "3.12", "3.13"]
|
||||
runs-on: [ubuntu-latest]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
path: .venv
|
||||
key: pydeps-${{ hashFiles('**/poetry.lock') }}
|
||||
fetch-depth: 0
|
||||
|
||||
- run: poetry install --no-interaction # --no-root
|
||||
if: steps.cache-deps.outputs.cache-hit != 'true'
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
allow-prereleases: true
|
||||
|
||||
# - run: poetry install --no-interaction
|
||||
- name: Install dependencies
|
||||
run: python -m pip install -e ".[test,docs]"
|
||||
|
||||
- run: poetry run pytest --cov --cov-report xml
|
||||
- name: Run tests
|
||||
run: python -m pytest -ra --cov --cov-report=xml --cov-report=term --durations=20
|
||||
|
||||
- name: Upload coverage roports to Codecov
|
||||
- name: Upload coverage
|
||||
uses: codecov/codecov-action@v4
|
||||
env:
|
||||
CODECOV_TOKEN: ${{secrets.CODECOV_TOKEN}}
|
||||
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
|
||||
with:
|
||||
files: ./coverage.xml
|
||||
fail_ci_if_error: true
|
||||
|
||||
docs:
|
||||
runs-on: ubuntu-latest
|
||||
needs: [pre-commit]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
- name: Build docs
|
||||
run: |
|
||||
pip install ".[docs]"
|
||||
sphinx-build -b html docs docs/_build/html
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
repos:
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: "v4.4.0"
|
||||
rev: "v5.0.0"
|
||||
hooks:
|
||||
- id: check-case-conflict
|
||||
- id: check-merge-conflict
|
||||
@@ -10,10 +10,8 @@ repos:
|
||||
- id: trailing-whitespace
|
||||
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
# Ruff version.
|
||||
rev: v0.1.5
|
||||
rev: "v0.7.4"
|
||||
hooks:
|
||||
# Run the linter.
|
||||
- id: ruff
|
||||
# Run the formatter.
|
||||
args: ["--fix", "--show-fixes"]
|
||||
- id: ruff-format
|
||||
|
||||
@@ -1,12 +1,17 @@
|
||||
# Read the Docs configuration file
|
||||
# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
|
||||
|
||||
version: 2
|
||||
|
||||
build:
|
||||
os: ubuntu-20.04
|
||||
os: ubuntu-22.04
|
||||
tools:
|
||||
python: "3.10"
|
||||
sphinx:
|
||||
configuration: docs/conf.py
|
||||
formats: all
|
||||
python:
|
||||
install:
|
||||
- requirements: docs/requirements.txt
|
||||
- path: .
|
||||
python: "3.12"
|
||||
commands:
|
||||
- asdf plugin add uv
|
||||
- asdf install uv latest
|
||||
- asdf global uv latest
|
||||
- uv venv
|
||||
- uv pip install .[docs]
|
||||
- .venv/bin/python -m sphinx -T -b html -d docs/_build/doctrees -D
|
||||
language=en docs $READTHEDOCS_OUTPUT/html
|
||||
|
||||
50
README.md
50
README.md
@@ -23,9 +23,8 @@ Heisskleber is a versatile library designed to seamlessly "glue" together variou
|
||||
|
||||
## Features
|
||||
|
||||
- Multiple Protocol Support: Easy integration with zmq, mqtt, udp, serial, influxdb, and cmdline. Future plans include REST API and file operations.
|
||||
- Custom Data Handling: Customizable "unpack" and "pack" functions allow for the translation of any data format (e.g., ascii encoded, comma-separated messages from a serial bus) into dictionaries for easy manipulation and transmission.
|
||||
- Synchronous & Asynchronous Versions: Cater to different programming needs and scenarios with both sync and async interfaces.
|
||||
- Multiple Protocol Support: Easy integration with zmq, mqtt, udp, serial, and cmdline. Future plans include REST API and file operations.
|
||||
- Custom Data Handling: Customizable "unpacker" and "packer" functions allow for the translation of any data format (e.g., ascii encoded, comma-separated messages from a serial bus) into dictionaries for easy manipulation and transmission.
|
||||
- Extensible: Designed for easy extension with additional protocols and data handling functions.
|
||||
|
||||
## Installation
|
||||
@@ -36,56 +35,35 @@ You can install _Heisskleber_ via [pip] from [PyPI]:
|
||||
$ pip install heisskleber
|
||||
```
|
||||
|
||||
Configuration files for zmq, mqtt and other heisskleber related settings should be placed in the user's config directory, usually `$HOME/.config`. Config file templates can be found in the `config`
|
||||
directory of the package.
|
||||
|
||||
## Quick Start
|
||||
|
||||
Here's a simple example to demonstrate how Heisskleber can be used to connect a zmq source to an mqtt sink:
|
||||
|
||||
```python
|
||||
"""
|
||||
A simple forwarder that takes messages from
|
||||
A simple forwarder that takes messages from a serial device and publishes them via MQTT.
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
from heisskleber.serial import SerialSubscriber, SerialConf
|
||||
from heisskleber.mqtt import MqttPublisher, MqttConf
|
||||
|
||||
source = SerialSubscriber(config=SerialConf(port="/dev/ACM0"))
|
||||
sink = MqttPublisher(config=MqttConf(host="127.0.0.1", port=1883, user="", password=""))
|
||||
|
||||
while True:
|
||||
topic, data = source.receive()
|
||||
sink.send(data, topic="/hostname/" + topic)
|
||||
async def main():
|
||||
source = SerialSubscriber(config=SerialConf(port="/dev/ACM0", baudrate=9600))
|
||||
sink = MqttPublisher(config=MqttConf(host="mqtt.example.com", port=1883, user="", password=""))
|
||||
|
||||
while True:
|
||||
data, metadata = await source.receive()
|
||||
await sink.send(data, topic="/hotglue/" + metadata.get("topic", "serial"))
|
||||
|
||||
asyncio.run(main())
|
||||
```
|
||||
|
||||
All sources and sinks come with customizable "unpack" and "pack" functions, making it simple to work with various data formats.
|
||||
|
||||
It is also possible to do configuration via yaml files, placed at `$HOME/.config/heisskleber` and named according to the protocol in question.
|
||||
All sources and sinks come with customizable "unpacker" and "packer" functions, making it simple to work with various data formats.
|
||||
|
||||
See the [documentation][read the docs] for detailed usage.
|
||||
|
||||
## Development
|
||||
|
||||
1. Install poetry
|
||||
|
||||
```
|
||||
curl -sSL https://install.python-poetry.org | python3 -
|
||||
```
|
||||
|
||||
2. clone repository
|
||||
|
||||
```
|
||||
git clone https://github.com/flucto-gmbh/heisskleber.git
|
||||
cd heisskleber
|
||||
```
|
||||
|
||||
3. setup
|
||||
|
||||
```
|
||||
make install
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
Distributed under the terms of the [MIT license][license],
|
||||
|
||||
52
bin/zmq_broker.py
Normal file
52
bin/zmq_broker.py
Normal file
@@ -0,0 +1,52 @@
|
||||
# /// script
|
||||
# dependencies = [
|
||||
# "pyzmq",
|
||||
# "heisskleber"
|
||||
# ]
|
||||
# ///
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import sys
|
||||
|
||||
import zmq
|
||||
|
||||
from heisskleber.zmq import ZmqConf
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s - ZmqBroker - %(levelname)s - %(message)s")
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""Run ZMQ broker."""
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-c", "--config", type=str, required=True, help="ZMQ configuration file (yaml or json)")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
config = ZmqConf.from_file(args.config)
|
||||
|
||||
try:
|
||||
ctx = zmq.Context()
|
||||
|
||||
logger.info("Creating XPUB socket")
|
||||
xpub = ctx.socket(zmq.XPUB)
|
||||
logger.info("Creating XSUB socket")
|
||||
xsub = ctx.socket(zmq.XSUB)
|
||||
|
||||
logger.info("Connecting XPUB socket to %(addr)s", {"addr": config.subscriber_address})
|
||||
xpub.bind(config.subscriber_address)
|
||||
|
||||
logger.info("Connecting XSUB socket to %(addr)s", {"addr": config.publisher_address})
|
||||
xsub.bind(config.publisher_address)
|
||||
|
||||
logger.info("Starting proxy...")
|
||||
zmq.proxy(xpub, xsub)
|
||||
except Exception:
|
||||
logger.exception("Oh no! ZMQ broker failed!")
|
||||
sys.exit(-1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
76
docs/conf.py
76
docs/conf.py
@@ -1,12 +1,74 @@
|
||||
"""Sphinx configuration."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib.metadata
|
||||
from typing import Any
|
||||
|
||||
project = "Heisskleber"
|
||||
author = "Felix Weiler"
|
||||
copyright = "2023, Felix Weiler"
|
||||
author = "Felix Weiler-Detjen"
|
||||
copyright = "2023, Flucto GmbH"
|
||||
|
||||
version = release = importlib.metadata.version("heisskleber")
|
||||
|
||||
extensions = [
|
||||
"sphinx.ext.autodoc",
|
||||
"sphinx.ext.napoleon",
|
||||
"myst_parser",
|
||||
] # , "autodoc2"
|
||||
# autodoc2_packages = ["../heisskleber"]
|
||||
autodoc_typehints = "description"
|
||||
"sphinx.ext.autodoc",
|
||||
"sphinx.ext.intersphinx",
|
||||
"sphinx.ext.mathjax",
|
||||
"sphinx.ext.napoleon",
|
||||
"sphinx_autodoc_typehints",
|
||||
"sphinx_copybutton",
|
||||
]
|
||||
|
||||
autodoc_typehints = "description" # or 'signature' or 'both'
|
||||
|
||||
autodoc_type_aliases = {
|
||||
"T": "heisskleber.core.T",
|
||||
"T_co": "heisskleber.core.T_co",
|
||||
"T_contra": "heisskleber.core.T_contra",
|
||||
}
|
||||
|
||||
# If you're using typing.TypeVar in your code:
|
||||
nitpicky = True
|
||||
nitpick_ignore = [
|
||||
("py:class", "T"),
|
||||
("py:class", "T_co"),
|
||||
("py:class", "T_contra"),
|
||||
("py:data", "typing.Any"),
|
||||
("py:class", "_io.StringIO"),
|
||||
("py:class", "_io.BytesIO"),
|
||||
]
|
||||
|
||||
source_suffix = [".rst", ".md"]
|
||||
|
||||
exclude_patterns = [
|
||||
"_build",
|
||||
"**.ipynb_checkpoints",
|
||||
"Thumbs.db",
|
||||
".DS_Store",
|
||||
".env",
|
||||
".venv",
|
||||
]
|
||||
|
||||
html_theme = "furo"
|
||||
|
||||
html_theme_options: dict[str, Any] = {
|
||||
"footer_icons": [
|
||||
{
|
||||
"name": "GitHub",
|
||||
"url": "https://github.com/flucto-gmbh/heisskleber",
|
||||
"html": """
|
||||
<svg stroke="currentColor" fill="currentColor" stroke-width="0" viewBox="0 0 16 16">
|
||||
<path fill-rule="evenodd" d="M8 0C3.58 0 0 3.58 0 8c0 3.54 2.29 6.53 5.47 7.59.4.07.55-.17.55-.38 0-.19-.01-.82-.01-1.49-2.01.37-2.53-.49-2.69-.94-.09-.23-.48-.94-.82-1.13-.28-.15-.68-.52-.01-.53.63-.01 1.08.58 1.23.82.72 1.21 1.87.87 2.33.66.07-.52.28-.87.51-1.07-1.78-.2-3.64-.89-3.64-3.95 0-.87.31-1.59.82-2.15-.08-.2-.36-1.02.08-2.12 0 0 .67-.21 2.2.82.64-.18 1.32-.27 2-.27.68 0 1.36.09 2 .27 1.53-1.04 2.2-.82 2.2-.82.44 1.1.16 1.92.08 2.12.51.56.82 1.27.82 2.15 0 3.07-1.87 3.75-3.65 3.95.29.25.54.73.54 1.48 0 1.07-.01 1.93-.01 2.2 0 .21.15.46.55.38A8.013 8.013 0 0 0 16 8c0-4.42-3.58-8-8-8z"></path>
|
||||
</svg>
|
||||
""",
|
||||
"class": "",
|
||||
},
|
||||
],
|
||||
"source_repository": "https://github.com/flucto-gmbh/heisskleber",
|
||||
"source_branch": "main",
|
||||
"source_directory": "docs/",
|
||||
}
|
||||
|
||||
always_document_param_types = True
|
||||
|
||||
7
docs/development.md
Normal file
7
docs/development.md
Normal file
@@ -0,0 +1,7 @@
|
||||
# How to contribute to development
|
||||
|
||||
1. Fork repository
|
||||
2. Set up development environment
|
||||
|
||||
- Install uv, if you don't have it already: `curl -LsSf https://astral.sh/uv/install.sh | sh` (Or install from package manager, if applicable)
|
||||
- Set python version (`uv venv --python 3.10`) (or whatever version you would like, as long as it's 3.10+. Careful if you use pyenv.)
|
||||
@@ -5,6 +5,7 @@ end-before: <!-- github-only -->
|
||||
```
|
||||
|
||||
[license]: license
|
||||
[serializing]: packer_and_unpacker
|
||||
|
||||
```{toctree}
|
||||
---
|
||||
@@ -14,6 +15,8 @@ maxdepth: 1
|
||||
|
||||
yaml-config
|
||||
reference
|
||||
serialization
|
||||
development
|
||||
License <license>
|
||||
Changelog <https://github.com/flucto-gmbh/heisskleber/releases>
|
||||
```
|
||||
|
||||
@@ -1,45 +1,115 @@
|
||||
# Reference
|
||||
|
||||
## Network
|
||||
|
||||
```{eval-rst}
|
||||
.. automodule:: heisskleber.mqtt
|
||||
:members:
|
||||
.. automodule:: heisskleber.zmq
|
||||
:members:
|
||||
.. automodule:: heisskleber.serial
|
||||
:members:
|
||||
.. automodule:: heisskleber.udp
|
||||
:members:
|
||||
```
|
||||
|
||||
## Baseclasses
|
||||
|
||||
```{eval-rst}
|
||||
.. automodule:: heisskleber.core.types
|
||||
.. autoclass:: heisskleber.core.AsyncSink
|
||||
:members:
|
||||
|
||||
.. autoclass:: heisskleber.core.AsyncSource
|
||||
:members:
|
||||
```
|
||||
|
||||
## Stream
|
||||
## Serialization
|
||||
|
||||
Work on streaming data.
|
||||
See <project:serialization.md> for a tutorial on how to implement custom packer and unpacker for (de-)serialization.
|
||||
|
||||
```{eval-rst}
|
||||
.. automodule:: heisskleber.stream.filter
|
||||
:members: __aiter__
|
||||
.. autoclass:: heisskleber.core::Packer
|
||||
|
||||
.. automodule:: heisskleber.stream.butter
|
||||
:members:
|
||||
.. autoclass:: heisskleber.core::Unpacker
|
||||
|
||||
.. automodule:: heisskleber.stream.gh-filter
|
||||
:members:
|
||||
.. autoclass:: heisskleber.core.unpacker::JSONUnpacker
|
||||
|
||||
.. autoclass:: heisskleber.core.packer::JSONPacker
|
||||
```
|
||||
|
||||
## Config
|
||||
|
||||
### Loading configs
|
||||
### Errors
|
||||
|
||||
```{eval-rst}
|
||||
.. automodule:: heisskleber.config
|
||||
:members:
|
||||
.. autoclass:: heisskleber.core::UnpackError
|
||||
|
||||
.. autoclass:: heisskleber.core::PackerError
|
||||
```
|
||||
|
||||
## Implementations (Adapters)
|
||||
|
||||
### MQTT
|
||||
|
||||
```{eval-rst}
|
||||
.. automodule:: heisskleber.mqtt
|
||||
:no-members:
|
||||
|
||||
.. autoclass:: heisskleber.mqtt.MqttSink
|
||||
:members: send
|
||||
|
||||
.. autoclass:: heisskleber.mqtt.MqttSource
|
||||
:members: receive, subscribe
|
||||
|
||||
.. autoclass:: heisskleber.mqtt.MqttConf
|
||||
:members:
|
||||
```
|
||||
|
||||
### ZMQ
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: heisskleber.zmq::ZmqConf
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: heisskleber.zmq::ZmqSink
|
||||
:members: send
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: heisskleber.zmq::ZmqSource
|
||||
:members: receive
|
||||
```
|
||||
|
||||
### Serial
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: heisskleber.serial::SerialConf
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: heisskleber.serial::SerialSink
|
||||
:members: send
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: heisskleber.serial::SerialSource
|
||||
:members: receive
|
||||
```
|
||||
|
||||
### TCP
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: heisskleber.tcp::TcpConf
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: heisskleber.tcp::TcpSink
|
||||
:members: send
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: heisskleber.tcp::TcpSource
|
||||
:members: receive
|
||||
```
|
||||
|
||||
### UDP
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: heisskleber.udp::UdpConf
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: heisskleber.udp::UdpSink
|
||||
:members: send
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: heisskleber.udp::UdpSource
|
||||
:members: receive
|
||||
```
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
furo==2024.1.29
|
||||
sphinx==7.2.6
|
||||
myst_parser==2.0.0
|
||||
furo==2024.8.6
|
||||
sphinx==8.1.3
|
||||
myst_parser==4.0.0
|
||||
|
||||
72
docs/serialization.md
Normal file
72
docs/serialization.md
Normal file
@@ -0,0 +1,72 @@
|
||||
# Serialization
|
||||
|
||||
## Implementing a custom Packer
|
||||
|
||||
The packer class is defined in heisskleber.core.packer.py as a Protocol [see PEP 544](https://peps.python.org/pep-0544/).
|
||||
|
||||
```python
|
||||
T = TypeVar("T", contravariant=True)
|
||||
|
||||
class Packer(Protocol[T]):
|
||||
def __call__(self, data: T) -> bytes:
|
||||
pass
|
||||
```
|
||||
|
||||
Users can create custom Packer classes with variable input data, either as callable classes, subclasses of the packer class or functions.
|
||||
Please note, that to satisfy type checking engines, the argument must be named `data`, but being Python, it's obviously not enforced at runtime.
|
||||
The AsyncSink's type is defined by the concrete packer implementation. So if your Packer packs strings to bytes, the AsyncSink will be of type `AsyncSink[str]`,
|
||||
indicating that the send function takes strings only, see example below:
|
||||
|
||||
```python
|
||||
from heisskleber import MqttSink, MqttConf
|
||||
|
||||
def string_packer(data: str) -> bytes:
|
||||
return data.encode("ascii")
|
||||
|
||||
async def main():
|
||||
sink = MqttSink(MqttConf(), packer = string_packer)
|
||||
await sink.send("Hi there!") # This is fine
|
||||
await sink.send({"data": 3.14}) # Type checker will complain
|
||||
```
|
||||
|
||||
Heisskleber comes with default packers, such as the JSON_Packer, which can be importet as json_packer from heisskleber.core and is the default value for most Sinks.
|
||||
|
||||
## Implementing a custom Unpacker
|
||||
|
||||
The unpacker's responsibility is creating usable data from serialized byte strings.
|
||||
This may be a serialized json string which is unpacked into a dictionary, but could be anything the user defines.
|
||||
In heisskleber.core.unpacker.py the Unpacker Protocol is defined.
|
||||
|
||||
```python
|
||||
class Unpacker(Protocol[T]):
|
||||
def __call__(self, payload: bytes) -> tuple[T, dict[str, Any]]:
|
||||
pass
|
||||
```
|
||||
|
||||
Here, the payload is fixed to be of type bytes and the return type is a combination of a user-defined data type and a dictionary of meta-data.
|
||||
|
||||
```{eval-rst}
|
||||
.. note::
|
||||
Please Note: The extra dictionary may be updated by the Source, e.g. the MqttSource will add a "topic" field, received from the mqtt node.
|
||||
```
|
||||
|
||||
The receive function of an AsyncSource object will have its return type informed by the signature of the unpacker.
|
||||
|
||||
```python
|
||||
from heisskleber import MqttSource, MqttConf
|
||||
import time
|
||||
|
||||
def csv_unpacker(payload: bytes) -> tuple[list[str], dict[str, Any]]:
|
||||
# Unpack a utf-8 encoded csv string, such as b'1,42,3.14,100.0' to [1.0, 42.0, 3.14, 100.0]
|
||||
# Adds some exemplary meta data
|
||||
return [float(chunk) for chunk in payload.decode().split(",")], {"processed_at": time.time()}
|
||||
|
||||
async def main():
|
||||
sub = MqttSource(MqttConf, unpacker = csv_unpacker)
|
||||
data, extra = await sub.receive()
|
||||
assert isinstance(data, list[str]) # passes
|
||||
```
|
||||
|
||||
## Error handling
|
||||
|
||||
To be implemented...
|
||||
@@ -11,10 +11,6 @@ The configuration parameters are host, port, ssl, user and password to establish
|
||||
- 2: "At least once", where messages are assured to arrive but duplicates can occur.
|
||||
- 3: "Exactly once", where messages are assured to arrive exactly once.
|
||||
- **max_saved_messages**: maximum number of messages that will be saved in the buffer until connection is available.
|
||||
- **packstyle**: key of the serialization technique to use. Currently only JSON is supported.
|
||||
- **source_id**: id of the device that will be used to identify the MQTT messages to be used by clients to format the topic.
|
||||
Suggested topic format is in the form of `f"/{measurement_type}/{source_id}"`, eg. "/temperature/box-01".
|
||||
- **topics**: the topics that the mqtt forwarder will subscribe to.
|
||||
|
||||
```yaml
|
||||
# Heisskleber config file for MqttConf
|
||||
@@ -27,10 +23,4 @@ qos: 0 # quality of service, 0=at most once, 1=at least once, 2=exactly once
|
||||
timeout_s: 60
|
||||
retain: false # save last message
|
||||
max_saved_messages: 100 # buffer messages in until connection available
|
||||
packstyle: json
|
||||
|
||||
# configs only valid for mqtt forwarder
|
||||
mapping: /deprecated/
|
||||
source_id: box-01
|
||||
topics: ["topic1", "topic2"]
|
||||
```
|
||||
|
||||
@@ -1,19 +0,0 @@
|
||||
"""Heisskleber."""
|
||||
|
||||
from .core.async_factories import get_async_sink, get_async_source
|
||||
from .core.factories import get_publisher, get_sink, get_source, get_subscriber
|
||||
from .core.types import AsyncSink, AsyncSource, Sink, Source
|
||||
|
||||
__all__ = [
|
||||
"get_source",
|
||||
"get_sink",
|
||||
"get_publisher",
|
||||
"get_subscriber",
|
||||
"get_async_source",
|
||||
"get_async_sink",
|
||||
"Sink",
|
||||
"Source",
|
||||
"AsyncSink",
|
||||
"AsyncSource",
|
||||
]
|
||||
__version__ = "0.5.7"
|
||||
@@ -1,3 +0,0 @@
|
||||
from .zmq_broker import zmq_broker as start_zmq_broker
|
||||
|
||||
__all__ = ["start_zmq_broker"]
|
||||
@@ -1,60 +0,0 @@
|
||||
import sys
|
||||
|
||||
import zmq
|
||||
from zmq import Socket
|
||||
|
||||
from heisskleber.config import load_config
|
||||
from heisskleber.zmq.config import ZmqConf as BrokerConf
|
||||
|
||||
|
||||
class BrokerBindingError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def bind_socket(socket: Socket, address: str, socket_type: str, verbose=False) -> None:
|
||||
"""Bind a ZMQ socket and handle errors."""
|
||||
if verbose:
|
||||
print(f"creating {socket_type} socket")
|
||||
try:
|
||||
socket.bind(address)
|
||||
except Exception as err:
|
||||
error_message = f"failed to bind to {socket_type}: {err}"
|
||||
raise BrokerBindingError(error_message) from err
|
||||
if verbose:
|
||||
print(f"successfully bound to {socket_type} socket: {address}")
|
||||
|
||||
|
||||
def create_proxy(xpub: Socket, xsub: Socket, verbose=False) -> None:
|
||||
"""Create a ZMQ proxy to connect XPUB and XSUB sockets."""
|
||||
if verbose:
|
||||
print("creating proxy")
|
||||
try:
|
||||
zmq.proxy(xpub, xsub)
|
||||
except Exception as err:
|
||||
error_message = f"failed to create proxy: {err}"
|
||||
raise BrokerBindingError(error_message) from err
|
||||
|
||||
|
||||
# TODO reimplement as object?
|
||||
def zmq_broker(config: BrokerConf) -> None:
|
||||
"""Start a zmq broker.
|
||||
|
||||
Binds to a publisher and subscriber port, allowing many to many connections."""
|
||||
ctx = zmq.Context()
|
||||
|
||||
xpub = ctx.socket(zmq.XPUB)
|
||||
xsub = ctx.socket(zmq.XSUB)
|
||||
|
||||
try:
|
||||
bind_socket(xpub, config.subscriber_address, "publisher", config.verbose)
|
||||
bind_socket(xsub, config.publisher_address, "subscriber", config.verbose)
|
||||
create_proxy(xpub, xsub, config.verbose)
|
||||
except BrokerBindingError as e:
|
||||
print(e)
|
||||
sys.exit(-1)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""Start a zmq broker, with a user specified configuration."""
|
||||
broker_config = load_config(BrokerConf(), "zmq")
|
||||
zmq_broker(broker_config)
|
||||
@@ -1,4 +0,0 @@
|
||||
from .config import BaseConf, Config
|
||||
from .parse import load_config
|
||||
|
||||
__all__ = ["load_config", "BaseConf", "Config"]
|
||||
@@ -1,47 +0,0 @@
|
||||
import argparse
|
||||
|
||||
|
||||
class KeyValue(argparse.Action):
|
||||
def __call__(self, parser, args, values, option_string=None) -> None:
|
||||
try:
|
||||
params = dict(x.split("=") for x in values)
|
||||
except ValueError as ex:
|
||||
raise argparse.ArgumentError(
|
||||
self,
|
||||
f'Could not parse argument "{values}" as k1=v1 k2=v2 ... format: {ex}',
|
||||
) from ex
|
||||
setattr(args, self.dest, params)
|
||||
|
||||
|
||||
def get_cmdline(args=None) -> dict:
|
||||
"""
|
||||
get commandline arguments and return a dictionary of
|
||||
the provided arguments.
|
||||
|
||||
available commandline arguments are:
|
||||
--verbose: flag to toggle debugging output
|
||||
--print-stdout: flag to toggle all data printed to stdout
|
||||
--param key1=value1 key2=value2: allows to pass service specific
|
||||
parameters
|
||||
"""
|
||||
arp = argparse.ArgumentParser()
|
||||
arp.add_argument("--verbose", action="store_true", help="debug output flag")
|
||||
arp.add_argument(
|
||||
"--print-stdout",
|
||||
action="store_true",
|
||||
help="toggles output of all data to stdout",
|
||||
)
|
||||
arp.add_argument(
|
||||
"--params",
|
||||
nargs="*",
|
||||
action=KeyValue,
|
||||
)
|
||||
args = arp.parse_args(args)
|
||||
config = {}
|
||||
if args.verbose:
|
||||
config["verbose"] = args.verbose
|
||||
if args.print_stdout:
|
||||
config["print_stdout"] = args.print_stdout
|
||||
if args.params:
|
||||
config |= args.params
|
||||
return config
|
||||
@@ -1,35 +0,0 @@
|
||||
import socket
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, TypeVar
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseConf:
|
||||
"""
|
||||
default configuration class for generic configuration info
|
||||
"""
|
||||
|
||||
verbose: bool = False
|
||||
print_stdout: bool = False
|
||||
|
||||
def __setitem__(self, key: str, value: Any) -> None:
|
||||
if hasattr(self, key):
|
||||
self.__setattr__(key, value)
|
||||
else:
|
||||
warnings.warn(UserWarning(f"no such class member: {key}"), stacklevel=2)
|
||||
|
||||
def __getitem__(self, key: str) -> Any:
|
||||
if hasattr(self, key):
|
||||
return getattr(self, key)
|
||||
else:
|
||||
warnings.warn(UserWarning(f"no such class member: {key}"), stacklevel=2)
|
||||
|
||||
@property
|
||||
def serial_number(self) -> str:
|
||||
return socket.gethostname().upper()
|
||||
|
||||
|
||||
Config = TypeVar(
|
||||
"Config", bound=BaseConf
|
||||
) # https://stackoverflow.com/a/46227137 , https://docs.python.org/3/library/typing.html#typing.TypeVar
|
||||
@@ -1,75 +0,0 @@
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import yaml
|
||||
|
||||
from heisskleber.config.cmdline import get_cmdline
|
||||
from heisskleber.config.config import Config
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_config_dir() -> Path:
|
||||
config_dir = Path.home() / ".config" / "heisskleber"
|
||||
if not config_dir.is_dir():
|
||||
log.error(f"no such directory: {config_dir}", stacklevel=2)
|
||||
raise FileNotFoundError
|
||||
return config_dir
|
||||
|
||||
|
||||
def get_config_filepath(filename: str) -> Path:
|
||||
config_filepath = get_config_dir() / filename
|
||||
if not config_filepath.is_file():
|
||||
log.error(f"no such file: {config_filepath}", stacklevel=2)
|
||||
raise FileNotFoundError
|
||||
return config_filepath
|
||||
|
||||
|
||||
def read_yaml_config_file(config_fpath: Path) -> dict[str, Any]:
|
||||
with config_fpath.open() as config_filehandle:
|
||||
return yaml.safe_load(config_filehandle) # type: ignore [no-any-return]
|
||||
|
||||
|
||||
def update_config(config: Config, config_dict: dict[str, Any]) -> Config:
|
||||
for config_key, config_value in config_dict.items():
|
||||
if not hasattr(config, config_key):
|
||||
error_msg = f"no such configuration parameter: {config_key}, skipping"
|
||||
log.info(error_msg, stacklevel=2)
|
||||
continue
|
||||
cast_func = type(config[config_key])
|
||||
try:
|
||||
config[config_key] = cast_func(config_value)
|
||||
except Exception as e:
|
||||
log.warning(
|
||||
f"failed to cast {config_value} to {type(config[config_key])}: {e}. skipping",
|
||||
stacklevel=2,
|
||||
)
|
||||
continue
|
||||
return config
|
||||
|
||||
|
||||
def load_config(config: Config, config_filename: str, read_commandline: bool = True) -> Config:
|
||||
"""Load the config file and update the config object.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config : BaseConf
|
||||
The config object to fill with values.
|
||||
config_filename : str
|
||||
The name of the config file in $HOME/.config
|
||||
If the file does not have an extension the default extension .yaml is appended.
|
||||
read_commandline : bool
|
||||
Whether to read arguments from the command line. Optional. Defaults to True.
|
||||
"""
|
||||
config_filename = config_filename if "." in config_filename else config_filename + ".yaml"
|
||||
config_filepath = get_config_filepath(config_filename)
|
||||
config_dict = read_yaml_config_file(config_filepath)
|
||||
config = update_config(config, config_dict)
|
||||
|
||||
if not read_commandline:
|
||||
return config
|
||||
|
||||
config_dict = get_cmdline()
|
||||
config = update_config(config, config_dict)
|
||||
return config
|
||||
@@ -1,55 +0,0 @@
|
||||
import json
|
||||
import time
|
||||
|
||||
from heisskleber.core.types import AsyncSink, Serializable, Sink
|
||||
|
||||
|
||||
class ConsoleSink(Sink):
|
||||
def __init__(self, pretty: bool = False, verbose: bool = False) -> None:
|
||||
self.verbose = verbose
|
||||
self.pretty = pretty
|
||||
|
||||
def send(self, data: dict[str, Serializable], topic: str) -> None:
|
||||
verbose_topic = topic + ":\t" if self.verbose else ""
|
||||
if self.pretty:
|
||||
print(verbose_topic + json.dumps(data, indent=4))
|
||||
else:
|
||||
print(verbose_topic + str(data))
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}(pretty={self.pretty}, verbose={self.verbose})"
|
||||
|
||||
def start(self) -> None:
|
||||
pass
|
||||
|
||||
def stop(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class AsyncConsoleSink(AsyncSink):
|
||||
def __init__(self, pretty: bool = False, verbose: bool = False) -> None:
|
||||
self.verbose = verbose
|
||||
self.pretty = pretty
|
||||
|
||||
async def send(self, data: dict[str, Serializable], topic: str) -> None:
|
||||
verbose_topic = topic + ":\t" if self.verbose else ""
|
||||
if self.pretty:
|
||||
print(verbose_topic + json.dumps(data, indent=4))
|
||||
else:
|
||||
print(verbose_topic + str(data))
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}(pretty={self.pretty}, verbose={self.verbose})"
|
||||
|
||||
def start(self) -> None:
|
||||
pass
|
||||
|
||||
def stop(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sink = ConsoleSink()
|
||||
while True:
|
||||
sink.send({"test": "test"}, "test")
|
||||
time.sleep(1)
|
||||
@@ -1,98 +0,0 @@
|
||||
import asyncio
|
||||
import json
|
||||
import sys
|
||||
import time
|
||||
from queue import SimpleQueue
|
||||
from threading import Thread
|
||||
|
||||
from heisskleber.core.types import AsyncSource, Serializable, Source
|
||||
|
||||
|
||||
class ConsoleSource(Source):
|
||||
def __init__(self, topic: str = "console") -> None:
|
||||
self.topic = topic
|
||||
self.queue = SimpleQueue()
|
||||
self.pack = json.loads
|
||||
self.thread: Thread | None = None
|
||||
|
||||
def listener_task(self):
|
||||
while True:
|
||||
try:
|
||||
data = sys.stdin.readline()
|
||||
payload = self.pack(data)
|
||||
self.queue.put(payload)
|
||||
except json.decoder.JSONDecodeError:
|
||||
print("Invalid JSON")
|
||||
continue
|
||||
except ValueError:
|
||||
break
|
||||
print("listener task finished")
|
||||
|
||||
def receive(self) -> tuple[str, dict[str, Serializable]]:
|
||||
if not self.thread:
|
||||
self.start()
|
||||
|
||||
data = self.queue.get()
|
||||
return self.topic, data
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}(topic={self.topic})"
|
||||
|
||||
def start(self) -> None:
|
||||
self.thread = Thread(target=self.listener_task, daemon=True)
|
||||
self.thread.start()
|
||||
|
||||
def stop(self) -> None:
|
||||
if self.thread:
|
||||
sys.stdin.close()
|
||||
self.thread.join()
|
||||
|
||||
|
||||
class AsyncConsoleSource(AsyncSource):
|
||||
def __init__(self, topic: str = "console") -> None:
|
||||
self.topic = topic
|
||||
self.queue: asyncio.Queue[dict[str, Serializable]] = asyncio.Queue(maxsize=10)
|
||||
self.pack = json.loads
|
||||
self.task: asyncio.Task[None] | None = None
|
||||
|
||||
async def listener_task(self):
|
||||
while True:
|
||||
data = sys.stdin.readline()
|
||||
payload = self.pack(data)
|
||||
await self.queue.put(payload)
|
||||
|
||||
async def receive(self) -> tuple[str, dict[str, Serializable]]:
|
||||
if not self.task:
|
||||
self.start()
|
||||
|
||||
data = await self.queue.get()
|
||||
return self.topic, data
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}(topic={self.topic})"
|
||||
|
||||
def start(self) -> None:
|
||||
self.task = asyncio.create_task(self.listener_task())
|
||||
|
||||
def stop(self) -> None:
|
||||
if self.task:
|
||||
self.task.cancel()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
console_source = ConsoleSource()
|
||||
console_source.start()
|
||||
|
||||
print("Listening to console input.")
|
||||
|
||||
count = 0
|
||||
|
||||
try:
|
||||
while True:
|
||||
print(console_source.receive())
|
||||
time.sleep(1)
|
||||
count += 1
|
||||
print(count)
|
||||
except KeyboardInterrupt:
|
||||
print("Stopped")
|
||||
sys.exit(0)
|
||||
@@ -1,59 +0,0 @@
|
||||
from heisskleber.config import BaseConf, load_config
|
||||
from heisskleber.mqtt import AsyncMqttPublisher, AsyncMqttSubscriber, MqttConf
|
||||
from heisskleber.udp import AsyncUdpSink, AsyncUdpSource, UdpConf
|
||||
from heisskleber.zmq import ZmqAsyncPublisher, ZmqAsyncSubscriber, ZmqConf
|
||||
|
||||
from .types import AsyncSink, AsyncSource
|
||||
|
||||
_registered_async_sinks: dict[str, tuple[type[AsyncSink], type[BaseConf]]] = {
|
||||
"mqtt": (AsyncMqttPublisher, MqttConf),
|
||||
"zmq": (ZmqAsyncPublisher, ZmqConf),
|
||||
"udp": (AsyncUdpSink, UdpConf),
|
||||
}
|
||||
|
||||
_registered_async_sources: dict[str, tuple] = {
|
||||
"mqtt": (AsyncMqttSubscriber, MqttConf),
|
||||
"zmq": (ZmqAsyncSubscriber, ZmqConf),
|
||||
"udp": (AsyncUdpSource, UdpConf),
|
||||
}
|
||||
|
||||
|
||||
def get_async_sink(name: str) -> AsyncSink:
|
||||
"""
|
||||
Factory function to create a sink object.
|
||||
|
||||
Parameters:
|
||||
name: Name of the sink to create.
|
||||
config: Configuration object to use for the sink.
|
||||
"""
|
||||
|
||||
if name not in _registered_async_sinks:
|
||||
error_message = f"{name} is not a registered Sink."
|
||||
raise KeyError(error_message)
|
||||
|
||||
pub_cls, conf_cls = _registered_async_sinks[name]
|
||||
|
||||
config = load_config(conf_cls(), name, read_commandline=False)
|
||||
|
||||
return pub_cls(config)
|
||||
|
||||
|
||||
def get_async_source(name: str, topic: str | list[str] | tuple[str]) -> AsyncSource:
|
||||
"""
|
||||
Factory function to create a source object.
|
||||
|
||||
Parameters:
|
||||
name: Name of the source to create.
|
||||
config: Configuration object to use for the source.
|
||||
topic: Topic to subscribe to.
|
||||
"""
|
||||
|
||||
if name not in _registered_async_sources:
|
||||
error_message = f"{name} is not a registered Source."
|
||||
raise KeyError(error_message)
|
||||
|
||||
sub_cls, conf_cls = _registered_async_sources[name]
|
||||
|
||||
config = load_config(conf_cls(), name, read_commandline=False)
|
||||
|
||||
return sub_cls(config, topic)
|
||||
@@ -1,90 +0,0 @@
|
||||
from heisskleber.config import BaseConf, load_config
|
||||
from heisskleber.core.types import Sink, Source
|
||||
from heisskleber.mqtt import MqttConf, MqttPublisher, MqttSubscriber
|
||||
from heisskleber.serial import SerialConf, SerialPublisher, SerialSubscriber
|
||||
from heisskleber.udp import UdpConf, UdpPublisher, UdpSubscriber
|
||||
from heisskleber.zmq import ZmqConf, ZmqPublisher, ZmqSubscriber
|
||||
|
||||
_registered_sinks: dict[str, tuple[type[Sink], type[BaseConf]]] = {
|
||||
"zmq": (ZmqPublisher, ZmqConf),
|
||||
"mqtt": (MqttPublisher, MqttConf),
|
||||
"serial": (SerialPublisher, SerialConf),
|
||||
"udp": (UdpPublisher, UdpConf),
|
||||
}
|
||||
|
||||
_registered_sources: dict[str, tuple[type[Source], type[BaseConf]]] = {
|
||||
"zmq": (ZmqSubscriber, ZmqConf),
|
||||
"mqtt": (MqttSubscriber, MqttConf),
|
||||
"serial": (SerialSubscriber, SerialConf),
|
||||
"udp": (UdpSubscriber, UdpConf),
|
||||
}
|
||||
|
||||
|
||||
def get_sink(name: str) -> Sink:
|
||||
"""
|
||||
Factory function to create a sink object.
|
||||
|
||||
Parameters:
|
||||
name: Name of the sink to create.
|
||||
config: Configuration object to use for the sink.
|
||||
"""
|
||||
|
||||
if name not in _registered_sinks:
|
||||
error_message = f"{name} is not a registered Sink."
|
||||
raise KeyError(error_message)
|
||||
|
||||
pub_cls, conf_cls = _registered_sinks[name]
|
||||
|
||||
print(f"loading {name} config")
|
||||
config = load_config(conf_cls(), name, read_commandline=False)
|
||||
|
||||
return pub_cls(config)
|
||||
|
||||
|
||||
def get_source(name: str, topic: str | list[str]) -> Source:
|
||||
"""
|
||||
Factory function to create a source object.
|
||||
|
||||
Parameters:
|
||||
name: Name of the source to create.
|
||||
config: Configuration object to use for the source.
|
||||
topic: Topic to subscribe to.
|
||||
"""
|
||||
|
||||
if name not in _registered_sinks:
|
||||
error_message = f"{name} is not a registered Source."
|
||||
raise KeyError(error_message)
|
||||
|
||||
sub_cls, conf_cls = _registered_sources[name]
|
||||
|
||||
print(f"loading {name} config")
|
||||
config = load_config(conf_cls(), name, read_commandline=False)
|
||||
|
||||
return sub_cls(config, topic)
|
||||
|
||||
|
||||
def get_subscriber(name: str, topic: str | list[str]) -> Source:
|
||||
"""
|
||||
Deprecated: Factory function to create a source object (formerly known as subscriber).
|
||||
|
||||
Parameters:
|
||||
name: Name of the source to create.
|
||||
config: Configuration object to use for the source.
|
||||
topic: Topic to subscribe to.
|
||||
"""
|
||||
|
||||
print("Deprecated: use get_source instead.")
|
||||
return get_source(name, topic)
|
||||
|
||||
|
||||
def get_publisher(name: str) -> Sink:
|
||||
"""
|
||||
Deprecated: Factory function to create a sink object (formerly known as publisher).
|
||||
|
||||
Parameters:
|
||||
name: Name of the sink to create.
|
||||
config: Configuration object to use for the sink.
|
||||
"""
|
||||
|
||||
print("Deprecated: use get_sink instead.")
|
||||
return get_sink(name)
|
||||
@@ -1,44 +0,0 @@
|
||||
"""Packer and unpacker for network data."""
|
||||
import json
|
||||
import pickle
|
||||
from typing import Any, Callable
|
||||
|
||||
from .types import Serializable
|
||||
|
||||
|
||||
def get_packer(style: str) -> Callable[[dict[str, Serializable]], str]:
|
||||
"""Return a packer function for the given style.
|
||||
|
||||
Packer func serializes a given dict."""
|
||||
if style in _packstyles:
|
||||
return _packstyles[style]
|
||||
else:
|
||||
return _packstyles["default"]
|
||||
|
||||
|
||||
def get_unpacker(style: str) -> Callable[[str], dict[str, Serializable]]:
|
||||
"""Return an unpacker function for the given style.
|
||||
|
||||
Unpacker func deserializes a string."""
|
||||
if style in _unpackstyles:
|
||||
return _unpackstyles[style]
|
||||
else:
|
||||
return _unpackstyles["default"]
|
||||
|
||||
|
||||
def serialpacker(data: dict[str, Any]) -> str:
|
||||
return ",".join([str(v) for v in data.values()])
|
||||
|
||||
|
||||
_packstyles: dict[str, Callable[[dict[str, Serializable]], str]] = {
|
||||
"default": json.dumps,
|
||||
"json": json.dumps,
|
||||
"pickle": pickle.dumps, # type: ignore
|
||||
"serial": serialpacker,
|
||||
}
|
||||
|
||||
_unpackstyles: dict[str, Callable[[str], dict[str, Serializable]]] = {
|
||||
"default": json.loads,
|
||||
"json": json.loads,
|
||||
"pickle": pickle.loads, # type: ignore
|
||||
}
|
||||
@@ -1,180 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import AsyncGenerator, Generator
|
||||
from typing import Any, Callable, Union
|
||||
|
||||
from heisskleber.config import BaseConf
|
||||
|
||||
Serializable = Union[int, float]
|
||||
|
||||
|
||||
class Sink(ABC):
|
||||
"""
|
||||
Sink interface to send() data to.
|
||||
"""
|
||||
|
||||
pack: Callable[[dict[str, Serializable]], str]
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, config: BaseConf) -> None:
|
||||
"""
|
||||
Initialize the publisher with a configuration object.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def send(self, data: dict[str, Serializable], topic: str) -> None:
|
||||
"""
|
||||
Send data via the implemented output stream.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def __repr__(self) -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def start(self) -> None:
|
||||
"""
|
||||
Start any background processes and tasks.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def stop(self) -> None:
|
||||
"""
|
||||
Stop any background processes and tasks.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class Source(ABC):
|
||||
"""
|
||||
Source interface that emits data via the receive() method.
|
||||
"""
|
||||
|
||||
unpack: Callable[[str], dict[str, Serializable]]
|
||||
|
||||
def __iter__(self) -> Generator[tuple[str, dict[str, Serializable]], None, None]:
|
||||
topic, data = self.receive()
|
||||
yield topic, data
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, config: BaseConf, topic: str | list[str]) -> None:
|
||||
"""
|
||||
Initialize the subscriber with a topic and a configuration object.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def receive(self) -> tuple[str, dict[str, Serializable]]:
|
||||
"""
|
||||
Blocking function to receive data from the implemented input stream.
|
||||
|
||||
Data is returned as a tuple of (topic, data).
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def __repr__(self) -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def start(self) -> None:
|
||||
"""
|
||||
Start any background processes and tasks.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def stop(self) -> None:
|
||||
"""
|
||||
Stop any background processes and tasks.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class AsyncSource(ABC):
|
||||
"""
|
||||
AsyncSubscriber interface
|
||||
"""
|
||||
|
||||
async def __aiter__(self) -> AsyncGenerator[tuple[str, dict[str, Serializable]], None]:
|
||||
while True:
|
||||
topic, data = await self.receive()
|
||||
yield topic, data
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, config: Any, topic: str | list[str]) -> None:
|
||||
"""
|
||||
Initialize the subscriber with a topic and a configuration object.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def receive(self) -> tuple[str, dict[str, Serializable]]:
|
||||
"""
|
||||
Blocking function to receive data from the implemented input stream.
|
||||
|
||||
Data is returned as a tuple of (topic, data).
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def __repr__(self) -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def start(self) -> None:
|
||||
"""
|
||||
Start any background processes and tasks.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def stop(self) -> None:
|
||||
"""
|
||||
Stop any background processes and tasks.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class AsyncSink(ABC):
|
||||
"""
|
||||
Sink interface to send() data to.
|
||||
"""
|
||||
|
||||
pack: Callable[[dict[str, Serializable]], str]
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, config: BaseConf) -> None:
|
||||
"""
|
||||
Initialize the publisher with a configuration object.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def send(self, data: dict[str, Any], topic: str) -> None:
|
||||
"""
|
||||
Send data via the implemented output stream.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def __repr__(self) -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def start(self) -> None:
|
||||
"""
|
||||
Start any background processes and tasks.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def stop(self) -> None:
|
||||
"""
|
||||
Stop any background processes and tasks.
|
||||
"""
|
||||
pass
|
||||
@@ -1,20 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from heisskleber.config import BaseConf
|
||||
|
||||
|
||||
@dataclass
|
||||
class InfluxDBConf(BaseConf):
|
||||
host: str = "localhost"
|
||||
port: int = 8086
|
||||
bucket: str = "test"
|
||||
org: str = "test"
|
||||
ssl: bool = False
|
||||
read_token: str = ""
|
||||
write_token: str = ""
|
||||
all_access_token: str = ""
|
||||
|
||||
@property
|
||||
def url(self) -> str:
|
||||
protocol = "https" if self.ssl else "http"
|
||||
return f"{protocol}://{self.host}:{self.port}"
|
||||
@@ -1,75 +0,0 @@
|
||||
import pandas as pd
|
||||
from influxdb_client import InfluxDBClient
|
||||
|
||||
from heisskleber.core.types import Source
|
||||
|
||||
from .config import InfluxDBConf
|
||||
|
||||
|
||||
def build_query(options: dict) -> str:
|
||||
query = (
|
||||
f'from(bucket:"{options["bucket"]}")'
|
||||
+ f'|> range(start: {options["start"].isoformat("T")}, stop: {options["end"].isoformat("T")})'
|
||||
+ f'|> filter(fn:(r) => r._measurement == "{options["measurement"]}")'
|
||||
)
|
||||
if options["filter"]:
|
||||
for attribute, value in options["filter"].items():
|
||||
if isinstance(value, list):
|
||||
query += f'|> filter(fn:(r) => r.{attribute} == "{value[0]}"'
|
||||
for vv in value[1:]:
|
||||
query += f' or r.{attribute} == "{vv}"'
|
||||
query += ")"
|
||||
else:
|
||||
query += f'|> filter(fn:(r) => r.{attribute} == "{value}")'
|
||||
|
||||
query += (
|
||||
f'|> aggregateWindow(every: {options["resample"]}, fn: mean)'
|
||||
+ '|> pivot(rowKey:["_time"], columnKey: ["_field"], valueColumn: "_value")'
|
||||
)
|
||||
|
||||
return query
|
||||
|
||||
|
||||
class Influx_Subscriber(Source):
|
||||
def __init__(self, config: InfluxDBConf, query: str):
|
||||
self.config = config
|
||||
self.query = query
|
||||
|
||||
self.client: InfluxDBClient = InfluxDBClient(
|
||||
url=self.config.url,
|
||||
token=self.config.all_access_token or self.config.read_token,
|
||||
org=self.config.org,
|
||||
timeout=60_000,
|
||||
)
|
||||
self.reader = self.client.query_api()
|
||||
|
||||
self._run_query()
|
||||
self.index = 0
|
||||
|
||||
def receive(self) -> tuple[str, dict]:
|
||||
row = self.df.iloc[self.index].to_dict()
|
||||
self.index += 1
|
||||
return "influx", row
|
||||
|
||||
def _run_query(self):
|
||||
self.df: pd.DataFrame = self.reader.query_data_frame(self.query, org=self.config.org)
|
||||
self.df["epoch"] = pd.to_numeric(self.df["_time"]) / 1e9
|
||||
self.df.drop(
|
||||
columns=[
|
||||
"result",
|
||||
"table",
|
||||
"_start",
|
||||
"_stop",
|
||||
"_measurement",
|
||||
"_time",
|
||||
"topic",
|
||||
],
|
||||
inplace=True,
|
||||
)
|
||||
|
||||
def __iter__(self):
|
||||
for _, row in self.df.iterrows():
|
||||
yield "influx", row.to_dict()
|
||||
|
||||
def __next__(self):
|
||||
return self.__iter__().__next__()
|
||||
@@ -1,48 +0,0 @@
|
||||
from influxdb_client import InfluxDBClient, WriteOptions
|
||||
|
||||
from config import InfluxDBConf
|
||||
from heisskleber.config import load_config
|
||||
|
||||
|
||||
class Influx_Writer:
|
||||
def __init__(self, config: InfluxDBConf):
|
||||
self.config = config
|
||||
# self.write_options = SYNCHRONOUS
|
||||
self.write_options = WriteOptions(
|
||||
batch_size=500,
|
||||
flush_interval=10_000,
|
||||
jitter_interval=2_000,
|
||||
retry_interval=5_000,
|
||||
max_retries=5,
|
||||
max_retry_delay=30_000,
|
||||
exponential_base=2,
|
||||
)
|
||||
self.client = InfluxDBClient(url=self.config.url, token=self.config.token, org=self.config.org)
|
||||
self.writer = self.client.write_api(
|
||||
write_options=self.write_options,
|
||||
)
|
||||
|
||||
def __del__(self):
|
||||
self.writer.close()
|
||||
self.client.close()
|
||||
|
||||
def write_line(self, line):
|
||||
self.writer.write(bucket=self.config.bucket, record=line)
|
||||
|
||||
def write_from_generator(self, generator):
|
||||
for line in generator:
|
||||
self.writer.write(bucket=self.config.bucket, record=line)
|
||||
|
||||
def write_from_line_generator(self, generator):
|
||||
with InfluxDBClient(
|
||||
url=self.config.url, token=self.config.token, org=self.config.org
|
||||
) as client, client.write_api(
|
||||
write_options=self.write_options,
|
||||
) as write_api:
|
||||
for line in generator:
|
||||
write_api.write(bucket=self.config.bucket, record=line)
|
||||
|
||||
|
||||
def get_parsed_flux_writer():
|
||||
config = load_config(InfluxDBConf(), "flux", read_commandline=False)
|
||||
return Influx_Writer(config)
|
||||
@@ -1,7 +0,0 @@
|
||||
from .config import MqttConf
|
||||
from .publisher import MqttPublisher
|
||||
from .publisher_async import AsyncMqttPublisher
|
||||
from .subscriber import MqttSubscriber
|
||||
from .subscriber_async import AsyncMqttSubscriber
|
||||
|
||||
__all__ = ["MqttConf", "MqttPublisher", "MqttSubscriber", "AsyncMqttSubscriber", "AsyncMqttPublisher"]
|
||||
@@ -1,24 +0,0 @@
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from heisskleber.config import BaseConf
|
||||
|
||||
|
||||
@dataclass
|
||||
class MqttConf(BaseConf):
|
||||
"""
|
||||
MQTT configuration class.
|
||||
"""
|
||||
|
||||
host: str = "localhost"
|
||||
user: str = ""
|
||||
password: str = ""
|
||||
port: int = 1883
|
||||
ssl: bool = False
|
||||
qos: int = 0
|
||||
retain: bool = False
|
||||
topics: list[str] = field(default_factory=list)
|
||||
mapping: str = "/deprecated/" # deprecated
|
||||
packstyle: str = "json"
|
||||
max_saved_messages: int = 100
|
||||
timeout_s: int = 60
|
||||
source_id: str = "box-01"
|
||||
@@ -1,23 +0,0 @@
|
||||
from heisskleber import get_publisher, get_subscriber
|
||||
from heisskleber.config import load_config
|
||||
|
||||
from .config import MqttConf
|
||||
|
||||
|
||||
def map_topic(zmq_topic: str, mapping: str) -> str:
|
||||
return mapping + zmq_topic
|
||||
|
||||
|
||||
def main() -> None:
|
||||
config: MqttConf = load_config(MqttConf(), "mqtt")
|
||||
sub = get_subscriber("zmq", config.topics)
|
||||
pub = get_publisher("mqtt")
|
||||
|
||||
pub.pack = lambda x: x # type: ignore
|
||||
sub.unpack = lambda x: x # type: ignore
|
||||
|
||||
while True:
|
||||
(zmq_topic, data) = sub.receive()
|
||||
mqtt_topic = map_topic(zmq_topic, config.mapping)
|
||||
|
||||
pub.send(data, mqtt_topic)
|
||||
@@ -1,97 +0,0 @@
|
||||
import ssl
|
||||
import sys
|
||||
import threading
|
||||
|
||||
from paho.mqtt.client import Client as mqtt_client
|
||||
|
||||
from .config import MqttConf
|
||||
|
||||
|
||||
class ThreadDiedError(RuntimeError):
|
||||
pass
|
||||
|
||||
|
||||
_thread_died = threading.Event()
|
||||
|
||||
_default_excepthook = threading.excepthook
|
||||
|
||||
|
||||
def _set_thread_died_excepthook(args, /):
|
||||
_default_excepthook(args)
|
||||
global _thread_died
|
||||
_thread_died.set()
|
||||
|
||||
|
||||
threading.excepthook = _set_thread_died_excepthook
|
||||
|
||||
|
||||
class MqttBase:
|
||||
"""
|
||||
Wrapper around eclipse paho mqtt client.
|
||||
Handles connection and callbacks.
|
||||
Callbacks may be overwritten in subclasses.
|
||||
"""
|
||||
|
||||
def __init__(self, config: MqttConf) -> None:
|
||||
self.config = config
|
||||
self.client = mqtt_client()
|
||||
self.is_connected = False
|
||||
|
||||
def start(self) -> None:
|
||||
if not self.is_connected:
|
||||
self.connect()
|
||||
|
||||
def stop(self) -> None:
|
||||
if self.client:
|
||||
self.client.loop_stop()
|
||||
self.is_connected = False
|
||||
|
||||
def connect(self) -> None:
|
||||
self.client.username_pw_set(self.config.user, self.config.password)
|
||||
|
||||
# Add callbacks
|
||||
self.client.on_connect = self._on_connect
|
||||
self.client.on_disconnect = self._on_disconnect
|
||||
self.client.on_publish = self._on_publish
|
||||
self.client.on_message = self._on_message
|
||||
|
||||
if self.config.ssl:
|
||||
# By default, on Python 2.7.9+ or 3.4+,
|
||||
# the default certification authority of the system is used.
|
||||
self.client.tls_set(tls_version=ssl.PROTOCOL_TLS_CLIENT)
|
||||
|
||||
self.client.connect(self.config.host, self.config.port)
|
||||
self.client.loop_start()
|
||||
self.is_connected = True
|
||||
|
||||
@staticmethod
|
||||
def _raise_if_thread_died() -> None:
|
||||
global _thread_died
|
||||
if _thread_died.is_set():
|
||||
raise ThreadDiedError()
|
||||
|
||||
# MQTT callbacks
|
||||
def _on_connect(self, client, userdata, flags, return_code) -> None:
|
||||
if return_code == 0:
|
||||
print(f"MQTT node connected to {self.config.host}:{self.config.port}")
|
||||
else:
|
||||
print("Connection failed!")
|
||||
if self.config.verbose:
|
||||
print(flags)
|
||||
|
||||
def _on_disconnect(self, client, userdata, return_code) -> None:
|
||||
print(f"Disconnected from broker with return code {return_code}")
|
||||
if return_code != 0:
|
||||
print("Killing this service")
|
||||
sys.exit(-1)
|
||||
|
||||
def _on_publish(self, client, userdata, message_id) -> None:
|
||||
if self.config.verbose:
|
||||
print(f"Published message with id {message_id}, qos={self.config.qos}")
|
||||
|
||||
def _on_message(self, client, userdata, message) -> None:
|
||||
if self.config.verbose:
|
||||
print(f"Received message: {message.payload!s}, topic: {message.topic}, qos: {message.qos}")
|
||||
|
||||
def __del__(self) -> None:
|
||||
self.stop()
|
||||
@@ -1,44 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from heisskleber.core.packer import get_packer
|
||||
from heisskleber.core.types import Serializable, Sink
|
||||
|
||||
from .config import MqttConf
|
||||
from .mqtt_base import MqttBase
|
||||
|
||||
|
||||
class MqttPublisher(MqttBase, Sink):
|
||||
"""
|
||||
MQTT publisher class.
|
||||
Can be used everywhere that a flucto style publishing connection is required.
|
||||
|
||||
Network message loop is handled in a separated thread.
|
||||
"""
|
||||
|
||||
def __init__(self, config: MqttConf) -> None:
|
||||
super().__init__(config)
|
||||
self.pack = get_packer(config.packstyle)
|
||||
|
||||
def send(self, data: dict[str, Serializable], topic: str) -> None:
|
||||
"""
|
||||
Takes python dictionary, serializes it according to the packstyle
|
||||
and sends it to the broker.
|
||||
|
||||
Publishing is asynchronous
|
||||
"""
|
||||
if not self.is_connected:
|
||||
self.start()
|
||||
|
||||
self._raise_if_thread_died()
|
||||
|
||||
payload = self.pack(data)
|
||||
self.client.publish(topic, payload, qos=self.config.qos, retain=self.config.retain)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}(host={self.config.host}, port={self.config.port})"
|
||||
|
||||
def start(self) -> None:
|
||||
super().start()
|
||||
|
||||
def stop(self) -> None:
|
||||
super().stop()
|
||||
@@ -1,70 +0,0 @@
|
||||
from asyncio import Queue, Task, create_task, sleep
|
||||
|
||||
import aiomqtt
|
||||
|
||||
from heisskleber.core.packer import get_packer
|
||||
from heisskleber.core.types import AsyncSink, Serializable
|
||||
|
||||
from .config import MqttConf
|
||||
|
||||
|
||||
class AsyncMqttPublisher(AsyncSink):
|
||||
"""
|
||||
MQTT publisher class.
|
||||
Can be used everywhere that a flucto style publishing connection is required.
|
||||
|
||||
Network message loop is handled in a separated thread.
|
||||
"""
|
||||
|
||||
def __init__(self, config: MqttConf) -> None:
|
||||
self.config = config
|
||||
self.pack = get_packer(config.packstyle)
|
||||
self._send_queue: Queue[tuple[dict[str, Serializable], str]] = Queue()
|
||||
self._sender_task: Task[None] | None = None
|
||||
|
||||
async def send(self, data: dict[str, Serializable], topic: str) -> None:
|
||||
"""
|
||||
Takes python dictionary, serializes it according to the packstyle
|
||||
and sends it to the broker.
|
||||
|
||||
Publishing is asynchronous
|
||||
"""
|
||||
if not self._sender_task:
|
||||
self.start()
|
||||
|
||||
await self._send_queue.put((data, topic))
|
||||
|
||||
async def send_work(self) -> None:
|
||||
"""
|
||||
Takes python dictionary, serializes it according to the packstyle
|
||||
and sends it to the broker.
|
||||
|
||||
Publishing is asynchronous
|
||||
"""
|
||||
while True:
|
||||
try:
|
||||
async with aiomqtt.Client(
|
||||
hostname=self.config.host,
|
||||
port=self.config.port,
|
||||
username=self.config.user,
|
||||
password=self.config.password,
|
||||
timeout=float(self.config.timeout_s),
|
||||
) as client:
|
||||
while True:
|
||||
data, topic = await self._send_queue.get()
|
||||
payload = self.pack(data)
|
||||
await client.publish(topic, payload)
|
||||
except aiomqtt.MqttError:
|
||||
print("Connection to MQTT broker failed. Retrying in 5 seconds")
|
||||
await sleep(5)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}(broker={self.config.host}, port={self.config.port})"
|
||||
|
||||
def start(self) -> None:
|
||||
self._sender_task = create_task(self.send_work())
|
||||
|
||||
def stop(self) -> None:
|
||||
if self._sender_task:
|
||||
self._sender_task.cancel()
|
||||
self._sender_task = None
|
||||
@@ -1,81 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from queue import SimpleQueue
|
||||
from typing import Any
|
||||
|
||||
from paho.mqtt.client import MQTTMessage
|
||||
|
||||
from heisskleber.core.packer import get_unpacker
|
||||
from heisskleber.core.types import Source
|
||||
|
||||
from .config import MqttConf
|
||||
from .mqtt_base import MqttBase
|
||||
|
||||
|
||||
class MqttSubscriber(MqttBase, Source):
|
||||
"""
|
||||
MQTT subscriber, wraps around ecplipse's paho mqtt client.
|
||||
Network message loop is handled in a separated thread.
|
||||
|
||||
Incoming messages are saved as a stack when not processed via the receive() function.
|
||||
"""
|
||||
|
||||
def __init__(self, config: MqttConf, topics: str | list[str]) -> None:
|
||||
super().__init__(config)
|
||||
self.topics = topics
|
||||
self._message_queue: SimpleQueue[MQTTMessage] = SimpleQueue()
|
||||
self.unpack = get_unpacker(config.packstyle)
|
||||
|
||||
def subscribe(self, topics: str | list[str] | tuple[str]) -> None:
|
||||
"""
|
||||
Subscribe to one or multiple topics
|
||||
"""
|
||||
if not self.is_connected:
|
||||
super().start()
|
||||
self.client.on_message = self._on_message
|
||||
|
||||
if isinstance(topics, (list, tuple)):
|
||||
# if subscribing to multiple topics, use a list of tuples
|
||||
subscription_list = [(topic, self.config.qos) for topic in topics]
|
||||
self.client.subscribe(subscription_list)
|
||||
else:
|
||||
self.client.subscribe(topics, self.config.qos)
|
||||
if self.config.verbose:
|
||||
print(f"Subscribed to: {topics}")
|
||||
|
||||
def receive(self) -> tuple[str, dict[str, Any]]:
|
||||
"""
|
||||
Reads a message from mqtt and returns it
|
||||
|
||||
Messages are saved in a stack, if no message is available, this function blocks.
|
||||
|
||||
Returns:
|
||||
tuple(topic: str, message: dict): the message received
|
||||
"""
|
||||
if not self.client:
|
||||
self.start()
|
||||
|
||||
self._raise_if_thread_died()
|
||||
mqtt_message = self._message_queue.get(block=True, timeout=self.config.timeout_s)
|
||||
|
||||
message_returned = self.unpack(mqtt_message.payload.decode())
|
||||
return (mqtt_message.topic, message_returned)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}(host={self.config.host}, port={self.config.port})"
|
||||
|
||||
def start(self) -> None:
|
||||
super().start()
|
||||
self.subscribe(self.topics)
|
||||
self.client.on_message = self._on_message
|
||||
|
||||
def stop(self) -> None:
|
||||
super().stop()
|
||||
|
||||
# callback to add incoming messages onto stack
|
||||
def _on_message(self, client, userdata, message) -> None:
|
||||
self._message_queue.put(message)
|
||||
|
||||
if self.config.verbose:
|
||||
print(f"Topic: {message.topic}")
|
||||
print(f"MQTT message: {message.payload.decode()}")
|
||||
@@ -1,86 +0,0 @@
|
||||
from asyncio import Queue, Task, create_task, sleep
|
||||
|
||||
from aiomqtt import Client, Message, MqttError
|
||||
|
||||
from heisskleber.core.packer import get_unpacker
|
||||
from heisskleber.core.types import AsyncSource, Serializable
|
||||
from heisskleber.mqtt import MqttConf
|
||||
|
||||
|
||||
class AsyncMqttSubscriber(AsyncSource):
|
||||
"""Asynchronous MQTT susbsciber based on aiomqtt.
|
||||
|
||||
Data is received by the `receive` method returns the newest message in the queue.
|
||||
"""
|
||||
|
||||
def __init__(self, config: MqttConf, topic: str | list[str]) -> None:
|
||||
self.config: MqttConf = config
|
||||
self.client = Client(
|
||||
hostname=self.config.host,
|
||||
port=self.config.port,
|
||||
username=self.config.user,
|
||||
password=self.config.password,
|
||||
)
|
||||
self.topics = topic
|
||||
self.unpack = get_unpacker(self.config.packstyle)
|
||||
self.message_queue: Queue[Message] = Queue(self.config.max_saved_messages)
|
||||
self._listener_task: Task[None] | None = None
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}(broker={self.config.host}, port={self.config.port})"
|
||||
|
||||
def start(self) -> None:
|
||||
self._listener_task = create_task(self.run())
|
||||
|
||||
def stop(self) -> None:
|
||||
if self._listener_task:
|
||||
self._listener_task.cancel()
|
||||
self._listener_task = None
|
||||
|
||||
async def receive(self) -> tuple[str, dict[str, Serializable]]:
|
||||
"""
|
||||
Await the newest message in the queue and return Tuple
|
||||
"""
|
||||
if not self._listener_task:
|
||||
self.start()
|
||||
mqtt_message = await self.message_queue.get()
|
||||
return self._handle_message(mqtt_message)
|
||||
|
||||
async def run(self):
|
||||
"""
|
||||
Handle the connection to MQTT broker and run the message loop.
|
||||
"""
|
||||
while True:
|
||||
try:
|
||||
async with self.client:
|
||||
await self._subscribe_topics()
|
||||
await self._listen_mqtt_loop()
|
||||
except MqttError as e:
|
||||
print(f"MqttError: {e}")
|
||||
print("Connection to MQTT failed. Retrying...")
|
||||
await sleep(1)
|
||||
|
||||
async def _listen_mqtt_loop(self) -> None:
|
||||
"""
|
||||
Listen to incoming messages asynchronously and put them into a queue
|
||||
"""
|
||||
async with self.client.messages() as messages:
|
||||
# async with self.client.filtered_messages(self.topics) as messages:
|
||||
async for message in messages:
|
||||
await self.message_queue.put(message)
|
||||
|
||||
def _handle_message(self, message: Message) -> tuple[str, dict[str, Serializable]]:
|
||||
if not isinstance(message.payload, bytes):
|
||||
error_msg = "Payload is not of type bytes."
|
||||
raise TypeError(error_msg)
|
||||
|
||||
topic = str(message.topic)
|
||||
message_returned = self.unpack(message.payload.decode())
|
||||
return (topic, message_returned)
|
||||
|
||||
async def _subscribe_topics(self) -> None:
|
||||
print(f"subscribing to {self.topics}")
|
||||
if isinstance(self.topics, list):
|
||||
await self.client.subscribe([(topic, self.config.qos) for topic in self.topics])
|
||||
else:
|
||||
await self.client.subscribe(self.topics, self.config.qos)
|
||||
@@ -1,95 +0,0 @@
|
||||
import argparse
|
||||
import sys
|
||||
from typing import Callable, Union
|
||||
|
||||
from heisskleber.config import load_config
|
||||
from heisskleber.console.sink import ConsoleSink
|
||||
from heisskleber.core.factories import _registered_sources
|
||||
from heisskleber.mqtt import MqttSubscriber
|
||||
from heisskleber.udp import UdpSubscriber
|
||||
from heisskleber.zmq import ZmqSubscriber
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="hkcli",
|
||||
description="Heisskleber command line interface",
|
||||
usage="%(prog)s [options]",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-t",
|
||||
"--type",
|
||||
type=str,
|
||||
choices=["zmq", "mqtt", "serial", "udp"],
|
||||
default="zmq",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-T",
|
||||
"--topic",
|
||||
type=str,
|
||||
default="#",
|
||||
help="Topic to subscribe to, valid for zmq and mqtt only.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-H",
|
||||
"--host",
|
||||
type=str,
|
||||
help="Host or broker for MQTT, zmq and UDP.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-P",
|
||||
"--port",
|
||||
type=int,
|
||||
help="Port or serial interface for MQTT, zmq and UDP.",
|
||||
)
|
||||
parser.add_argument("-v", "--verbose", action="store_true")
|
||||
parser.add_argument("-p", "--pretty", action="store_true", help="Pretty print JSON data.")
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def keyboardexit(func) -> Callable:
|
||||
def wrapper(*args, **kwargs) -> Union[None, int]:
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except KeyboardInterrupt:
|
||||
print("Exiting...")
|
||||
sys.exit(0)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@keyboardexit
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
sink = ConsoleSink(pretty=args.pretty, verbose=args.verbose)
|
||||
|
||||
sub_cls, conf_cls = _registered_sources[args.type]
|
||||
|
||||
try:
|
||||
config = load_config(conf_cls(), args.type, read_commandline=False)
|
||||
except FileNotFoundError:
|
||||
print(f"No config file found for {args.type}, using default values and user input.")
|
||||
config = conf_cls()
|
||||
|
||||
source = sub_cls(config, args.topic)
|
||||
if isinstance(source, (MqttSubscriber, UdpSubscriber)):
|
||||
source.config.host = args.host or source.config.host
|
||||
source.config.port = args.port or source.config.port
|
||||
elif isinstance(source, ZmqSubscriber):
|
||||
source.config.host = args.host or source.config.host
|
||||
source.config.subscriber_port = args.port or source.config.subscriber_port
|
||||
source.topic = "" if args.topic == "#" else args.topic
|
||||
elif isinstance(source, UdpSubscriber):
|
||||
source.config.port = args.port or source.config.port
|
||||
|
||||
source.start()
|
||||
sink.start()
|
||||
|
||||
while True:
|
||||
topic, data = source.receive()
|
||||
sink.send(data, topic)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,12 +0,0 @@
|
||||
from heisskleber.broker import start_zmq_broker
|
||||
from heisskleber.config import load_config
|
||||
from heisskleber.zmq.config import ZmqConf as BrokerConf
|
||||
|
||||
|
||||
def main():
|
||||
broker_config = load_config(BrokerConf(), "zmq")
|
||||
start_zmq_broker(config=broker_config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,5 +0,0 @@
|
||||
from .config import SerialConf
|
||||
from .publisher import SerialPublisher
|
||||
from .subscriber import SerialSubscriber
|
||||
|
||||
__all__ = ["SerialConf", "SerialPublisher", "SerialSubscriber"]
|
||||
@@ -1,11 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from heisskleber.config import BaseConf
|
||||
|
||||
|
||||
@dataclass
|
||||
class SerialConf(BaseConf):
|
||||
port: str = "/dev/serial0"
|
||||
baudrate: int = 9600
|
||||
bytesize: int = 8
|
||||
encoding: str = "ascii"
|
||||
@@ -1,31 +0,0 @@
|
||||
from heisskleber.core.types import Source
|
||||
|
||||
from .publisher import SerialPublisher
|
||||
|
||||
|
||||
class SerialForwarder:
|
||||
def __init__(self, subscriber: Source, publisher: SerialPublisher) -> None:
|
||||
self.sub = subscriber
|
||||
self.pub = publisher
|
||||
|
||||
"""
|
||||
Wait for message and forward
|
||||
"""
|
||||
|
||||
def forward_message(self) -> None:
|
||||
# collected = {}
|
||||
# for sub in self.sub:
|
||||
# topic, data = sub.receive()
|
||||
# collected.update(data)
|
||||
topic, data = self.sub.receive()
|
||||
|
||||
# We send the topic and let the publisher decide what to do with it
|
||||
self.pub.send(data, topic)
|
||||
|
||||
"""
|
||||
Enter loop and continuously forward messages
|
||||
"""
|
||||
|
||||
def sub_pub_loop(self) -> None:
|
||||
while True:
|
||||
self.forward_message()
|
||||
@@ -1,87 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from typing import Callable, Optional
|
||||
|
||||
import serial
|
||||
|
||||
from heisskleber.core.packer import get_packer
|
||||
from heisskleber.core.types import Serializable, Sink
|
||||
|
||||
from .config import SerialConf
|
||||
|
||||
|
||||
class SerialPublisher(Sink):
|
||||
serial_connection: serial.Serial
|
||||
"""
|
||||
Publisher for serial devices.
|
||||
Can be used everywhere that a flucto style publishing connection is required.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config : SerialConf
|
||||
Configuration for the serial connection.
|
||||
pack_func : FunctionType
|
||||
Function to translate from a dict to a serialized string.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: SerialConf,
|
||||
pack_func: Optional[Callable] = None, # noqa: UP007
|
||||
):
|
||||
self.config = config
|
||||
self.pack = pack_func if pack_func else get_packer("serial")
|
||||
self.is_connected = False
|
||||
|
||||
def start(self) -> None:
|
||||
"""
|
||||
Start the serial connection.
|
||||
"""
|
||||
try:
|
||||
self.serial_connection = serial.Serial(
|
||||
port=self.config.port,
|
||||
baudrate=self.config.baudrate,
|
||||
bytesize=self.config.bytesize,
|
||||
parity=serial.PARITY_NONE,
|
||||
stopbits=serial.STOPBITS_ONE,
|
||||
)
|
||||
except serial.SerialException:
|
||||
print(f"Failed to connect to serial device at port {self.config.port}")
|
||||
sys.exit(1)
|
||||
|
||||
print(f"Successfully connected to serial device at port {self.config.port}")
|
||||
self.is_connected = True
|
||||
|
||||
def stop(self) -> None:
|
||||
"""
|
||||
Stop the serial connection.
|
||||
"""
|
||||
if hasattr(self, "serial_connection") and self.serial_connection.is_open:
|
||||
self.serial_connection.flush()
|
||||
self.serial_connection.close()
|
||||
|
||||
def send(self, data: dict[str, Serializable], topic: str) -> None:
|
||||
"""
|
||||
Takes python dictionary, serializes it according to the packstyle
|
||||
and sends it to the broker.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
message : dict
|
||||
object to be serialized and sent via the serial connection. Usually a dict.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
self.start()
|
||||
|
||||
payload = self.pack(data)
|
||||
self.serial_connection.write(payload.encode(self.config.encoding))
|
||||
self.serial_connection.flush()
|
||||
if self.config.verbose:
|
||||
print(f"{topic}: {payload}")
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"SerialPublisher(port={self.config.port}, baudrate={self.config.baudrate}, bytezize={self.config.bytesize}, encoding={self.config.encoding})"
|
||||
|
||||
def __del__(self) -> None:
|
||||
self.stop()
|
||||
@@ -1,110 +0,0 @@
|
||||
import sys
|
||||
from collections.abc import Generator
|
||||
from typing import Callable
|
||||
|
||||
import serial
|
||||
|
||||
from heisskleber.core.types import Source
|
||||
|
||||
from .config import SerialConf
|
||||
|
||||
|
||||
class SerialSubscriber(Source):
|
||||
serial_connection: serial.Serial
|
||||
"""
|
||||
Subscriber for serial devices. Connects to a serial port and reads from it.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
topics :
|
||||
Placeholder for topic. Not used.
|
||||
|
||||
config : SerialConf
|
||||
Configuration class for the serial connection.
|
||||
|
||||
unpack_func : FunctionType
|
||||
Function to translate from a serialized string to a dict.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: SerialConf,
|
||||
topic: str | None = None,
|
||||
custom_unpack: Callable | None = None,
|
||||
):
|
||||
self.config = config
|
||||
self.topic = topic
|
||||
self.unpack = custom_unpack if custom_unpack else lambda x: x # types: ignore
|
||||
self.is_connected = False
|
||||
|
||||
def start(self) -> None:
|
||||
"""
|
||||
Start the serial connection.
|
||||
"""
|
||||
try:
|
||||
self.serial_connection = serial.Serial(
|
||||
port=self.config.port,
|
||||
baudrate=self.config.baudrate,
|
||||
bytesize=self.config.bytesize,
|
||||
parity=serial.PARITY_NONE,
|
||||
stopbits=serial.STOPBITS_ONE,
|
||||
)
|
||||
except serial.SerialException:
|
||||
print(f"Failed to connect to serial device at port {self.config.port}")
|
||||
sys.exit(1)
|
||||
print(f"Successfully connected to serial device at port {self.config.port}")
|
||||
self.is_connected = True
|
||||
|
||||
def stop(self) -> None:
|
||||
"""
|
||||
Stop the serial connection.
|
||||
"""
|
||||
if hasattr(self, "serial_connection") and self.serial_connection.is_open:
|
||||
self.serial_connection.flush()
|
||||
self.serial_connection.close()
|
||||
|
||||
def receive(self) -> tuple[str, dict]:
|
||||
"""
|
||||
Wait for data to arrive on the serial port and return it.
|
||||
|
||||
Returns
|
||||
-------
|
||||
:return: (topic, payload)
|
||||
topic is a placeholder to adhere to the Subscriber interface
|
||||
payload is a dictionary containing the data from the serial port
|
||||
"""
|
||||
if not self.is_connected:
|
||||
self.start()
|
||||
|
||||
# message is a string
|
||||
message = next(self.read_serial_port())
|
||||
# payload is a dictionary
|
||||
payload = self.unpack(message)
|
||||
# port is a placeholder for topic
|
||||
return self.config.port, payload
|
||||
|
||||
def read_serial_port(self) -> Generator[str, None, None]:
|
||||
"""
|
||||
Generator function reading from the serial port.
|
||||
|
||||
Returns
|
||||
-------
|
||||
:return: Generator[str, None, None]
|
||||
Generator yielding strings read from the serial port
|
||||
"""
|
||||
buffer = ""
|
||||
while True:
|
||||
try:
|
||||
buffer = self.serial_connection.readline().decode(self.config.encoding, "ignore")
|
||||
yield buffer
|
||||
except UnicodeError as e:
|
||||
if self.config.verbose:
|
||||
print(f"Could not decode: {buffer!r}")
|
||||
print(e)
|
||||
continue
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"SerialPublisher(port={self.config.port}, baudrate={self.config.baudrate}, bytezize={self.config.bytesize}, encoding={self.config.encoding})"
|
||||
|
||||
def __del__(self) -> None:
|
||||
self.stop()
|
||||
@@ -1,5 +0,0 @@
|
||||
from .config import ResamplerConf
|
||||
from .joint import Joint
|
||||
from .resampler import Resampler
|
||||
|
||||
__all__ = ["Resampler", "ResamplerConf", "Joint"]
|
||||
@@ -1,77 +0,0 @@
|
||||
from collections import deque
|
||||
|
||||
import numpy as np
|
||||
import scipy.signal # type: ignore [import-untyped]
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from heisskleber.core.types import AsyncSource, Serializable
|
||||
from heisskleber.stream.filter import Filter
|
||||
|
||||
|
||||
class LiveLFilter:
|
||||
"""
|
||||
Filter using standard difference equations.
|
||||
Kudos to Sam Proell https://www.samproell.io/posts/yarppg/yarppg-live-digital-filter/
|
||||
"""
|
||||
|
||||
def __init__(self, b: NDArray[np.float64], a: NDArray[np.float64], init_val: float = 0.0) -> None:
|
||||
"""Initialize live filter based on difference equation.
|
||||
|
||||
Args:
|
||||
b (array-like): numerator coefficients obtained from scipy.
|
||||
a (array-like): denominator coefficients obtained from scipy.
|
||||
"""
|
||||
self.b = b
|
||||
self.a = a
|
||||
self._xs = deque([init_val] * len(b), maxlen=len(b))
|
||||
self._ys = deque([init_val] * (len(a) - 1), maxlen=len(a) - 1)
|
||||
|
||||
def __call__(self, x: float) -> float:
|
||||
"""Filter incoming data with standard difference equations."""
|
||||
self._xs.appendleft(x)
|
||||
y = np.dot(self.b, self._xs) - np.dot(self.a[1:], self._ys)
|
||||
y = y / self.a[0]
|
||||
self._ys.appendleft(y)
|
||||
|
||||
return y # type: ignore [no-any-return]
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}(b={self.b}, a={self.a})"
|
||||
|
||||
|
||||
class ButterFilter(Filter):
|
||||
"""
|
||||
Butterworth filter based on scipy.
|
||||
|
||||
Args:
|
||||
source (AsyncSource): Source of data.
|
||||
cutoff_freq (float): Cutoff frequency.
|
||||
sampling_rate (float): Sampling rate of the input signal, i.e update frequency
|
||||
btype (str): Type of filter, "high" or "low"
|
||||
order (int): order of the filter to be applied
|
||||
|
||||
Example:
|
||||
>>> source = get_async_source()
|
||||
>>> filter = LowPassFilter(source, 0.1, 100, 3)
|
||||
>>> async for topic, data in filter:
|
||||
>>> print(topic, data)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, source: AsyncSource, cutoff_freq: float, sampling_rate: float, btype: str = "low", order: int = 3
|
||||
) -> None:
|
||||
self.source = source
|
||||
nyquist_fq = sampling_rate / 2
|
||||
Wn = cutoff_freq / nyquist_fq
|
||||
self.b, self.a = scipy.signal.iirfilter(order, Wn=Wn, fs=sampling_rate, btype=btype, ftype="butter")
|
||||
self.filters: dict[str, LiveLFilter] = {}
|
||||
|
||||
def _filter(self, data: dict[str, Serializable]) -> dict[str, Serializable]:
|
||||
if not self.filters:
|
||||
for key in data:
|
||||
self.filters[key] = LiveLFilter(a=self.a, b=self.b)
|
||||
|
||||
for key, value in data.items():
|
||||
data[key] = self.filters[key](value)
|
||||
|
||||
return data
|
||||
@@ -1,6 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class ResamplerConf:
|
||||
resample_rate: int = 1000
|
||||
@@ -1,19 +0,0 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
|
||||
from heisskleber.core.types import AsyncSource, Serializable
|
||||
|
||||
|
||||
class Filter(ABC):
|
||||
def __init__(self, source: AsyncSource):
|
||||
self.source = source
|
||||
|
||||
async def __aiter__(self) -> AsyncGenerator[Any, None]:
|
||||
async for topic, data in self.source:
|
||||
data = self._filter(data)
|
||||
yield topic, data
|
||||
|
||||
@abstractmethod
|
||||
def _filter(self, data: dict[str, Serializable]) -> dict[str, Serializable]:
|
||||
pass
|
||||
@@ -1,56 +0,0 @@
|
||||
from heisskleber.core.types import AsyncSource
|
||||
from heisskleber.stream.filter import Filter
|
||||
|
||||
|
||||
class GhFilter(Filter):
|
||||
"""
|
||||
G-H filter (also called alpha-beta, f-g filter), simplified observer for estimation and data smoothing.
|
||||
|
||||
Args:
|
||||
source (AsyncSource): Source of data.
|
||||
g (float): Correction gain for value
|
||||
h (float): Correction gain for derivative
|
||||
|
||||
Example:
|
||||
>>> source = get_async_source()
|
||||
>>> filter = GhFilter(source, 0.008, 0.001)
|
||||
>>> async for topic, data in filter:
|
||||
>>> print(topic, data)
|
||||
"""
|
||||
|
||||
def __init__(self, source: AsyncSource, g: float, h: float):
|
||||
self.source = source
|
||||
if not 0 < g < 1.0 or not 0 < h < 1.0:
|
||||
msg = "g and h must be between 0 and 1.0"
|
||||
raise ValueError(msg)
|
||||
self.g = g
|
||||
self.h = h
|
||||
self.x: dict[str, float] = {}
|
||||
self.dx: dict[str, float] = {}
|
||||
|
||||
def _filter(self, data: dict[str, float]) -> dict[str, float]:
|
||||
if not self.x:
|
||||
self.x = data
|
||||
self.dx = {key: 0.0 for key in data}
|
||||
return data
|
||||
|
||||
invalid_keys = []
|
||||
ts = data.pop("epoch")
|
||||
dt = ts - self.x["epoch"]
|
||||
if abs(dt) <= 1e-4:
|
||||
data["epoch"] = ts
|
||||
return data
|
||||
|
||||
for key in data:
|
||||
if not isinstance(data[key], float):
|
||||
invalid_keys.append(key)
|
||||
continue
|
||||
x_pred = self.x[key] + dt * self.dx[key]
|
||||
residual = data[key] - x_pred
|
||||
self.dx[key] = self.dx[key] + self.h * residual / dt
|
||||
self.x[key] = x_pred + self.g * residual
|
||||
|
||||
for key in invalid_keys:
|
||||
self.x[key] = data[key]
|
||||
self.x["epoch"] = ts
|
||||
return self.x
|
||||
@@ -1,114 +0,0 @@
|
||||
import asyncio
|
||||
from typing import Any
|
||||
|
||||
from heisskleber.core.types import Serializable
|
||||
from heisskleber.stream.resampler import Resampler, ResamplerConf
|
||||
|
||||
|
||||
class Joint:
|
||||
"""Joint that takes multiple async streams and synchronizes them based on their timestamps.
|
||||
|
||||
Note that you need to run the setup() function first to initialize the
|
||||
|
||||
Parameters:
|
||||
----------
|
||||
conf : ResamplerConf
|
||||
Configuration for the joint.
|
||||
subscribers : list[AsyncSubscriber]
|
||||
List of asynchronous subscribers.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, conf: ResamplerConf, resamplers: list[Resampler]):
|
||||
self.conf = conf
|
||||
self.resamplers = resamplers
|
||||
self.output_queue: asyncio.Queue[dict[str, Serializable]] = asyncio.Queue()
|
||||
self.initialized = asyncio.Event()
|
||||
self.initalize_task = asyncio.create_task(self.sync())
|
||||
self.combined_dict: dict[str, Serializable] = {}
|
||||
self.task: asyncio.Task[None] | None = None
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"""Joint(resample_rate={self.conf.resample_rate},
|
||||
sources={len(self.resamplers)} of type(s): {{r.__class__.__name__ for r in self.resamplers}})"""
|
||||
|
||||
def start(self) -> None:
|
||||
self.task = asyncio.create_task(self.output_work())
|
||||
|
||||
def stop(self) -> None:
|
||||
if self.task:
|
||||
self.task.cancel()
|
||||
|
||||
async def receive(self) -> dict[str, Any]:
|
||||
"""
|
||||
Main interaction coroutine: Get next value out of the queue.
|
||||
"""
|
||||
if not self.task:
|
||||
self.start()
|
||||
output = await self.output_queue.get()
|
||||
return output
|
||||
|
||||
async def sync(self) -> None:
|
||||
"""Synchronize the resamplers by pulling data from each until the timestamp is aligned. Retains first matching data."""
|
||||
print("Starting sync")
|
||||
datas = await asyncio.gather(*[source.receive() for source in self.resamplers])
|
||||
print("Got data")
|
||||
output_data = {}
|
||||
data = {}
|
||||
|
||||
latest_timestamp: float = 0.0
|
||||
timestamps = []
|
||||
|
||||
print("Syncing...")
|
||||
for data in datas:
|
||||
if not isinstance(data["epoch"], float):
|
||||
error = "Timestamps must be floats"
|
||||
raise TypeError(error)
|
||||
|
||||
ts = float(data["epoch"])
|
||||
|
||||
print(f"Syncing..., got {ts}")
|
||||
|
||||
timestamps.append(ts)
|
||||
if ts > latest_timestamp:
|
||||
latest_timestamp = ts
|
||||
|
||||
# only take the piece of the latest data
|
||||
output_data = data
|
||||
|
||||
for resampler, ts in zip(self.resamplers, timestamps):
|
||||
while ts < latest_timestamp:
|
||||
data = await resampler.receive()
|
||||
ts = float(data["epoch"])
|
||||
|
||||
output_data.update(data)
|
||||
|
||||
await self.output_queue.put(output_data)
|
||||
|
||||
print("Finished initalization")
|
||||
self.initialized.set()
|
||||
|
||||
"""
|
||||
Coroutine that waits for new queue data and updates dict.
|
||||
"""
|
||||
|
||||
async def update_dict(self, resampler: Resampler) -> None:
|
||||
data = await resampler.receive()
|
||||
if self.combined_dict and self.combined_dict["epoch"] != data["epoch"]:
|
||||
print("Oh shit, this is bad!")
|
||||
self.combined_dict.update(data)
|
||||
|
||||
"""
|
||||
Output worker: iterate through queues, read data and join into output queue.
|
||||
"""
|
||||
|
||||
async def output_work(self) -> None:
|
||||
print("Output worker waiting for intitialization")
|
||||
await self.initialized.wait()
|
||||
print("Output worker resuming")
|
||||
|
||||
while True:
|
||||
self.combined_dict = {}
|
||||
tasks = [asyncio.create_task(self.update_dict(res)) for res in self.resamplers]
|
||||
await asyncio.gather(*tasks)
|
||||
await self.output_queue.put(self.combined_dict)
|
||||
@@ -1,195 +0,0 @@
|
||||
import math
|
||||
from asyncio import Queue, Task, create_task
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import numpy as np
|
||||
|
||||
from heisskleber.core.types import AsyncSource, Serializable
|
||||
|
||||
from .config import ResamplerConf
|
||||
|
||||
|
||||
def floor_dt(dt: datetime, delta: timedelta) -> datetime:
|
||||
"""Round a datetime object based on a delta timedelta."""
|
||||
return datetime.min + math.floor((dt - datetime.min) / delta) * delta
|
||||
|
||||
|
||||
def timestamp_generator(start_epoch: float, timedelta_in_ms: int) -> Generator[float, None, None]:
|
||||
"""Generate increasing timestamps based on a start epoch and a delta in ms.
|
||||
The timestamps are meant to be used with the resampler and generator half delta offsets of the returned timetsamps.
|
||||
"""
|
||||
timestamp_start = datetime.fromtimestamp(start_epoch)
|
||||
delta = timedelta(milliseconds=timedelta_in_ms)
|
||||
delta_half = timedelta(milliseconds=timedelta_in_ms // 2)
|
||||
next_timestamp = floor_dt(timestamp_start, delta) + delta_half
|
||||
while True:
|
||||
yield datetime.timestamp(next_timestamp)
|
||||
next_timestamp += delta
|
||||
|
||||
|
||||
def interpolate(t1: float, y1: list[float], t2: float, y2: list[float], t_target: float) -> list[float]:
|
||||
"""Perform linear interpolation between two data points."""
|
||||
y1_array, y2_array = np.array(y1), np.array(y2)
|
||||
fraction = (t_target - t1) / (t2 - t1)
|
||||
interpolated_values = y1_array + fraction * (y2_array - y1_array)
|
||||
return interpolated_values.tolist()
|
||||
|
||||
|
||||
def check_dict(data: dict[str, Serializable]) -> None:
|
||||
"""Check that only numeric types are in input data."""
|
||||
for key, value in data.items():
|
||||
if not isinstance(value, (int, float)):
|
||||
error_msg = f"Value {value} for key {key} is not of type int or float"
|
||||
raise TypeError(error_msg)
|
||||
|
||||
|
||||
class Resampler:
|
||||
"""
|
||||
Async resample data based on a fixed rate. Can handle upsampling and downsampling.
|
||||
|
||||
Methods:
|
||||
--------
|
||||
start()
|
||||
Start the resampler task.
|
||||
|
||||
stop()
|
||||
Stop the resampler task.
|
||||
|
||||
receive()
|
||||
Get next resampled dictonary from the resampler.
|
||||
"""
|
||||
|
||||
def __init__(self, config: ResamplerConf, subscriber: AsyncSource) -> None:
|
||||
"""
|
||||
Parameters:
|
||||
----------
|
||||
config : namedtuple
|
||||
Configuration for the resampler.
|
||||
subscriber : AsyncMQTTSubscriber
|
||||
Asynchronous Subscriber
|
||||
|
||||
"""
|
||||
self.config = config
|
||||
self.subscriber = subscriber
|
||||
self.resample_rate = self.config.resample_rate
|
||||
self.delta_t = round(self.resample_rate / 1_000, 3)
|
||||
self.message_queue: Queue[dict[str, float]] = Queue(maxsize=50)
|
||||
self.resample_task: None | Task[None] = None
|
||||
|
||||
def start(self) -> None:
|
||||
"""
|
||||
Start the resampler task.
|
||||
"""
|
||||
self.resample_task = create_task(self.resample())
|
||||
|
||||
def stop(self) -> None:
|
||||
"""
|
||||
Stop the resampler task
|
||||
"""
|
||||
if self.resample_task:
|
||||
self.resample_task.cancel()
|
||||
|
||||
async def receive(self) -> dict[str, float]:
|
||||
"""
|
||||
Get next resampled dictonary from the resampler.
|
||||
|
||||
Implicitly starts the resampler if not already running.
|
||||
"""
|
||||
if not self.resample_task:
|
||||
self.start()
|
||||
return await self.message_queue.get()
|
||||
|
||||
async def resample(self) -> None:
|
||||
"""
|
||||
Resample data based on a fixed rate.
|
||||
|
||||
Can handle upsampling and downsampling.
|
||||
Data will always be centered around the output resample timestamp.
|
||||
(i.e. for data returned for t = 1.0s, the data will be resampled for [0.5, 1.5]s)
|
||||
"""
|
||||
|
||||
print("Starting resampler")
|
||||
aggregated_data = []
|
||||
aggregated_timestamps = []
|
||||
|
||||
# Get first element to determine timestamp
|
||||
topic, data = await self.subscriber.receive()
|
||||
|
||||
check_dict(data)
|
||||
|
||||
timestamp, message = self._pack_data(data) # type: ignore [arg-type]
|
||||
timestamps = timestamp_generator(timestamp, self.resample_rate)
|
||||
print(f"Got first element {topic}: {data}")
|
||||
|
||||
# Set data keys to reconstruct dict later
|
||||
self.data_keys = data.keys()
|
||||
self.topic = topic
|
||||
|
||||
# step through interpolation timestamps
|
||||
for next_timestamp in timestamps:
|
||||
# await new data and append to buffer until the most recent data
|
||||
# is newer than the next interplation timestamp
|
||||
while timestamp < next_timestamp:
|
||||
aggregated_timestamps.append(timestamp)
|
||||
aggregated_data.append(message)
|
||||
|
||||
topic, data = await self.subscriber.receive()
|
||||
timestamp, message = self._pack_data(data) # type: ignore [arg-type]
|
||||
|
||||
return_timestamp = round(next_timestamp - self.delta_t / 2, 3)
|
||||
|
||||
# Only one new data point was received
|
||||
if len(aggregated_data) == 1:
|
||||
self._is_upsampling = False
|
||||
# print("Only one data point")
|
||||
last_timestamp, last_message = (
|
||||
aggregated_timestamps[0],
|
||||
aggregated_data[0],
|
||||
)
|
||||
|
||||
# Case 2 Upsampling:
|
||||
while timestamp - next_timestamp > self.delta_t:
|
||||
self._is_upsampling = True
|
||||
# print("Upsampling")
|
||||
last_message = interpolate(
|
||||
last_timestamp,
|
||||
last_message,
|
||||
timestamp,
|
||||
message,
|
||||
return_timestamp,
|
||||
)
|
||||
last_timestamp = return_timestamp
|
||||
return_timestamp += self.delta_t
|
||||
next_timestamp = next(timestamps)
|
||||
await self.message_queue.put(self._unpack_data(last_timestamp, last_message))
|
||||
|
||||
if self._is_upsampling:
|
||||
last_message = interpolate(
|
||||
last_timestamp,
|
||||
last_message,
|
||||
timestamp,
|
||||
message,
|
||||
return_timestamp,
|
||||
)
|
||||
last_timestamp = return_timestamp
|
||||
|
||||
await self.message_queue.put(self._unpack_data(last_timestamp, last_message))
|
||||
|
||||
if len(aggregated_data) > 1:
|
||||
# Case 4 - downsampling: Multiple data points were during the resampling timeframe
|
||||
mean_message = np.mean(np.array(aggregated_data), axis=0)
|
||||
await self.message_queue.put(self._unpack_data(return_timestamp, mean_message))
|
||||
|
||||
# reset the aggregator
|
||||
aggregated_data.clear()
|
||||
aggregated_timestamps.clear()
|
||||
|
||||
def _pack_data(self, data: dict[str, float]) -> tuple[float, list[float]]:
|
||||
# pack data from dict to tuple list
|
||||
ts = data.pop("epoch")
|
||||
return (ts, list(data.values()))
|
||||
|
||||
def _unpack_data(self, ts: float, values: list[float]) -> dict[str, float]:
|
||||
# from tuple
|
||||
return {"epoch": round(ts, 3), **dict(zip(self.data_keys, values))}
|
||||
@@ -1,142 +0,0 @@
|
||||
import math
|
||||
from datetime import datetime, timedelta
|
||||
from queue import Queue
|
||||
|
||||
import numpy as np
|
||||
|
||||
from heisskleber.mqtt import MqttSubscriber
|
||||
|
||||
|
||||
def round_dt(dt, delta):
|
||||
"""Round a datetime object based on a delta timedelta."""
|
||||
return datetime.min + math.floor((dt - datetime.min) / delta) * delta
|
||||
|
||||
|
||||
def timestamp_generator(start_epoch, timedelta_in_ms):
|
||||
"""Generate increasing timestamps based on a start epoch and a delta in ms."""
|
||||
timestamp_start = datetime.fromtimestamp(start_epoch)
|
||||
delta = timedelta(milliseconds=timedelta_in_ms)
|
||||
delta_half = timedelta(milliseconds=timedelta_in_ms // 2)
|
||||
next_timestamp = round_dt(timestamp_start, delta) + delta_half
|
||||
while True:
|
||||
yield datetime.timestamp(next_timestamp)
|
||||
next_timestamp += delta
|
||||
|
||||
|
||||
def interpolate(t1, y1, t2, y2, t_target):
|
||||
"""Perform linear interpolation between two data points."""
|
||||
y1, y2 = np.array(y1), np.array(y2)
|
||||
fraction = (t_target - t1) / (t2 - t1)
|
||||
interpolated_values = y1 + fraction * (y2 - y1)
|
||||
return interpolated_values.tolist()
|
||||
|
||||
|
||||
class Resampler:
|
||||
"""
|
||||
Synchronously resample data based on a fixed rate. Can handle upsampling and downsampling.
|
||||
|
||||
Parameters:
|
||||
----------
|
||||
config : namedtuple
|
||||
Configuration for the resampler.
|
||||
subscriber : MqttSubscriber
|
||||
Synchronous Subscriber
|
||||
"""
|
||||
|
||||
def __init__(self, config, subscriber: MqttSubscriber):
|
||||
self.config = config
|
||||
self.subscriber = subscriber
|
||||
self.buffer = Queue()
|
||||
self.resample_rate = self.config.resample_rate
|
||||
self.delta_t = round(self.resample_rate / 1_000, 3)
|
||||
|
||||
def run(self):
|
||||
topic, message = self.subscriber.receive()
|
||||
self.buffer.put(self._pack_data(message))
|
||||
self.data_keys = message.keys()
|
||||
|
||||
while True:
|
||||
topic, message = self.subscriber.receive()
|
||||
self.buffer.put(self._pack_data(message))
|
||||
|
||||
def resample(self):
|
||||
aggregated_data = []
|
||||
aggregated_timestamps = []
|
||||
# Get first element to determine timestamp
|
||||
timestamp, message = self.buffer.get()
|
||||
timestamps = timestamp_generator(timestamp, self.resample_rate)
|
||||
|
||||
# step through interpolation timestamps
|
||||
for next_timestamp in timestamps:
|
||||
# last_timestamp, last_message = timestamp, message
|
||||
|
||||
# append new data to buffer until the most recent data
|
||||
# is newer than the next interplation timestamp
|
||||
while timestamp < next_timestamp:
|
||||
aggregated_timestamps.append(timestamp)
|
||||
aggregated_data.append(message)
|
||||
timestamp, message = self.buffer.get()
|
||||
|
||||
return_timestamp = round(next_timestamp - self.delta_t / 2, 3)
|
||||
|
||||
# Case 1: Only one new data point was received
|
||||
if len(aggregated_data) == 1:
|
||||
last_timestamp, last_message = (
|
||||
aggregated_timestamps[0],
|
||||
aggregated_data[0],
|
||||
)
|
||||
|
||||
# Case 1a Upsampling:
|
||||
# The data point is not within our time interval
|
||||
# We step through time intervals, yielding interpolated data points
|
||||
while timestamp - next_timestamp > self.delta_t:
|
||||
last_message = interpolate(
|
||||
last_timestamp,
|
||||
last_message,
|
||||
timestamp,
|
||||
message,
|
||||
return_timestamp,
|
||||
)
|
||||
last_timestamp = return_timestamp
|
||||
return_timestamp += self.delta_t
|
||||
next_timestamp = next(timestamps)
|
||||
yield self._unpack_data(last_timestamp, last_message)
|
||||
|
||||
# Case 1b: The data point is within our time interval
|
||||
# We simply yield the data point
|
||||
# Note, this will also be the case once we have advanced the time interval by upsampling
|
||||
last_message = interpolate(
|
||||
last_timestamp,
|
||||
last_message,
|
||||
timestamp,
|
||||
message,
|
||||
return_timestamp,
|
||||
)
|
||||
last_timestamp = return_timestamp
|
||||
return_timestamp += self.delta_t
|
||||
yield self._unpack_data(last_timestamp, last_message)
|
||||
|
||||
# Case 2 - downsampling: Multiple data points were during the resampling timeframe
|
||||
# We simply yield the mean of the data points, which is more robust and performant than interpolation
|
||||
if len(aggregated_data) > 1:
|
||||
# yield self._handle_downsampling(return_timestamp, aggregated_data)
|
||||
mean_message = np.mean(np.array(aggregated_data), axis=0)
|
||||
yield self._unpack_data(return_timestamp, mean_message)
|
||||
|
||||
# reset the aggregator
|
||||
aggregated_data.clear()
|
||||
aggregated_timestamps.clear()
|
||||
|
||||
def _handle_downsampling(self, return_timestamp, aggregated_data) -> dict:
|
||||
"""Handle the downsampling case."""
|
||||
mean_message = np.mean(np.array(aggregated_data), axis=0)
|
||||
return self._unpack_data(return_timestamp, mean_message)
|
||||
|
||||
def _pack_data(self, data) -> tuple[int, list]:
|
||||
# pack data from dict to tuple list
|
||||
ts = data.pop("epoch")
|
||||
return (ts, list(data.values()))
|
||||
|
||||
def _unpack_data(self, ts, values) -> dict:
|
||||
# from tuple
|
||||
return {"epoch": round(ts, 3), **dict(zip(self.data_keys, values))}
|
||||
@@ -1,5 +0,0 @@
|
||||
from heisskleber.tcp.config import TcpConf
|
||||
from heisskleber.tcp.sink import AsyncTcpSink
|
||||
from heisskleber.tcp.source import AsyncTcpSource
|
||||
|
||||
__all__ = ["AsyncTcpSource", "AsyncTcpSink", "TcpConf"]
|
||||
@@ -1,10 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from heisskleber.config import BaseConf
|
||||
|
||||
|
||||
@dataclass
|
||||
class TcpConf(BaseConf):
|
||||
host: str = "localhost"
|
||||
port: int = 6000
|
||||
timeout: int = 60
|
||||
@@ -1,60 +0,0 @@
|
||||
import asyncio
|
||||
from typing import Callable
|
||||
|
||||
from heisskleber.core.types import AsyncSource, Serializable
|
||||
from heisskleber.tcp.config import TcpConf
|
||||
|
||||
|
||||
def bytes_csv_unpacker(data: bytes) -> tuple[str, dict[str, str]]:
|
||||
vals = data.decode().rstrip().split(",")
|
||||
keys = [f"key{i}" for i in range(len(vals))]
|
||||
return ("tcp", dict(zip(keys, vals)))
|
||||
|
||||
|
||||
class AsyncTcpSource(AsyncSource):
|
||||
"""
|
||||
Async TCP connection, connects to host:port and reads byte encoded strings.
|
||||
|
||||
|
||||
Pass an unpack function like so:
|
||||
|
||||
Example
|
||||
-------
|
||||
def unpack(data: bytes) -> tuple[str, dict[str, float | int | str]]:
|
||||
return dict(zip(["key1", "key2"], data.decode().split(","))
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, config: TcpConf, unpack: Callable[[bytes], tuple[str, dict[str, Serializable]]] | None) -> None:
|
||||
self.config = config
|
||||
self.is_connected = asyncio.Event()
|
||||
self.unpack = unpack or bytes_csv_unpacker
|
||||
self.timeout = config.timeout
|
||||
self.start_task: asyncio.Task[None] | None = None
|
||||
|
||||
async def receive(self) -> tuple[str, dict[str, Serializable]]:
|
||||
await self._check_connection()
|
||||
data = await self.reader.readline()
|
||||
topic, payload = self.unpack(data)
|
||||
return (topic, payload) # type: ignore
|
||||
|
||||
def start(self) -> None:
|
||||
self.start_task = asyncio.create_task(self._connect())
|
||||
|
||||
def stop(self) -> None:
|
||||
if self.is_connected:
|
||||
print("stopping")
|
||||
|
||||
async def _check_connection(self) -> None:
|
||||
if not self.start_task:
|
||||
self.start()
|
||||
await self.is_connected.wait()
|
||||
|
||||
async def _connect(self) -> None:
|
||||
print(f"{self} waiting for connection.")
|
||||
(self.reader, self.writer) = await asyncio.open_connection(self.config.host, self.config.port)
|
||||
print(f"{self} connected successfully!")
|
||||
self.is_connected.set()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}(host={self.config.host}, port={self.config.port})"
|
||||
@@ -1,5 +0,0 @@
|
||||
from .config import UdpConf
|
||||
from .publisher import AsyncUdpSink, UdpPublisher
|
||||
from .subscriber import AsyncUdpSource, UdpSubscriber
|
||||
|
||||
__all__ = ["AsyncUdpSource", "UdpSubscriber", "AsyncUdpSink", "UdpPublisher", "UdpConf"]
|
||||
@@ -1,84 +0,0 @@
|
||||
import asyncio
|
||||
import socket
|
||||
import sys
|
||||
|
||||
from heisskleber.core.packer import get_packer
|
||||
from heisskleber.core.types import AsyncSink, Serializable, Sink
|
||||
from heisskleber.udp.config import UdpConf
|
||||
|
||||
|
||||
class UdpPublisher(Sink):
|
||||
def __init__(self, config: UdpConf) -> None:
|
||||
self.config = config
|
||||
self.pack = get_packer(self.config.packer)
|
||||
self.is_connected = False
|
||||
|
||||
def start(self) -> None:
|
||||
try:
|
||||
self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
except OSError as e:
|
||||
print(f"failed to create socket: {e}")
|
||||
sys.exit(-1)
|
||||
else:
|
||||
self.is_connected = True
|
||||
|
||||
def stop(self) -> None:
|
||||
self.socket.close()
|
||||
self.is_connected = True
|
||||
|
||||
def send(self, data: dict[str, Serializable], topic: str | None = None) -> None:
|
||||
if not self.is_connected:
|
||||
self.start()
|
||||
if topic:
|
||||
data["topic"] = topic
|
||||
payload = self.pack(data).encode("utf-8")
|
||||
self.socket.sendto(payload, (self.config.host, self.config.port))
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}(host={self.config.host}, port={self.config.port})"
|
||||
|
||||
|
||||
class UdpProtocol(asyncio.DatagramProtocol):
|
||||
def __init__(self, is_connected: bool) -> None:
|
||||
super().__init__()
|
||||
self.is_connected = is_connected
|
||||
|
||||
def connection_lost(self, exc: Exception | None) -> None:
|
||||
print("Connection lost")
|
||||
self.is_connected = False
|
||||
|
||||
|
||||
class AsyncUdpSink(AsyncSink):
|
||||
def __init__(self, config: UdpConf) -> None:
|
||||
self.config = config
|
||||
self.pack = get_packer(self.config.packer)
|
||||
self.socket: asyncio.DatagramTransport | None = None
|
||||
self.is_connected = False
|
||||
|
||||
def start(self) -> None:
|
||||
# No background loop required
|
||||
pass
|
||||
|
||||
def stop(self) -> None:
|
||||
if self.socket is not None:
|
||||
self.socket.close()
|
||||
self.is_connected = False
|
||||
|
||||
async def _ensure_connection(self) -> None:
|
||||
if not self.is_connected:
|
||||
loop = asyncio.get_running_loop()
|
||||
self.socket, _ = await loop.create_datagram_endpoint(
|
||||
lambda: UdpProtocol(self.is_connected),
|
||||
remote_addr=(self.config.host, self.config.port),
|
||||
)
|
||||
self.is_connected = True
|
||||
|
||||
async def send(self, data: dict[str, Serializable], topic: str | None = None) -> None:
|
||||
await self._ensure_connection()
|
||||
if topic:
|
||||
data["topic"] = topic
|
||||
payload = self.pack(data).encode(self.config.encoding)
|
||||
self.socket.sendto(payload) # type: ignore
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}(host={self.config.host}, port={self.config.port})"
|
||||
@@ -1,117 +0,0 @@
|
||||
import asyncio
|
||||
import socket
|
||||
import sys
|
||||
import threading
|
||||
from queue import Queue
|
||||
from typing import Any
|
||||
|
||||
from heisskleber.core.packer import get_unpacker
|
||||
from heisskleber.core.types import AsyncSource, Serializable, Source
|
||||
from heisskleber.udp.config import UdpConf
|
||||
|
||||
|
||||
class UdpSubscriber(Source):
|
||||
def __init__(self, config: UdpConf, topic: str | None = None):
|
||||
self.config = config
|
||||
self.topic = topic
|
||||
self.unpacker = get_unpacker(self.config.packer)
|
||||
self._queue: Queue[tuple[str, dict[str, Serializable]]] = Queue(maxsize=self.config.max_queue_size)
|
||||
self._running = threading.Event()
|
||||
|
||||
def start(self) -> None:
|
||||
try:
|
||||
self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
except OSError as e:
|
||||
print(f"failed to create socket: {e}")
|
||||
sys.exit(-1)
|
||||
self.socket.bind((self.config.host, self.config.port))
|
||||
self._running.set()
|
||||
self._thread = threading.Thread(target=self._loop, daemon=True)
|
||||
self._thread.start()
|
||||
|
||||
def stop(self) -> None:
|
||||
self._running.clear()
|
||||
# if self._thread is not None:
|
||||
# self._thread.join()
|
||||
self.socket.close()
|
||||
|
||||
def receive(self) -> tuple[str, dict[str, Serializable]]:
|
||||
if not self._running.is_set():
|
||||
self.start()
|
||||
return self._queue.get()
|
||||
|
||||
def _loop(self) -> None:
|
||||
while self._running.is_set():
|
||||
try:
|
||||
payload, _ = self.socket.recvfrom(1024)
|
||||
data = self.unpacker(payload.decode("utf-8"))
|
||||
topic: str = str(data.pop("topic")) if "topic" in data else ""
|
||||
self._queue.put((topic, data))
|
||||
except Exception as e:
|
||||
error_message = f"Error in UDP listener loop: {e}"
|
||||
print(error_message)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}(host={self.config.host}, port={self.config.port})"
|
||||
|
||||
|
||||
class UdpProtocol(asyncio.DatagramProtocol):
|
||||
def __init__(self, queue: asyncio.Queue[bytes]) -> None:
|
||||
super().__init__()
|
||||
self.queue = queue
|
||||
|
||||
def datagram_received(self, data: bytes, addr: tuple[str | Any, int]) -> None:
|
||||
self.queue.put_nowait(data)
|
||||
|
||||
def connection_made(self, transport: asyncio.DatagramTransport) -> None:
|
||||
print("Connection made")
|
||||
|
||||
|
||||
class AsyncUdpSource(AsyncSource):
|
||||
"""
|
||||
An asynchronous UDP subscriber based on asyncio.protocols.DatagramProtocol
|
||||
"""
|
||||
|
||||
def __init__(self, config: UdpConf, topic: str = "udp"):
|
||||
self.config = config
|
||||
self.topic = topic
|
||||
self.EOF = self.config.delimiter.encode(self.config.encoding)
|
||||
self.unpacker = get_unpacker(self.config.packer)
|
||||
self.queue: asyncio.Queue[bytes] = asyncio.Queue(maxsize=self.config.max_queue_size)
|
||||
self.task: asyncio.Task[None] | None = None
|
||||
self.is_connected = False
|
||||
|
||||
async def setup(self) -> None:
|
||||
loop = asyncio.get_event_loop()
|
||||
self.transport, self.protocol = await loop.create_datagram_endpoint(
|
||||
lambda: UdpProtocol(self.queue),
|
||||
local_addr=(self.config.host, self.config.port),
|
||||
)
|
||||
self.is_connected = True
|
||||
print("Udp connection established.")
|
||||
|
||||
def start(self) -> None:
|
||||
# Background loop not required, handled by Protocol
|
||||
pass
|
||||
|
||||
def stop(self) -> None:
|
||||
self.transport.close()
|
||||
|
||||
async def receive(self) -> tuple[str, dict[str, Serializable]]:
|
||||
if not self.is_connected:
|
||||
await self.setup()
|
||||
data = await self.queue.get()
|
||||
try:
|
||||
payload = self.unpacker(data.decode(self.config.encoding, errors="ignore"))
|
||||
# except UnicodeDecodeError: # this won't be thrown anymore, as the error flag is set to ignore!
|
||||
# print(f"Could not decode data, is not {self.config.encoding}")
|
||||
except Exception:
|
||||
if self.config.verbose:
|
||||
print(f"Could not deserialize data: {data!r}")
|
||||
else:
|
||||
return (self.topic, payload)
|
||||
|
||||
return await self.receive() # Try again
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}(host={self.config.host}, port={self.config.port})"
|
||||
@@ -1,5 +0,0 @@
|
||||
from .config import ZmqConf
|
||||
from .publisher import ZmqAsyncPublisher, ZmqPublisher
|
||||
from .subscriber import ZmqAsyncSubscriber, ZmqSubscriber
|
||||
|
||||
__all__ = ["ZmqConf", "ZmqPublisher", "ZmqSubscriber", "ZmqAsyncPublisher", "ZmqAsyncSubscriber"]
|
||||
@@ -1,129 +0,0 @@
|
||||
import sys
|
||||
from typing import Callable
|
||||
|
||||
import zmq
|
||||
import zmq.asyncio
|
||||
|
||||
from heisskleber.core.packer import get_packer
|
||||
from heisskleber.core.types import AsyncSink, Serializable, Sink
|
||||
|
||||
from .config import ZmqConf
|
||||
|
||||
|
||||
class ZmqPublisher(Sink):
|
||||
"""
|
||||
Publisher that sends messages to a ZMQ PUB socket.
|
||||
|
||||
Attributes:
|
||||
-----------
|
||||
pack : Callable
|
||||
The packer function to use for serializing the data.
|
||||
|
||||
Methods:
|
||||
--------
|
||||
send(data : dict, topic : str):
|
||||
Send the data with the given topic.
|
||||
|
||||
start():
|
||||
Connect to the socket.
|
||||
|
||||
stop():
|
||||
Close the socket.
|
||||
"""
|
||||
|
||||
def __init__(self, config: ZmqConf):
|
||||
self.config = config
|
||||
self.context = zmq.Context.instance()
|
||||
self.socket = self.context.socket(zmq.PUB)
|
||||
self.pack = get_packer(config.packstyle)
|
||||
self.is_connected = False
|
||||
|
||||
def send(self, data: dict[str, Serializable], topic: str) -> None:
|
||||
"""
|
||||
Take the data as a dict, serialize it with the given packer and send it to the zmq socket.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
self.start()
|
||||
payload = self.pack(data)
|
||||
if self.config.verbose:
|
||||
print(f"sending message {payload} to topic {topic}")
|
||||
self.socket.send_multipart([topic.encode(), payload.encode()])
|
||||
|
||||
def start(self) -> None:
|
||||
"""Connect to the zmq socket."""
|
||||
try:
|
||||
if self.config.verbose:
|
||||
print(f"connecting to {self.config.publisher_address}")
|
||||
self.socket.connect(self.config.publisher_address)
|
||||
except Exception as e:
|
||||
print(f"failed to bind to zeromq socket: {e}")
|
||||
sys.exit(-1)
|
||||
else:
|
||||
self.is_connected = True
|
||||
|
||||
def stop(self) -> None:
|
||||
self.socket.close()
|
||||
self.is_connected = False
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}(host={self.config.publisher_address}, port={self.config.publisher_port})"
|
||||
|
||||
|
||||
class ZmqAsyncPublisher(AsyncSink):
|
||||
"""
|
||||
Async publisher that sends messages to a ZMQ PUB socket.
|
||||
|
||||
Attributes:
|
||||
-----------
|
||||
pack : Callable
|
||||
The packer function to use for serializing the data.
|
||||
|
||||
Methods:
|
||||
--------
|
||||
send(data : dict, topic : str):
|
||||
Send the data with the given topic.
|
||||
|
||||
start():
|
||||
Connect to the socket.
|
||||
|
||||
stop():
|
||||
Close the socket.
|
||||
"""
|
||||
|
||||
def __init__(self, config: ZmqConf):
|
||||
self.config = config
|
||||
self.context = zmq.asyncio.Context.instance()
|
||||
self.socket: zmq.asyncio.Socket = self.context.socket(zmq.PUB)
|
||||
self.pack: Callable = get_packer(config.packstyle)
|
||||
self.is_connected = False
|
||||
|
||||
async def send(self, data: dict[str, Serializable], topic: str) -> None:
|
||||
"""
|
||||
Take the data as a dict, serialize it with the given packer and send it to the zmq socket.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
self.start()
|
||||
payload = self.pack(data)
|
||||
if self.config.verbose:
|
||||
print(f"sending message {payload} to topic {topic}")
|
||||
await self.socket.send_multipart([topic.encode(), payload.encode()])
|
||||
|
||||
def start(self) -> None:
|
||||
"""Connect to the zmq socket."""
|
||||
try:
|
||||
if self.config.verbose:
|
||||
print(f"connecting to {self.config.publisher_address}")
|
||||
self.socket.connect(self.config.publisher_address)
|
||||
except Exception as e:
|
||||
print(f"failed to bind to zeromq socket: {e}")
|
||||
sys.exit(-1)
|
||||
else:
|
||||
self.is_connected = True
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Close the zmq socket."""
|
||||
self.socket.close()
|
||||
self.is_connected = False
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}(host={self.config.publisher_address}, port={self.config.publisher_port})"
|
||||
@@ -1,181 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
|
||||
import zmq
|
||||
import zmq.asyncio
|
||||
|
||||
from heisskleber.core.packer import get_unpacker
|
||||
from heisskleber.core.types import AsyncSource, Source
|
||||
|
||||
from .config import ZmqConf
|
||||
|
||||
|
||||
class ZmqSubscriber(Source):
|
||||
"""
|
||||
Source that subscribes to one or many topics from a zmq broker and receives messages via the receive() function.
|
||||
|
||||
Attributes:
|
||||
-----------
|
||||
unpack : Callable
|
||||
The unpacker function to use for deserializing the data.
|
||||
|
||||
Methods:
|
||||
--------
|
||||
receive() -> tuple[str, dict]:
|
||||
Send the data with the given topic.
|
||||
|
||||
start():
|
||||
Connect to the socket.
|
||||
|
||||
stop():
|
||||
Close the socket.
|
||||
"""
|
||||
|
||||
def __init__(self, config: ZmqConf, topic: str | list[str]):
|
||||
"""
|
||||
Constructs new ZmqAsyncSubscriber instance.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
config : ZmqConf
|
||||
The configuration dataclass object for the zmq connection.
|
||||
topic : str
|
||||
The topic or list of topics to subscribe to.
|
||||
"""
|
||||
self.config = config
|
||||
self.topic = topic
|
||||
self.context = zmq.Context.instance()
|
||||
self.socket = self.context.socket(zmq.SUB)
|
||||
self.unpack = get_unpacker(config.packstyle)
|
||||
self.is_connected = False
|
||||
|
||||
def receive(self) -> tuple[str, dict]:
|
||||
"""
|
||||
reads a message from the zmq bus and returns it
|
||||
|
||||
Returns:
|
||||
tuple(topic: str, message: dict): the message received
|
||||
"""
|
||||
if not self.is_connected:
|
||||
self.start()
|
||||
(topic, payload) = self.socket.recv_multipart()
|
||||
message = self.unpack(payload.decode())
|
||||
topic = topic.decode()
|
||||
return (topic, message)
|
||||
|
||||
def start(self):
|
||||
try:
|
||||
self.socket.connect(self.config.subscriber_address)
|
||||
self.subscribe(self.topic)
|
||||
except Exception as e:
|
||||
print(f"failed to bind to zeromq socket: {e}")
|
||||
sys.exit(-1)
|
||||
else:
|
||||
self.is_connected = True
|
||||
|
||||
def stop(self):
|
||||
self.socket.close()
|
||||
self.is_connected = False
|
||||
|
||||
def subscribe(self, topic: str | list[str] | tuple[str]):
|
||||
# Accepts single topic or list of topics
|
||||
if isinstance(topic, (list, tuple)):
|
||||
for t in topic:
|
||||
self._subscribe_single_topic(t)
|
||||
else:
|
||||
self._subscribe_single_topic(topic)
|
||||
|
||||
def _subscribe_single_topic(self, topic: str):
|
||||
self.socket.setsockopt(zmq.SUBSCRIBE, topic.encode())
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}(host={self.config.subscriber_address}, port={self.config.subscriber_port})"
|
||||
|
||||
|
||||
class ZmqAsyncSubscriber(AsyncSource):
|
||||
"""
|
||||
Async source that subscribes to one or many topics from a zmq broker and receives messages via the receive() function.
|
||||
|
||||
Attributes:
|
||||
-----------
|
||||
unpack : Callable
|
||||
The unpacker function to use for deserializing the data.
|
||||
|
||||
Methods:
|
||||
--------
|
||||
receive() -> tuple[str, dict]:
|
||||
Send the data with the given topic.
|
||||
|
||||
start():
|
||||
Connect to the socket.
|
||||
|
||||
stop():
|
||||
Close the socket.
|
||||
"""
|
||||
|
||||
def __init__(self, config: ZmqConf, topic: str | list[str]):
|
||||
"""
|
||||
Constructs new ZmqAsyncSubscriber instance.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
config : ZmqConf
|
||||
The configuration dataclass object for the zmq connection.
|
||||
topic : str
|
||||
The topic or list of topics to subscribe to.
|
||||
"""
|
||||
self.config = config
|
||||
self.topic = topic
|
||||
self.context = zmq.asyncio.Context.instance()
|
||||
self.socket: zmq.asyncio.Socket = self.context.socket(zmq.SUB)
|
||||
self.unpack = get_unpacker(config.packstyle)
|
||||
self.is_connected = False
|
||||
|
||||
async def receive(self) -> tuple[str, dict]:
|
||||
"""
|
||||
reads a message from the zmq bus and returns it
|
||||
|
||||
Returns:
|
||||
tuple(topic: str, message: dict): the message received
|
||||
"""
|
||||
if not self.is_connected:
|
||||
self.start()
|
||||
(topic, payload) = await self.socket.recv_multipart()
|
||||
message = self.unpack(payload.decode())
|
||||
topic = topic.decode()
|
||||
return (topic, message)
|
||||
|
||||
def start(self):
|
||||
"""Connect to the zmq socket."""
|
||||
try:
|
||||
self.socket.connect(self.config.subscriber_address)
|
||||
except Exception as e:
|
||||
print(f"failed to bind to zeromq socket: {e}")
|
||||
sys.exit(-1)
|
||||
else:
|
||||
self.is_connected = True
|
||||
self.subscribe(self.topic)
|
||||
|
||||
def stop(self):
|
||||
"""Close the zmq socket."""
|
||||
self.socket.close()
|
||||
self.is_connected = False
|
||||
|
||||
def subscribe(self, topic: str | list[str] | tuple[str]):
|
||||
"""
|
||||
Subscribes to the given topic(s) on the zmq socket.
|
||||
|
||||
Accepts single topic or list of topics.
|
||||
"""
|
||||
if isinstance(topic, (list, tuple)):
|
||||
for t in topic:
|
||||
self._subscribe_single_topic(t)
|
||||
else:
|
||||
self._subscribe_single_topic(topic)
|
||||
|
||||
def _subscribe_single_topic(self, topic: str):
|
||||
self.socket.setsockopt(zmq.SUBSCRIBE, topic.encode())
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}(host={self.config.subscriber_address}, port={self.config.subscriber_port})"
|
||||
52
noxfile.py
Normal file
52
noxfile.py
Normal file
@@ -0,0 +1,52 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import nox
|
||||
|
||||
DIR = Path(__file__).parent.resolve()
|
||||
|
||||
nox.needs_version = ">=2024.3.2"
|
||||
nox.options.sessions = ["lint", "tests", "check"]
|
||||
nox.options.default_venv_backend = "uv|virtualenv"
|
||||
|
||||
|
||||
@nox.session
|
||||
def lint(session: nox.Session) -> None:
|
||||
"""Run the linter."""
|
||||
session.install("pre-commit")
|
||||
session.run("pre-commit", "run", "--all-files", "--show-diff-on-failure", *session.posargs)
|
||||
|
||||
|
||||
@nox.session
|
||||
def tests(session: nox.Session) -> None:
|
||||
"""Run the unit and regular tests."""
|
||||
session.install(".[test]")
|
||||
session.run("pytest", *session.posargs)
|
||||
|
||||
|
||||
@nox.session(reuse_venv=True)
|
||||
def docs(session: nox.Session) -> None:
|
||||
"""Build the docs. Pass --non-interactive to avoid serving. First positional argument is the target directory."""
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-b", dest="builder", default="html", help="Build target (default: html)")
|
||||
parser.add_argument("output", nargs="?", help="Output directory")
|
||||
args, posargs = parser.parse_known_args(session.posargs)
|
||||
serve = args.builder == "html" and session.interactive
|
||||
|
||||
session.install("-e.[docs]", "sphinx-autobuild")
|
||||
|
||||
shared_args = (
|
||||
"-n", # nitpicky mode
|
||||
"-T", # full tracebacks
|
||||
f"-b={args.builder}",
|
||||
"docs",
|
||||
args.output or f"docs/_build/{args.builder}",
|
||||
*posargs,
|
||||
)
|
||||
|
||||
if serve:
|
||||
session.run("sphinx-autobuild", "--open-browser", *shared_args)
|
||||
else:
|
||||
session.run("sphinx-build", "--keep-going", *shared_args)
|
||||
2083
poetry.lock
generated
2083
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
166
pyproject.toml
166
pyproject.toml
@@ -1,57 +1,66 @@
|
||||
[tool.poetry]
|
||||
[build-system]
|
||||
requires = ["hatchling", "hatch-vcs"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[project]
|
||||
name = "heisskleber"
|
||||
version = "0.5.7"
|
||||
description = "Heisskleber"
|
||||
authors = ["Felix Weiler <felix@flucto.tech>"]
|
||||
license = "MIT"
|
||||
authors = [
|
||||
{ name = "Felix Weiler-Detjen", email = "felix@flucto.tech" },
|
||||
]
|
||||
license = {file = "LICENSE"}
|
||||
readme = "README.md"
|
||||
homepage = "https://github.com/flucto-gmbh/heisskleber"
|
||||
repository = "https://github.com/flucto-gmbh/heisskleber"
|
||||
documentation = "https://heisskleber.readthedocs.io"
|
||||
[tool.poetry.urls]
|
||||
Changelog = "https://github.com/flucto-gmbh/heisskleber/releases"
|
||||
requires-python = ">=3.10"
|
||||
dynamic = ["version"]
|
||||
dependencies= [
|
||||
"aiomqtt>=2.3.0",
|
||||
"pyserial>=3.5",
|
||||
"pyyaml>=6.0.2",
|
||||
"pyzmq>=26.2.0",
|
||||
]
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = "^3.9"
|
||||
paho-mqtt = "^1.6.1"
|
||||
pyserial = "^3.5"
|
||||
pyyaml = "^6.0.1"
|
||||
pyzmq = "^25.1.1"
|
||||
aiomqtt = "^1.2.1"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
black = ">=21.10b0"
|
||||
coverage = { extras = ["toml"], version = ">=6.2" }
|
||||
furo = ">=2021.11.12"
|
||||
mypy = ">=0.930"
|
||||
pre-commit = ">=2.16.0"
|
||||
pre-commit-hooks = ">=4.1.0"
|
||||
pytest = ">=6.2.5"
|
||||
pytest-cov = "^4.1.0"
|
||||
pytest-mock = "^3.11.1"
|
||||
ruff = "^0.0.292"
|
||||
pyupgrade = ">=2.29.1"
|
||||
safety = ">=1.10.3"
|
||||
sphinx = ">=4.3.2"
|
||||
sphinx-autobuild = ">=2021.3.14"
|
||||
sphinx-autodoc-typehints = "^1.24.0"
|
||||
sphinx-rtd-theme = "^1.3.0"
|
||||
typeguard = ">=2.13.3"
|
||||
xdoctest = { extras = ["colors"], version = ">=0.15.10" }
|
||||
myst-parser = { version = ">=0.16.1" }
|
||||
pytest-asyncio = "^0.21.1"
|
||||
termcolor = "^2.4.0"
|
||||
codecov = "^2.1.13"
|
||||
[project.urls]
|
||||
Homepage = "https://github.com/flucto-gmbh/heisskleber"
|
||||
Repository = "https://github.com/flucto-gmbh/heisskleber"
|
||||
Documentation = "https://heisskleber.readthedocs.io"
|
||||
|
||||
|
||||
[tool.poetry.group.types.dependencies]
|
||||
pandas-stubs = "^2.1.1.230928"
|
||||
types-pyyaml = "^6.0.12.12"
|
||||
types-paho-mqtt = "^1.6.0.7"
|
||||
|
||||
[tool.poetry.scripts]
|
||||
hkcli = "heisskleber.run.cli:main"
|
||||
zmqbroker = "heisskleber.run.zmqbroker:main"
|
||||
[project.optional-dependencies]
|
||||
test = [
|
||||
"pytest>=8.3.3",
|
||||
"pytest-cov>=5.0.0",
|
||||
"coverage[toml]>=7.6.1",
|
||||
"xdoctest>=1.2.0",
|
||||
"pytest-asyncio>=0.24.0",
|
||||
]
|
||||
docs = [
|
||||
"furo>=2024.8.6",
|
||||
"myst-parser>=4.0.0",
|
||||
"sphinx>=8.0.2",
|
||||
"sphinx-autobuild>=2024.9.19",
|
||||
"sphinx-rtd-theme>=0.5.1",
|
||||
"sphinx_copybutton",
|
||||
"sphinx_autodoc_typehints",
|
||||
]
|
||||
filter = [
|
||||
"numpy>=2.1.1",
|
||||
"scipy>=1.14.1",
|
||||
]
|
||||
|
||||
[tool.uv]
|
||||
dev-dependencies = [
|
||||
"deptry>=0.20.0",
|
||||
"mypy>=1.11.2",
|
||||
"ruff>=0.6.8",
|
||||
"xdoctest>=1.2.0",
|
||||
"nox>=2024.4.15",
|
||||
"pytest>=8.3.3",
|
||||
"pytest-cov>=5.0.0",
|
||||
"coverage[toml]>=7.6.1",
|
||||
"pytest-asyncio>=0.24.0",
|
||||
]
|
||||
package = true
|
||||
|
||||
[tool.coverage.paths]
|
||||
source = ["heisskleber", "*/site-packages"]
|
||||
@@ -59,7 +68,7 @@ tests = ["tests", "*/tests"]
|
||||
|
||||
[tool.coverage.run]
|
||||
branch = true
|
||||
source = ["heisskleber"]
|
||||
source = ["src/heisskleber"]
|
||||
omit = ["tests/*"]
|
||||
|
||||
[tool.coverage.report]
|
||||
@@ -75,39 +84,48 @@ show_error_context = true
|
||||
exclude = ["tests/*", "^test_*\\.py"]
|
||||
|
||||
[tool.ruff]
|
||||
ignore-init-module-imports = true
|
||||
target-version = "py39"
|
||||
line-length = 120
|
||||
fix = true
|
||||
select = [
|
||||
"YTT", # flake8-2020
|
||||
"S", # flake8-bandit
|
||||
"B", # flake8-bugbear
|
||||
"A", # flake8-builtins
|
||||
"C4", # flake8-comprehensions
|
||||
"T10", # flake8-debugger
|
||||
"SIM", # flake8-simplify
|
||||
"I", # isort
|
||||
"C90", # mccabe
|
||||
"E",
|
||||
"W", # pycodestyle
|
||||
"F", # pyflakes
|
||||
"PGH", # pygrep-hooks
|
||||
"UP", # pyupgrade
|
||||
"RUF", # ruff
|
||||
"TRY", # tryceratops
|
||||
]
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = ["ALL"]
|
||||
ignore = [
|
||||
"E501", # LineTooLong
|
||||
"E731", # DoNotAssignLambda
|
||||
"A001", #
|
||||
"PGH003", # Use specific rules when ignoring type issues
|
||||
"D100", # Missing module docstring
|
||||
"D104", # Missing package docstring
|
||||
"D107", # Missing __init__ docstring
|
||||
"ANN101", # Deprecated and stupid self annotation
|
||||
"ANN401", # Dynamically typed annotation
|
||||
"FA102", # Missing from __future__ import annotations
|
||||
"FBT001", # boolean style argument in function definition
|
||||
"FBT002", # boolean style argument in function definition
|
||||
"ARG002", # Unused kwargs
|
||||
"TD002",
|
||||
"TD003",
|
||||
"FIX002",
|
||||
"COM812",
|
||||
"ISC001",
|
||||
"ARG001",
|
||||
"INP001"
|
||||
]
|
||||
|
||||
[tool.ruff.per-file-ignores]
|
||||
"tests/*" = ["S101", "D"]
|
||||
"tests/test_import.py" = ["F401"]
|
||||
[tool.ruff.lint.pydocstyle]
|
||||
convention = "google"
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core>=1.0.0"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
[tool.ruff.lint.per-file-ignores]
|
||||
"tests/*" = ["S101", "D", "T201", "PLR2", "SLF001", "ANN"]
|
||||
"bin/*" = [
|
||||
"ERA001", # Found commented-out code
|
||||
]
|
||||
|
||||
|
||||
[tool.hatch]
|
||||
version.source = "vcs"
|
||||
version.path = "src/heisskleber/__init__.py"
|
||||
|
||||
[tool.hatch.envs.default]
|
||||
features = ["test"]
|
||||
scripts.test = "pytest {args}"
|
||||
|
||||
@@ -1,17 +0,0 @@
|
||||
from dataclasses import asdict
|
||||
|
||||
import yaml
|
||||
|
||||
from heisskleber.mqtt.config import MqttConf
|
||||
from heisskleber.serial.config import SerialConf
|
||||
from heisskleber.tcp.config import TcpConf
|
||||
from heisskleber.udp.config import UdpConf
|
||||
from heisskleber.zmq.config import ZmqConf
|
||||
|
||||
configs = {"mqtt": MqttConf(), "zmq": ZmqConf(), "udp": UdpConf(), "tcp": TcpConf(), "serial": SerialConf()}
|
||||
|
||||
|
||||
for name, config in configs.items():
|
||||
with open(f"./config/heisskleber/{name}.yaml", "w") as file:
|
||||
file.write(f"# Heisskleber config file for {config.__class__.__name__}\n")
|
||||
file.write(yaml.dump(asdict(config)))
|
||||
@@ -1,14 +0,0 @@
|
||||
from heisskleber.udp import UdpConf, UdpSubscriber
|
||||
|
||||
|
||||
def main() -> None:
|
||||
conf = UdpConf(host="192.168.137.1", port=6600)
|
||||
subscriber = UdpSubscriber(conf)
|
||||
|
||||
while True:
|
||||
topic, data = subscriber.receive()
|
||||
print(topic, data)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
38
src/heisskleber/__init__.py
Normal file
38
src/heisskleber/__init__.py
Normal file
@@ -0,0 +1,38 @@
|
||||
"""Heisskleber."""
|
||||
|
||||
from heisskleber.console import ConsoleReceiver, ConsoleSender
|
||||
from heisskleber.core import Receiver, Sender
|
||||
from heisskleber.mqtt import MqttConf, MqttReceiver, MqttSender
|
||||
from heisskleber.serial import SerialConf, SerialReceiver, SerialSender
|
||||
from heisskleber.tcp import TcpConf, TcpReceiver, TcpSender
|
||||
from heisskleber.udp import UdpConf, UdpReceiver, UdpSender
|
||||
from heisskleber.zmq import ZmqConf, ZmqReceiver, ZmqSender
|
||||
|
||||
__all__ = [
|
||||
"Sender",
|
||||
"Receiver",
|
||||
# mqtt
|
||||
"MqttConf",
|
||||
"MqttSender",
|
||||
"MqttReceiver",
|
||||
# zmq
|
||||
"ZmqConf",
|
||||
"ZmqSender",
|
||||
"ZmqReceiver",
|
||||
# udp
|
||||
"UdpConf",
|
||||
"UdpSender",
|
||||
"UdpReceiver",
|
||||
# tcp
|
||||
"TcpConf",
|
||||
"TcpSender",
|
||||
"TcpReceiver",
|
||||
# serial
|
||||
"SerialConf",
|
||||
"SerialSender",
|
||||
"SerialReceiver",
|
||||
# console
|
||||
"ConsoleSender",
|
||||
"ConsoleReceiver",
|
||||
]
|
||||
__version__ = "1.0.0"
|
||||
4
src/heisskleber/console/__init__.py
Normal file
4
src/heisskleber/console/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from heisskleber.console.receiver import ConsoleReceiver
|
||||
from heisskleber.console.sender import ConsoleSender
|
||||
|
||||
__all__ = ["ConsoleReceiver", "ConsoleSender"]
|
||||
46
src/heisskleber/console/receiver.py
Normal file
46
src/heisskleber/console/receiver.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import asyncio
|
||||
import sys
|
||||
from typing import Any, TypeVar
|
||||
|
||||
from heisskleber.core import Receiver, Unpacker, json_unpacker
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class ConsoleReceiver(Receiver[T]):
|
||||
"""Read stdin from console and create data of type T."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
unpacker: Unpacker[T] = json_unpacker, # type: ignore[assignment]
|
||||
) -> None:
|
||||
self.queue: asyncio.Queue[tuple[T, dict[str, Any]]] = asyncio.Queue(maxsize=10)
|
||||
self.unpack = unpacker
|
||||
self.task: asyncio.Task[None] | None = None
|
||||
|
||||
async def _listener_task(self) -> None:
|
||||
while True:
|
||||
payload = sys.stdin.readline().encode() # I know this is stupid, but I adhere to the interface for now
|
||||
data, extra = self.unpack(payload)
|
||||
await self.queue.put((data, extra))
|
||||
|
||||
async def receive(self) -> tuple[T, dict[str, Any]]:
|
||||
"""Receive the next message from the console input."""
|
||||
if not self.task:
|
||||
self.task = asyncio.create_task(self._listener_task())
|
||||
|
||||
data, extra = await self.queue.get()
|
||||
return data, extra
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Return string representation of ConsoleSource."""
|
||||
return f"{self.__class__.__name__}"
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start ConsoleSource."""
|
||||
self.task = asyncio.create_task(self._listener_task())
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop ConsoleSource."""
|
||||
if self.task:
|
||||
self.task.cancel()
|
||||
35
src/heisskleber/console/sender.py
Normal file
35
src/heisskleber/console/sender.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from typing import Any, TypeVar
|
||||
|
||||
from heisskleber.core import Packer, Sender, json_packer
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class ConsoleSender(Sender[T]):
|
||||
"""Send data to console out."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pretty: bool = False,
|
||||
verbose: bool = False,
|
||||
packer: Packer[T] = json_packer, # type: ignore[assignment]
|
||||
) -> None:
|
||||
self.verbose = verbose
|
||||
self.pretty = pretty
|
||||
self.packer = packer
|
||||
|
||||
async def send(self, data: T, topic: str | None = None, **kwargs: dict[str, Any]) -> None:
|
||||
"""Serialize data and write to console output."""
|
||||
serialized = self.packer(data)
|
||||
output = f"{topic}:\t{serialized.decode()}" if topic else serialized.decode()
|
||||
print(output) # noqa: T201
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Return string reprensentation of ConsoleSink."""
|
||||
return f"{self.__class__.__name__}(pretty={self.pretty}, verbose={self.verbose})"
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Not implemented."""
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Not implemented."""
|
||||
23
src/heisskleber/core/__init__.py
Normal file
23
src/heisskleber/core/__init__.py
Normal file
@@ -0,0 +1,23 @@
|
||||
"""Core classes of the heisskleber library."""
|
||||
|
||||
from .config import BaseConf, ConfigType
|
||||
from .packer import JSONPacker, Packer, PackerError
|
||||
from .receiver import Receiver
|
||||
from .sender import Sender
|
||||
from .unpacker import JSONUnpacker, Unpacker, UnpackError
|
||||
|
||||
json_packer = JSONPacker()
|
||||
json_unpacker = JSONUnpacker()
|
||||
|
||||
__all__ = [
|
||||
"Packer",
|
||||
"Unpacker",
|
||||
"Sender",
|
||||
"Receiver",
|
||||
"json_packer",
|
||||
"json_unpacker",
|
||||
"BaseConf",
|
||||
"ConfigType",
|
||||
"PackerError",
|
||||
"UnpackError",
|
||||
]
|
||||
127
src/heisskleber/core/config.py
Normal file
127
src/heisskleber/core/config.py
Normal file
@@ -0,0 +1,127 @@
|
||||
"""Configuration baseclass."""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, fields
|
||||
from pathlib import Path
|
||||
from typing import Any, TextIO, TypeVar, Union
|
||||
|
||||
import yaml # type: ignore[import-untyped]
|
||||
|
||||
logger = logging.getLogger("heisskleber")
|
||||
|
||||
ConfigType = TypeVar(
|
||||
"ConfigType",
|
||||
bound="BaseConf",
|
||||
) # https://stackoverflow.com/a/46227137 , https://docs.python.org/3/library/typing.html#typing.TypeVar
|
||||
|
||||
|
||||
def _parse_yaml(file: TextIO) -> dict[str, Any]:
|
||||
try:
|
||||
return dict(yaml.safe_load(file))
|
||||
except yaml.YAMLError as e:
|
||||
msg = "Failed to parse config file!"
|
||||
logger.exception(msg)
|
||||
raise ValueError(msg) from e
|
||||
|
||||
|
||||
def _parse_json(file: TextIO) -> dict[str, Any]:
|
||||
import json
|
||||
|
||||
try:
|
||||
return dict(json.load(file))
|
||||
except json.JSONDecodeError as e:
|
||||
msg = "Failed to parse config file!"
|
||||
logger.exception(msg)
|
||||
raise ValueError(msg) from e
|
||||
|
||||
|
||||
def _parser(path: Path) -> dict[str, Any]:
|
||||
suffix = path.suffix.lower()
|
||||
|
||||
with path.open() as f:
|
||||
if suffix in [".yaml", ".yml"]:
|
||||
return _parse_yaml(f)
|
||||
if suffix == ".json":
|
||||
return _parse_json(f)
|
||||
msg = f"Unsupported file format {suffix}."
|
||||
logger.exception(msg)
|
||||
raise ValueError
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseConf:
|
||||
"""Default configuration class for generic configuration info."""
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Check if all attributes are the same type as the original defition of the dataclass."""
|
||||
for field in fields(self):
|
||||
value = getattr(self, field.name)
|
||||
if value is None: # Allow optional fields
|
||||
continue
|
||||
if not isinstance(value, field.type): # Failed field comparison
|
||||
raise TypeError
|
||||
if ( # Failed Union comparison
|
||||
hasattr(field.type, "__origin__")
|
||||
and field.type.__origin__ is Union
|
||||
and not any(isinstance(value, t) for t in field.type.__args__)
|
||||
):
|
||||
raise TypeError
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls: type[ConfigType], config_dict: dict[str, Any]) -> ConfigType:
|
||||
"""Create a config instance from a dictionary, including only fields defined in the dataclass.
|
||||
|
||||
Arguments:
|
||||
config_dict: Dictionary containing configuration values.
|
||||
Keys should match dataclass field names.
|
||||
|
||||
Returns:
|
||||
An instance of the configuration class with values from the dictionary.
|
||||
|
||||
Raises:
|
||||
TypeError: If provided values don't match field types.
|
||||
|
||||
Example:
|
||||
>>> from dataclasses import dataclass
|
||||
>>> @dataclass
|
||||
... class ServerConfig(BaseConf):
|
||||
... host: str
|
||||
... port: int
|
||||
... debug: bool = False
|
||||
>>>
|
||||
>>> config = ServerConfig.from_dict({
|
||||
... "host": "localhost",
|
||||
... "port": 8080,
|
||||
... "debug": True,
|
||||
... "invalid_key": "ignored" # Will be filtered out
|
||||
... })
|
||||
>>> config.host
|
||||
'localhost'
|
||||
>>> config.port
|
||||
8080
|
||||
>>> config.debug
|
||||
True
|
||||
>>> hasattr(config, "invalid_key") # Extra keys are ignored
|
||||
False
|
||||
>>>
|
||||
>>> # Type validation
|
||||
>>> try:
|
||||
... ServerConfig.from_dict({"host": "localhost", "port": "8080"}) # Wrong type
|
||||
... except TypeError as e:
|
||||
... print("TypeError raised as expected")
|
||||
TypeError raised as expected
|
||||
|
||||
"""
|
||||
valid_fields = {f.name for f in fields(cls)}
|
||||
filtered_dict = {k: v for k, v in config_dict.items() if k in valid_fields}
|
||||
return cls(**filtered_dict)
|
||||
|
||||
@classmethod
|
||||
def from_file(cls: type[ConfigType], file_path: str | Path) -> ConfigType:
|
||||
"""Create a config instance from a file - accepts yaml or json."""
|
||||
path = Path(file_path)
|
||||
if not path.exists():
|
||||
logger.exception("Config file not found: %(path)s", {"path": path})
|
||||
raise FileNotFoundError
|
||||
|
||||
return cls.from_dict(_parser(path))
|
||||
80
src/heisskleber/core/packer.py
Normal file
80
src/heisskleber/core/packer.py
Normal file
@@ -0,0 +1,80 @@
|
||||
"""Packer and unpacker for network data."""
|
||||
|
||||
import json
|
||||
from abc import abstractmethod
|
||||
from typing import Any, Protocol, TypeVar
|
||||
|
||||
T_contra = TypeVar("T_contra", contravariant=True)
|
||||
|
||||
|
||||
class PackerError(Exception):
|
||||
"""Raised when unpacking operations fail.
|
||||
|
||||
This exception wraps underlying errors that may occur during unpacking,
|
||||
providing a consistent interface for error handling.
|
||||
|
||||
Arguments:
|
||||
data: The data object that caused the PackerError
|
||||
|
||||
"""
|
||||
|
||||
PREVIEW_LENGTH = 100
|
||||
|
||||
def __init__(self, data: Any) -> None:
|
||||
"""Initialize the error with the failed payload and cause."""
|
||||
message = "Failed to pack data."
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class Packer(Protocol[T_contra]):
|
||||
"""Packer Interface.
|
||||
|
||||
This class defines a protocol for packing data.
|
||||
It takes data and converts it into a bytes payload.
|
||||
|
||||
Attributes:
|
||||
None
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def __call__(self, data: T_contra) -> bytes:
|
||||
"""Packs the data dictionary into a bytes payload.
|
||||
|
||||
Arguments:
|
||||
data (T_contra): The input data dictionary to be packed.
|
||||
|
||||
Returns:
|
||||
bytes: The packed payload.
|
||||
|
||||
Raises:
|
||||
PackerError: The data dictionary could not be packed.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class JSONPacker(Packer[dict[str, Any]]):
|
||||
"""Converts a dictionary into JSON-formatted bytes.
|
||||
|
||||
Arguments:
|
||||
data: A dictionary with string keys and arbitrary values to be serialized into JSON format.
|
||||
|
||||
Returns:
|
||||
bytes: The JSON-encoded data as a bytes object.
|
||||
|
||||
Raises:
|
||||
PackerError: If the data cannot be serialized to JSON.
|
||||
|
||||
Example:
|
||||
>>> packer = JSONPacker()
|
||||
>>> result = packer({"key": "value"})
|
||||
b'{"key": "value"}'
|
||||
|
||||
"""
|
||||
|
||||
def __call__(self, data: dict[str, Any]) -> bytes:
|
||||
"""Pack the data."""
|
||||
try:
|
||||
return json.dumps(data).encode()
|
||||
except (UnicodeEncodeError, TypeError) as err:
|
||||
raise PackerError(data) from err
|
||||
101
src/heisskleber/core/receiver.py
Normal file
101
src/heisskleber/core/receiver.py
Normal file
@@ -0,0 +1,101 @@
|
||||
"""Asynchronous data source interface."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import AsyncGenerator
|
||||
from types import TracebackType
|
||||
from typing import Any, Generic
|
||||
|
||||
from .unpacker import T_co, Unpacker
|
||||
|
||||
|
||||
class Receiver(ABC, Generic[T_co]):
|
||||
"""Abstract interface for asynchronous data sources.
|
||||
|
||||
This class defines a protocol for receiving data from various input streams
|
||||
asynchronously. It supports both async iteration and context manager patterns,
|
||||
and ensures proper resource management.
|
||||
|
||||
The source is covariant in its type parameter, allowing for type-safe subtyping
|
||||
relationships.
|
||||
|
||||
Attributes:
|
||||
unpacker: Component responsible for deserializing incoming data into type T_co.
|
||||
|
||||
Example:
|
||||
>>> async with CustomSource(unpacker) as source:
|
||||
... async for data, metadata in source:
|
||||
... print(f"Received: {data}, metadata: {metadata}")
|
||||
|
||||
"""
|
||||
|
||||
unpacker: Unpacker[T_co]
|
||||
|
||||
@abstractmethod
|
||||
async def receive(self) -> tuple[T_co, dict[str, Any]]:
|
||||
"""Receive data from the implemented input stream.
|
||||
|
||||
Returns:
|
||||
tuple[T_co, dict[str, Any]]: A tuple containing:
|
||||
- The received and unpacked data of type T_co
|
||||
- A dictionary of metadata associated with the received data
|
||||
|
||||
Raises:
|
||||
Any implementation-specific exceptions that might occur during receiving.
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def start(self) -> None:
|
||||
"""Initialize and start any background processes and tasks of the source."""
|
||||
|
||||
@abstractmethod
|
||||
async def stop(self) -> None:
|
||||
"""Stop any background processes and tasks."""
|
||||
|
||||
@abstractmethod
|
||||
def __repr__(self) -> str:
|
||||
"""A string reprensatiion of the source."""
|
||||
|
||||
async def __aiter__(self) -> AsyncGenerator[tuple[T_co, dict[str, Any]], None]:
|
||||
"""Implement async iteration over the source's data stream.
|
||||
|
||||
Yields:
|
||||
tuple[T_co, dict[str, Any]]: Each data item and its associated metadata
|
||||
as returned by receive().
|
||||
|
||||
Raises:
|
||||
Any exceptions that might occur during receive().
|
||||
|
||||
"""
|
||||
while True:
|
||||
data, meta = await self.receive()
|
||||
yield data, meta
|
||||
|
||||
async def __aenter__(self) -> "Receiver[T_co]":
|
||||
"""Initialize the source for use in an async context manager.
|
||||
|
||||
Returns:
|
||||
AsyncSource[T_co]: The initialized source instance.
|
||||
|
||||
Raises:
|
||||
Any exceptions that might occur during start().
|
||||
|
||||
"""
|
||||
await self.start()
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_value: BaseException | None,
|
||||
traceback: TracebackType | None,
|
||||
) -> None:
|
||||
"""Cleanup the source when exiting an async context manager.
|
||||
|
||||
Arguments:
|
||||
exc_type: The type of the exception that was raised, if any.
|
||||
exc_value: The instance of the exception that was raised, if any.
|
||||
traceback: The traceback of the exception that was raised, if any.
|
||||
|
||||
"""
|
||||
await self.stop()
|
||||
79
src/heisskleber/core/sender.py
Normal file
79
src/heisskleber/core/sender.py
Normal file
@@ -0,0 +1,79 @@
|
||||
"""Asyncronous data sink interface."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from types import TracebackType
|
||||
from typing import Any, Generic, TypeVar
|
||||
|
||||
from .packer import Packer
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class Sender(ABC, Generic[T]):
|
||||
"""Abstract interface for asynchronous data sinks.
|
||||
|
||||
This class defines a protocol for sending data to various output streams
|
||||
asynchronously. It supports context manager usage and ensures proper
|
||||
resource management.
|
||||
|
||||
Attributes:
|
||||
packer: Component responsible for serializing type T data before sending.
|
||||
|
||||
"""
|
||||
|
||||
packer: Packer[T]
|
||||
|
||||
@abstractmethod
|
||||
async def send(self, data: T, **kwargs: Any) -> None:
|
||||
"""Send data through the implemented output stream.
|
||||
|
||||
Arguments:
|
||||
data: The data to be sent, of type T.
|
||||
**kwargs: Additional implementation-specific arguments.
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def start(self) -> None:
|
||||
"""Initialize and start the sink's background processes and tasks."""
|
||||
|
||||
@abstractmethod
|
||||
async def stop(self) -> None:
|
||||
"""Stop and cleanup the sink's background processes and tasks.
|
||||
|
||||
This method should be called when the sink is no longer needed.
|
||||
It should handle cleanup of any resources initialized in start().
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def __repr__(self) -> str:
|
||||
"""A string representation of the sink."""
|
||||
|
||||
async def __aenter__(self) -> "Sender[T]":
|
||||
"""Initialize the sink for use in an async context manager.
|
||||
|
||||
Returns:
|
||||
AsyncSink[T]: The initialized sink instance.
|
||||
|
||||
Raises:
|
||||
Any exceptions that might occur during start().
|
||||
|
||||
"""
|
||||
await self.start()
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_value: BaseException | None,
|
||||
traceback: TracebackType | None,
|
||||
) -> None:
|
||||
"""Cleanup the sink when exiting an async context manager.
|
||||
|
||||
Arguments:
|
||||
exc_type: The type of the exception that was raised, if any.
|
||||
exc_value: The instance of the exception that was raised, if any.
|
||||
traceback: The traceback of the exception that was raised, if any.
|
||||
|
||||
"""
|
||||
await self.stop()
|
||||
84
src/heisskleber/core/unpacker.py
Normal file
84
src/heisskleber/core/unpacker.py
Normal file
@@ -0,0 +1,84 @@
|
||||
"""Unpacker protocol definition and example implemetation."""
|
||||
|
||||
import json
|
||||
from abc import abstractmethod
|
||||
from typing import Any, Protocol, TypeVar
|
||||
|
||||
T_co = TypeVar("T_co", covariant=True)
|
||||
|
||||
|
||||
class UnpackError(Exception):
|
||||
"""Raised when unpacking operations fail.
|
||||
|
||||
This exception wraps underlying errors that may occur during unpacking,
|
||||
providing a consistent interface for error handling.
|
||||
|
||||
Arguments:
|
||||
payload: The bytes payload that failed to unpack.
|
||||
|
||||
"""
|
||||
|
||||
PREVIEW_LENGTH = 100
|
||||
|
||||
def __init__(self, payload: bytes) -> None:
|
||||
"""Initialize the error with the failed payload and cause."""
|
||||
self.payload = payload
|
||||
preview = payload[: self.PREVIEW_LENGTH] + b"..." if len(payload) > self.PREVIEW_LENGTH else payload
|
||||
message = f"Failed to unpack payload: {preview!r}. "
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class Unpacker(Protocol[T_co]):
|
||||
"""Unpacker Interface.
|
||||
|
||||
This abstract base class defines an interface for unpacking payloads.
|
||||
It takes a payload of bytes, creates a data dictionary and an optional topic,
|
||||
and returns a tuple containing the topic and data.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def __call__(self, payload: bytes) -> tuple[T_co, dict[str, Any]]:
|
||||
"""Unpacks the payload into a data object and optional meta-data dictionary.
|
||||
|
||||
Args:
|
||||
payload (bytes): The input payload to be unpacked.
|
||||
|
||||
Returns:
|
||||
tuple[T, Optional[dict[str, Any]]]: A tuple containing:
|
||||
- T: The data object generated from the input data, e.g. dict or dataclass
|
||||
- dict[str, Any]: The meta data associated with the unpack operation, such as topic, timestamp or errors
|
||||
|
||||
Raises:
|
||||
UnpackError: The payload could not be unpacked.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class JSONUnpacker(Unpacker[dict[str, Any]]):
|
||||
"""Deserializes JSON-formatted bytes into dictionaries.
|
||||
|
||||
Arguments:
|
||||
payload: JSON-formatted bytes to deserialize.
|
||||
|
||||
Returns:
|
||||
tuple[dict[str, Any], dict[str, Any]]: A tuple containing:
|
||||
- The deserialized JSON data as a dictionary
|
||||
- An empty dictionary for metadata (not used in JSON unpacking)
|
||||
|
||||
Raises:
|
||||
UnpackError: If the payload cannot be decoded as valid JSON.
|
||||
|
||||
Example:
|
||||
>>> unpacker = JSONUnpacker()
|
||||
>>> data, metadata = unpacker(b'{"hotglue": "very_nais"}')
|
||||
>>> print(data)
|
||||
{'hotglue': 'very_nais'}
|
||||
|
||||
"""
|
||||
|
||||
def __call__(self, payload: bytes) -> tuple[dict[str, Any], dict[str, Any]]:
|
||||
"""Unpack the payload."""
|
||||
try:
|
||||
return json.loads(payload), {}
|
||||
except json.JSONDecodeError as e:
|
||||
raise UnpackError(payload) from e
|
||||
44
src/heisskleber/core/utils.py
Normal file
44
src/heisskleber/core/utils.py
Normal file
@@ -0,0 +1,44 @@
|
||||
import asyncio
|
||||
from collections.abc import Coroutine
|
||||
from functools import wraps
|
||||
from typing import Any, Callable, ParamSpec, TypeVar
|
||||
|
||||
P = ParamSpec("P")
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def retry(
|
||||
every: int = 5,
|
||||
strategy: str = "always",
|
||||
catch: type[Exception] | tuple[type[Exception], ...] = Exception,
|
||||
logger_fn: Callable[[str, dict[str, Any]], None] | None = None,
|
||||
) -> Callable[[Callable[P, Coroutine[Any, Any, T]]], Callable[P, Coroutine[Any, Any, T]]]:
|
||||
"""Retry a coroutine."""
|
||||
|
||||
def decorator(func: Callable[P, Coroutine[Any, Any, T]]) -> Callable[P, Coroutine[Any, Any, T]]:
|
||||
@wraps(func)
|
||||
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
|
||||
retries = 0
|
||||
while True:
|
||||
try:
|
||||
result = await func(*args, **kwargs)
|
||||
break
|
||||
except catch as e:
|
||||
if logger_fn:
|
||||
logger_fn(
|
||||
"Error occurred: %(err). Retrying in %(seconds) seconds",
|
||||
{"err": e, "seconds": every},
|
||||
)
|
||||
retries += 1
|
||||
await asyncio.sleep(every)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
|
||||
if strategy != "always":
|
||||
raise NotImplementedError
|
||||
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
13
src/heisskleber/mqtt/__init__.py
Normal file
13
src/heisskleber/mqtt/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
"""Async wrappers for mqtt functionality.
|
||||
|
||||
MQTT implementation is achieved via the `aiomqtt`_ package, which is an async wrapper around the `paho-mqtt`_ package.
|
||||
|
||||
.. _aiomqtt: https://github.com/mossblaser/aiomqtt
|
||||
.. _paho-mqtt: https://github.com/eclipse/paho.mqtt.python
|
||||
"""
|
||||
|
||||
from .config import MqttConf
|
||||
from .receiver import MqttReceiver
|
||||
from .sender import MqttSender
|
||||
|
||||
__all__ = ["MqttConf", "MqttReceiver", "MqttSender"]
|
||||
50
src/heisskleber/mqtt/config.py
Normal file
50
src/heisskleber/mqtt/config.py
Normal file
@@ -0,0 +1,50 @@
|
||||
"""Mqtt config."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from aiomqtt import Will
|
||||
|
||||
from heisskleber.core import BaseConf
|
||||
|
||||
|
||||
@dataclass
|
||||
class WillConf(BaseConf):
|
||||
"""MQTT Last Will and Testament message configuration."""
|
||||
|
||||
topic: str
|
||||
payload: str | None = None
|
||||
qos: int = 0
|
||||
retain: bool = False
|
||||
|
||||
def to_aiomqtt_will(self) -> Will:
|
||||
"""Create an aiomqtt style will."""
|
||||
return Will(topic=self.topic, payload=self.payload, qos=self.qos, retain=self.retain, properties=None)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MqttConf(BaseConf):
|
||||
"""MQTT configuration class."""
|
||||
|
||||
# transport
|
||||
host: str = "localhost"
|
||||
port: int = 1883
|
||||
ssl: bool = False
|
||||
|
||||
# mqtt
|
||||
user: str = ""
|
||||
password: str = ""
|
||||
qos: int = 0
|
||||
retain: bool = False
|
||||
max_saved_messages: int = 100
|
||||
timeout: int = 60
|
||||
keep_alive: int = 60
|
||||
will: Will | None = None
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, config_dict: dict[str, Any]) -> "MqttConf":
|
||||
"""Create a MqttConf object from a dictionary."""
|
||||
if "will" in config_dict:
|
||||
config_dict = config_dict.copy()
|
||||
config_dict["will"] = WillConf.from_dict(config_dict["will"]).to_aiomqtt_will()
|
||||
return super().from_dict(config_dict)
|
||||
133
src/heisskleber/mqtt/receiver.py
Normal file
133
src/heisskleber/mqtt/receiver.py
Normal file
@@ -0,0 +1,133 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from asyncio import Queue, Task, create_task
|
||||
from typing import Any, TypeVar
|
||||
|
||||
from aiomqtt import Client, Message, MqttError
|
||||
|
||||
from heisskleber.core import Receiver, Unpacker, json_unpacker
|
||||
from heisskleber.core.utils import retry
|
||||
from heisskleber.mqtt import MqttConf
|
||||
|
||||
T = TypeVar("T")
|
||||
logger = logging.getLogger("heisskleber.mqtt")
|
||||
|
||||
|
||||
class MqttReceiver(Receiver[T]):
|
||||
"""Asynchronous MQTT subscriber based on aiomqtt.
|
||||
|
||||
This class implements an asynchronous MQTT subscriber that handles connection, subscription, and message reception from an MQTT broker. It uses aiomqtt as the underlying MQTT client implementation.
|
||||
|
||||
The subscriber maintains a queue of received messages which can be accessed through the `receive` method.
|
||||
|
||||
Attributes:
|
||||
config (MqttConf): Stored configuration for MQTT connection.
|
||||
topics (Union[str, List[str]]): Topics to subscribe to.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: MqttConf,
|
||||
topic: str | list[str],
|
||||
unpacker: Unpacker[T] = json_unpacker, # type: ignore[assignment]
|
||||
) -> None:
|
||||
"""Initialize the MQTT source.
|
||||
|
||||
Args:
|
||||
config: Configuration object containing:
|
||||
- host (str): MQTT broker hostname
|
||||
- port (int): MQTT broker port
|
||||
- user (str): Username for authentication
|
||||
- password (str): Password for authentication
|
||||
- qos (int): Default Quality of Service level
|
||||
- max_saved_messages (int): Maximum queue size
|
||||
topic: Single topic string or list of topics to subscribe to
|
||||
unpacker: Function to deserialize received messages, defaults to json_unpacker
|
||||
|
||||
"""
|
||||
self.config = config
|
||||
self.topics = topic if isinstance(topic, list) else [topic]
|
||||
self.unpacker = unpacker
|
||||
self._message_queue: Queue[Message] = Queue(self.config.max_saved_messages)
|
||||
self._listener_task: Task[None] | None = None
|
||||
|
||||
async def receive(self) -> tuple[T, dict[str, Any]]:
|
||||
"""Receive and process the next message from the queue.
|
||||
|
||||
Returns:
|
||||
tuple[T, dict[str, Any]]
|
||||
- The unpacked message data
|
||||
- A dictionary with metadata including the message topic
|
||||
|
||||
Raises:
|
||||
TypeError: If the message payload is not of type bytes.
|
||||
UnpackError: If the message could not be unpacked with the unpacker protocol.
|
||||
|
||||
"""
|
||||
if not self._listener_task:
|
||||
await self.start()
|
||||
|
||||
message = await self._message_queue.get()
|
||||
if not isinstance(message.payload, bytes):
|
||||
error_msg = "Payload is not of type bytes."
|
||||
raise TypeError(error_msg)
|
||||
|
||||
data, extra = self.unpacker(message.payload)
|
||||
extra["topic"] = message.topic.value
|
||||
return (data, extra)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Return string representation of Mqtt Source class."""
|
||||
return f"{self.__class__.__name__}(broker={self.config.host}, port={self.config.port})"
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the MQTT listener task."""
|
||||
self._listener_task = create_task(self._run())
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the MQTT listener task."""
|
||||
if not self._listener_task:
|
||||
return
|
||||
|
||||
self._listener_task.cancel()
|
||||
try:
|
||||
await self._listener_task
|
||||
except asyncio.CancelledError:
|
||||
# Raise if the stop task was cancelled
|
||||
# kudos:https://superfastpython.com/asyncio-cancel-task-and-wait/
|
||||
task = asyncio.current_task()
|
||||
if task and task.cancelled():
|
||||
raise
|
||||
self._listener_task = None
|
||||
|
||||
async def subscribe(self, topic: str, qos: int | None = None) -> None:
|
||||
"""Subscribe to an additional MQTT topic.
|
||||
|
||||
Args:
|
||||
topic: The topic to subscribe to
|
||||
qos: Quality of Service level, uses config.qos if None
|
||||
|
||||
"""
|
||||
qos = qos or self.config.qos
|
||||
self.topics.append(topic)
|
||||
await self._client.subscribe(topic, qos)
|
||||
|
||||
@retry(every=1, catch=MqttError, logger_fn=logger.exception)
|
||||
async def _run(self) -> None:
|
||||
"""Background task for MQTT connection."""
|
||||
async with Client(
|
||||
hostname=self.config.host,
|
||||
port=self.config.port,
|
||||
username=self.config.user,
|
||||
password=self.config.password,
|
||||
timeout=self.config.timeout,
|
||||
keepalive=self.config.keep_alive,
|
||||
will=self.config.will,
|
||||
) as client:
|
||||
self._client = client
|
||||
logger.info("subscribing to %(topics)s", {"topics": self.topics})
|
||||
await client.subscribe([(topic, self.config.qos) for topic in self.topics])
|
||||
|
||||
async for message in client.messages:
|
||||
await self._message_queue.put(message)
|
||||
99
src/heisskleber/mqtt/sender.py
Normal file
99
src/heisskleber/mqtt/sender.py
Normal file
@@ -0,0 +1,99 @@
|
||||
"""Async mqtt sink implementation."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from asyncio import CancelledError, create_task
|
||||
from typing import Any, TypeVar
|
||||
|
||||
import aiomqtt
|
||||
|
||||
from heisskleber.core import Packer, Sender, json_packer
|
||||
from heisskleber.core.utils import retry
|
||||
|
||||
from .config import MqttConf
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
logger = logging.getLogger("heisskleber.mqtt")
|
||||
|
||||
|
||||
class MqttSender(Sender[T]):
|
||||
"""MQTT publisher with queued message handling.
|
||||
|
||||
This sink implementation provides asynchronous MQTT publishing capabilities with automatic connection management and message queueing.
|
||||
Network operations are handled in a separate task.
|
||||
|
||||
Attributes:
|
||||
config: MQTT configuration in a dataclass.
|
||||
packer: Callable to pack data from type T to bytes for transport.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, config: MqttConf, packer: Packer[T] = json_packer) -> None: # type: ignore[assignment]
|
||||
self.config = config
|
||||
self.packer = packer
|
||||
self._send_queue: asyncio.Queue[tuple[T, str]] = asyncio.Queue()
|
||||
self._sender_task: asyncio.Task[None] | None = None
|
||||
|
||||
async def send(self, data: T, topic: str = "mqtt", qos: int = 0, retain: bool = False, **kwargs: Any) -> None:
|
||||
"""Queue data for asynchronous publication to the mqtt broker.
|
||||
|
||||
Arguments:
|
||||
data: The data to be published.
|
||||
topic: The mqtt topic to publish to.
|
||||
qos: MQTT QOS level (0, 1, or 2). Defaults to 0.o
|
||||
retain: Whether to set the MQTT retain flag. Defaults to False.
|
||||
**kwargs: Not implemented.
|
||||
|
||||
"""
|
||||
if not self._sender_task:
|
||||
await self.start()
|
||||
|
||||
await self._send_queue.put((data, topic))
|
||||
|
||||
@retry(every=5, catch=aiomqtt.MqttError, logger_fn=logger.exception)
|
||||
async def _send_work(self) -> None:
|
||||
async with aiomqtt.Client(
|
||||
hostname=self.config.host,
|
||||
port=self.config.port,
|
||||
username=self.config.user,
|
||||
password=self.config.password,
|
||||
timeout=float(self.config.timeout),
|
||||
keepalive=self.config.keep_alive,
|
||||
will=self.config.will,
|
||||
) as client:
|
||||
try:
|
||||
while True:
|
||||
data, topic = await self._send_queue.get()
|
||||
payload = self.packer(data)
|
||||
await client.publish(topic=topic, payload=payload)
|
||||
except CancelledError:
|
||||
logger.info("MqttSink background loop cancelled. Emptying queue...")
|
||||
while not self._send_queue.empty():
|
||||
_ = self._send_queue.get_nowait()
|
||||
raise
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Return string representation of the MQTT sink object."""
|
||||
return f"{self.__class__.__name__}(broker={self.config.host}, port={self.config.port})"
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the send queue in a separate task.
|
||||
|
||||
The task will retry connections every 5 seconds on failure.
|
||||
"""
|
||||
self._sender_task = create_task(self._send_work())
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the background task."""
|
||||
if not self._sender_task:
|
||||
return
|
||||
self._sender_task.cancel()
|
||||
try:
|
||||
await self._sender_task
|
||||
except asyncio.CancelledError:
|
||||
# If the stop task was cancelled, we raise.
|
||||
task = asyncio.current_task()
|
||||
if task and task.cancelled():
|
||||
raise
|
||||
self._sender_task = None
|
||||
7
src/heisskleber/serial/__init__.py
Normal file
7
src/heisskleber/serial/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""Asyncronous implementations to read and write to a serial interface."""
|
||||
|
||||
from .config import SerialConf
|
||||
from .receiver import SerialReceiver
|
||||
from .sender import SerialSender
|
||||
|
||||
__all__ = ["SerialConf", "SerialSender", "SerialReceiver"]
|
||||
29
src/heisskleber/serial/config.py
Normal file
29
src/heisskleber/serial/config.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal
|
||||
|
||||
from heisskleber.core.config import BaseConf
|
||||
|
||||
|
||||
@dataclass
|
||||
class SerialConf(BaseConf):
|
||||
"""Serial Config class.
|
||||
|
||||
Attributes:
|
||||
port: The port to connect to. Defaults to /dev/serial0.
|
||||
baudrate: The baudrate of the serial connection. Defaults to 9600.
|
||||
bytesize: The bytesize of the messages. Defaults to 8.
|
||||
encoding: The string encoding of the messages. Defaults to ascii.
|
||||
parity: The parity checking value. One of "N" for none, "E" for even, "O" for odd. Defaults to None.
|
||||
stopbits: Stopbits. One of 1, 2 or 1.5. Defaults to 1.
|
||||
|
||||
Note:
|
||||
stopbits 1.5 is not yet implemented.
|
||||
|
||||
"""
|
||||
|
||||
port: str = "/dev/serial0"
|
||||
baudrate: int = 9600
|
||||
bytesize: int = 8
|
||||
encoding: str = "ascii"
|
||||
parity: Literal["N", "O", "E"] = "N" # definitions from serial.PARTITY_'N'ONE / 'O'DD / 'E'VEN
|
||||
stopbits: Literal[1, 2] = 1 # 1.5 not yet implemented
|
||||
98
src/heisskleber/serial/receiver.py
Normal file
98
src/heisskleber/serial/receiver.py
Normal file
@@ -0,0 +1,98 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any, TypeVar
|
||||
|
||||
import serial # type: ignore[import-untyped]
|
||||
|
||||
from heisskleber.core import Receiver, Unpacker
|
||||
|
||||
from .config import SerialConf
|
||||
|
||||
T = TypeVar("T")
|
||||
logger = logging.getLogger("heisskleber.serial")
|
||||
|
||||
|
||||
class SerialReceiver(Receiver[T]):
|
||||
"""An asynchronous source for reading data from a serial port.
|
||||
|
||||
This class implements the AsyncSource interface for reading data from a serial port.
|
||||
It uses a thread pool executor to perform blocking I/O operations asynchronously.
|
||||
|
||||
Attributes:
|
||||
config: Configuration for the serial port.
|
||||
unpacker: Function to unpack received data.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, config: SerialConf, unpack: Unpacker[T]) -> None:
|
||||
self.config = config
|
||||
self.unpacker = unpack
|
||||
self._loop = asyncio.get_running_loop()
|
||||
self._executor = ThreadPoolExecutor(max_workers=2)
|
||||
self._lock = asyncio.Lock()
|
||||
self._is_connected = False
|
||||
self._cancel_read_timeout = 1
|
||||
|
||||
async def receive(self) -> tuple[T, dict[str, Any]]:
|
||||
"""Receive data from the serial port.
|
||||
|
||||
This method reads a line from the serial port, unpacks it, and returns the data.
|
||||
If the serial port is not connected, it will attempt to connect first.
|
||||
|
||||
Returns:
|
||||
tuple[T, dict[str, Any]]: A tuple containing the unpacked data and any extra information.
|
||||
|
||||
Raises:
|
||||
UnpackError: If the data could not be unpacked with the provided unpacker.
|
||||
|
||||
"""
|
||||
if not self._is_connected:
|
||||
await self.start()
|
||||
|
||||
try:
|
||||
payload = await asyncio.get_running_loop().run_in_executor(self._executor, self._ser.readline, -1)
|
||||
except asyncio.CancelledError:
|
||||
await asyncio.shield(self._cancel_read())
|
||||
raise
|
||||
|
||||
data, extra = self.unpacker(payload=payload)
|
||||
logger.debug(
|
||||
"SerialSource(%(port)s): Unpacked: %(data)s, extra information: %(extra)s",
|
||||
{"port": self.config.port, "data": data, "extra": extra},
|
||||
)
|
||||
return (data, extra)
|
||||
|
||||
async def _cancel_read(self) -> None:
|
||||
if not hasattr(self, "_ser"):
|
||||
return
|
||||
logger.warning(
|
||||
"SerialSource(%(port)s).read() cancelled, waiting for %(timeout)s",
|
||||
{"port": self.config.port, "timeout": self._cancel_read_timeout},
|
||||
)
|
||||
await asyncio.wait_for(
|
||||
asyncio.get_running_loop().run_in_executor(self._executor, self._ser.cancel_read),
|
||||
self._cancel_read_timeout,
|
||||
)
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Open serial device."""
|
||||
if hasattr(self, "_ser"):
|
||||
return
|
||||
|
||||
self._ser = serial.Serial(
|
||||
port=self.config.port,
|
||||
baudrate=self.config.baudrate,
|
||||
bytesize=self.config.bytesize,
|
||||
parity=self.config.parity,
|
||||
stopbits=self.config.stopbits,
|
||||
)
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Close serial connection."""
|
||||
await self._cancel_read()
|
||||
self._ser.close()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Return string representation of Serial Source."""
|
||||
return f"SerialSource({self.config.port}, baudrate={self.config.baudrate})"
|
||||
91
src/heisskleber/serial/sender.py
Normal file
91
src/heisskleber/serial/sender.py
Normal file
@@ -0,0 +1,91 @@
|
||||
"""Asynchronous sink implementation for sending data via serial port."""
|
||||
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any, TypeVar
|
||||
|
||||
import serial # type: ignore[import-untyped]
|
||||
|
||||
from heisskleber.core import Packer, Sender
|
||||
|
||||
from .config import SerialConf
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class SerialSender(Sender[T]):
|
||||
"""An asynchronous sink for writing data to a serial port.
|
||||
|
||||
This class implements the AsyncSink interface for writing data to a serial port.
|
||||
It uses a thread pool executor to perform blocking I/O operations asynchronously.
|
||||
|
||||
Attributes:
|
||||
config: Configuration for the serial port.
|
||||
packer: Function to pack data for sending.
|
||||
"""
|
||||
|
||||
def __init__(self, config: SerialConf, pack: Packer[T]) -> None:
|
||||
"""SerialSink constructor."""
|
||||
self.config = config
|
||||
self.packer = pack
|
||||
self._loop = asyncio.get_running_loop()
|
||||
self._executor = ThreadPoolExecutor(max_workers=2)
|
||||
self._lock = asyncio.Lock()
|
||||
self._is_connected = False
|
||||
self._cancel_write_timeout = 1
|
||||
|
||||
async def send(self, data: T, **kwargs: dict[str, Any]) -> None:
|
||||
"""Send data to the serial port.
|
||||
|
||||
This method packs the data, writes it to the serial port, and then flushes the port.
|
||||
|
||||
Arguments:
|
||||
data: The data to be sent.
|
||||
**kwargs: Not implemented.
|
||||
|
||||
Raises:
|
||||
PackerError: If data could not be packed to bytes with the provided packer.
|
||||
|
||||
Note:
|
||||
If the serial port is not connected, it will implicitly attempt to connect first.
|
||||
|
||||
"""
|
||||
if not self._is_connected:
|
||||
await self.start()
|
||||
|
||||
payload = self.packer(data)
|
||||
try:
|
||||
await asyncio.get_running_loop().run_in_executor(self._executor, self._ser.write, payload)
|
||||
await asyncio.get_running_loop().run_in_executor(self._executor, self._ser.flush)
|
||||
except asyncio.CancelledError:
|
||||
await asyncio.shield(self._cancel_write())
|
||||
raise
|
||||
|
||||
async def _cancel_write(self) -> None:
|
||||
if not hasattr(self, "_ser"):
|
||||
return
|
||||
await asyncio.wait_for(
|
||||
asyncio.get_running_loop().run_in_executor(self._executor, self._ser.cancel_write),
|
||||
self._cancel_write_timeout,
|
||||
)
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Open serial connection."""
|
||||
if hasattr(self, "_ser"):
|
||||
return
|
||||
|
||||
self._ser = serial.Serial(
|
||||
port=self.config.port,
|
||||
baudrate=self.config.baudrate,
|
||||
bytesize=self.config.bytesize,
|
||||
parity=self.config.parity,
|
||||
stopbits=self.config.stopbits,
|
||||
)
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Close serial connection."""
|
||||
self._ser.close()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Return string representation of SerialSink."""
|
||||
return f"SerialSink({self.config.port}, baudrate={self.config.baudrate})"
|
||||
5
src/heisskleber/tcp/__init__.py
Normal file
5
src/heisskleber/tcp/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from .config import TcpConf
|
||||
from .receiver import TcpReceiver
|
||||
from .sender import TcpSender
|
||||
|
||||
__all__ = ["TcpReceiver", "TcpSender", "TcpConf"]
|
||||
22
src/heisskleber/tcp/config.py
Normal file
22
src/heisskleber/tcp/config.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
from heisskleber.core import BaseConf
|
||||
|
||||
|
||||
@dataclass
|
||||
class TcpConf(BaseConf):
|
||||
"""Configuration dataclass for TCP connections."""
|
||||
|
||||
class RestartBehavior(Enum):
|
||||
"""The three types of restart behaviour."""
|
||||
|
||||
NEVER = 0 # Never restart on failure
|
||||
ONCE = 1 # Restart once
|
||||
ALWAYS = 2 # Restart until the connection succeeds
|
||||
|
||||
host: str = "localhost"
|
||||
port: int = 6000
|
||||
timeout: int = 60
|
||||
retry_delay: float = 0.5
|
||||
restart_behavior: RestartBehavior = RestartBehavior.ALWAYS
|
||||
107
src/heisskleber/tcp/receiver.py
Normal file
107
src/heisskleber/tcp/receiver.py
Normal file
@@ -0,0 +1,107 @@
|
||||
"""Async TCP Source - get data from arbitrary TCP server."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any, TypeVar
|
||||
|
||||
from heisskleber.core import Receiver, Unpacker, json_unpacker
|
||||
from heisskleber.tcp.config import TcpConf
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
logger = logging.getLogger("heisskleber.tcp")
|
||||
|
||||
|
||||
class TcpReceiver(Receiver[T]):
|
||||
"""Async TCP connection, connects to host:port and reads byte encoded strings."""
|
||||
|
||||
def __init__(self, config: TcpConf, unpacker: Unpacker[T] = json_unpacker) -> None: # type: ignore [assignment]
|
||||
self.config = config
|
||||
self.unpack = unpacker
|
||||
self.is_connected = False
|
||||
self.timeout = config.timeout
|
||||
self._start_task: asyncio.Task[None] | None = None
|
||||
self.reader: asyncio.StreamReader | None = None
|
||||
self.writer: asyncio.StreamWriter | None = None
|
||||
|
||||
async def receive(self) -> tuple[T, dict[str, Any]]:
|
||||
"""Receive data from a connection.
|
||||
|
||||
Attempt to read data from the connection and handle the process of re-establishing the connection if necessary.
|
||||
|
||||
Returns:
|
||||
tuple[T, dict[str, Any]]
|
||||
- The unpacked message data
|
||||
- A dictionary with metadata including the message topic
|
||||
|
||||
Raises:
|
||||
TypeError: If the message payload is not of type bytes.
|
||||
UnpackError: If the message could not be unpacked with the unpacker protocol.
|
||||
|
||||
"""
|
||||
data = b""
|
||||
retry_delay = self.config.retry_delay
|
||||
while not data:
|
||||
await self._ensure_connected()
|
||||
data = await self.reader.readline() # type: ignore [union-attr]
|
||||
if not data:
|
||||
self.is_connected = False
|
||||
logger.warning(
|
||||
"%(self)s nothing received, retrying connect in %(seconds)s",
|
||||
{"self": self, "seconds": retry_delay},
|
||||
)
|
||||
await asyncio.sleep(retry_delay)
|
||||
retry_delay = min(self.config.timeout, retry_delay * 2)
|
||||
|
||||
payload, extra = self.unpack(data)
|
||||
return payload, extra
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start TcpSource."""
|
||||
await self._connect()
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop TcpSource."""
|
||||
if self.is_connected:
|
||||
logger.info("%(self)s stopping", {"self": self})
|
||||
|
||||
async def _ensure_connected(self) -> None:
|
||||
if self.is_connected:
|
||||
return
|
||||
|
||||
# Not connected, try to (re-)connect
|
||||
if not self._start_task:
|
||||
# Possibly multiple reconnects, so can't just await once
|
||||
self._start_task = asyncio.create_task(self._connect())
|
||||
|
||||
try:
|
||||
await self._start_task
|
||||
finally:
|
||||
self._start_task = None
|
||||
|
||||
async def _connect(self) -> None:
|
||||
logger.info("%(self)s waiting for connection.", {"self": self})
|
||||
|
||||
num_attempts = 0
|
||||
while True:
|
||||
try:
|
||||
self.reader, self.writer = await asyncio.wait_for(
|
||||
asyncio.open_connection(self.config.host, self.config.port),
|
||||
timeout=self.timeout,
|
||||
)
|
||||
logger.info("%(self)s connected successfully!", {"self": self})
|
||||
break
|
||||
except ConnectionRefusedError as e:
|
||||
logger.exception("%(self)s: %(error_type)s", {"self": self, "error_type": type(e).__name__})
|
||||
if self.config.restart_behavior == TcpConf.RestartBehavior.NEVER:
|
||||
raise
|
||||
num_attempts += 1
|
||||
if self.config.restart_behavior == TcpConf.RestartBehavior.ONCE and num_attempts > 1:
|
||||
raise
|
||||
# otherwise retry indefinitely
|
||||
|
||||
self.is_connected = True
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Return string representation of TcpSource."""
|
||||
return f"{self.__class__.__name__}(host={self.config.host}, port={self.config.port})"
|
||||
41
src/heisskleber/tcp/sender.py
Normal file
41
src/heisskleber/tcp/sender.py
Normal file
@@ -0,0 +1,41 @@
|
||||
from typing import Any, TypeVar
|
||||
|
||||
from heisskleber.core import Sender
|
||||
from heisskleber.core.packer import Packer
|
||||
|
||||
from .config import TcpConf
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class TcpSender(Sender[T]):
|
||||
"""Async TCP Sink.
|
||||
|
||||
Attributes:
|
||||
config: The TcpConf configuration object.
|
||||
packer: The packer protocol to serialize data before sending.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, config: TcpConf, packer: Packer[T]) -> None:
|
||||
self.config = config
|
||||
self.packer = packer
|
||||
|
||||
async def send(self, data: T, **kwargs: dict[str, Any]) -> None:
|
||||
"""Send data via tcp connection.
|
||||
|
||||
Arguments:
|
||||
data: The data to be sent.
|
||||
kwargs: Not implemented.
|
||||
|
||||
"""
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Return string representation of TcpSink."""
|
||||
return f"TcpSink({self.config.host}:{self.config.port})"
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start TcpSink."""
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop TcpSink."""
|
||||
5
src/heisskleber/udp/__init__.py
Normal file
5
src/heisskleber/udp/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from .config import UdpConf
|
||||
from .receiver import UdpReceiver
|
||||
from .sender import UdpSender
|
||||
|
||||
__all__ = ["UdpReceiver", "UdpSender", "UdpConf"]
|
||||
@@ -1,17 +1,14 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from heisskleber.config import BaseConf
|
||||
from heisskleber.core import BaseConf
|
||||
|
||||
|
||||
@dataclass
|
||||
class UdpConf(BaseConf):
|
||||
"""
|
||||
UDP configuration.
|
||||
"""
|
||||
"""UDP configuration."""
|
||||
|
||||
port: int = 1234
|
||||
host: str = "127.0.0.1"
|
||||
packer: str = "json"
|
||||
max_queue_size: int = 1000
|
||||
encoding: str = "utf-8"
|
||||
delimiter: str = "\r\n"
|
||||
86
src/heisskleber/udp/receiver.py
Normal file
86
src/heisskleber/udp/receiver.py
Normal file
@@ -0,0 +1,86 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any, TypeVar
|
||||
|
||||
from heisskleber.core import Receiver, Unpacker, json_unpacker
|
||||
from heisskleber.udp.config import UdpConf
|
||||
|
||||
logger = logging.getLogger("heisskleber.udp")
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class UdpProtocol(asyncio.DatagramProtocol):
|
||||
"""Protocol for udp connection.
|
||||
|
||||
Arguments:
|
||||
queue: The asyncioQueue to put messages into.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, queue: asyncio.Queue[bytes]) -> None:
|
||||
super().__init__()
|
||||
self.queue = queue
|
||||
|
||||
def datagram_received(self, data: bytes, addr: tuple[str | Any, int]) -> None:
|
||||
"""Handle received udp message."""
|
||||
self.queue.put_nowait(data)
|
||||
|
||||
def connection_made(self, transport: asyncio.DatagramTransport) -> None: # type: ignore[override]
|
||||
"""Log successful connection."""
|
||||
logger.info("UdpSource: Connection made")
|
||||
|
||||
|
||||
class UdpReceiver(Receiver[T]):
|
||||
"""An asynchronous UDP subscriber based on asyncio.protocols.DatagramProtocol."""
|
||||
|
||||
def __init__(self, config: UdpConf, unpacker: Unpacker[T] = json_unpacker) -> None: # type: ignore[assignment]
|
||||
self.config = config
|
||||
self.EOF = self.config.delimiter.encode(self.config.encoding)
|
||||
self.unpacker = unpacker
|
||||
self._queue: asyncio.Queue[bytes] = asyncio.Queue(maxsize=self.config.max_queue_size)
|
||||
self._task: asyncio.Task[None] | None = None
|
||||
self._is_connected = False
|
||||
self._transport: asyncio.DatagramTransport | None = None
|
||||
self._protocol: asyncio.DatagramProtocol | None = None
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start udp connection."""
|
||||
loop = asyncio.get_event_loop()
|
||||
self._transport, self._protocol = await loop.create_datagram_endpoint(
|
||||
lambda: UdpProtocol(self._queue),
|
||||
local_addr=(self.config.host, self.config.port),
|
||||
)
|
||||
self._is_connected = True
|
||||
logger.info("Udp connection established.")
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the udp connection."""
|
||||
if self._transport is not None:
|
||||
self._transport.close()
|
||||
self._transport = None
|
||||
self._is_connected = False
|
||||
|
||||
async def receive(self) -> tuple[T, dict[str, Any]]:
|
||||
"""Get the next message from the udp connection.
|
||||
|
||||
Returns:
|
||||
tuple[T, dict[str, Any]]
|
||||
- The data as returned by the unpacker.
|
||||
- A dictionary containing extra information.
|
||||
|
||||
Raises:
|
||||
UnpackError: If the received message could not be unpacked.
|
||||
|
||||
"""
|
||||
if not self._is_connected:
|
||||
await self.start()
|
||||
|
||||
while True:
|
||||
data = None
|
||||
data = await self._queue.get()
|
||||
payload, extra = self.unpacker(data)
|
||||
return (payload, extra)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Return string representation of UdpSource."""
|
||||
return f"{self.__class__.__name__}(host={self.config.host}, port={self.config.port})"
|
||||
88
src/heisskleber/udp/sender.py
Normal file
88
src/heisskleber/udp/sender.py
Normal file
@@ -0,0 +1,88 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any, TypeVar
|
||||
|
||||
from heisskleber.core import Packer, Sender, json_packer
|
||||
from heisskleber.udp.config import UdpConf
|
||||
|
||||
logger = logging.getLogger("heisskleber.udp")
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class UdpProtocol(asyncio.DatagramProtocol):
|
||||
"""UDP protocol handler that tracks connection state.
|
||||
|
||||
Arguments:
|
||||
is_connected: Flag tracking if protocol is connected
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, is_connected: bool) -> None:
|
||||
super().__init__()
|
||||
self.is_connected = is_connected
|
||||
|
||||
def connection_lost(self, exc: Exception | None) -> None:
|
||||
"""Update state and log a lost connection."""
|
||||
logger.info("UDP Connection lost")
|
||||
self.is_connected = False
|
||||
|
||||
|
||||
class UdpSender(Sender[T]):
|
||||
"""UDP sink for sending data via UDP protocol.
|
||||
|
||||
Arguments:
|
||||
config: UDP configuration parameters
|
||||
packer: Function to serialize data, defaults to JSON packing
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, config: UdpConf, packer: Packer[T] = json_packer) -> None: # type: ignore[assignment]
|
||||
self.config = config
|
||||
self.pack = packer
|
||||
self.is_connected = False
|
||||
self._transport: asyncio.DatagramTransport | None = None
|
||||
self._protocol: UdpProtocol | None = None
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Connect the UdpSink."""
|
||||
await self._ensure_connection()
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Disconnect the UdpSink connection."""
|
||||
if self._transport is not None:
|
||||
self._transport.close()
|
||||
self.is_connected = False
|
||||
self._transport = None
|
||||
self._protocol = None
|
||||
|
||||
async def _ensure_connection(self) -> None:
|
||||
"""Create UDP endpoint if not connected.
|
||||
|
||||
Creates datagram endpoint using protocol handler if no connection exists.
|
||||
Updates connected state on successful connection.
|
||||
|
||||
"""
|
||||
if not self.is_connected or self._transport is None:
|
||||
loop = asyncio.get_running_loop()
|
||||
self._transport, _ = await loop.create_datagram_endpoint(
|
||||
lambda: UdpProtocol(self.is_connected),
|
||||
remote_addr=(self.config.host, self.config.port),
|
||||
)
|
||||
self.is_connected = True
|
||||
|
||||
async def send(self, data: T, **kwargs: dict[str, Any]) -> None:
|
||||
"""Send data over UDP connection.
|
||||
|
||||
Arguments:
|
||||
data: Data to send
|
||||
**kwargs: Additional arguments passed to send
|
||||
|
||||
"""
|
||||
await self._ensure_connection() # we know that self._transport is intialized
|
||||
payload = self.pack(data)
|
||||
self._transport.sendto(payload) # type: ignore [union-attr]
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Return string representation of UdpSink."""
|
||||
return f"{self.__class__.__name__}(host={self.config.host}, port={self.config.port})"
|
||||
5
src/heisskleber/zmq/__init__.py
Normal file
5
src/heisskleber/zmq/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from .config import ZmqConf
|
||||
from .receiver import ZmqReceiver
|
||||
from .sender import ZmqSender
|
||||
|
||||
__all__ = ["ZmqConf", "ZmqSender", "ZmqReceiver"]
|
||||
@@ -1,10 +1,12 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from heisskleber.config import BaseConf
|
||||
from heisskleber.core import BaseConf
|
||||
|
||||
|
||||
@dataclass
|
||||
class ZmqConf(BaseConf):
|
||||
"""ZMQ Configuration file."""
|
||||
|
||||
protocol: str = "tcp"
|
||||
host: str = "127.0.0.1"
|
||||
publisher_port: int = 5555
|
||||
@@ -13,8 +15,10 @@ class ZmqConf(BaseConf):
|
||||
|
||||
@property
|
||||
def publisher_address(self) -> str:
|
||||
"""Return the full url to connect to the publisher port."""
|
||||
return f"{self.protocol}://{self.host}:{self.publisher_port}"
|
||||
|
||||
@property
|
||||
def subscriber_address(self) -> str:
|
||||
"""Return the full url to connect to the subscriber port."""
|
||||
return f"{self.protocol}://{self.host}:{self.subscriber_port}"
|
||||
85
src/heisskleber/zmq/receiver.py
Normal file
85
src/heisskleber/zmq/receiver.py
Normal file
@@ -0,0 +1,85 @@
|
||||
import logging
|
||||
from typing import Any, TypeVar
|
||||
|
||||
import zmq
|
||||
import zmq.asyncio
|
||||
|
||||
from heisskleber.core import Receiver, Unpacker, json_unpacker
|
||||
from heisskleber.zmq.config import ZmqConf
|
||||
|
||||
logger = logging.getLogger("heisskleber.zmq")
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class ZmqReceiver(Receiver[T]):
|
||||
"""Async source that subscribes to one or many topics from a zmq broker and receives messages via the receive() function.
|
||||
|
||||
Attributes:
|
||||
config: The ZmqConf configuration object for the connection.
|
||||
unpacker : The unpacker function to use for deserializing the data.
|
||||
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, config: ZmqConf, topic: str | list[str], unpacker: Unpacker[T] = json_unpacker) -> None: # type: ignore [assignment]
|
||||
self.config = config
|
||||
self.topic = topic
|
||||
self.context = zmq.asyncio.Context.instance()
|
||||
self.socket: zmq.asyncio.Socket = self.context.socket(zmq.SUB)
|
||||
self.unpack = unpacker
|
||||
self.is_connected = False
|
||||
|
||||
async def receive(self) -> tuple[T, dict[str, Any]]:
|
||||
"""Read a message from the zmq bus and return it.
|
||||
|
||||
Returns:
|
||||
tuple(topic: str, message: dict): the message received
|
||||
|
||||
Raises:
|
||||
UnpackError: If payload could not be unpacked with provided unpacker.
|
||||
|
||||
"""
|
||||
if not self.is_connected:
|
||||
await self.start()
|
||||
(topic, payload) = await self.socket.recv_multipart()
|
||||
data, extra = self.unpack(payload)
|
||||
extra["topic"] = topic.decode()
|
||||
return data, extra
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Connect to the zmq socket."""
|
||||
try:
|
||||
self.socket.connect(self.config.subscriber_address)
|
||||
except Exception:
|
||||
logger.exception("Failed to bind to zeromq socket")
|
||||
else:
|
||||
self.is_connected = True
|
||||
self.subscribe(self.topic)
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Close the zmq socket."""
|
||||
self.socket.close()
|
||||
self.is_connected = False
|
||||
|
||||
def subscribe(self, topic: str | list[str] | tuple[str]) -> None:
|
||||
"""Subscribe to the given topic(s) on the zmq socket.
|
||||
|
||||
Arguments:
|
||||
---------
|
||||
topic: The topic or list of topics to subscribe to.
|
||||
|
||||
"""
|
||||
if isinstance(topic, (list, tuple)):
|
||||
for t in topic:
|
||||
self._subscribe_single_topic(t)
|
||||
else:
|
||||
self._subscribe_single_topic(topic)
|
||||
|
||||
def _subscribe_single_topic(self, topic: str) -> None:
|
||||
self.socket.setsockopt(zmq.SUBSCRIBE, topic.encode())
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Return string representation of ZmqSource."""
|
||||
return f"{self.__class__.__name__}(host={self.config.subscriber_address}, port={self.config.subscriber_port})"
|
||||
54
src/heisskleber/zmq/sender.py
Normal file
54
src/heisskleber/zmq/sender.py
Normal file
@@ -0,0 +1,54 @@
|
||||
import logging
|
||||
from typing import Any, TypeVar
|
||||
|
||||
import zmq
|
||||
import zmq.asyncio
|
||||
|
||||
from heisskleber.core import Packer, Sender, json_packer
|
||||
|
||||
from .config import ZmqConf
|
||||
|
||||
logger = logging.getLogger("heisskleber.zmq")
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class ZmqSender(Sender[T]):
|
||||
"""Async publisher that sends messages to a ZMQ PUB socket.
|
||||
|
||||
Attributes:
|
||||
config: The ZmqConf configuration object for the connection.
|
||||
packer : The packer strategy to use for serializing the data.
|
||||
Defaults to json packer with utf-8 encoding.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, config: ZmqConf, packer: Packer[T] = json_packer) -> None: # type: ignore[assignment]
|
||||
self.config = config
|
||||
self.context = zmq.asyncio.Context.instance()
|
||||
self.socket: zmq.asyncio.Socket = self.context.socket(zmq.PUB)
|
||||
self.packer = packer
|
||||
self.is_connected = False
|
||||
|
||||
async def send(self, data: T, topic: str = "zmq", **kwargs: Any) -> None:
|
||||
"""Take the data as a dict, serialize it with the given packer and send it to the zmq socket."""
|
||||
if not self.is_connected:
|
||||
await self.start()
|
||||
payload = self.packer(data)
|
||||
logger.debug("sending payload %(payload)b to topic %(topic)s", {"payload": payload, "topic": topic})
|
||||
await self.socket.send_multipart([topic.encode(), payload])
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Connect to the zmq socket."""
|
||||
logger.info("Connecting to %(addr)s", {"addr": self.config.publisher_address})
|
||||
self.socket.connect(self.config.publisher_address)
|
||||
self.is_connected = True
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Close the zmq socket."""
|
||||
self.socket.close()
|
||||
self.is_connected = False
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Return string representation of ZmqSink."""
|
||||
return f"{self.__class__.__name__}(host={self.config.publisher_address}, port={self.config.publisher_port})"
|
||||
@@ -1,2 +0,0 @@
|
||||
verbose: True
|
||||
print_stdout: False
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user