mirror of
https://github.com/OMGeeky/flucto-heisskleber.git
synced 2025-12-26 16:07:50 +01:00
Handle payload based on type in all senders.
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
import json
|
||||
from typing import Any, TypeVar
|
||||
|
||||
from heisskleber.core import Packer, Sender, json_packer
|
||||
from heisskleber.core import Packer, Sender
|
||||
|
||||
from .config import ConsoleConf
|
||||
|
||||
@@ -13,7 +14,7 @@ class ConsoleSender(Sender[T]):
|
||||
def __init__(
|
||||
self,
|
||||
config: ConsoleConf,
|
||||
packer: Packer[T] = json_packer, # type: ignore[assignment]
|
||||
packer: Packer[T] = lambda x: json.dumps(x), # type: ignore[assignment]
|
||||
) -> None:
|
||||
self.verbose = config.verbose
|
||||
self.pretty = config.pretty
|
||||
@@ -22,7 +23,8 @@ class ConsoleSender(Sender[T]):
|
||||
async def send(self, data: T, topic: str | None = None, **kwargs: dict[str, Any]) -> None:
|
||||
"""Serialize data and write to console output."""
|
||||
serialized = self.packer(data)
|
||||
output = f"{topic}:\t{serialized.decode()}" if topic else serialized.decode()
|
||||
serialized = serialized.decode() if isinstance(serialized, bytes | bytearray) else serialized
|
||||
output = f"{topic}:\t{serialized}" if topic else serialized
|
||||
print(output) # noqa: T201
|
||||
|
||||
def __repr__(self) -> str:
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from typing import Any
|
||||
|
||||
from .config import BaseConf, ConfigType
|
||||
from .packer import JSONPacker, Packer, PackerError
|
||||
from .packer import JSONPacker, Packer, PackerError, Payload
|
||||
from .receiver import Receiver
|
||||
from .sender import Sender
|
||||
from .unpacker import JSONUnpacker, Unpacker, UnpackerError
|
||||
@@ -29,6 +29,7 @@ __all__ = [
|
||||
"ConfigType",
|
||||
"Packer",
|
||||
"PackerError",
|
||||
"Payload",
|
||||
"Receiver",
|
||||
"Sender",
|
||||
"Unpacker",
|
||||
|
||||
@@ -2,10 +2,12 @@
|
||||
|
||||
import json
|
||||
from abc import abstractmethod
|
||||
from typing import Any, Protocol, TypeVar
|
||||
from typing import Any, Protocol, TypeAlias, TypeVar
|
||||
|
||||
T_contra = TypeVar("T_contra", contravariant=True)
|
||||
|
||||
Payload: TypeAlias = str | bytes | bytearray
|
||||
|
||||
|
||||
class PackerError(Exception):
|
||||
"""Raised when unpacking operations fail.
|
||||
@@ -38,7 +40,7 @@ class Packer(Protocol[T_contra]):
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def __call__(self, data: T_contra) -> bytes:
|
||||
def __call__(self, data: T_contra) -> Payload:
|
||||
"""Packs the data dictionary into a bytes payload.
|
||||
|
||||
Arguments:
|
||||
|
||||
@@ -6,6 +6,8 @@ from typing import Any, Protocol, TypeVar
|
||||
|
||||
T_co = TypeVar("T_co", covariant=True)
|
||||
|
||||
Payload = str | bytes | bytearray
|
||||
|
||||
|
||||
class UnpackerError(Exception):
|
||||
"""Raised when unpacking operations fail.
|
||||
@@ -20,10 +22,11 @@ class UnpackerError(Exception):
|
||||
|
||||
PREVIEW_LENGTH = 100
|
||||
|
||||
def __init__(self, payload: bytes) -> None:
|
||||
def __init__(self, payload: Payload) -> None:
|
||||
"""Initialize the error with the failed payload and cause."""
|
||||
self.payload = payload
|
||||
preview = payload[: self.PREVIEW_LENGTH] + b"..." if len(payload) > self.PREVIEW_LENGTH else payload
|
||||
dots = b"..." if isinstance(payload, bytes | bytearray) else "..."
|
||||
preview = payload[: self.PREVIEW_LENGTH] + dots if len(payload) > self.PREVIEW_LENGTH else payload
|
||||
message = f"Failed to unpack payload: {preview!r}. "
|
||||
super().__init__(message)
|
||||
|
||||
@@ -37,7 +40,7 @@ class Unpacker(Protocol[T_co]):
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def __call__(self, payload: bytes) -> tuple[T_co, dict[str, Any]]:
|
||||
def __call__(self, payload: Payload) -> tuple[T_co, dict[str, Any]]:
|
||||
"""Unpacks the payload into a data object and optional meta-data dictionary.
|
||||
|
||||
Args:
|
||||
@@ -76,7 +79,7 @@ class JSONUnpacker(Unpacker[dict[str, Any]]):
|
||||
|
||||
"""
|
||||
|
||||
def __call__(self, payload: bytes) -> tuple[dict[str, Any], dict[str, Any]]:
|
||||
def __call__(self, payload: Payload) -> tuple[dict[str, Any], dict[str, Any]]:
|
||||
"""Unpack the payload."""
|
||||
try:
|
||||
return json.loads(payload), {}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import asyncio
|
||||
from collections.abc import Coroutine
|
||||
from collections.abc import Callable, Coroutine
|
||||
from functools import wraps
|
||||
from typing import Any, Callable, ParamSpec, TypeVar
|
||||
from typing import Any, ParamSpec, TypeVar
|
||||
|
||||
P = ParamSpec("P")
|
||||
T = TypeVar("T")
|
||||
|
||||
@@ -1,13 +1,15 @@
|
||||
import asyncio
|
||||
import contextlib
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from datetime import datetime
|
||||
from io import BufferedWriter
|
||||
from io import TextIOWrapper
|
||||
from pathlib import Path
|
||||
from typing import Any, TypeVar
|
||||
|
||||
from heisskleber.core import Packer, Sender, json_packer
|
||||
from heisskleber.core import Packer, PackerError, Sender
|
||||
from heisskleber.file.config import FileConf
|
||||
|
||||
T = TypeVar("T")
|
||||
@@ -15,6 +17,19 @@ T = TypeVar("T")
|
||||
logger = logging.getLogger("heisskleber.file")
|
||||
|
||||
|
||||
def json_packer(data: dict[str, Any]) -> str:
|
||||
"""Pack to json string."""
|
||||
try:
|
||||
return json.dumps(data)
|
||||
except json.JSONDecodeError:
|
||||
raise PackerError(data) from None
|
||||
|
||||
|
||||
def csv_packer(data: dict[str, Any]) -> str:
|
||||
"""Create csv string from data."""
|
||||
return ",".join(map(str, data.values()))
|
||||
|
||||
|
||||
class FileWriter(Sender[T]):
|
||||
"""Asynchronous file writer implementation of the Sender interface.
|
||||
|
||||
@@ -22,34 +37,48 @@ class FileWriter(Sender[T]):
|
||||
Files are named according to the configured datetime format.
|
||||
"""
|
||||
|
||||
def __init__(self, config: FileConf, packer: Packer[T] = json_packer) -> None: # type: ignore[assignment]
|
||||
def __init__(
|
||||
self,
|
||||
config: FileConf,
|
||||
packer: Packer[T] = json_packer, # type: ignore[assignment]
|
||||
header_func: Callable[[T], list[str]] | None = None,
|
||||
) -> None:
|
||||
"""Initialize the file writer.
|
||||
|
||||
Args:
|
||||
base_path: Directory path where files will be written
|
||||
config: Configuration for file rollover and naming
|
||||
header_func: Function to extract header from T
|
||||
packer: Optional packer for serializing data
|
||||
"""
|
||||
self.base_path = Path(config.directory)
|
||||
self.config = config
|
||||
self.header_func = header_func
|
||||
self.packer = packer
|
||||
|
||||
self._executor = ThreadPoolExecutor(max_workers=1)
|
||||
self._loop = asyncio.get_running_loop()
|
||||
self._header: list[str] | None = None
|
||||
|
||||
self._current_file: BufferedWriter | None = None
|
||||
self._rollover_task: asyncio.Task | None = None
|
||||
self._current_file: TextIOWrapper | None = None
|
||||
self._rollover_task: asyncio.Task[None] | None = None
|
||||
self._last_rollover: float = 0
|
||||
self.filename: Path = Path()
|
||||
|
||||
async def _open_file(self, filename: Path) -> BufferedWriter:
|
||||
async def _open_file(self, filename: Path) -> TextIOWrapper:
|
||||
"""Open file asynchronously."""
|
||||
return await self._loop.run_in_executor(self._executor, lambda: filename.open(mode="ab"))
|
||||
return await self._loop.run_in_executor(self._executor, lambda: filename.open(mode="a"))
|
||||
|
||||
async def _close_file(self) -> None:
|
||||
if self._current_file is not None:
|
||||
await self._loop.run_in_executor(self._executor, self._current_file.close)
|
||||
|
||||
async def _write_header(self) -> None:
|
||||
if not self._header or not self._current_file:
|
||||
return
|
||||
for line in self._header:
|
||||
await self._loop.run_in_executor(self._executor, self._current_file.write, line)
|
||||
await self._loop.run_in_executor(self._executor, self._current_file.write, "\n")
|
||||
|
||||
async def _rollover(self) -> None:
|
||||
"""Close current file and open a new one."""
|
||||
if self._current_file is not None:
|
||||
@@ -60,6 +89,7 @@ class FileWriter(Sender[T]):
|
||||
self._current_file = await self._open_file(self.filename)
|
||||
self._last_rollover = self._loop.time()
|
||||
logger.info("Rolled over to new file: %s", self.filename)
|
||||
await self._write_header()
|
||||
|
||||
async def _rollover_loop(self) -> None:
|
||||
"""Background task that handles periodic file rollover."""
|
||||
@@ -83,10 +113,15 @@ class FileWriter(Sender[T]):
|
||||
await self.start()
|
||||
if not self._current_file:
|
||||
raise RuntimeError("FileWriter not started")
|
||||
if not self._header and self.header_func is not None:
|
||||
self._header = self.header_func(data)
|
||||
await self._write_header()
|
||||
|
||||
payload = self.packer(data)
|
||||
if isinstance(payload, bytes | bytearray):
|
||||
payload = payload.decode()
|
||||
await self._loop.run_in_executor(self._executor, self._current_file.write, payload)
|
||||
await self._loop.run_in_executor(self._executor, self._current_file.write, b"\n")
|
||||
await self._loop.run_in_executor(self._executor, self._current_file.write, "\n")
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the file writer and rollover background task."""
|
||||
|
||||
@@ -8,7 +8,7 @@ from typing import Any, TypeVar
|
||||
|
||||
import aiomqtt
|
||||
|
||||
from heisskleber.core import Packer, Sender, json_packer
|
||||
from heisskleber.core import Packer, Payload, Sender, json_packer
|
||||
from heisskleber.core.utils import retry
|
||||
|
||||
from .config import MqttConf
|
||||
@@ -34,7 +34,7 @@ class MqttSender(Sender[T]):
|
||||
def __init__(self, config: MqttConf, packer: Packer[T] = json_packer) -> None: # type: ignore[assignment]
|
||||
self.config = config
|
||||
self.packer = packer
|
||||
self._send_queue: asyncio.Queue[tuple[bytes, str]] = asyncio.Queue()
|
||||
self._send_queue: asyncio.Queue[tuple[Payload, str]] = asyncio.Queue()
|
||||
self._sender_task: asyncio.Task[None] | None = None
|
||||
|
||||
async def send(self, data: T, topic: str = "mqtt", qos: int = 0, retain: bool = False, **kwargs: Any) -> None:
|
||||
|
||||
@@ -54,6 +54,7 @@ class SerialSender(Sender[T]):
|
||||
await self.start()
|
||||
|
||||
payload = self.packer(data)
|
||||
payload = payload.encode() if isinstance(payload, str) else payload
|
||||
try:
|
||||
await asyncio.get_running_loop().run_in_executor(self._executor, self._ser.write, payload)
|
||||
await asyncio.get_running_loop().run_in_executor(self._executor, self._ser.flush)
|
||||
|
||||
@@ -81,6 +81,7 @@ class UdpSender(Sender[T]):
|
||||
"""
|
||||
await self._ensure_connection() # we know that self._transport is intialized
|
||||
payload = self.pack(data)
|
||||
payload = payload.encode() if isinstance(payload, str) else payload
|
||||
self._transport.sendto(payload) # type: ignore [union-attr]
|
||||
|
||||
def __repr__(self) -> str:
|
||||
|
||||
@@ -71,7 +71,7 @@ class ZmqReceiver(Receiver[T]):
|
||||
topic: The topic or list of topics to subscribe to.
|
||||
|
||||
"""
|
||||
if isinstance(topic, (list, tuple)):
|
||||
if isinstance(topic, list | tuple):
|
||||
for t in topic:
|
||||
self._subscribe_single_topic(t)
|
||||
else:
|
||||
|
||||
@@ -35,6 +35,7 @@ class ZmqSender(Sender[T]):
|
||||
if not self.is_connected:
|
||||
await self.start()
|
||||
payload = self.packer(data)
|
||||
payload = payload.encode() if isinstance(payload, str) else payload
|
||||
logger.debug("sending payload %(payload)b to topic %(topic)s", {"payload": payload, "topic": topic})
|
||||
await self.socket.send_multipart([topic.encode(), payload])
|
||||
|
||||
|
||||
Reference in New Issue
Block a user