diff --git a/src/heisskleber/console/sender.py b/src/heisskleber/console/sender.py index da685bd..90f4796 100644 --- a/src/heisskleber/console/sender.py +++ b/src/heisskleber/console/sender.py @@ -1,6 +1,7 @@ +import json from typing import Any, TypeVar -from heisskleber.core import Packer, Sender, json_packer +from heisskleber.core import Packer, Sender from .config import ConsoleConf @@ -13,7 +14,7 @@ class ConsoleSender(Sender[T]): def __init__( self, config: ConsoleConf, - packer: Packer[T] = json_packer, # type: ignore[assignment] + packer: Packer[T] = lambda x: json.dumps(x), # type: ignore[assignment] ) -> None: self.verbose = config.verbose self.pretty = config.pretty @@ -22,7 +23,8 @@ class ConsoleSender(Sender[T]): async def send(self, data: T, topic: str | None = None, **kwargs: dict[str, Any]) -> None: """Serialize data and write to console output.""" serialized = self.packer(data) - output = f"{topic}:\t{serialized.decode()}" if topic else serialized.decode() + serialized = serialized.decode() if isinstance(serialized, bytes | bytearray) else serialized + output = f"{topic}:\t{serialized}" if topic else serialized print(output) # noqa: T201 def __repr__(self) -> str: diff --git a/src/heisskleber/core/__init__.py b/src/heisskleber/core/__init__.py index c678aec..6a3f1b5 100644 --- a/src/heisskleber/core/__init__.py +++ b/src/heisskleber/core/__init__.py @@ -3,7 +3,7 @@ from typing import Any from .config import BaseConf, ConfigType -from .packer import JSONPacker, Packer, PackerError +from .packer import JSONPacker, Packer, PackerError, Payload from .receiver import Receiver from .sender import Sender from .unpacker import JSONUnpacker, Unpacker, UnpackerError @@ -29,6 +29,7 @@ __all__ = [ "ConfigType", "Packer", "PackerError", + "Payload", "Receiver", "Sender", "Unpacker", diff --git a/src/heisskleber/core/packer.py b/src/heisskleber/core/packer.py index 3ea010e..9baaca4 100644 --- a/src/heisskleber/core/packer.py +++ b/src/heisskleber/core/packer.py @@ -2,10 +2,12 @@ import json from abc import abstractmethod -from typing import Any, Protocol, TypeVar +from typing import Any, Protocol, TypeAlias, TypeVar T_contra = TypeVar("T_contra", contravariant=True) +Payload: TypeAlias = str | bytes | bytearray + class PackerError(Exception): """Raised when unpacking operations fail. @@ -38,7 +40,7 @@ class Packer(Protocol[T_contra]): """ @abstractmethod - def __call__(self, data: T_contra) -> bytes: + def __call__(self, data: T_contra) -> Payload: """Packs the data dictionary into a bytes payload. Arguments: diff --git a/src/heisskleber/core/unpacker.py b/src/heisskleber/core/unpacker.py index 5fd305a..df94803 100644 --- a/src/heisskleber/core/unpacker.py +++ b/src/heisskleber/core/unpacker.py @@ -6,6 +6,8 @@ from typing import Any, Protocol, TypeVar T_co = TypeVar("T_co", covariant=True) +Payload = str | bytes | bytearray + class UnpackerError(Exception): """Raised when unpacking operations fail. @@ -20,10 +22,11 @@ class UnpackerError(Exception): PREVIEW_LENGTH = 100 - def __init__(self, payload: bytes) -> None: + def __init__(self, payload: Payload) -> None: """Initialize the error with the failed payload and cause.""" self.payload = payload - preview = payload[: self.PREVIEW_LENGTH] + b"..." if len(payload) > self.PREVIEW_LENGTH else payload + dots = b"..." if isinstance(payload, bytes | bytearray) else "..." + preview = payload[: self.PREVIEW_LENGTH] + dots if len(payload) > self.PREVIEW_LENGTH else payload message = f"Failed to unpack payload: {preview!r}. " super().__init__(message) @@ -37,7 +40,7 @@ class Unpacker(Protocol[T_co]): """ @abstractmethod - def __call__(self, payload: bytes) -> tuple[T_co, dict[str, Any]]: + def __call__(self, payload: Payload) -> tuple[T_co, dict[str, Any]]: """Unpacks the payload into a data object and optional meta-data dictionary. Args: @@ -76,7 +79,7 @@ class JSONUnpacker(Unpacker[dict[str, Any]]): """ - def __call__(self, payload: bytes) -> tuple[dict[str, Any], dict[str, Any]]: + def __call__(self, payload: Payload) -> tuple[dict[str, Any], dict[str, Any]]: """Unpack the payload.""" try: return json.loads(payload), {} diff --git a/src/heisskleber/core/utils.py b/src/heisskleber/core/utils.py index c42418b..61b0c08 100644 --- a/src/heisskleber/core/utils.py +++ b/src/heisskleber/core/utils.py @@ -1,7 +1,7 @@ import asyncio -from collections.abc import Coroutine +from collections.abc import Callable, Coroutine from functools import wraps -from typing import Any, Callable, ParamSpec, TypeVar +from typing import Any, ParamSpec, TypeVar P = ParamSpec("P") T = TypeVar("T") diff --git a/src/heisskleber/file/sender.py b/src/heisskleber/file/sender.py index 6b6e728..29224be 100644 --- a/src/heisskleber/file/sender.py +++ b/src/heisskleber/file/sender.py @@ -1,13 +1,15 @@ import asyncio import contextlib +import json import logging +from collections.abc import Callable from concurrent.futures import ThreadPoolExecutor from datetime import datetime -from io import BufferedWriter +from io import TextIOWrapper from pathlib import Path from typing import Any, TypeVar -from heisskleber.core import Packer, Sender, json_packer +from heisskleber.core import Packer, PackerError, Sender from heisskleber.file.config import FileConf T = TypeVar("T") @@ -15,6 +17,19 @@ T = TypeVar("T") logger = logging.getLogger("heisskleber.file") +def json_packer(data: dict[str, Any]) -> str: + """Pack to json string.""" + try: + return json.dumps(data) + except json.JSONDecodeError: + raise PackerError(data) from None + + +def csv_packer(data: dict[str, Any]) -> str: + """Create csv string from data.""" + return ",".join(map(str, data.values())) + + class FileWriter(Sender[T]): """Asynchronous file writer implementation of the Sender interface. @@ -22,34 +37,48 @@ class FileWriter(Sender[T]): Files are named according to the configured datetime format. """ - def __init__(self, config: FileConf, packer: Packer[T] = json_packer) -> None: # type: ignore[assignment] + def __init__( + self, + config: FileConf, + packer: Packer[T] = json_packer, # type: ignore[assignment] + header_func: Callable[[T], list[str]] | None = None, + ) -> None: """Initialize the file writer. Args: - base_path: Directory path where files will be written config: Configuration for file rollover and naming + header_func: Function to extract header from T packer: Optional packer for serializing data """ self.base_path = Path(config.directory) self.config = config + self.header_func = header_func self.packer = packer self._executor = ThreadPoolExecutor(max_workers=1) self._loop = asyncio.get_running_loop() + self._header: list[str] | None = None - self._current_file: BufferedWriter | None = None - self._rollover_task: asyncio.Task | None = None + self._current_file: TextIOWrapper | None = None + self._rollover_task: asyncio.Task[None] | None = None self._last_rollover: float = 0 self.filename: Path = Path() - async def _open_file(self, filename: Path) -> BufferedWriter: + async def _open_file(self, filename: Path) -> TextIOWrapper: """Open file asynchronously.""" - return await self._loop.run_in_executor(self._executor, lambda: filename.open(mode="ab")) + return await self._loop.run_in_executor(self._executor, lambda: filename.open(mode="a")) async def _close_file(self) -> None: if self._current_file is not None: await self._loop.run_in_executor(self._executor, self._current_file.close) + async def _write_header(self) -> None: + if not self._header or not self._current_file: + return + for line in self._header: + await self._loop.run_in_executor(self._executor, self._current_file.write, line) + await self._loop.run_in_executor(self._executor, self._current_file.write, "\n") + async def _rollover(self) -> None: """Close current file and open a new one.""" if self._current_file is not None: @@ -60,6 +89,7 @@ class FileWriter(Sender[T]): self._current_file = await self._open_file(self.filename) self._last_rollover = self._loop.time() logger.info("Rolled over to new file: %s", self.filename) + await self._write_header() async def _rollover_loop(self) -> None: """Background task that handles periodic file rollover.""" @@ -83,10 +113,15 @@ class FileWriter(Sender[T]): await self.start() if not self._current_file: raise RuntimeError("FileWriter not started") + if not self._header and self.header_func is not None: + self._header = self.header_func(data) + await self._write_header() payload = self.packer(data) + if isinstance(payload, bytes | bytearray): + payload = payload.decode() await self._loop.run_in_executor(self._executor, self._current_file.write, payload) - await self._loop.run_in_executor(self._executor, self._current_file.write, b"\n") + await self._loop.run_in_executor(self._executor, self._current_file.write, "\n") async def start(self) -> None: """Start the file writer and rollover background task.""" diff --git a/src/heisskleber/mqtt/sender.py b/src/heisskleber/mqtt/sender.py index 743d419..c627691 100644 --- a/src/heisskleber/mqtt/sender.py +++ b/src/heisskleber/mqtt/sender.py @@ -8,7 +8,7 @@ from typing import Any, TypeVar import aiomqtt -from heisskleber.core import Packer, Sender, json_packer +from heisskleber.core import Packer, Payload, Sender, json_packer from heisskleber.core.utils import retry from .config import MqttConf @@ -34,7 +34,7 @@ class MqttSender(Sender[T]): def __init__(self, config: MqttConf, packer: Packer[T] = json_packer) -> None: # type: ignore[assignment] self.config = config self.packer = packer - self._send_queue: asyncio.Queue[tuple[bytes, str]] = asyncio.Queue() + self._send_queue: asyncio.Queue[tuple[Payload, str]] = asyncio.Queue() self._sender_task: asyncio.Task[None] | None = None async def send(self, data: T, topic: str = "mqtt", qos: int = 0, retain: bool = False, **kwargs: Any) -> None: diff --git a/src/heisskleber/serial/sender.py b/src/heisskleber/serial/sender.py index f433bdd..01e0a9a 100644 --- a/src/heisskleber/serial/sender.py +++ b/src/heisskleber/serial/sender.py @@ -54,6 +54,7 @@ class SerialSender(Sender[T]): await self.start() payload = self.packer(data) + payload = payload.encode() if isinstance(payload, str) else payload try: await asyncio.get_running_loop().run_in_executor(self._executor, self._ser.write, payload) await asyncio.get_running_loop().run_in_executor(self._executor, self._ser.flush) diff --git a/src/heisskleber/udp/sender.py b/src/heisskleber/udp/sender.py index 8693db4..cfc1952 100644 --- a/src/heisskleber/udp/sender.py +++ b/src/heisskleber/udp/sender.py @@ -81,6 +81,7 @@ class UdpSender(Sender[T]): """ await self._ensure_connection() # we know that self._transport is intialized payload = self.pack(data) + payload = payload.encode() if isinstance(payload, str) else payload self._transport.sendto(payload) # type: ignore [union-attr] def __repr__(self) -> str: diff --git a/src/heisskleber/zmq/receiver.py b/src/heisskleber/zmq/receiver.py index 4a80dfd..880d23d 100644 --- a/src/heisskleber/zmq/receiver.py +++ b/src/heisskleber/zmq/receiver.py @@ -71,7 +71,7 @@ class ZmqReceiver(Receiver[T]): topic: The topic or list of topics to subscribe to. """ - if isinstance(topic, (list, tuple)): + if isinstance(topic, list | tuple): for t in topic: self._subscribe_single_topic(t) else: diff --git a/src/heisskleber/zmq/sender.py b/src/heisskleber/zmq/sender.py index 955c393..99265be 100644 --- a/src/heisskleber/zmq/sender.py +++ b/src/heisskleber/zmq/sender.py @@ -35,6 +35,7 @@ class ZmqSender(Sender[T]): if not self.is_connected: await self.start() payload = self.packer(data) + payload = payload.encode() if isinstance(payload, str) else payload logger.debug("sending payload %(payload)b to topic %(topic)s", {"payload": payload, "topic": topic}) await self.socket.send_multipart([topic.encode(), payload])