Handle payload based on type in all senders.

This commit is contained in:
Felix Weiler-Detjen
2025-01-23 16:28:10 +00:00
parent 1dd91517a7
commit 8bd332d27c
11 changed files with 70 additions and 24 deletions

View File

@@ -1,6 +1,7 @@
import json
from typing import Any, TypeVar
from heisskleber.core import Packer, Sender, json_packer
from heisskleber.core import Packer, Sender
from .config import ConsoleConf
@@ -13,7 +14,7 @@ class ConsoleSender(Sender[T]):
def __init__(
self,
config: ConsoleConf,
packer: Packer[T] = json_packer, # type: ignore[assignment]
packer: Packer[T] = lambda x: json.dumps(x), # type: ignore[assignment]
) -> None:
self.verbose = config.verbose
self.pretty = config.pretty
@@ -22,7 +23,8 @@ class ConsoleSender(Sender[T]):
async def send(self, data: T, topic: str | None = None, **kwargs: dict[str, Any]) -> None:
"""Serialize data and write to console output."""
serialized = self.packer(data)
output = f"{topic}:\t{serialized.decode()}" if topic else serialized.decode()
serialized = serialized.decode() if isinstance(serialized, bytes | bytearray) else serialized
output = f"{topic}:\t{serialized}" if topic else serialized
print(output) # noqa: T201
def __repr__(self) -> str:

View File

