mirror of
https://github.com/OMGeeky/flucto-heisskleber.git
synced 2026-02-23 15:38:33 +01:00
Feature/async sources (#46)
* WIP: Added async file reader. * Async resampling and synchronization refactored. * Add async mqtt publisher. Remove queue from joint. * Add async zmq publisher and subscriber. * Modify integration tests for streaming. * Name refactoring resampler. * Added async source/sink to factory. * Refactor joint and add integration tests. * Add termcolor dev dependency * Add conosole source and sink * Add cli interface for different protocols * Removed files unfit for merge. * Fix review requests. * Restore use of $MSB_CONFIG_DIR for now. It seems that the default behaviour is not looking for .config/heisskleber * Remove version test, causing unnecessary failures.
This commit is contained in:
@@ -1,13 +1,18 @@
|
||||
"""Heisskleber."""
|
||||
from .core.async_factories import get_async_sink, get_async_source
|
||||
from .core.factories import get_publisher, get_sink, get_source, get_subscriber
|
||||
from .core.types import Sink, Source
|
||||
from .core.types import AsyncSink, AsyncSource, Sink, Source
|
||||
|
||||
__all__ = [
|
||||
"get_source",
|
||||
"get_sink",
|
||||
"get_publisher",
|
||||
"get_subscriber",
|
||||
"get_async_source",
|
||||
"get_async_sink",
|
||||
"Sink",
|
||||
"Source",
|
||||
"AsyncSink",
|
||||
"AsyncSource",
|
||||
]
|
||||
__version__ = "0.3.1"
|
||||
__version__ = "0.4.0"
|
||||
|
||||
0
heisskleber/console/__init__.py
Normal file
0
heisskleber/console/__init__.py
Normal file
37
heisskleber/console/sink.py
Normal file
37
heisskleber/console/sink.py
Normal file
@@ -0,0 +1,37 @@
|
||||
import json
|
||||
import sys
|
||||
import time
|
||||
from typing import TextIO
|
||||
|
||||
from heisskleber.core.types import AsyncSink, Serializable, Sink
|
||||
|
||||
|
||||
def pretty_print(data: dict[str, Serializable]) -> str:
|
||||
return json.dumps(data, indent=4)
|
||||
|
||||
|
||||
class ConsoleSink(Sink):
|
||||
def __init__(self, stream: TextIO = sys.stdout, pretty: bool = False):
|
||||
self.stream = stream
|
||||
self.print = pretty_print if pretty else json.dumps
|
||||
|
||||
def send(self, data: dict[str, Serializable], topic: str) -> None:
|
||||
self.stream.write(self.print(data)) # type: ignore[operator]
|
||||
self.stream.write("\n")
|
||||
|
||||
|
||||
class AsyncConsoleSink(AsyncSink):
|
||||
def __init__(self, stream: TextIO = sys.stdout, pretty: bool = False):
|
||||
self.stream = stream
|
||||
self.print = pretty_print if pretty else json.dumps
|
||||
|
||||
async def send(self, data: dict[str, Serializable], topic: str) -> None:
|
||||
self.stream.write(self.print(data)) # type: ignore[operator]
|
||||
self.stream.write("\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sink = ConsoleSink()
|
||||
while True:
|
||||
sink.send({"test": "test"}, "test")
|
||||
time.sleep(1)
|
||||
34
heisskleber/console/source.py
Normal file
34
heisskleber/console/source.py
Normal file
@@ -0,0 +1,34 @@
|
||||
import json
|
||||
import sys
|
||||
import time
|
||||
from queue import SimpleQueue
|
||||
from threading import Thread
|
||||
|
||||
from heisskleber.core.types import Serializable, Source
|
||||
|
||||
|
||||
class ConsoleSource(Source):
|
||||
def __init__(self, topic: str | list[str] | tuple[str] = "console") -> None:
|
||||
self.topic = "console"
|
||||
self.queue = SimpleQueue()
|
||||
self.listener_daemon = Thread(target=self.listener_task, daemon=True)
|
||||
self.listener_daemon.start()
|
||||
self.pack = json.loads
|
||||
|
||||
def listener_task(self):
|
||||
while True:
|
||||
data = sys.stdin.readline()
|
||||
payload = self.pack(data)
|
||||
self.queue.put(payload)
|
||||
|
||||
def receive(self) -> tuple[str, dict[str, Serializable]]:
|
||||
data = self.queue.get()
|
||||
return self.topic, data
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
console_source = ConsoleSource()
|
||||
|
||||
while True:
|
||||
print(console_source.receive())
|
||||
time.sleep(1)
|
||||
68
heisskleber/core/async_factories.py
Normal file
68
heisskleber/core/async_factories.py
Normal file
@@ -0,0 +1,68 @@
|
||||
import os
|
||||
|
||||
from heisskleber.config import BaseConf, load_config
|
||||
from heisskleber.mqtt import AsyncMqttPublisher, AsyncMqttSubscriber, MqttConf
|
||||
from heisskleber.zmq import ZmqAsyncPublisher, ZmqAsyncSubscriber, ZmqConf
|
||||
|
||||
from .types import AsyncSink, AsyncSource
|
||||
|
||||
_registered_async_sinks: dict[str, tuple[type[AsyncSink], type[BaseConf]]] = {
|
||||
"mqtt": (AsyncMqttPublisher, MqttConf),
|
||||
"zmq": (ZmqAsyncPublisher, ZmqConf),
|
||||
}
|
||||
|
||||
_registered_async_sources: dict[str, tuple] = {
|
||||
"mqtt": (AsyncMqttSubscriber, MqttConf),
|
||||
"zmq": (ZmqAsyncSubscriber, ZmqConf),
|
||||
}
|
||||
|
||||
|
||||
def get_async_sink(name: str) -> AsyncSink:
|
||||
"""
|
||||
Factory function to create a sink object.
|
||||
|
||||
Parameters:
|
||||
name: Name of the sink to create.
|
||||
config: Configuration object to use for the sink.
|
||||
"""
|
||||
|
||||
if name not in _registered_async_sinks:
|
||||
error_message = f"{name} is not a registered Sink."
|
||||
raise KeyError(error_message)
|
||||
|
||||
pub_cls, conf_cls = _registered_async_sinks[name]
|
||||
|
||||
if "MSB_CONFIG_DIR" in os.environ:
|
||||
print(f"loading {name} config")
|
||||
config = load_config(conf_cls(), name, read_commandline=False)
|
||||
else:
|
||||
print(f"using default {name} config")
|
||||
config = conf_cls()
|
||||
|
||||
return pub_cls(config)
|
||||
|
||||
|
||||
def get_async_source(name: str, topic: str | list[str] | tuple[str]) -> AsyncSource:
|
||||
"""
|
||||
Factory function to create a source object.
|
||||
|
||||
Parameters:
|
||||
name: Name of the source to create.
|
||||
config: Configuration object to use for the source.
|
||||
topic: Topic to subscribe to.
|
||||
"""
|
||||
|
||||
if name not in _registered_async_sources:
|
||||
error_message = f"{name} is not a registered Source."
|
||||
raise KeyError(error_message)
|
||||
|
||||
sub_cls, conf_cls = _registered_async_sources[name]
|
||||
|
||||
if "MSB_CONFIG_DIR" in os.environ:
|
||||
print(f"loading {name} config")
|
||||
config = load_config(conf_cls(), name, read_commandline=False)
|
||||
else:
|
||||
print(f"using default {name} config")
|
||||
config = conf_cls()
|
||||
|
||||
return sub_cls(config, topic)
|
||||
@@ -14,7 +14,6 @@ class Sink(ABC):
|
||||
"""
|
||||
|
||||
pack: Callable[[dict[str, Serializable]], str]
|
||||
config: BaseConf
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, config: BaseConf) -> None:
|
||||
@@ -37,7 +36,6 @@ class Source(ABC):
|
||||
"""
|
||||
|
||||
unpack: Callable[[str], dict[str, Serializable]]
|
||||
config: BaseConf
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, config: BaseConf, topic: str | list[str]) -> None:
|
||||
@@ -77,9 +75,46 @@ class AsyncSubscriber(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class AsyncSource(ABC):
|
||||
"""
|
||||
AsyncSubscriber interface
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def run(self) -> None:
|
||||
def __init__(self, config: Any, topic: str | list[str]) -> None:
|
||||
"""
|
||||
Run the subscriber loop.
|
||||
Initialize the subscriber with a topic and a configuration object.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def receive(self) -> tuple[str, dict[str, Serializable]]:
|
||||
"""
|
||||
Blocking function to receive data from the implemented input stream.
|
||||
|
||||
Data is returned as a tuple of (topic, data).
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class AsyncSink(ABC):
|
||||
"""
|
||||
Sink interface to send() data to.
|
||||
"""
|
||||
|
||||
pack: Callable[[dict[str, Serializable]], str]
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, config: BaseConf) -> None:
|
||||
"""
|
||||
Initialize the publisher with a configuration object.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def send(self, data: dict[str, Any], topic: str) -> None:
|
||||
"""
|
||||
Send data via the implemented output stream.
|
||||
"""
|
||||
pass
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from .config import MqttConf
|
||||
from .publisher import MqttPublisher
|
||||
from .publisher_async import AsyncMqttPublisher
|
||||
from .subscriber import MqttSubscriber
|
||||
from .subscriber_async import AsyncMqttSubscriber
|
||||
|
||||
__all__ = ["MqttConf", "MqttPublisher", "MqttSubscriber", "AsyncMqttSubscriber"]
|
||||
__all__ = ["MqttConf", "MqttPublisher", "MqttSubscriber", "AsyncMqttSubscriber", "AsyncMqttPublisher"]
|
||||
|
||||
51
heisskleber/mqtt/publisher_async.py
Normal file
51
heisskleber/mqtt/publisher_async.py
Normal file
@@ -0,0 +1,51 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import Any
|
||||
|
||||
import aiomqtt
|
||||
|
||||
from heisskleber.core.packer import get_packer
|
||||
from heisskleber.core.types import AsyncSink
|
||||
|
||||
from .config import MqttConf
|
||||
|
||||
|
||||
class AsyncMqttPublisher(AsyncSink):
|
||||
"""
|
||||
MQTT publisher class.
|
||||
Can be used everywhere that a flucto style publishing connection is required.
|
||||
|
||||
Network message loop is handled in a separated thread.
|
||||
"""
|
||||
|
||||
def __init__(self, config: MqttConf) -> None:
|
||||
self.config = config
|
||||
self.pack = get_packer(config.packstyle)
|
||||
self._send_queue = asyncio.Queue()
|
||||
self._sender_task = asyncio.create_task(self.send_work())
|
||||
|
||||
async def send(self, data: dict[str, Any], topic: str) -> None:
|
||||
"""
|
||||
Takes python dictionary, serializes it according to the packstyle
|
||||
and sends it to the broker.
|
||||
|
||||
Publishing is asynchronous
|
||||
"""
|
||||
|
||||
await self._send_queue.put((data, topic))
|
||||
|
||||
async def send_work(self) -> None:
|
||||
"""
|
||||
Takes python dictionary, serializes it according to the packstyle
|
||||
and sends it to the broker.
|
||||
|
||||
Publishing is asynchronous
|
||||
"""
|
||||
async with aiomqtt.Client(
|
||||
hostname=self.config.broker, port=self.config.port, username=self.config.user, password=self.config.password
|
||||
) as client:
|
||||
while True:
|
||||
data, topic = await self._send_queue.get()
|
||||
payload = self.pack(data)
|
||||
await client.publish(topic, payload)
|
||||
@@ -1,18 +1,16 @@
|
||||
from asyncio import Queue, sleep
|
||||
from asyncio import Queue, Task, create_task, sleep
|
||||
|
||||
from aiomqtt import Client, Message, MqttError
|
||||
|
||||
from heisskleber.core.packer import get_unpacker
|
||||
from heisskleber.core.types import AsyncSubscriber, Serializable
|
||||
from heisskleber.core.types import AsyncSource, Serializable
|
||||
from heisskleber.mqtt import MqttConf
|
||||
|
||||
|
||||
class AsyncMqttSubscriber(AsyncSubscriber):
|
||||
class AsyncMqttSubscriber(AsyncSource):
|
||||
"""Asynchronous MQTT susbsciber based on aiomqtt.
|
||||
|
||||
Data is received by one of two methods:
|
||||
1. The `receive` method returns the newest message in the queue. For this to work, the `run` method must be called in a separae task.
|
||||
2. The `generator` method is a generator function that yields a topic and dict payload.
|
||||
Data is received by the `receive` method returns the newest message in the queue.
|
||||
"""
|
||||
|
||||
def __init__(self, config: MqttConf, topic: str | list[str]) -> None:
|
||||
@@ -26,6 +24,7 @@ class AsyncMqttSubscriber(AsyncSubscriber):
|
||||
self.topics = topic
|
||||
self.unpack = get_unpacker(self.config.packstyle)
|
||||
self.message_queue: Queue[Message] = Queue(self.config.max_saved_messages)
|
||||
self._listener_task: Task = create_task(self.create_listener())
|
||||
|
||||
"""
|
||||
Await the newest message in the queue and return Tuple
|
||||
@@ -35,41 +34,21 @@ class AsyncMqttSubscriber(AsyncSubscriber):
|
||||
mqtt_message: Message = await self.message_queue.get()
|
||||
return self._handle_message(mqtt_message)
|
||||
|
||||
"""
|
||||
Listen to incoming messages asynchronously and put them into a queue
|
||||
"""
|
||||
|
||||
async def run(self) -> None:
|
||||
"""
|
||||
Run the async mqtt listening loop.
|
||||
Includes reconnecting to mqtt broker.
|
||||
"""
|
||||
# Manage connection to mqtt
|
||||
async def create_listener(self):
|
||||
while True:
|
||||
try:
|
||||
async with self.client:
|
||||
await self._subscribe_topics()
|
||||
await self._listen_mqtt_loop()
|
||||
except MqttError:
|
||||
except MqttError as e:
|
||||
print(f"MqttError: {e}")
|
||||
print("Connection to MQTT failed. Retrying...")
|
||||
await sleep(1)
|
||||
|
||||
"""
|
||||
Generator function that yields topic and dict payload.
|
||||
Listen to incoming messages asynchronously and put them into a queue
|
||||
"""
|
||||
|
||||
async def generator(self):
|
||||
while True:
|
||||
try:
|
||||
async with self.client:
|
||||
await self._subscribe_topics()
|
||||
async with self.client.messages() as messages:
|
||||
async for message in messages:
|
||||
yield self._handle_message(message)
|
||||
except MqttError:
|
||||
print("Connection to MQTT failed. Retrying...")
|
||||
await sleep(1)
|
||||
|
||||
async def _listen_mqtt_loop(self) -> None:
|
||||
async with self.client.messages() as messages:
|
||||
# async with self.client.filtered_messages(self.topics) as messages:
|
||||
@@ -77,10 +56,11 @@ class AsyncMqttSubscriber(AsyncSubscriber):
|
||||
await self.message_queue.put(message)
|
||||
|
||||
def _handle_message(self, message: Message) -> tuple[str, dict[str, Serializable]]:
|
||||
topic = str(message.topic)
|
||||
if not isinstance(message.payload, bytes):
|
||||
error_msg = "Payload is not of type bytes."
|
||||
raise TypeError(error_msg)
|
||||
|
||||
topic = str(message.topic)
|
||||
message_returned = self.unpack(message.payload.decode())
|
||||
return (topic, message_returned)
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import asyncio
|
||||
|
||||
from heisskleber.core.types import AsyncSubscriber
|
||||
from heisskleber.stream.resampler import Resampler, ResamplerConf
|
||||
|
||||
|
||||
@@ -18,65 +17,83 @@ class Joint:
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, conf: ResamplerConf, subscribers: list[AsyncSubscriber]):
|
||||
def __init__(self, conf: ResamplerConf, resamplers: list[Resampler]):
|
||||
self.conf = conf
|
||||
self.subscribers = subscribers
|
||||
self.generators = []
|
||||
self.resampler_timestamps = []
|
||||
self.latest_timestamp = 0
|
||||
self.latest_data = {}
|
||||
self.tasks = []
|
||||
self.resamplers = resamplers
|
||||
self.output_queue = asyncio.Queue()
|
||||
self.initialized = asyncio.Event()
|
||||
self.initalize_task = asyncio.create_task(self.sync())
|
||||
self.output_task = asyncio.create_task(self.output_work())
|
||||
|
||||
async def receive(self):
|
||||
old_value = self.latest_data.copy()
|
||||
await self._update()
|
||||
return old_value
|
||||
self.combined_dict = {}
|
||||
|
||||
async def generate(self):
|
||||
while True:
|
||||
yield self.latest_data
|
||||
await self._update()
|
||||
"""
|
||||
Main interaction coroutine: Get next value out of the queue.
|
||||
"""
|
||||
|
||||
"""Set up the streamer joint, which will activate all subscribers."""
|
||||
async def receive(self) -> dict:
|
||||
return await self.output_queue.get()
|
||||
|
||||
async def setup(self):
|
||||
for sub in self.subscribers:
|
||||
# Start an async task to run the subscriber loop
|
||||
task = asyncio.create_task(sub.run())
|
||||
self.tasks.append(task)
|
||||
self.generators.append(Resampler(self.conf, sub).resample())
|
||||
|
||||
await self._synchronize()
|
||||
|
||||
async def _synchronize(self):
|
||||
async def sync(self) -> None:
|
||||
print("Starting sync")
|
||||
datas = await asyncio.gather(*[source.receive() for source in self.resamplers])
|
||||
output_data = {}
|
||||
data = {}
|
||||
# first pass to initialize resamplers
|
||||
for resampler in self.generators:
|
||||
data = await anext(resampler)
|
||||
self.resampler_timestamps.append(data["epoch"])
|
||||
|
||||
if data["epoch"] > self.latest_timestamp:
|
||||
self.latest_timestamp = data["epoch"]
|
||||
self.latest_data = dict(data)
|
||||
latest_timestamp: float = 0.0
|
||||
timestamps = []
|
||||
|
||||
for resampler, timestamp in zip(self.generators, self.resampler_timestamps):
|
||||
if timestamp == self.latest_timestamp:
|
||||
continue
|
||||
print("Syncing...")
|
||||
for data in datas:
|
||||
if not isinstance(data["epoch"], float):
|
||||
error = "Timestamps must be floats"
|
||||
raise TypeError(error)
|
||||
|
||||
while timestamp < self.latest_timestamp:
|
||||
data = await anext(resampler)
|
||||
timestamp = data["epoch"]
|
||||
ts = float(data["epoch"])
|
||||
|
||||
self.latest_data.update(data)
|
||||
print(f"Syncing..., got {ts}")
|
||||
|
||||
async def _update(self):
|
||||
data: dict = {}
|
||||
for resampler in self.generators:
|
||||
try:
|
||||
data = await anext(resampler)
|
||||
timestamps.append(ts)
|
||||
if ts > latest_timestamp:
|
||||
latest_timestamp = ts
|
||||
|
||||
if data["epoch"] >= self.latest_timestamp:
|
||||
self.latest_timestamp = data["epoch"]
|
||||
self.latest_data.update(data)
|
||||
except Exception:
|
||||
print(Exception)
|
||||
# only take the piece of the latest data
|
||||
output_data = data
|
||||
|
||||
for resampler, ts in zip(self.resamplers, timestamps):
|
||||
while ts < latest_timestamp:
|
||||
data = await resampler.receive()
|
||||
ts = float(data["epoch"])
|
||||
|
||||
output_data.update(data)
|
||||
|
||||
await self.output_queue.put(output_data)
|
||||
|
||||
print("Finished initalization")
|
||||
self.initialized.set()
|
||||
|
||||
"""
|
||||
Coroutine that waits for new queue data and updates dict.
|
||||
"""
|
||||
|
||||
async def update_dict(self, resampler):
|
||||
# queue is passed by reference, python y u so weird!
|
||||
data = await resampler.receive()
|
||||
if self.combined_dict and self.combined_dict["epoch"] != data["epoch"]:
|
||||
print("Oh shit, this is bad!")
|
||||
self.combined_dict.update(data)
|
||||
|
||||
"""
|
||||
Output worker: iterate through queues, read data and join into output queue.
|
||||
"""
|
||||
|
||||
async def output_work(self):
|
||||
print("Output worker waiting for intitialization")
|
||||
await self.initialized.wait()
|
||||
print("Output worker resuming")
|
||||
|
||||
while True:
|
||||
self.combined_dict = {}
|
||||
tasks = [asyncio.create_task(self.update_dict(res)) for res in self.resamplers]
|
||||
await asyncio.gather(*tasks)
|
||||
await self.output_queue.put(self.combined_dict)
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
import math
|
||||
from collections.abc import AsyncGenerator, Generator
|
||||
from asyncio import Queue, Task, create_task
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import numpy as np
|
||||
|
||||
from heisskleber.core.types import AsyncSubscriber, Serializable
|
||||
from heisskleber.core.types import AsyncSource, Serializable
|
||||
|
||||
from .config import ResamplerConf
|
||||
|
||||
@@ -27,7 +28,7 @@ def timestamp_generator(start_epoch: float, timedelta_in_ms: int) -> Generator[f
|
||||
next_timestamp += delta
|
||||
|
||||
|
||||
def interpolate(t1, y1, t2, y2, t_target):
|
||||
def interpolate(t1, y1, t2, y2, t_target) -> list[float]:
|
||||
"""Perform linear interpolation between two data points."""
|
||||
y1, y2 = np.array(y1), np.array(y2)
|
||||
fraction = (t_target - t1) / (t2 - t1)
|
||||
@@ -47,13 +48,18 @@ class Resampler:
|
||||
Asynchronous Subscriber
|
||||
"""
|
||||
|
||||
def __init__(self, config: ResamplerConf, subscriber: AsyncSubscriber) -> None:
|
||||
def __init__(self, config: ResamplerConf, subscriber: AsyncSource) -> None:
|
||||
self.config = config
|
||||
self.subscriber = subscriber
|
||||
self.resample_rate = self.config.resample_rate
|
||||
self.delta_t = round(self.resample_rate / 1_000, 3)
|
||||
self.message_queue: Queue[dict[str, Serializable]] = Queue(maxsize=50)
|
||||
self.resample_task: Task = create_task(self.resample())
|
||||
|
||||
async def resample(self) -> AsyncGenerator[dict[str, Serializable], None]:
|
||||
async def receive(self) -> dict[str, Serializable]:
|
||||
return await self.message_queue.get()
|
||||
|
||||
async def resample(self) -> None:
|
||||
"""
|
||||
Resample data based on a fixed rate.
|
||||
|
||||
@@ -81,10 +87,7 @@ class Resampler:
|
||||
aggregated_timestamps.append(timestamp)
|
||||
aggregated_data.append(message)
|
||||
# timestamp, message = await self.buffer.get()
|
||||
try:
|
||||
topic, data = await self.subscriber.receive()
|
||||
except Exception as e:
|
||||
raise StopAsyncIteration from e
|
||||
topic, data = await self.subscriber.receive()
|
||||
timestamp, message = self._pack_data(data)
|
||||
# timestamp, message = self._pack_data(message)
|
||||
|
||||
@@ -113,7 +116,7 @@ class Resampler:
|
||||
last_timestamp = return_timestamp
|
||||
return_timestamp += self.delta_t
|
||||
next_timestamp = next(timestamps)
|
||||
yield self._unpack_data(last_timestamp, last_message)
|
||||
await self.message_queue.put(self._unpack_data(last_timestamp, last_message))
|
||||
|
||||
if self._is_upsampling:
|
||||
last_message = interpolate(
|
||||
@@ -127,12 +130,12 @@ class Resampler:
|
||||
# else:
|
||||
# return_timestamp += self.delta_t
|
||||
|
||||
yield self._unpack_data(last_timestamp, last_message)
|
||||
await self.message_queue.put(self._unpack_data(last_timestamp, last_message))
|
||||
|
||||
if len(aggregated_data) > 1:
|
||||
# Case 4 - downsampling: Multiple data points were during the resampling timeframe
|
||||
mean_message = np.mean(np.array(aggregated_data), axis=0)
|
||||
yield self._unpack_data(return_timestamp, mean_message)
|
||||
await self.message_queue.put(self._unpack_data(return_timestamp, mean_message))
|
||||
|
||||
# reset the aggregator
|
||||
aggregated_data.clear()
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from .config import ZmqConf
|
||||
from .publisher import ZmqPublisher
|
||||
from .subscriber import ZmqSubscriber
|
||||
from .publisher import ZmqAsyncPublisher, ZmqPublisher
|
||||
from .subscriber import ZmqAsyncSubscriber, ZmqSubscriber
|
||||
|
||||
__all__ = ["ZmqConf", "ZmqPublisher", "ZmqSubscriber"]
|
||||
__all__ = ["ZmqConf", "ZmqPublisher", "ZmqSubscriber", "ZmqAsyncPublisher", "ZmqAsyncSubscriber"]
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import sys
|
||||
|
||||
import zmq
|
||||
import zmq.asyncio
|
||||
|
||||
from heisskleber.core.packer import get_packer
|
||||
from heisskleber.core.types import Serializable, Sink
|
||||
from heisskleber.core.types import AsyncSink, Serializable, Sink
|
||||
|
||||
from .config import ZmqConf
|
||||
|
||||
@@ -35,3 +36,32 @@ class ZmqPublisher(Sink):
|
||||
|
||||
def __del__(self):
|
||||
self.socket.close()
|
||||
|
||||
|
||||
class ZmqAsyncPublisher(AsyncSink):
|
||||
def __init__(self, config: ZmqConf):
|
||||
self.config = config
|
||||
|
||||
self.context = zmq.asyncio.Context.instance()
|
||||
self.socket: zmq.asyncio.Socket = self.context.socket(zmq.PUB)
|
||||
|
||||
self.pack = get_packer(config.packstyle)
|
||||
self.connect()
|
||||
|
||||
def connect(self) -> None:
|
||||
try:
|
||||
if self.config.verbose:
|
||||
print(f"connecting to {self.config.publisher_address}")
|
||||
self.socket.connect(self.config.publisher_address)
|
||||
except Exception as e:
|
||||
print(f"failed to bind to zeromq socket: {e}")
|
||||
sys.exit(-1)
|
||||
|
||||
async def send(self, data: dict[str, Serializable], topic: str) -> None:
|
||||
payload = self.pack(data)
|
||||
if self.config.verbose:
|
||||
print(f"sending message {payload} to topic {topic}")
|
||||
await self.socket.send_multipart([topic.encode(), payload.encode()])
|
||||
|
||||
def __del__(self):
|
||||
self.socket.close()
|
||||
|
||||
@@ -3,9 +3,10 @@ from __future__ import annotations
|
||||
import sys
|
||||
|
||||
import zmq
|
||||
import zmq.asyncio
|
||||
|
||||
from heisskleber.core.packer import get_unpacker
|
||||
from heisskleber.core.types import Source
|
||||
from heisskleber.core.types import AsyncSource, Source
|
||||
|
||||
from .config import ZmqConf
|
||||
|
||||
@@ -54,3 +55,47 @@ class ZmqSubscriber(Source):
|
||||
|
||||
def __del__(self):
|
||||
self.socket.close()
|
||||
|
||||
|
||||
class ZmqAsyncSubscriber(AsyncSource):
|
||||
def __init__(self, config: ZmqConf, topic: str):
|
||||
self.config = config
|
||||
self.context = zmq.asyncio.Context.instance()
|
||||
self.socket: zmq.asyncio.Socket = self.context.socket(zmq.SUB)
|
||||
self.connect()
|
||||
self.subscribe(topic)
|
||||
|
||||
self.unpack = get_unpacker(config.packstyle)
|
||||
|
||||
def connect(self):
|
||||
try:
|
||||
self.socket.connect(self.config.subscriber_address)
|
||||
except Exception as e:
|
||||
print(f"failed to bind to zeromq socket: {e}")
|
||||
sys.exit(-1)
|
||||
|
||||
def _subscribe_single_topic(self, topic: str):
|
||||
self.socket.setsockopt(zmq.SUBSCRIBE, topic.encode())
|
||||
|
||||
def subscribe(self, topic: str | list[str] | tuple[str]):
|
||||
# Accepts single topic or list of topics
|
||||
if isinstance(topic, (list, tuple)):
|
||||
for t in topic:
|
||||
self._subscribe_single_topic(t)
|
||||
else:
|
||||
self._subscribe_single_topic(topic)
|
||||
|
||||
async def receive(self) -> tuple[str, dict]:
|
||||
"""
|
||||
reads a message from the zmq bus and returns it
|
||||
|
||||
Returns:
|
||||
tuple(topic: str, message: dict): the message received
|
||||
"""
|
||||
(topic, payload) = await self.socket.recv_multipart()
|
||||
message = self.unpack(payload.decode())
|
||||
topic = topic.decode()
|
||||
return (topic, message)
|
||||
|
||||
def __del__(self):
|
||||
self.socket.close()
|
||||
|
||||
36
poetry.lock
generated
36
poetry.lock
generated
@@ -1287,30 +1287,50 @@ files = [
|
||||
{file = "ruamel.yaml.clib-0.2.8-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:b42169467c42b692c19cf539c38d4602069d8c1505e97b86387fcf7afb766e1d"},
|
||||
{file = "ruamel.yaml.clib-0.2.8-cp310-cp310-macosx_13_0_arm64.whl", hash = "sha256:07238db9cbdf8fc1e9de2489a4f68474e70dffcb32232db7c08fa61ca0c7c462"},
|
||||
{file = "ruamel.yaml.clib-0.2.8-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_24_x86_64.whl", hash = "sha256:fff3573c2db359f091e1589c3d7c5fc2f86f5bdb6f24252c2d8e539d4e45f412"},
|
||||
{file = "ruamel.yaml.clib-0.2.8-cp310-cp310-manylinux_2_24_aarch64.whl", hash = "sha256:aa2267c6a303eb483de8d02db2871afb5c5fc15618d894300b88958f729ad74f"},
|
||||
{file = "ruamel.yaml.clib-0.2.8-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:840f0c7f194986a63d2c2465ca63af8ccbbc90ab1c6001b1978f05119b5e7334"},
|
||||
{file = "ruamel.yaml.clib-0.2.8-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:024cfe1fc7c7f4e1aff4a81e718109e13409767e4f871443cbff3dba3578203d"},
|
||||
{file = "ruamel.yaml.clib-0.2.8-cp310-cp310-win32.whl", hash = "sha256:c69212f63169ec1cfc9bb44723bf2917cbbd8f6191a00ef3410f5a7fe300722d"},
|
||||
{file = "ruamel.yaml.clib-0.2.8-cp310-cp310-win_amd64.whl", hash = "sha256:cabddb8d8ead485e255fe80429f833172b4cadf99274db39abc080e068cbcc31"},
|
||||
{file = "ruamel.yaml.clib-0.2.8-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:bef08cd86169d9eafb3ccb0a39edb11d8e25f3dae2b28f5c52fd997521133069"},
|
||||
{file = "ruamel.yaml.clib-0.2.8-cp311-cp311-macosx_13_0_arm64.whl", hash = "sha256:b16420e621d26fdfa949a8b4b47ade8810c56002f5389970db4ddda51dbff248"},
|
||||
{file = "ruamel.yaml.clib-0.2.8-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_24_x86_64.whl", hash = "sha256:25c515e350e5b739842fc3228d662413ef28f295791af5e5110b543cf0b57d9b"},
|
||||
{file = "ruamel.yaml.clib-0.2.8-cp311-cp311-manylinux_2_24_aarch64.whl", hash = "sha256:1707814f0d9791df063f8c19bb51b0d1278b8e9a2353abbb676c2f685dee6afe"},
|
||||
{file = "ruamel.yaml.clib-0.2.8-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:46d378daaac94f454b3a0e3d8d78cafd78a026b1d71443f4966c696b48a6d899"},
|
||||
{file = "ruamel.yaml.clib-0.2.8-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:09b055c05697b38ecacb7ac50bdab2240bfca1a0c4872b0fd309bb07dc9aa3a9"},
|
||||
{file = "ruamel.yaml.clib-0.2.8-cp311-cp311-win32.whl", hash = "sha256:53a300ed9cea38cf5a2a9b069058137c2ca1ce658a874b79baceb8f892f915a7"},
|
||||
{file = "ruamel.yaml.clib-0.2.8-cp311-cp311-win_amd64.whl", hash = "sha256:c2a72e9109ea74e511e29032f3b670835f8a59bbdc9ce692c5b4ed91ccf1eedb"},
|
||||
{file = "ruamel.yaml.clib-0.2.8-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:ebc06178e8821efc9692ea7544aa5644217358490145629914d8020042c24aa1"},
|
||||
{file = "ruamel.yaml.clib-0.2.8-cp312-cp312-macosx_13_0_arm64.whl", hash = "sha256:edaef1c1200c4b4cb914583150dcaa3bc30e592e907c01117c08b13a07255ec2"},
|
||||
{file = "ruamel.yaml.clib-0.2.8-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d176b57452ab5b7028ac47e7b3cf644bcfdc8cacfecf7e71759f7f51a59e5c92"},
|
||||
{file = "ruamel.yaml.clib-0.2.8-cp312-cp312-manylinux_2_24_aarch64.whl", hash = "sha256:1dc67314e7e1086c9fdf2680b7b6c2be1c0d8e3a8279f2e993ca2a7545fecf62"},
|
||||
{file = "ruamel.yaml.clib-0.2.8-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:3213ece08ea033eb159ac52ae052a4899b56ecc124bb80020d9bbceeb50258e9"},
|
||||
{file = "ruamel.yaml.clib-0.2.8-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:aab7fd643f71d7946f2ee58cc88c9b7bfc97debd71dcc93e03e2d174628e7e2d"},
|
||||
{file = "ruamel.yaml.clib-0.2.8-cp312-cp312-win32.whl", hash = "sha256:5c365d91c88390c8d0a8545df0b5857172824b1c604e867161e6b3d59a827eaa"},
|
||||
{file = "ruamel.yaml.clib-0.2.8-cp312-cp312-win_amd64.whl", hash = "sha256:1758ce7d8e1a29d23de54a16ae867abd370f01b5a69e1a3ba75223eaa3ca1a1b"},
|
||||
{file = "ruamel.yaml.clib-0.2.8-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:a5aa27bad2bb83670b71683aae140a1f52b0857a2deff56ad3f6c13a017a26ed"},
|
||||
{file = "ruamel.yaml.clib-0.2.8-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:c58ecd827313af6864893e7af0a3bb85fd529f862b6adbefe14643947cfe2942"},
|
||||
{file = "ruamel.yaml.clib-0.2.8-cp37-cp37m-macosx_12_0_arm64.whl", hash = "sha256:f481f16baec5290e45aebdc2a5168ebc6d35189ae6fea7a58787613a25f6e875"},
|
||||
{file = "ruamel.yaml.clib-0.2.8-cp37-cp37m-manylinux_2_24_aarch64.whl", hash = "sha256:77159f5d5b5c14f7c34073862a6b7d34944075d9f93e681638f6d753606c6ce6"},
|
||||
{file = "ruamel.yaml.clib-0.2.8-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:7f67a1ee819dc4562d444bbafb135832b0b909f81cc90f7aa00260968c9ca1b3"},
|
||||
{file = "ruamel.yaml.clib-0.2.8-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:4ecbf9c3e19f9562c7fdd462e8d18dd902a47ca046a2e64dba80699f0b6c09b7"},
|
||||
{file = "ruamel.yaml.clib-0.2.8-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:87ea5ff66d8064301a154b3933ae406b0863402a799b16e4a1d24d9fbbcbe0d3"},
|
||||
{file = "ruamel.yaml.clib-0.2.8-cp37-cp37m-win32.whl", hash = "sha256:75e1ed13e1f9de23c5607fe6bd1aeaae21e523b32d83bb33918245361e9cc51b"},
|
||||
{file = "ruamel.yaml.clib-0.2.8-cp37-cp37m-win_amd64.whl", hash = "sha256:3f215c5daf6a9d7bbed4a0a4f760f3113b10e82ff4c5c44bec20a68c8014f675"},
|
||||
{file = "ruamel.yaml.clib-0.2.8-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1b617618914cb00bf5c34d4357c37aa15183fa229b24767259657746c9077615"},
|
||||
{file = "ruamel.yaml.clib-0.2.8-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:a6a9ffd280b71ad062eae53ac1659ad86a17f59a0fdc7699fd9be40525153337"},
|
||||
{file = "ruamel.yaml.clib-0.2.8-cp38-cp38-manylinux_2_24_aarch64.whl", hash = "sha256:305889baa4043a09e5b76f8e2a51d4ffba44259f6b4c72dec8ca56207d9c6fe1"},
|
||||
{file = "ruamel.yaml.clib-0.2.8-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:700e4ebb569e59e16a976857c8798aee258dceac7c7d6b50cab63e080058df91"},
|
||||
{file = "ruamel.yaml.clib-0.2.8-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:e2b4c44b60eadec492926a7270abb100ef9f72798e18743939bdbf037aab8c28"},
|
||||
{file = "ruamel.yaml.clib-0.2.8-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:e79e5db08739731b0ce4850bed599235d601701d5694c36570a99a0c5ca41a9d"},
|
||||
{file = "ruamel.yaml.clib-0.2.8-cp38-cp38-win32.whl", hash = "sha256:955eae71ac26c1ab35924203fda6220f84dce57d6d7884f189743e2abe3a9fbe"},
|
||||
{file = "ruamel.yaml.clib-0.2.8-cp38-cp38-win_amd64.whl", hash = "sha256:56f4252222c067b4ce51ae12cbac231bce32aee1d33fbfc9d17e5b8d6966c312"},
|
||||
{file = "ruamel.yaml.clib-0.2.8-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:03d1162b6d1df1caa3a4bd27aa51ce17c9afc2046c31b0ad60a0a96ec22f8001"},
|
||||
{file = "ruamel.yaml.clib-0.2.8-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:bba64af9fa9cebe325a62fa398760f5c7206b215201b0ec825005f1b18b9bccf"},
|
||||
{file = "ruamel.yaml.clib-0.2.8-cp39-cp39-manylinux_2_24_aarch64.whl", hash = "sha256:a1a45e0bb052edf6a1d3a93baef85319733a888363938e1fc9924cb00c8df24c"},
|
||||
{file = "ruamel.yaml.clib-0.2.8-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:da09ad1c359a728e112d60116f626cc9f29730ff3e0e7db72b9a2dbc2e4beed5"},
|
||||
{file = "ruamel.yaml.clib-0.2.8-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:184565012b60405d93838167f425713180b949e9d8dd0bbc7b49f074407c5a8b"},
|
||||
{file = "ruamel.yaml.clib-0.2.8-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:a75879bacf2c987c003368cf14bed0ffe99e8e85acfa6c0bfffc21a090f16880"},
|
||||
{file = "ruamel.yaml.clib-0.2.8-cp39-cp39-win32.whl", hash = "sha256:84b554931e932c46f94ab306913ad7e11bba988104c5cff26d90d03f68258cd5"},
|
||||
{file = "ruamel.yaml.clib-0.2.8-cp39-cp39-win_amd64.whl", hash = "sha256:25ac8c08322002b06fa1d49d1646181f0b2c72f5cbc15a85e80b4c30a544bb15"},
|
||||
{file = "ruamel.yaml.clib-0.2.8.tar.gz", hash = "sha256:beb2e0404003de9a4cab9753a8805a8fe9320ee6673136ed7f04255fe60bb512"},
|
||||
@@ -1641,6 +1661,20 @@ Sphinx = ">=5"
|
||||
lint = ["docutils-stubs", "flake8", "mypy"]
|
||||
test = ["pytest"]
|
||||
|
||||
[[package]]
|
||||
name = "termcolor"
|
||||
version = "2.4.0"
|
||||
description = "ANSI color formatting for output in terminal"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "termcolor-2.4.0-py3-none-any.whl", hash = "sha256:9297c0df9c99445c2412e832e882a7884038a25617c60cea2ad69488d4040d63"},
|
||||
{file = "termcolor-2.4.0.tar.gz", hash = "sha256:aab9e56047c8ac41ed798fa36d892a37aca6b3e9159f3e0c24bc64a9b3ac7b7a"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
tests = ["pytest", "pytest-cov"]
|
||||
|
||||
[[package]]
|
||||
name = "tokenize-rt"
|
||||
version = "5.2.0"
|
||||
@@ -1827,4 +1861,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.9"
|
||||
content-hash = "62e84c22355bb5cf5e86a3b6e58dd801376d460877b65a3d2a972c2c5f72a13f"
|
||||
content-hash = "f1f0bc51241cb45c05aa4d5d99bcac21c6a14c24d9a331795d117e0801a3d0dd"
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[tool.poetry]
|
||||
name = "heisskleber"
|
||||
version = "0.3.1"
|
||||
version = "0.4.0"
|
||||
description = "Heisskleber"
|
||||
authors = ["Felix Weiler <felix@flucto.tech>"]
|
||||
license = "MIT"
|
||||
@@ -40,6 +40,7 @@ typeguard = ">=2.13.3"
|
||||
xdoctest = { extras = ["colors"], version = ">=0.15.10" }
|
||||
myst-parser = { version = ">=0.16.1" }
|
||||
pytest-asyncio = "^0.21.1"
|
||||
termcolor = "^2.4.0"
|
||||
|
||||
|
||||
[tool.poetry.group.types.dependencies]
|
||||
@@ -48,7 +49,7 @@ types-pyyaml = "^6.0.12.12"
|
||||
types-paho-mqtt = "^1.6.0.7"
|
||||
|
||||
[tool.poetry.scripts]
|
||||
heisskleber = "heisskleber.__main__:main"
|
||||
hkcli = "run.cli:main"
|
||||
|
||||
[tool.coverage.paths]
|
||||
source = ["heisskleber", "*/site-packages"]
|
||||
|
||||
34
run/cli.py
Normal file
34
run/cli.py
Normal file
@@ -0,0 +1,34 @@
|
||||
import argparse
|
||||
import sys
|
||||
from typing import Union
|
||||
|
||||
from heisskleber import get_source
|
||||
from heisskleber.console.sink import ConsoleSink
|
||||
|
||||
TopicType = Union[str, list[str]]
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--type", type=str, choices=["zmq", "mqtt", "serial"], default="zmq")
|
||||
parser.add_argument("--topic", type=str, default="#")
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
source = get_source(args.type, args.topic)
|
||||
sink = ConsoleSink()
|
||||
|
||||
while True:
|
||||
topic, data = source.receive()
|
||||
sink.send(data, topic)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
main()
|
||||
except KeyboardInterrupt:
|
||||
print("Exiting...")
|
||||
sys.exit(0)
|
||||
@@ -7,9 +7,6 @@ from heisskleber.stream.resampler import Resampler, ResamplerConf
|
||||
|
||||
|
||||
async def main():
|
||||
# topic1 = "/msb-fwd-body/imu"
|
||||
# topic2 = "/msb-102-a/imu"
|
||||
# topic2 = "/msb-102-a/rpy"
|
||||
topic1 = "topic1"
|
||||
topic2 = "topic2"
|
||||
|
||||
@@ -22,18 +19,8 @@ async def main():
|
||||
resampler1 = Resampler(resampler_config, sub1)
|
||||
resampler2 = Resampler(resampler_config, sub2)
|
||||
|
||||
_ = asyncio.create_task(sub1.run())
|
||||
_ = asyncio.create_task(sub2.run())
|
||||
|
||||
# async for resampled_dict in resampler2.resample():
|
||||
# print(resampled_dict)
|
||||
|
||||
gen1 = resampler1.resample()
|
||||
gen2 = resampler2.resample()
|
||||
|
||||
while True:
|
||||
m1 = await anext(gen1)
|
||||
m2 = await anext(gen2)
|
||||
m1, m2 = await asyncio.gather(resampler1.receive(), resampler2.receive())
|
||||
|
||||
print(f"epoch: {m1['epoch']}")
|
||||
print(f"diff: {diff(m1, m2)}")
|
||||
|
||||
@@ -1,21 +1,18 @@
|
||||
import asyncio
|
||||
|
||||
from heisskleber.mqtt import AsyncMqttSubscriber, MqttConf
|
||||
from heisskleber.stream import Joint, ResamplerConf
|
||||
from heisskleber.stream import Joint, Resampler, ResamplerConf
|
||||
|
||||
|
||||
async def main():
|
||||
topic1 = "topic1"
|
||||
topic2 = "topic2"
|
||||
topics = ["topic0", "topic1", "topic2", "topic3"]
|
||||
|
||||
config = MqttConf(broker="localhost", port=1883, user="", password="") # not a real password
|
||||
sub1 = AsyncMqttSubscriber(config, topic1)
|
||||
sub2 = AsyncMqttSubscriber(config, topic2)
|
||||
subs = [AsyncMqttSubscriber(config, topic=topic) for topic in topics]
|
||||
|
||||
resampler_config = ResamplerConf(resample_rate=250)
|
||||
resampler_config = ResamplerConf(resample_rate=1000)
|
||||
|
||||
joint = Joint(resampler_config, [sub1, sub2])
|
||||
await joint.setup()
|
||||
joint = Joint(resampler_config, [Resampler(resampler_config, sub) for sub in subs])
|
||||
|
||||
while True:
|
||||
data = await joint.receive()
|
||||
|
||||
@@ -8,7 +8,7 @@ async def main():
|
||||
sub = AsyncMqttSubscriber(conf, topic="#")
|
||||
# async for topic, message in sub:
|
||||
# print(message)
|
||||
_ = asyncio.create_task(sub.run())
|
||||
# _ = asyncio.create_task(sub.run())
|
||||
while True:
|
||||
topic, message = await sub.receive()
|
||||
print(message)
|
||||
|
||||
@@ -1,34 +1,35 @@
|
||||
import asyncio
|
||||
import time
|
||||
from random import random
|
||||
|
||||
from heisskleber.mqtt import MqttConf, MqttPublisher
|
||||
from termcolor import colored
|
||||
|
||||
from heisskleber.mqtt import AsyncMqttPublisher, MqttConf
|
||||
|
||||
colortable = ["red", "green", "yellow", "blue", "magenta", "cyan"]
|
||||
|
||||
|
||||
def main():
|
||||
config = MqttConf(broker="localhost", port=1883, user="", password="")
|
||||
pub = MqttPublisher(config)
|
||||
pub2 = MqttPublisher(config)
|
||||
|
||||
timestamp = 0
|
||||
dt1 = 0.7
|
||||
dt2 = 0.5
|
||||
t1 = 0
|
||||
t2 = 5
|
||||
async def send_every_n_miliseconds(frequency, value, pub, topic):
|
||||
start = time.time()
|
||||
while True:
|
||||
dt = random() # noqa: S311
|
||||
timestamp += dt
|
||||
print(f"timestamp at {timestamp} s")
|
||||
epoch = time.time() - start
|
||||
payload = {"epoch": epoch, f"value{value}": value}
|
||||
print_message = f"Pub #{int(value)} sending {payload}"
|
||||
print(colored(print_message, colortable[int(value)]))
|
||||
await pub.send(payload, topic)
|
||||
await asyncio.sleep(frequency)
|
||||
|
||||
while timestamp - t1 > dt1:
|
||||
t1 = timestamp + dt1
|
||||
pub.send({"value1": 1 + dt, "epoch": timestamp}, "topic1")
|
||||
print("Pub1 sending")
|
||||
while timestamp - t2 > dt2:
|
||||
t2 = timestamp + dt2
|
||||
pub2.send({"value2": 2 - dt, "epoch": timestamp}, "topic2")
|
||||
print("Pub2 sending")
|
||||
time.sleep(dt)
|
||||
|
||||
async def main2():
|
||||
config = MqttConf(broker="localhost", port=1883, user="", password="")
|
||||
|
||||
pubs = [AsyncMqttPublisher(config) for i in range(5)]
|
||||
tasks = []
|
||||
for i, pub in enumerate(pubs):
|
||||
tasks.append(asyncio.create_task(send_every_n_miliseconds(1 + i * 0.1, i, pub, f"topic{i}")))
|
||||
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
# main()
|
||||
asyncio.run(main2())
|
||||
|
||||
@@ -8,15 +8,12 @@ async def main():
|
||||
conf = MqttConf(broker="localhost", port=1883, user="", password="")
|
||||
sub = AsyncMqttSubscriber(conf, topic="#")
|
||||
|
||||
subscriber_task = asyncio.create_task(sub.run())
|
||||
|
||||
resampler = Resampler(ResamplerConf(), sub)
|
||||
|
||||
async for data in resampler.resample():
|
||||
while True:
|
||||
data = await resampler.receive()
|
||||
print(data)
|
||||
|
||||
subscriber_task.cancel()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
||||
17
tests/integration/mqtt_sub.py
Normal file
17
tests/integration/mqtt_sub.py
Normal file
@@ -0,0 +1,17 @@
|
||||
import asyncio
|
||||
|
||||
from heisskleber.mqtt import AsyncMqttSubscriber, MqttConf
|
||||
|
||||
|
||||
async def main():
|
||||
config = MqttConf(broker="localhost", port=1883, user="", password="")
|
||||
|
||||
sub = AsyncMqttSubscriber(config, topic="#")
|
||||
|
||||
while True:
|
||||
topic, data = await sub.receive()
|
||||
print(f"topic: {topic}, data: {data}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
17
tests/integration/zmq_pub.py
Normal file
17
tests/integration/zmq_pub.py
Normal file
@@ -0,0 +1,17 @@
|
||||
import time
|
||||
|
||||
from heisskleber import get_sink
|
||||
|
||||
|
||||
def main():
|
||||
sink = get_sink("zmq")
|
||||
|
||||
i = 0
|
||||
while True:
|
||||
sink.send({"test pub": i}, "test")
|
||||
time.sleep(1)
|
||||
i += 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -7,9 +7,24 @@ from heisskleber.mqtt import AsyncMqttSubscriber
|
||||
from heisskleber.mqtt.config import MqttConf
|
||||
|
||||
|
||||
class MockAsyncClient:
|
||||
def __init__(self):
|
||||
self.messages = AsyncMock()
|
||||
self.messages.return_value = [{"epoch": i, "data": 1} for i in range(10)]
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_value, traceback):
|
||||
pass
|
||||
|
||||
async def subscribe(self, *args):
|
||||
pass
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_client():
|
||||
return AsyncMock()
|
||||
return MockAsyncClient()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -18,7 +33,8 @@ def mock_queue():
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_subscribe_topics_single(mock_client, mock_queue):
|
||||
async def test_subscribe_topics_single(mock_queue):
|
||||
mock_client = AsyncMock()
|
||||
config = MqttConf()
|
||||
topics = "single_topic"
|
||||
sub = AsyncMqttSubscriber(config, topics)
|
||||
@@ -31,7 +47,8 @@ async def test_subscribe_topics_single(mock_client, mock_queue):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_subscribe_topics_multiple(mock_client, mock_queue):
|
||||
async def test_subscribe_topics_multiple(mock_queue):
|
||||
mock_client = AsyncMock()
|
||||
config = MqttConf()
|
||||
topics = ["topic1", "topic2"]
|
||||
sub = AsyncMqttSubscriber(config, topics)
|
||||
|
||||
@@ -3,7 +3,7 @@ from unittest.mock import AsyncMock, MagicMock
|
||||
import pytest
|
||||
|
||||
from heisskleber.mqtt import AsyncMqttSubscriber
|
||||
from heisskleber.stream import Joint, ResamplerConf
|
||||
from heisskleber.stream import Joint, Resampler, ResamplerConf
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -52,8 +52,9 @@ async def test_two_streams_are_parallel():
|
||||
)
|
||||
conf = ResamplerConf(resample_rate=1000)
|
||||
|
||||
joiner = Joint(conf, [sub1, sub2])
|
||||
await joiner.setup()
|
||||
resamplers = [Resampler(conf, sub1), Resampler(conf, sub2)]
|
||||
|
||||
joiner = Joint(conf, resamplers)
|
||||
|
||||
return_data = await joiner.receive()
|
||||
assert return_data == {"epoch": 2, "x": 2, "y": 0}
|
||||
|
||||
@@ -114,7 +114,7 @@ def test_receive_with_message(mock_mqtt_conf: MqttConf, mock_mqtt_client, mock_q
|
||||
fake_message = create_fake_mqtt_message(topic, payload)
|
||||
|
||||
mock_queue.return_value.get.side_effect = [fake_message]
|
||||
subscriber = MqttSubscriber(topics=[topic], config=mock_mqtt_conf)
|
||||
subscriber = MqttSubscriber(topics=[topic.decode()], config=mock_mqtt_conf)
|
||||
|
||||
received_topic, received_payload = subscriber.receive()
|
||||
|
||||
@@ -129,7 +129,7 @@ def test_message_is_put_into_queue(mock_mqtt_conf: MqttConf, mock_mqtt_client, m
|
||||
fake_message = create_fake_mqtt_message(topic, payload)
|
||||
|
||||
mock_queue.return_value.get.side_effect = [fake_message]
|
||||
subscriber = MqttSubscriber(topics=[topic], config=mock_mqtt_conf)
|
||||
subscriber = MqttSubscriber(topics=[topic.decode()], config=mock_mqtt_conf)
|
||||
|
||||
subscriber._on_message(None, None, fake_message)
|
||||
|
||||
@@ -143,7 +143,7 @@ def test_message_is_put_into_queue_with_actual_queue(mock_mqtt_conf, mock_mqtt_c
|
||||
fake_message = create_fake_mqtt_message(topic, payload)
|
||||
|
||||
# mock_queue.return_value.get.side_effect = [fake_message]
|
||||
subscriber = MqttSubscriber(topics=[topic], config=mock_mqtt_conf)
|
||||
subscriber = MqttSubscriber(topics=[topic.decode()], config=mock_mqtt_conf)
|
||||
|
||||
subscriber._on_message(None, None, fake_message)
|
||||
|
||||
|
||||
@@ -71,11 +71,7 @@ async def test_resampler_multiple_modes(mock_subscriber):
|
||||
resampler = Resampler(config, mock_subscriber)
|
||||
|
||||
# Test the resample method
|
||||
resampled_data = []
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
async for data in resampler.resample():
|
||||
resampled_data.append(data)
|
||||
resampled_data = [await resampler.receive() for _ in range(3)]
|
||||
|
||||
assert resampled_data[0] == {"epoch": 0.0, "data": 1.5}
|
||||
assert resampled_data[1] == {"epoch": 1.0, "data": 3.5}
|
||||
@@ -96,10 +92,8 @@ async def test_resampler_upsampling(mock_subscriber):
|
||||
config = ResamplerConf(resample_rate=250) # Fill in your MQTT configuration
|
||||
resampler = Resampler(config, mock_subscriber)
|
||||
|
||||
resampled_data = []
|
||||
with pytest.raises(RuntimeError):
|
||||
async for data in resampler.resample():
|
||||
resampled_data.append(data)
|
||||
# Test the resample method
|
||||
resampled_data = [await resampler.receive() for _ in range(7)]
|
||||
|
||||
assert resampled_data[0] == {"epoch": 0.0, "data": 1.0}
|
||||
assert resampled_data[1] == {"epoch": 0.25, "data": 1.25}
|
||||
|
||||
@@ -1,7 +0,0 @@
|
||||
"""Test the version of the package."""
|
||||
import heisskleber
|
||||
|
||||
|
||||
def test_heisskleber_version() -> None:
|
||||
"""Test that the glue version is correct."""
|
||||
assert heisskleber.__version__ == "0.3.1"
|
||||
62
tests/zmq/test_zmq_asyncio.py
Normal file
62
tests/zmq/test_zmq_asyncio.py
Normal file
@@ -0,0 +1,62 @@
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from multiprocessing import Process
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from heisskleber.broker.zmq_broker import zmq_broker
|
||||
from heisskleber.config import load_config
|
||||
from heisskleber.zmq.config import ZmqConf
|
||||
from heisskleber.zmq.publisher import ZmqAsyncPublisher, ZmqPublisher
|
||||
from heisskleber.zmq.subscriber import ZmqAsyncSubscriber
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def start_broker() -> Generator[Process, None, None]:
|
||||
# setup broker
|
||||
with patch(
|
||||
"heisskleber.config.parse.get_config_filepath",
|
||||
return_value="tests/resources/zmq.yaml",
|
||||
):
|
||||
broker_config = load_config(ZmqConf(), "zmq", read_commandline=False)
|
||||
broker_process = Process(
|
||||
target=zmq_broker,
|
||||
args=(broker_config,),
|
||||
)
|
||||
# start broker
|
||||
broker_process.start()
|
||||
|
||||
yield broker_process
|
||||
|
||||
broker_process.terminate()
|
||||
|
||||
|
||||
def test_instantiate_subscriber() -> None:
|
||||
conf = ZmqConf(protocol="tcp", interface="localhost", publisher_port=5555, subscriber_port=5556)
|
||||
sub = ZmqAsyncSubscriber(conf, "test")
|
||||
assert sub.config == conf
|
||||
|
||||
|
||||
def test_instantiate_publisher() -> None:
|
||||
conf = ZmqConf(protocol="tcp", interface="localhost", publisher_port=5555, subscriber_port=5556)
|
||||
pub = ZmqPublisher(conf)
|
||||
assert pub.config == conf
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_receive(start_broker) -> None:
|
||||
print("test_send_receive")
|
||||
topic = "test"
|
||||
conf = ZmqConf(protocol="tcp", interface="localhost", publisher_port=5555, subscriber_port=5556)
|
||||
source = ZmqAsyncSubscriber(conf, topic)
|
||||
sink = ZmqAsyncPublisher(conf)
|
||||
time.sleep(1) # this is crucial, otherwise the source might hang
|
||||
for i in range(10):
|
||||
message = {"m": i}
|
||||
await sink.send(message, topic)
|
||||
print(f"sent {topic} {message}")
|
||||
t, m = await source.receive()
|
||||
print(f"received {t} {m}")
|
||||
assert t == topic
|
||||
assert m == {"m": i}
|
||||
Reference in New Issue
Block a user