From 78ba98a587f06e501bf0a3979f217458fe81d2c6 Mon Sep 17 00:00:00 2001 From: Felix Weiler Date: Mon, 22 Jan 2024 11:23:00 +0100 Subject: [PATCH] Feature/async sources (#46) * WIP: Added async file reader. * Async resampling and synchronization refactored. * Add async mqtt publisher. Remove queue from joint. * Add async zmq publisher and subscriber. * Modify integration tests for streaming. * Name refactoring resampler. * Added async source/sink to factory. * Refactor joint and add integration tests. * Add termcolor dev dependency * Add conosole source and sink * Add cli interface for different protocols * Removed files unfit for merge. * Fix review requests. * Restore use of $MSB_CONFIG_DIR for now. It seems that the default behaviour is not looking for .config/heisskleber * Remove version test, causing unnecessary failures. --- heisskleber/__init__.py | 9 +- heisskleber/console/__init__.py | 0 heisskleber/console/sink.py | 37 ++++++++ heisskleber/console/source.py | 34 +++++++ heisskleber/core/async_factories.py | 68 ++++++++++++++ heisskleber/core/types.py | 43 ++++++++- heisskleber/mqtt/__init__.py | 3 +- heisskleber/mqtt/publisher_async.py | 51 +++++++++++ heisskleber/mqtt/subscriber_async.py | 42 +++------ heisskleber/stream/joint.py | 119 ++++++++++++++----------- heisskleber/stream/resampler.py | 27 +++--- heisskleber/zmq/__init__.py | 6 +- heisskleber/zmq/publisher.py | 32 ++++++- heisskleber/zmq/subscriber.py | 47 +++++++++- poetry.lock | 36 +++++++- pyproject.toml | 5 +- run/cli.py | 34 +++++++ tests/integration/async_streamer.py | 15 +--- tests/integration/integration_joint.py | 13 ++- tests/integration/mqtt_async.py | 2 +- tests/integration/mqtt_pub.py | 51 +++++------ tests/integration/mqtt_stream.py | 7 +- tests/integration/mqtt_sub.py | 17 ++++ tests/integration/zmq_pub.py | 17 ++++ tests/stream/test_async_mqtt.py | 23 ++++- tests/test_joint.py | 7 +- tests/test_mqtt.py | 6 +- tests/test_streamer.py | 12 +-- tests/test_version.py | 7 -- tests/zmq/test_zmq_asyncio.py | 62 +++++++++++++ 30 files changed, 645 insertions(+), 187 deletions(-) create mode 100644 heisskleber/console/__init__.py create mode 100644 heisskleber/console/sink.py create mode 100644 heisskleber/console/source.py create mode 100644 heisskleber/core/async_factories.py create mode 100644 heisskleber/mqtt/publisher_async.py create mode 100644 run/cli.py create mode 100644 tests/integration/mqtt_sub.py create mode 100644 tests/integration/zmq_pub.py delete mode 100644 tests/test_version.py create mode 100644 tests/zmq/test_zmq_asyncio.py diff --git a/heisskleber/__init__.py b/heisskleber/__init__.py index 359c3cf..2c671e1 100644 --- a/heisskleber/__init__.py +++ b/heisskleber/__init__.py @@ -1,13 +1,18 @@ """Heisskleber.""" +from .core.async_factories import get_async_sink, get_async_source from .core.factories import get_publisher, get_sink, get_source, get_subscriber -from .core.types import Sink, Source +from .core.types import AsyncSink, AsyncSource, Sink, Source __all__ = [ "get_source", "get_sink", "get_publisher", "get_subscriber", + "get_async_source", + "get_async_sink", "Sink", "Source", + "AsyncSink", + "AsyncSource", ] -__version__ = "0.3.1" +__version__ = "0.4.0" diff --git a/heisskleber/console/__init__.py b/heisskleber/console/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/heisskleber/console/sink.py b/heisskleber/console/sink.py new file mode 100644 index 0000000..1db96f2 --- /dev/null +++ b/heisskleber/console/sink.py @@ -0,0 +1,37 @@ +import json +import sys +import time +from typing import TextIO + +from heisskleber.core.types import AsyncSink, Serializable, Sink + + +def pretty_print(data: dict[str, Serializable]) -> str: + return json.dumps(data, indent=4) + + +class ConsoleSink(Sink): + def __init__(self, stream: TextIO = sys.stdout, pretty: bool = False): + self.stream = stream + self.print = pretty_print if pretty else json.dumps + + def send(self, data: dict[str, Serializable], topic: str) -> None: + self.stream.write(self.print(data)) # type: ignore[operator] + self.stream.write("\n") + + +class AsyncConsoleSink(AsyncSink): + def __init__(self, stream: TextIO = sys.stdout, pretty: bool = False): + self.stream = stream + self.print = pretty_print if pretty else json.dumps + + async def send(self, data: dict[str, Serializable], topic: str) -> None: + self.stream.write(self.print(data)) # type: ignore[operator] + self.stream.write("\n") + + +if __name__ == "__main__": + sink = ConsoleSink() + while True: + sink.send({"test": "test"}, "test") + time.sleep(1) diff --git a/heisskleber/console/source.py b/heisskleber/console/source.py new file mode 100644 index 0000000..48f6242 --- /dev/null +++ b/heisskleber/console/source.py @@ -0,0 +1,34 @@ +import json +import sys +import time +from queue import SimpleQueue +from threading import Thread + +from heisskleber.core.types import Serializable, Source + + +class ConsoleSource(Source): + def __init__(self, topic: str | list[str] | tuple[str] = "console") -> None: + self.topic = "console" + self.queue = SimpleQueue() + self.listener_daemon = Thread(target=self.listener_task, daemon=True) + self.listener_daemon.start() + self.pack = json.loads + + def listener_task(self): + while True: + data = sys.stdin.readline() + payload = self.pack(data) + self.queue.put(payload) + + def receive(self) -> tuple[str, dict[str, Serializable]]: + data = self.queue.get() + return self.topic, data + + +if __name__ == "__main__": + console_source = ConsoleSource() + + while True: + print(console_source.receive()) + time.sleep(1) diff --git a/heisskleber/core/async_factories.py b/heisskleber/core/async_factories.py new file mode 100644 index 0000000..57aaf8f --- /dev/null +++ b/heisskleber/core/async_factories.py @@ -0,0 +1,68 @@ +import os + +from heisskleber.config import BaseConf, load_config +from heisskleber.mqtt import AsyncMqttPublisher, AsyncMqttSubscriber, MqttConf +from heisskleber.zmq import ZmqAsyncPublisher, ZmqAsyncSubscriber, ZmqConf + +from .types import AsyncSink, AsyncSource + +_registered_async_sinks: dict[str, tuple[type[AsyncSink], type[BaseConf]]] = { + "mqtt": (AsyncMqttPublisher, MqttConf), + "zmq": (ZmqAsyncPublisher, ZmqConf), +} + +_registered_async_sources: dict[str, tuple] = { + "mqtt": (AsyncMqttSubscriber, MqttConf), + "zmq": (ZmqAsyncSubscriber, ZmqConf), +} + + +def get_async_sink(name: str) -> AsyncSink: + """ + Factory function to create a sink object. + + Parameters: + name: Name of the sink to create. + config: Configuration object to use for the sink. + """ + + if name not in _registered_async_sinks: + error_message = f"{name} is not a registered Sink." + raise KeyError(error_message) + + pub_cls, conf_cls = _registered_async_sinks[name] + + if "MSB_CONFIG_DIR" in os.environ: + print(f"loading {name} config") + config = load_config(conf_cls(), name, read_commandline=False) + else: + print(f"using default {name} config") + config = conf_cls() + + return pub_cls(config) + + +def get_async_source(name: str, topic: str | list[str] | tuple[str]) -> AsyncSource: + """ + Factory function to create a source object. + + Parameters: + name: Name of the source to create. + config: Configuration object to use for the source. + topic: Topic to subscribe to. + """ + + if name not in _registered_async_sources: + error_message = f"{name} is not a registered Source." + raise KeyError(error_message) + + sub_cls, conf_cls = _registered_async_sources[name] + + if "MSB_CONFIG_DIR" in os.environ: + print(f"loading {name} config") + config = load_config(conf_cls(), name, read_commandline=False) + else: + print(f"using default {name} config") + config = conf_cls() + + return sub_cls(config, topic) diff --git a/heisskleber/core/types.py b/heisskleber/core/types.py index 15fa7bf..3009dfc 100644 --- a/heisskleber/core/types.py +++ b/heisskleber/core/types.py @@ -14,7 +14,6 @@ class Sink(ABC): """ pack: Callable[[dict[str, Serializable]], str] - config: BaseConf @abstractmethod def __init__(self, config: BaseConf) -> None: @@ -37,7 +36,6 @@ class Source(ABC): """ unpack: Callable[[str], dict[str, Serializable]] - config: BaseConf @abstractmethod def __init__(self, config: BaseConf, topic: str | list[str]) -> None: @@ -77,9 +75,46 @@ class AsyncSubscriber(ABC): """ pass + +class AsyncSource(ABC): + """ + AsyncSubscriber interface + """ + @abstractmethod - def run(self) -> None: + def __init__(self, config: Any, topic: str | list[str]) -> None: """ - Run the subscriber loop. + Initialize the subscriber with a topic and a configuration object. + """ + pass + + @abstractmethod + async def receive(self) -> tuple[str, dict[str, Serializable]]: + """ + Blocking function to receive data from the implemented input stream. + + Data is returned as a tuple of (topic, data). + """ + pass + + +class AsyncSink(ABC): + """ + Sink interface to send() data to. + """ + + pack: Callable[[dict[str, Serializable]], str] + + @abstractmethod + def __init__(self, config: BaseConf) -> None: + """ + Initialize the publisher with a configuration object. + """ + pass + + @abstractmethod + async def send(self, data: dict[str, Any], topic: str) -> None: + """ + Send data via the implemented output stream. """ pass diff --git a/heisskleber/mqtt/__init__.py b/heisskleber/mqtt/__init__.py index 1794bb5..4efa290 100644 --- a/heisskleber/mqtt/__init__.py +++ b/heisskleber/mqtt/__init__.py @@ -1,6 +1,7 @@ from .config import MqttConf from .publisher import MqttPublisher +from .publisher_async import AsyncMqttPublisher from .subscriber import MqttSubscriber from .subscriber_async import AsyncMqttSubscriber -__all__ = ["MqttConf", "MqttPublisher", "MqttSubscriber", "AsyncMqttSubscriber"] +__all__ = ["MqttConf", "MqttPublisher", "MqttSubscriber", "AsyncMqttSubscriber", "AsyncMqttPublisher"] diff --git a/heisskleber/mqtt/publisher_async.py b/heisskleber/mqtt/publisher_async.py new file mode 100644 index 0000000..0b79799 --- /dev/null +++ b/heisskleber/mqtt/publisher_async.py @@ -0,0 +1,51 @@ +from __future__ import annotations + +import asyncio +from typing import Any + +import aiomqtt + +from heisskleber.core.packer import get_packer +from heisskleber.core.types import AsyncSink + +from .config import MqttConf + + +class AsyncMqttPublisher(AsyncSink): + """ + MQTT publisher class. + Can be used everywhere that a flucto style publishing connection is required. + + Network message loop is handled in a separated thread. + """ + + def __init__(self, config: MqttConf) -> None: + self.config = config + self.pack = get_packer(config.packstyle) + self._send_queue = asyncio.Queue() + self._sender_task = asyncio.create_task(self.send_work()) + + async def send(self, data: dict[str, Any], topic: str) -> None: + """ + Takes python dictionary, serializes it according to the packstyle + and sends it to the broker. + + Publishing is asynchronous + """ + + await self._send_queue.put((data, topic)) + + async def send_work(self) -> None: + """ + Takes python dictionary, serializes it according to the packstyle + and sends it to the broker. + + Publishing is asynchronous + """ + async with aiomqtt.Client( + hostname=self.config.broker, port=self.config.port, username=self.config.user, password=self.config.password + ) as client: + while True: + data, topic = await self._send_queue.get() + payload = self.pack(data) + await client.publish(topic, payload) diff --git a/heisskleber/mqtt/subscriber_async.py b/heisskleber/mqtt/subscriber_async.py index 9028293..901db09 100644 --- a/heisskleber/mqtt/subscriber_async.py +++ b/heisskleber/mqtt/subscriber_async.py @@ -1,18 +1,16 @@ -from asyncio import Queue, sleep +from asyncio import Queue, Task, create_task, sleep from aiomqtt import Client, Message, MqttError from heisskleber.core.packer import get_unpacker -from heisskleber.core.types import AsyncSubscriber, Serializable +from heisskleber.core.types import AsyncSource, Serializable from heisskleber.mqtt import MqttConf -class AsyncMqttSubscriber(AsyncSubscriber): +class AsyncMqttSubscriber(AsyncSource): """Asynchronous MQTT susbsciber based on aiomqtt. - Data is received by one of two methods: - 1. The `receive` method returns the newest message in the queue. For this to work, the `run` method must be called in a separae task. - 2. The `generator` method is a generator function that yields a topic and dict payload. + Data is received by the `receive` method returns the newest message in the queue. """ def __init__(self, config: MqttConf, topic: str | list[str]) -> None: @@ -26,6 +24,7 @@ class AsyncMqttSubscriber(AsyncSubscriber): self.topics = topic self.unpack = get_unpacker(self.config.packstyle) self.message_queue: Queue[Message] = Queue(self.config.max_saved_messages) + self._listener_task: Task = create_task(self.create_listener()) """ Await the newest message in the queue and return Tuple @@ -35,41 +34,21 @@ class AsyncMqttSubscriber(AsyncSubscriber): mqtt_message: Message = await self.message_queue.get() return self._handle_message(mqtt_message) - """ - Listen to incoming messages asynchronously and put them into a queue - """ - - async def run(self) -> None: - """ - Run the async mqtt listening loop. - Includes reconnecting to mqtt broker. - """ - # Manage connection to mqtt + async def create_listener(self): while True: try: async with self.client: await self._subscribe_topics() await self._listen_mqtt_loop() - except MqttError: + except MqttError as e: + print(f"MqttError: {e}") print("Connection to MQTT failed. Retrying...") await sleep(1) """ - Generator function that yields topic and dict payload. + Listen to incoming messages asynchronously and put them into a queue """ - async def generator(self): - while True: - try: - async with self.client: - await self._subscribe_topics() - async with self.client.messages() as messages: - async for message in messages: - yield self._handle_message(message) - except MqttError: - print("Connection to MQTT failed. Retrying...") - await sleep(1) - async def _listen_mqtt_loop(self) -> None: async with self.client.messages() as messages: # async with self.client.filtered_messages(self.topics) as messages: @@ -77,10 +56,11 @@ class AsyncMqttSubscriber(AsyncSubscriber): await self.message_queue.put(message) def _handle_message(self, message: Message) -> tuple[str, dict[str, Serializable]]: - topic = str(message.topic) if not isinstance(message.payload, bytes): error_msg = "Payload is not of type bytes." raise TypeError(error_msg) + + topic = str(message.topic) message_returned = self.unpack(message.payload.decode()) return (topic, message_returned) diff --git a/heisskleber/stream/joint.py b/heisskleber/stream/joint.py index 8a275d8..f55114d 100644 --- a/heisskleber/stream/joint.py +++ b/heisskleber/stream/joint.py @@ -1,6 +1,5 @@ import asyncio -from heisskleber.core.types import AsyncSubscriber from heisskleber.stream.resampler import Resampler, ResamplerConf @@ -18,65 +17,83 @@ class Joint: """ - def __init__(self, conf: ResamplerConf, subscribers: list[AsyncSubscriber]): + def __init__(self, conf: ResamplerConf, resamplers: list[Resampler]): self.conf = conf - self.subscribers = subscribers - self.generators = [] - self.resampler_timestamps = [] - self.latest_timestamp = 0 - self.latest_data = {} - self.tasks = [] + self.resamplers = resamplers + self.output_queue = asyncio.Queue() + self.initialized = asyncio.Event() + self.initalize_task = asyncio.create_task(self.sync()) + self.output_task = asyncio.create_task(self.output_work()) - async def receive(self): - old_value = self.latest_data.copy() - await self._update() - return old_value + self.combined_dict = {} - async def generate(self): - while True: - yield self.latest_data - await self._update() + """ + Main interaction coroutine: Get next value out of the queue. + """ - """Set up the streamer joint, which will activate all subscribers.""" + async def receive(self) -> dict: + return await self.output_queue.get() - async def setup(self): - for sub in self.subscribers: - # Start an async task to run the subscriber loop - task = asyncio.create_task(sub.run()) - self.tasks.append(task) - self.generators.append(Resampler(self.conf, sub).resample()) - - await self._synchronize() - - async def _synchronize(self): + async def sync(self) -> None: + print("Starting sync") + datas = await asyncio.gather(*[source.receive() for source in self.resamplers]) + output_data = {} data = {} - # first pass to initialize resamplers - for resampler in self.generators: - data = await anext(resampler) - self.resampler_timestamps.append(data["epoch"]) - if data["epoch"] > self.latest_timestamp: - self.latest_timestamp = data["epoch"] - self.latest_data = dict(data) + latest_timestamp: float = 0.0 + timestamps = [] - for resampler, timestamp in zip(self.generators, self.resampler_timestamps): - if timestamp == self.latest_timestamp: - continue + print("Syncing...") + for data in datas: + if not isinstance(data["epoch"], float): + error = "Timestamps must be floats" + raise TypeError(error) - while timestamp < self.latest_timestamp: - data = await anext(resampler) - timestamp = data["epoch"] + ts = float(data["epoch"]) - self.latest_data.update(data) + print(f"Syncing..., got {ts}") - async def _update(self): - data: dict = {} - for resampler in self.generators: - try: - data = await anext(resampler) + timestamps.append(ts) + if ts > latest_timestamp: + latest_timestamp = ts - if data["epoch"] >= self.latest_timestamp: - self.latest_timestamp = data["epoch"] - self.latest_data.update(data) - except Exception: - print(Exception) + # only take the piece of the latest data + output_data = data + + for resampler, ts in zip(self.resamplers, timestamps): + while ts < latest_timestamp: + data = await resampler.receive() + ts = float(data["epoch"]) + + output_data.update(data) + + await self.output_queue.put(output_data) + + print("Finished initalization") + self.initialized.set() + + """ + Coroutine that waits for new queue data and updates dict. + """ + + async def update_dict(self, resampler): + # queue is passed by reference, python y u so weird! + data = await resampler.receive() + if self.combined_dict and self.combined_dict["epoch"] != data["epoch"]: + print("Oh shit, this is bad!") + self.combined_dict.update(data) + + """ + Output worker: iterate through queues, read data and join into output queue. + """ + + async def output_work(self): + print("Output worker waiting for intitialization") + await self.initialized.wait() + print("Output worker resuming") + + while True: + self.combined_dict = {} + tasks = [asyncio.create_task(self.update_dict(res)) for res in self.resamplers] + await asyncio.gather(*tasks) + await self.output_queue.put(self.combined_dict) diff --git a/heisskleber/stream/resampler.py b/heisskleber/stream/resampler.py index 1742c56..01f0f1a 100644 --- a/heisskleber/stream/resampler.py +++ b/heisskleber/stream/resampler.py @@ -1,10 +1,11 @@ import math -from collections.abc import AsyncGenerator, Generator +from asyncio import Queue, Task, create_task +from collections.abc import Generator from datetime import datetime, timedelta import numpy as np -from heisskleber.core.types import AsyncSubscriber, Serializable +from heisskleber.core.types import AsyncSource, Serializable from .config import ResamplerConf @@ -27,7 +28,7 @@ def timestamp_generator(start_epoch: float, timedelta_in_ms: int) -> Generator[f next_timestamp += delta -def interpolate(t1, y1, t2, y2, t_target): +def interpolate(t1, y1, t2, y2, t_target) -> list[float]: """Perform linear interpolation between two data points.""" y1, y2 = np.array(y1), np.array(y2) fraction = (t_target - t1) / (t2 - t1) @@ -47,13 +48,18 @@ class Resampler: Asynchronous Subscriber """ - def __init__(self, config: ResamplerConf, subscriber: AsyncSubscriber) -> None: + def __init__(self, config: ResamplerConf, subscriber: AsyncSource) -> None: self.config = config self.subscriber = subscriber self.resample_rate = self.config.resample_rate self.delta_t = round(self.resample_rate / 1_000, 3) + self.message_queue: Queue[dict[str, Serializable]] = Queue(maxsize=50) + self.resample_task: Task = create_task(self.resample()) - async def resample(self) -> AsyncGenerator[dict[str, Serializable], None]: + async def receive(self) -> dict[str, Serializable]: + return await self.message_queue.get() + + async def resample(self) -> None: """ Resample data based on a fixed rate. @@ -81,10 +87,7 @@ class Resampler: aggregated_timestamps.append(timestamp) aggregated_data.append(message) # timestamp, message = await self.buffer.get() - try: - topic, data = await self.subscriber.receive() - except Exception as e: - raise StopAsyncIteration from e + topic, data = await self.subscriber.receive() timestamp, message = self._pack_data(data) # timestamp, message = self._pack_data(message) @@ -113,7 +116,7 @@ class Resampler: last_timestamp = return_timestamp return_timestamp += self.delta_t next_timestamp = next(timestamps) - yield self._unpack_data(last_timestamp, last_message) + await self.message_queue.put(self._unpack_data(last_timestamp, last_message)) if self._is_upsampling: last_message = interpolate( @@ -127,12 +130,12 @@ class Resampler: # else: # return_timestamp += self.delta_t - yield self._unpack_data(last_timestamp, last_message) + await self.message_queue.put(self._unpack_data(last_timestamp, last_message)) if len(aggregated_data) > 1: # Case 4 - downsampling: Multiple data points were during the resampling timeframe mean_message = np.mean(np.array(aggregated_data), axis=0) - yield self._unpack_data(return_timestamp, mean_message) + await self.message_queue.put(self._unpack_data(return_timestamp, mean_message)) # reset the aggregator aggregated_data.clear() diff --git a/heisskleber/zmq/__init__.py b/heisskleber/zmq/__init__.py index bbdfd1e..5f2019c 100644 --- a/heisskleber/zmq/__init__.py +++ b/heisskleber/zmq/__init__.py @@ -1,5 +1,5 @@ from .config import ZmqConf -from .publisher import ZmqPublisher -from .subscriber import ZmqSubscriber +from .publisher import ZmqAsyncPublisher, ZmqPublisher +from .subscriber import ZmqAsyncSubscriber, ZmqSubscriber -__all__ = ["ZmqConf", "ZmqPublisher", "ZmqSubscriber"] +__all__ = ["ZmqConf", "ZmqPublisher", "ZmqSubscriber", "ZmqAsyncPublisher", "ZmqAsyncSubscriber"] diff --git a/heisskleber/zmq/publisher.py b/heisskleber/zmq/publisher.py index 70c7182..1a754ae 100644 --- a/heisskleber/zmq/publisher.py +++ b/heisskleber/zmq/publisher.py @@ -1,9 +1,10 @@ import sys import zmq +import zmq.asyncio from heisskleber.core.packer import get_packer -from heisskleber.core.types import Serializable, Sink +from heisskleber.core.types import AsyncSink, Serializable, Sink from .config import ZmqConf @@ -35,3 +36,32 @@ class ZmqPublisher(Sink): def __del__(self): self.socket.close() + + +class ZmqAsyncPublisher(AsyncSink): + def __init__(self, config: ZmqConf): + self.config = config + + self.context = zmq.asyncio.Context.instance() + self.socket: zmq.asyncio.Socket = self.context.socket(zmq.PUB) + + self.pack = get_packer(config.packstyle) + self.connect() + + def connect(self) -> None: + try: + if self.config.verbose: + print(f"connecting to {self.config.publisher_address}") + self.socket.connect(self.config.publisher_address) + except Exception as e: + print(f"failed to bind to zeromq socket: {e}") + sys.exit(-1) + + async def send(self, data: dict[str, Serializable], topic: str) -> None: + payload = self.pack(data) + if self.config.verbose: + print(f"sending message {payload} to topic {topic}") + await self.socket.send_multipart([topic.encode(), payload.encode()]) + + def __del__(self): + self.socket.close() diff --git a/heisskleber/zmq/subscriber.py b/heisskleber/zmq/subscriber.py index 43a764a..0297f41 100644 --- a/heisskleber/zmq/subscriber.py +++ b/heisskleber/zmq/subscriber.py @@ -3,9 +3,10 @@ from __future__ import annotations import sys import zmq +import zmq.asyncio from heisskleber.core.packer import get_unpacker -from heisskleber.core.types import Source +from heisskleber.core.types import AsyncSource, Source from .config import ZmqConf @@ -54,3 +55,47 @@ class ZmqSubscriber(Source): def __del__(self): self.socket.close() + + +class ZmqAsyncSubscriber(AsyncSource): + def __init__(self, config: ZmqConf, topic: str): + self.config = config + self.context = zmq.asyncio.Context.instance() + self.socket: zmq.asyncio.Socket = self.context.socket(zmq.SUB) + self.connect() + self.subscribe(topic) + + self.unpack = get_unpacker(config.packstyle) + + def connect(self): + try: + self.socket.connect(self.config.subscriber_address) + except Exception as e: + print(f"failed to bind to zeromq socket: {e}") + sys.exit(-1) + + def _subscribe_single_topic(self, topic: str): + self.socket.setsockopt(zmq.SUBSCRIBE, topic.encode()) + + def subscribe(self, topic: str | list[str] | tuple[str]): + # Accepts single topic or list of topics + if isinstance(topic, (list, tuple)): + for t in topic: + self._subscribe_single_topic(t) + else: + self._subscribe_single_topic(topic) + + async def receive(self) -> tuple[str, dict]: + """ + reads a message from the zmq bus and returns it + + Returns: + tuple(topic: str, message: dict): the message received + """ + (topic, payload) = await self.socket.recv_multipart() + message = self.unpack(payload.decode()) + topic = topic.decode() + return (topic, message) + + def __del__(self): + self.socket.close() diff --git a/poetry.lock b/poetry.lock index 36c2359..a27cf87 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1287,30 +1287,50 @@ files = [ {file = "ruamel.yaml.clib-0.2.8-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:b42169467c42b692c19cf539c38d4602069d8c1505e97b86387fcf7afb766e1d"}, {file = "ruamel.yaml.clib-0.2.8-cp310-cp310-macosx_13_0_arm64.whl", hash = "sha256:07238db9cbdf8fc1e9de2489a4f68474e70dffcb32232db7c08fa61ca0c7c462"}, {file = "ruamel.yaml.clib-0.2.8-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_24_x86_64.whl", hash = "sha256:fff3573c2db359f091e1589c3d7c5fc2f86f5bdb6f24252c2d8e539d4e45f412"}, + {file = "ruamel.yaml.clib-0.2.8-cp310-cp310-manylinux_2_24_aarch64.whl", hash = "sha256:aa2267c6a303eb483de8d02db2871afb5c5fc15618d894300b88958f729ad74f"}, + {file = "ruamel.yaml.clib-0.2.8-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:840f0c7f194986a63d2c2465ca63af8ccbbc90ab1c6001b1978f05119b5e7334"}, + {file = "ruamel.yaml.clib-0.2.8-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:024cfe1fc7c7f4e1aff4a81e718109e13409767e4f871443cbff3dba3578203d"}, {file = "ruamel.yaml.clib-0.2.8-cp310-cp310-win32.whl", hash = "sha256:c69212f63169ec1cfc9bb44723bf2917cbbd8f6191a00ef3410f5a7fe300722d"}, {file = "ruamel.yaml.clib-0.2.8-cp310-cp310-win_amd64.whl", hash = "sha256:cabddb8d8ead485e255fe80429f833172b4cadf99274db39abc080e068cbcc31"}, {file = "ruamel.yaml.clib-0.2.8-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:bef08cd86169d9eafb3ccb0a39edb11d8e25f3dae2b28f5c52fd997521133069"}, {file = "ruamel.yaml.clib-0.2.8-cp311-cp311-macosx_13_0_arm64.whl", hash = "sha256:b16420e621d26fdfa949a8b4b47ade8810c56002f5389970db4ddda51dbff248"}, {file = "ruamel.yaml.clib-0.2.8-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_24_x86_64.whl", hash = "sha256:25c515e350e5b739842fc3228d662413ef28f295791af5e5110b543cf0b57d9b"}, + {file = "ruamel.yaml.clib-0.2.8-cp311-cp311-manylinux_2_24_aarch64.whl", hash = "sha256:1707814f0d9791df063f8c19bb51b0d1278b8e9a2353abbb676c2f685dee6afe"}, + {file = "ruamel.yaml.clib-0.2.8-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:46d378daaac94f454b3a0e3d8d78cafd78a026b1d71443f4966c696b48a6d899"}, + {file = "ruamel.yaml.clib-0.2.8-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:09b055c05697b38ecacb7ac50bdab2240bfca1a0c4872b0fd309bb07dc9aa3a9"}, {file = "ruamel.yaml.clib-0.2.8-cp311-cp311-win32.whl", hash = "sha256:53a300ed9cea38cf5a2a9b069058137c2ca1ce658a874b79baceb8f892f915a7"}, {file = "ruamel.yaml.clib-0.2.8-cp311-cp311-win_amd64.whl", hash = "sha256:c2a72e9109ea74e511e29032f3b670835f8a59bbdc9ce692c5b4ed91ccf1eedb"}, {file = "ruamel.yaml.clib-0.2.8-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:ebc06178e8821efc9692ea7544aa5644217358490145629914d8020042c24aa1"}, {file = "ruamel.yaml.clib-0.2.8-cp312-cp312-macosx_13_0_arm64.whl", hash = "sha256:edaef1c1200c4b4cb914583150dcaa3bc30e592e907c01117c08b13a07255ec2"}, {file = "ruamel.yaml.clib-0.2.8-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d176b57452ab5b7028ac47e7b3cf644bcfdc8cacfecf7e71759f7f51a59e5c92"}, + {file = "ruamel.yaml.clib-0.2.8-cp312-cp312-manylinux_2_24_aarch64.whl", hash = "sha256:1dc67314e7e1086c9fdf2680b7b6c2be1c0d8e3a8279f2e993ca2a7545fecf62"}, + {file = "ruamel.yaml.clib-0.2.8-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:3213ece08ea033eb159ac52ae052a4899b56ecc124bb80020d9bbceeb50258e9"}, + {file = "ruamel.yaml.clib-0.2.8-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:aab7fd643f71d7946f2ee58cc88c9b7bfc97debd71dcc93e03e2d174628e7e2d"}, + {file = "ruamel.yaml.clib-0.2.8-cp312-cp312-win32.whl", hash = "sha256:5c365d91c88390c8d0a8545df0b5857172824b1c604e867161e6b3d59a827eaa"}, + {file = "ruamel.yaml.clib-0.2.8-cp312-cp312-win_amd64.whl", hash = "sha256:1758ce7d8e1a29d23de54a16ae867abd370f01b5a69e1a3ba75223eaa3ca1a1b"}, {file = "ruamel.yaml.clib-0.2.8-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:a5aa27bad2bb83670b71683aae140a1f52b0857a2deff56ad3f6c13a017a26ed"}, {file = "ruamel.yaml.clib-0.2.8-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:c58ecd827313af6864893e7af0a3bb85fd529f862b6adbefe14643947cfe2942"}, {file = "ruamel.yaml.clib-0.2.8-cp37-cp37m-macosx_12_0_arm64.whl", hash = "sha256:f481f16baec5290e45aebdc2a5168ebc6d35189ae6fea7a58787613a25f6e875"}, + {file = "ruamel.yaml.clib-0.2.8-cp37-cp37m-manylinux_2_24_aarch64.whl", hash = "sha256:77159f5d5b5c14f7c34073862a6b7d34944075d9f93e681638f6d753606c6ce6"}, {file = "ruamel.yaml.clib-0.2.8-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:7f67a1ee819dc4562d444bbafb135832b0b909f81cc90f7aa00260968c9ca1b3"}, + {file = "ruamel.yaml.clib-0.2.8-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:4ecbf9c3e19f9562c7fdd462e8d18dd902a47ca046a2e64dba80699f0b6c09b7"}, + {file = "ruamel.yaml.clib-0.2.8-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:87ea5ff66d8064301a154b3933ae406b0863402a799b16e4a1d24d9fbbcbe0d3"}, {file = "ruamel.yaml.clib-0.2.8-cp37-cp37m-win32.whl", hash = "sha256:75e1ed13e1f9de23c5607fe6bd1aeaae21e523b32d83bb33918245361e9cc51b"}, {file = "ruamel.yaml.clib-0.2.8-cp37-cp37m-win_amd64.whl", hash = "sha256:3f215c5daf6a9d7bbed4a0a4f760f3113b10e82ff4c5c44bec20a68c8014f675"}, {file = "ruamel.yaml.clib-0.2.8-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1b617618914cb00bf5c34d4357c37aa15183fa229b24767259657746c9077615"}, {file = "ruamel.yaml.clib-0.2.8-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:a6a9ffd280b71ad062eae53ac1659ad86a17f59a0fdc7699fd9be40525153337"}, + {file = "ruamel.yaml.clib-0.2.8-cp38-cp38-manylinux_2_24_aarch64.whl", hash = "sha256:305889baa4043a09e5b76f8e2a51d4ffba44259f6b4c72dec8ca56207d9c6fe1"}, {file = "ruamel.yaml.clib-0.2.8-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:700e4ebb569e59e16a976857c8798aee258dceac7c7d6b50cab63e080058df91"}, + {file = "ruamel.yaml.clib-0.2.8-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:e2b4c44b60eadec492926a7270abb100ef9f72798e18743939bdbf037aab8c28"}, + {file = "ruamel.yaml.clib-0.2.8-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:e79e5db08739731b0ce4850bed599235d601701d5694c36570a99a0c5ca41a9d"}, {file = "ruamel.yaml.clib-0.2.8-cp38-cp38-win32.whl", hash = "sha256:955eae71ac26c1ab35924203fda6220f84dce57d6d7884f189743e2abe3a9fbe"}, {file = "ruamel.yaml.clib-0.2.8-cp38-cp38-win_amd64.whl", hash = "sha256:56f4252222c067b4ce51ae12cbac231bce32aee1d33fbfc9d17e5b8d6966c312"}, {file = "ruamel.yaml.clib-0.2.8-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:03d1162b6d1df1caa3a4bd27aa51ce17c9afc2046c31b0ad60a0a96ec22f8001"}, {file = "ruamel.yaml.clib-0.2.8-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:bba64af9fa9cebe325a62fa398760f5c7206b215201b0ec825005f1b18b9bccf"}, + {file = "ruamel.yaml.clib-0.2.8-cp39-cp39-manylinux_2_24_aarch64.whl", hash = "sha256:a1a45e0bb052edf6a1d3a93baef85319733a888363938e1fc9924cb00c8df24c"}, {file = "ruamel.yaml.clib-0.2.8-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:da09ad1c359a728e112d60116f626cc9f29730ff3e0e7db72b9a2dbc2e4beed5"}, + {file = "ruamel.yaml.clib-0.2.8-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:184565012b60405d93838167f425713180b949e9d8dd0bbc7b49f074407c5a8b"}, + {file = "ruamel.yaml.clib-0.2.8-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:a75879bacf2c987c003368cf14bed0ffe99e8e85acfa6c0bfffc21a090f16880"}, {file = "ruamel.yaml.clib-0.2.8-cp39-cp39-win32.whl", hash = "sha256:84b554931e932c46f94ab306913ad7e11bba988104c5cff26d90d03f68258cd5"}, {file = "ruamel.yaml.clib-0.2.8-cp39-cp39-win_amd64.whl", hash = "sha256:25ac8c08322002b06fa1d49d1646181f0b2c72f5cbc15a85e80b4c30a544bb15"}, {file = "ruamel.yaml.clib-0.2.8.tar.gz", hash = "sha256:beb2e0404003de9a4cab9753a8805a8fe9320ee6673136ed7f04255fe60bb512"}, @@ -1641,6 +1661,20 @@ Sphinx = ">=5" lint = ["docutils-stubs", "flake8", "mypy"] test = ["pytest"] +[[package]] +name = "termcolor" +version = "2.4.0" +description = "ANSI color formatting for output in terminal" +optional = false +python-versions = ">=3.8" +files = [ + {file = "termcolor-2.4.0-py3-none-any.whl", hash = "sha256:9297c0df9c99445c2412e832e882a7884038a25617c60cea2ad69488d4040d63"}, + {file = "termcolor-2.4.0.tar.gz", hash = "sha256:aab9e56047c8ac41ed798fa36d892a37aca6b3e9159f3e0c24bc64a9b3ac7b7a"}, +] + +[package.extras] +tests = ["pytest", "pytest-cov"] + [[package]] name = "tokenize-rt" version = "5.2.0" @@ -1827,4 +1861,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "62e84c22355bb5cf5e86a3b6e58dd801376d460877b65a3d2a972c2c5f72a13f" +content-hash = "f1f0bc51241cb45c05aa4d5d99bcac21c6a14c24d9a331795d117e0801a3d0dd" diff --git a/pyproject.toml b/pyproject.toml index acb2069..cb77f1b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "heisskleber" -version = "0.3.1" +version = "0.4.0" description = "Heisskleber" authors = ["Felix Weiler "] license = "MIT" @@ -40,6 +40,7 @@ typeguard = ">=2.13.3" xdoctest = { extras = ["colors"], version = ">=0.15.10" } myst-parser = { version = ">=0.16.1" } pytest-asyncio = "^0.21.1" +termcolor = "^2.4.0" [tool.poetry.group.types.dependencies] @@ -48,7 +49,7 @@ types-pyyaml = "^6.0.12.12" types-paho-mqtt = "^1.6.0.7" [tool.poetry.scripts] -heisskleber = "heisskleber.__main__:main" +hkcli = "run.cli:main" [tool.coverage.paths] source = ["heisskleber", "*/site-packages"] diff --git a/run/cli.py b/run/cli.py new file mode 100644 index 0000000..1506f9b --- /dev/null +++ b/run/cli.py @@ -0,0 +1,34 @@ +import argparse +import sys +from typing import Union + +from heisskleber import get_source +from heisskleber.console.sink import ConsoleSink + +TopicType = Union[str, list[str]] + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--type", type=str, choices=["zmq", "mqtt", "serial"], default="zmq") + parser.add_argument("--topic", type=str, default="#") + + return parser.parse_args() + + +def main(): + args = parse_args() + source = get_source(args.type, args.topic) + sink = ConsoleSink() + + while True: + topic, data = source.receive() + sink.send(data, topic) + + +if __name__ == "__main__": + try: + main() + except KeyboardInterrupt: + print("Exiting...") + sys.exit(0) diff --git a/tests/integration/async_streamer.py b/tests/integration/async_streamer.py index 8a7081a..25f5227 100644 --- a/tests/integration/async_streamer.py +++ b/tests/integration/async_streamer.py @@ -7,9 +7,6 @@ from heisskleber.stream.resampler import Resampler, ResamplerConf async def main(): - # topic1 = "/msb-fwd-body/imu" - # topic2 = "/msb-102-a/imu" - # topic2 = "/msb-102-a/rpy" topic1 = "topic1" topic2 = "topic2" @@ -22,18 +19,8 @@ async def main(): resampler1 = Resampler(resampler_config, sub1) resampler2 = Resampler(resampler_config, sub2) - _ = asyncio.create_task(sub1.run()) - _ = asyncio.create_task(sub2.run()) - - # async for resampled_dict in resampler2.resample(): - # print(resampled_dict) - - gen1 = resampler1.resample() - gen2 = resampler2.resample() - while True: - m1 = await anext(gen1) - m2 = await anext(gen2) + m1, m2 = await asyncio.gather(resampler1.receive(), resampler2.receive()) print(f"epoch: {m1['epoch']}") print(f"diff: {diff(m1, m2)}") diff --git a/tests/integration/integration_joint.py b/tests/integration/integration_joint.py index 5f87848..8d2dbb3 100644 --- a/tests/integration/integration_joint.py +++ b/tests/integration/integration_joint.py @@ -1,21 +1,18 @@ import asyncio from heisskleber.mqtt import AsyncMqttSubscriber, MqttConf -from heisskleber.stream import Joint, ResamplerConf +from heisskleber.stream import Joint, Resampler, ResamplerConf async def main(): - topic1 = "topic1" - topic2 = "topic2" + topics = ["topic0", "topic1", "topic2", "topic3"] config = MqttConf(broker="localhost", port=1883, user="", password="") # not a real password - sub1 = AsyncMqttSubscriber(config, topic1) - sub2 = AsyncMqttSubscriber(config, topic2) + subs = [AsyncMqttSubscriber(config, topic=topic) for topic in topics] - resampler_config = ResamplerConf(resample_rate=250) + resampler_config = ResamplerConf(resample_rate=1000) - joint = Joint(resampler_config, [sub1, sub2]) - await joint.setup() + joint = Joint(resampler_config, [Resampler(resampler_config, sub) for sub in subs]) while True: data = await joint.receive() diff --git a/tests/integration/mqtt_async.py b/tests/integration/mqtt_async.py index aa5fe5a..64b7988 100644 --- a/tests/integration/mqtt_async.py +++ b/tests/integration/mqtt_async.py @@ -8,7 +8,7 @@ async def main(): sub = AsyncMqttSubscriber(conf, topic="#") # async for topic, message in sub: # print(message) - _ = asyncio.create_task(sub.run()) + # _ = asyncio.create_task(sub.run()) while True: topic, message = await sub.receive() print(message) diff --git a/tests/integration/mqtt_pub.py b/tests/integration/mqtt_pub.py index d5f0ed6..32d63b8 100644 --- a/tests/integration/mqtt_pub.py +++ b/tests/integration/mqtt_pub.py @@ -1,34 +1,35 @@ +import asyncio import time -from random import random -from heisskleber.mqtt import MqttConf, MqttPublisher +from termcolor import colored + +from heisskleber.mqtt import AsyncMqttPublisher, MqttConf + +colortable = ["red", "green", "yellow", "blue", "magenta", "cyan"] -def main(): - config = MqttConf(broker="localhost", port=1883, user="", password="") - pub = MqttPublisher(config) - pub2 = MqttPublisher(config) - - timestamp = 0 - dt1 = 0.7 - dt2 = 0.5 - t1 = 0 - t2 = 5 +async def send_every_n_miliseconds(frequency, value, pub, topic): + start = time.time() while True: - dt = random() # noqa: S311 - timestamp += dt - print(f"timestamp at {timestamp} s") + epoch = time.time() - start + payload = {"epoch": epoch, f"value{value}": value} + print_message = f"Pub #{int(value)} sending {payload}" + print(colored(print_message, colortable[int(value)])) + await pub.send(payload, topic) + await asyncio.sleep(frequency) - while timestamp - t1 > dt1: - t1 = timestamp + dt1 - pub.send({"value1": 1 + dt, "epoch": timestamp}, "topic1") - print("Pub1 sending") - while timestamp - t2 > dt2: - t2 = timestamp + dt2 - pub2.send({"value2": 2 - dt, "epoch": timestamp}, "topic2") - print("Pub2 sending") - time.sleep(dt) + +async def main2(): + config = MqttConf(broker="localhost", port=1883, user="", password="") + + pubs = [AsyncMqttPublisher(config) for i in range(5)] + tasks = [] + for i, pub in enumerate(pubs): + tasks.append(asyncio.create_task(send_every_n_miliseconds(1 + i * 0.1, i, pub, f"topic{i}"))) + + await asyncio.gather(*tasks) if __name__ == "__main__": - main() + # main() + asyncio.run(main2()) diff --git a/tests/integration/mqtt_stream.py b/tests/integration/mqtt_stream.py index caff62d..6d59dec 100644 --- a/tests/integration/mqtt_stream.py +++ b/tests/integration/mqtt_stream.py @@ -8,15 +8,12 @@ async def main(): conf = MqttConf(broker="localhost", port=1883, user="", password="") sub = AsyncMqttSubscriber(conf, topic="#") - subscriber_task = asyncio.create_task(sub.run()) - resampler = Resampler(ResamplerConf(), sub) - async for data in resampler.resample(): + while True: + data = await resampler.receive() print(data) - subscriber_task.cancel() - if __name__ == "__main__": asyncio.run(main()) diff --git a/tests/integration/mqtt_sub.py b/tests/integration/mqtt_sub.py new file mode 100644 index 0000000..f3a58d5 --- /dev/null +++ b/tests/integration/mqtt_sub.py @@ -0,0 +1,17 @@ +import asyncio + +from heisskleber.mqtt import AsyncMqttSubscriber, MqttConf + + +async def main(): + config = MqttConf(broker="localhost", port=1883, user="", password="") + + sub = AsyncMqttSubscriber(config, topic="#") + + while True: + topic, data = await sub.receive() + print(f"topic: {topic}, data: {data}") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/tests/integration/zmq_pub.py b/tests/integration/zmq_pub.py new file mode 100644 index 0000000..7de1429 --- /dev/null +++ b/tests/integration/zmq_pub.py @@ -0,0 +1,17 @@ +import time + +from heisskleber import get_sink + + +def main(): + sink = get_sink("zmq") + + i = 0 + while True: + sink.send({"test pub": i}, "test") + time.sleep(1) + i += 1 + + +if __name__ == "__main__": + main() diff --git a/tests/stream/test_async_mqtt.py b/tests/stream/test_async_mqtt.py index 6103627..27c4611 100644 --- a/tests/stream/test_async_mqtt.py +++ b/tests/stream/test_async_mqtt.py @@ -7,9 +7,24 @@ from heisskleber.mqtt import AsyncMqttSubscriber from heisskleber.mqtt.config import MqttConf +class MockAsyncClient: + def __init__(self): + self.messages = AsyncMock() + self.messages.return_value = [{"epoch": i, "data": 1} for i in range(10)] + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_value, traceback): + pass + + async def subscribe(self, *args): + pass + + @pytest.fixture def mock_client(): - return AsyncMock() + return MockAsyncClient() @pytest.fixture @@ -18,7 +33,8 @@ def mock_queue(): @pytest.mark.asyncio -async def test_subscribe_topics_single(mock_client, mock_queue): +async def test_subscribe_topics_single(mock_queue): + mock_client = AsyncMock() config = MqttConf() topics = "single_topic" sub = AsyncMqttSubscriber(config, topics) @@ -31,7 +47,8 @@ async def test_subscribe_topics_single(mock_client, mock_queue): @pytest.mark.asyncio -async def test_subscribe_topics_multiple(mock_client, mock_queue): +async def test_subscribe_topics_multiple(mock_queue): + mock_client = AsyncMock() config = MqttConf() topics = ["topic1", "topic2"] sub = AsyncMqttSubscriber(config, topics) diff --git a/tests/test_joint.py b/tests/test_joint.py index f076ba6..828cb79 100644 --- a/tests/test_joint.py +++ b/tests/test_joint.py @@ -3,7 +3,7 @@ from unittest.mock import AsyncMock, MagicMock import pytest from heisskleber.mqtt import AsyncMqttSubscriber -from heisskleber.stream import Joint, ResamplerConf +from heisskleber.stream import Joint, Resampler, ResamplerConf @pytest.fixture @@ -52,8 +52,9 @@ async def test_two_streams_are_parallel(): ) conf = ResamplerConf(resample_rate=1000) - joiner = Joint(conf, [sub1, sub2]) - await joiner.setup() + resamplers = [Resampler(conf, sub1), Resampler(conf, sub2)] + + joiner = Joint(conf, resamplers) return_data = await joiner.receive() assert return_data == {"epoch": 2, "x": 2, "y": 0} diff --git a/tests/test_mqtt.py b/tests/test_mqtt.py index 4956bc2..8edb070 100644 --- a/tests/test_mqtt.py +++ b/tests/test_mqtt.py @@ -114,7 +114,7 @@ def test_receive_with_message(mock_mqtt_conf: MqttConf, mock_mqtt_client, mock_q fake_message = create_fake_mqtt_message(topic, payload) mock_queue.return_value.get.side_effect = [fake_message] - subscriber = MqttSubscriber(topics=[topic], config=mock_mqtt_conf) + subscriber = MqttSubscriber(topics=[topic.decode()], config=mock_mqtt_conf) received_topic, received_payload = subscriber.receive() @@ -129,7 +129,7 @@ def test_message_is_put_into_queue(mock_mqtt_conf: MqttConf, mock_mqtt_client, m fake_message = create_fake_mqtt_message(topic, payload) mock_queue.return_value.get.side_effect = [fake_message] - subscriber = MqttSubscriber(topics=[topic], config=mock_mqtt_conf) + subscriber = MqttSubscriber(topics=[topic.decode()], config=mock_mqtt_conf) subscriber._on_message(None, None, fake_message) @@ -143,7 +143,7 @@ def test_message_is_put_into_queue_with_actual_queue(mock_mqtt_conf, mock_mqtt_c fake_message = create_fake_mqtt_message(topic, payload) # mock_queue.return_value.get.side_effect = [fake_message] - subscriber = MqttSubscriber(topics=[topic], config=mock_mqtt_conf) + subscriber = MqttSubscriber(topics=[topic.decode()], config=mock_mqtt_conf) subscriber._on_message(None, None, fake_message) diff --git a/tests/test_streamer.py b/tests/test_streamer.py index fbf869d..e7b350b 100644 --- a/tests/test_streamer.py +++ b/tests/test_streamer.py @@ -71,11 +71,7 @@ async def test_resampler_multiple_modes(mock_subscriber): resampler = Resampler(config, mock_subscriber) # Test the resample method - resampled_data = [] - - with pytest.raises(RuntimeError): - async for data in resampler.resample(): - resampled_data.append(data) + resampled_data = [await resampler.receive() for _ in range(3)] assert resampled_data[0] == {"epoch": 0.0, "data": 1.5} assert resampled_data[1] == {"epoch": 1.0, "data": 3.5} @@ -96,10 +92,8 @@ async def test_resampler_upsampling(mock_subscriber): config = ResamplerConf(resample_rate=250) # Fill in your MQTT configuration resampler = Resampler(config, mock_subscriber) - resampled_data = [] - with pytest.raises(RuntimeError): - async for data in resampler.resample(): - resampled_data.append(data) + # Test the resample method + resampled_data = [await resampler.receive() for _ in range(7)] assert resampled_data[0] == {"epoch": 0.0, "data": 1.0} assert resampled_data[1] == {"epoch": 0.25, "data": 1.25} diff --git a/tests/test_version.py b/tests/test_version.py deleted file mode 100644 index 95f9abe..0000000 --- a/tests/test_version.py +++ /dev/null @@ -1,7 +0,0 @@ -"""Test the version of the package.""" -import heisskleber - - -def test_heisskleber_version() -> None: - """Test that the glue version is correct.""" - assert heisskleber.__version__ == "0.3.1" diff --git a/tests/zmq/test_zmq_asyncio.py b/tests/zmq/test_zmq_asyncio.py new file mode 100644 index 0000000..1b9761b --- /dev/null +++ b/tests/zmq/test_zmq_asyncio.py @@ -0,0 +1,62 @@ +import time +from collections.abc import Generator +from multiprocessing import Process +from unittest.mock import patch + +import pytest + +from heisskleber.broker.zmq_broker import zmq_broker +from heisskleber.config import load_config +from heisskleber.zmq.config import ZmqConf +from heisskleber.zmq.publisher import ZmqAsyncPublisher, ZmqPublisher +from heisskleber.zmq.subscriber import ZmqAsyncSubscriber + + +@pytest.fixture +def start_broker() -> Generator[Process, None, None]: + # setup broker + with patch( + "heisskleber.config.parse.get_config_filepath", + return_value="tests/resources/zmq.yaml", + ): + broker_config = load_config(ZmqConf(), "zmq", read_commandline=False) + broker_process = Process( + target=zmq_broker, + args=(broker_config,), + ) + # start broker + broker_process.start() + + yield broker_process + + broker_process.terminate() + + +def test_instantiate_subscriber() -> None: + conf = ZmqConf(protocol="tcp", interface="localhost", publisher_port=5555, subscriber_port=5556) + sub = ZmqAsyncSubscriber(conf, "test") + assert sub.config == conf + + +def test_instantiate_publisher() -> None: + conf = ZmqConf(protocol="tcp", interface="localhost", publisher_port=5555, subscriber_port=5556) + pub = ZmqPublisher(conf) + assert pub.config == conf + + +@pytest.mark.asyncio +async def test_send_receive(start_broker) -> None: + print("test_send_receive") + topic = "test" + conf = ZmqConf(protocol="tcp", interface="localhost", publisher_port=5555, subscriber_port=5556) + source = ZmqAsyncSubscriber(conf, topic) + sink = ZmqAsyncPublisher(conf) + time.sleep(1) # this is crucial, otherwise the source might hang + for i in range(10): + message = {"m": i} + await sink.send(message, topic) + print(f"sent {topic} {message}") + t, m = await source.receive() + print(f"received {t} {m}") + assert t == topic + assert m == {"m": i}