@@ -3,7 +3,7 @@
from typing import Any
from .config import BaseConf, ConfigType
from .packer import JSONPacker, Packer, PackerError
from .packer import JSONPacker, Packer, PackerError, Payload
from .receiver import Receiver
from .sender import Sender
from .unpacker import JSONUnpacker, Unpacker, UnpackerError
@@ -29,6 +29,7 @@ __all__ = [
"ConfigType",
"Packer",
"PackerError",
"Payload",
"Receiver",
"Sender",
"Unpacker",

View File

@@ -2,10 +2,12 @@
import json
from abc import abstractmethod
from typing import Any, Protocol, TypeVar
from typing import Any, Protocol, TypeAlias, TypeVar
T_contra = TypeVar("T_contra", contravariant=True)
Payload: TypeAlias = str | bytes | bytearray
class PackerError(Exception):
"""Raised when unpacking operations fail.
@@ -38,7 +40,7 @@ class Packer(Protocol[T_contra]):
"""
@abstractmethod
def __call__(self, data: T_contra) -> bytes:
def __call__(self, data: T_contra) -> Payload:
"""Packs the data dictionary into a bytes payload.
Arguments:

View File

@@ -6,6 +6,8 @@ from typing import Any, Protocol, TypeVar
T_co = TypeVar("T_co", covariant=True)
Payload = str | bytes | bytearray
class UnpackerError(Exception):
"""Raised when unpacking operations fail.
@@ -20,10 +22,11 @@ class UnpackerError(Exception):
PREVIEW_LENGTH = 100
def __init__(self, payload: bytes) -> None:
def __init__(self, payload: Payload) -> None:
"""Initialize the error with the failed payload and cause."""
self.payload = payload
preview = payload[: self.PREVIEW_LENGTH] + b"..." if len(payload) > self.PREVIEW_LENGTH else payload
dots = b"..." if isinstance(payload, bytes | bytearray) else "..."
preview = payload[: self.PREVIEW_LENGTH] + dots if len(payload) > self.PREVIEW_LENGTH else payload
message = f"Failed to unpack payload: {preview!r}. "
super().__init__(message)
@@ -37,7 +40,7 @@ class Unpacker(Protocol[T_co]):
"""
@abstractmethod
def __call__(self, payload: bytes) -> tuple[T_co, dict[str, Any]]:
def __call__(self, payload: Payload) -> tuple[T_co, dict[str, Any]]:
"""Unpacks the payload into a data object and optional meta-data dictionary.
Args:
@@ -76,7 +79,7 @@ class JSONUnpacker(Unpacker[dict[str, Any]]):
"""
def __call__(self, payload: bytes) -> tuple[dict[str, Any], dict[str, Any]]:
def __call__(self, payload: Payload) -> tuple[dict[str, Any], dict[str, Any]]:
"""Unpack the payload."""
try:
return json.loads(payload), {}

View File

@@ -1,7 +1,7 @@
import asyncio
from collections.abc import Coroutine
from collections.abc import Callable, Coroutine
from functools import wraps
from typing import Any, Callable, ParamSpec, TypeVar
from typing import Any, ParamSpec, TypeVar
P = ParamSpec("P")
T = TypeVar("T")

View File

@@ -1,13 +1,15 @@
import asyncio
import contextlib
import json
import logging
from collections.abc import Callable
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from io import BufferedWriter
from io import TextIOWrapper
from pathlib import Path
from typing import Any, TypeVar
from heisskleber.core import Packer, Sender, json_packer
from heisskleber.core import Packer, PackerError, Sender
from heisskleber.file.config import FileConf
T = TypeVar("T")
@@ -15,6 +17,19 @@ T = TypeVar("T")
logger = logging.getLogger("heisskleber.file")
def json_packer(data: dict[str, Any]) -> str:
"""Pack to json string."""
try:
return json.dumps(data)
except json.JSONDecodeError:
raise PackerError(data) from None
def csv_packer(data: dict[str, Any]) -> str:
"""Create csv string from data."""
return ",".join(map(str, data.values()))
class FileWriter(Sender[T]):
"""Asynchronous file writer implementation of the Sender interface.
@@ -22,34 +37,48 @@ class FileWriter(Sender[T]):
Files are named according to the configured datetime format.
"""
def __init__(self, config: FileConf, packer: Packer[T] = json_packer) -> None: # type: ignore[assignment]
def __init__(
self,
config: FileConf,
packer: Packer[T] = json_packer, # type: ignore[assignment]
header_func: Callable[[T], list[str]] | None = None,
) -> None:
"""Initialize the file writer.
Args:
base_path: Directory path where files will be written
config: Configuration for file rollover and naming
header_func: Function to extract header from T
packer: Optional packer for serializing data
"""
self.base_path = Path(config.directory)
self.config = config
self.header_func = header_func
self.packer = packer
self._executor = ThreadPoolExecutor(max_workers=1)
self._loop = asyncio.get_running_loop()
self._header: list[str] | None = None
self._current_file: BufferedWriter | None = None
self._rollover_task: asyncio.Task | None = None
self._current_file: TextIOWrapper | None = None
self._rollover_task: asyncio.Task[None] | None = None
self._last_rollover: float = 0
self.filename: Path = Path()
async def _open_file(self, filename: Path) -> BufferedWriter:
async def _open_file(self, filename: Path) -> TextIOWrapper:
"""Open file asynchronously."""
return await self._loop.run_in_executor(self._executor, lambda: filename.open(mode="ab"))
return await self._loop.run_in_executor(self._executor, lambda: filename.open(mode="a"))
async def _close_file(self) -> None:
if self._current_file is not None:
await self._loop.run_in_executor(self._executor, self._current_file.close)
async def _write_header(self) -> None:
if not self._header or not self._current_file:
return
for line in self._header:
await self._loop.run_in_executor(self._executor, self._current_file.write, line)
await self._loop.run_in_executor(self._executor, self._current_file.write, "\n")
async def _rollover(self) -> None:
"""Close current file and open a new one."""
if self._current_file is not None:
@@ -60,6 +89,7 @@ class FileWriter(Sender[T]):
self._current_file = await self._open_file(self.filename)
self._last_rollover = self._loop.time()
logger.info("Rolled over to new file: %s", self.filename)
await self._write_header()
async def _rollover_loop(self) -> None:
"""Background task that handles periodic file rollover."""
@@ -83,10 +113,15 @@ class FileWriter(Sender[T]):
await self.start()
if not self._current_file:
raise RuntimeError("FileWriter not started")
if not self._header and self.header_func is not None:
self._header = self.header_func(data)
await self._write_header()
payload = self.packer(data)
if isinstance(payload, bytes | bytearray):
payload = payload.decode()
await self._loop.run_in_executor(self._executor, self._current_file.write, payload)
await self._loop.run_in_executor(self._executor, self._current_file.write, b"\n")
await self._loop.run_in_executor(self._executor, self._current_file.write, "\n")
async def start(self) -> None:
"""Start the file writer and rollover background task."""

View File

@@ -8,7 +8,7 @@ from typing import Any, TypeVar
import aiomqtt
from heisskleber.core import Packer, Sender, json_packer
from heisskleber.core import Packer, Payload, Sender, json_packer
from heisskleber.core.utils import retry
from .config import MqttConf
@@ -34,7 +34,7 @@ class MqttSender(Sender[T]):
def __init__(self, config: MqttConf, packer: Packer[T] = json_packer) -> None: # type: ignore[assignment]
self.config = config
self.packer = packer
self._send_queue: asyncio.Queue[tuple[bytes, str]] = asyncio.Queue()
self._send_queue: asyncio.Queue[tuple[Payload, str]] = asyncio.Queue()
self._sender_task: asyncio.Task[None] | None = None
async def send(self, data: T, topic: str = "mqtt", qos: int = 0, retain: bool = False, **kwargs: Any) -> None:

View File

@@ -54,6 +54,7 @@ class SerialSender(Sender[T]):
await self.start()
payload = self.packer(data)
payload = payload.encode() if isinstance(payload, str) else payload
try:
await asyncio.get_running_loop().run_in_executor(self._executor, self._ser.write, payload)
await asyncio.get_running_loop().run_in_executor(self._executor, self._ser.flush)

View File

@@ -81,6 +81,7 @@ class UdpSender(Sender[T]):
"""
await self._ensure_connection() # we know that self._transport is intialized
payload = self.pack(data)
payload = payload.encode() if isinstance(payload, str) else payload
self._transport.sendto(payload) # type: ignore [union-attr]
def __repr__(self) -> str:

View File

@@ -71,7 +71,7 @@ class ZmqReceiver(Receiver[T]):
topic: The topic or list of topics to subscribe to.
"""
if isinstance(topic, (list, tuple)):
if isinstance(topic, list | tuple):
for t in topic:
self._subscribe_single_topic(t)
else:

View File

@@ -35,6 +35,7 @@ class ZmqSender(Sender[T]):
if not self.is_connected:
await self.start()
payload = self.packer(data)
payload = payload.encode() if isinstance(payload, str) else payload
logger.debug("sending payload %(payload)b to topic %(topic)s", {"payload": payload, "topic": topic})
await self.socket.send_multipart([topic.encode(), payload])