Refactor/background tasks (#75)

* Add start, stop and __repr__ to sink and source types.

* Merge conflicts on mqtt async pub and resampler.

* Add start() and stop() functions to udp and zmq.

Change tests accordingly.

* Rename broker, ip, interface to common config name "host".

* Updated "host" entry in config files.

* Add lazyload to mqtt-source.
This commit is contained in:
Felix Weiler
2024-02-22 18:50:13 +08:00
committed by GitHub
parent 01eebe3cbd
commit 8c985bdf3c
38 changed files with 703 additions and 268 deletions

View File

@@ -1,4 +1,4 @@
broker: "10.47.36.1"
host: "10.47.36.1"
user: ""
password: ""
port: 1883

View File

@@ -1,4 +1,4 @@
protocol : "tcp" # ipc protocol
interface: "127.0.0.1" # the interface to bind to
publisher_port : 5555 # port used by primary producers
subscriber_port: 5556 # port used by primary consumers
protocol: "tcp" # ipc protocol
host: "127.0.0.1" # the interface to bind to
publisher_port: 5555 # port used by primary producers
subscriber_port: 5556 # port used by primary consumers

View File

@@ -1,4 +1,4 @@
from .config import BaseConf
from .parse import load_config
from .parse import ConfigType, load_config
__all__ = ["load_config", "BaseConf"]
__all__ = ["load_config", "BaseConf", "ConfigType"]

View File

@@ -16,6 +16,15 @@ class ConsoleSink(Sink):
else:
print(verbose_topic + str(data))
def __repr__(self) -> str:
return f"{self.__class__.__name__}(pretty={self.pretty}, verbose={self.verbose})"
def start(self) -> None:
pass
def stop(self) -> None:
pass
class AsyncConsoleSink(AsyncSink):
def __init__(self, pretty: bool = False, verbose: bool = False) -> None:
@@ -29,6 +38,15 @@ class AsyncConsoleSink(AsyncSink):
else:
print(verbose_topic + str(data))
def __repr__(self) -> str:
return f"{self.__class__.__name__}(pretty={self.pretty}, verbose={self.verbose})"
def start(self) -> None:
pass
def stop(self) -> None:
pass
if __name__ == "__main__":
sink = ConsoleSink()

View File

@@ -1,34 +1,98 @@
import asyncio
import json
import sys
import time
from queue import SimpleQueue
from threading import Thread
from heisskleber.core.types import Serializable, Source
from heisskleber.core.types import AsyncSource, Serializable, Source
class ConsoleSource(Source):
def __init__(self, topic: str | list[str] | tuple[str] = "console") -> None:
self.topic = "console"
def __init__(self, topic: str = "console") -> None:
self.topic = topic
self.queue = SimpleQueue()
self.listener_daemon = Thread(target=self.listener_task, daemon=True)
self.listener_daemon.start()
self.pack = json.loads
self.thread: Thread | None = None
def listener_task(self):
while True:
data = sys.stdin.readline()
payload = self.pack(data)
self.queue.put(payload)
try:
data = sys.stdin.readline()
payload = self.pack(data)
self.queue.put(payload)
except json.decoder.JSONDecodeError:
print("Invalid JSON")
continue
except ValueError:
break
print("listener task finished")
def receive(self) -> tuple[str, dict[str, Serializable]]:
if not self.thread:
self.start()
data = self.queue.get()
return self.topic, data
def __repr__(self) -> str:
return f"{self.__class__.__name__}(topic={self.topic})"
def start(self) -> None:
self.thread = Thread(target=self.listener_task, daemon=True)
self.thread.start()
def stop(self) -> None:
if self.thread:
sys.stdin.close()
self.thread.join()
class AsyncConsoleSource(AsyncSource):
def __init__(self, topic: str = "console") -> None:
self.topic = topic
self.queue: asyncio.Queue[dict[str, Serializable]] = asyncio.Queue(maxsize=10)
self.pack = json.loads
self.task: asyncio.Task[None] | None = None
async def listener_task(self):
while True:
data = sys.stdin.readline()
payload = self.pack(data)
await self.queue.put(payload)
async def receive(self) -> tuple[str, dict[str, Serializable]]:
if not self.task:
self.start()
data = await self.queue.get()
return self.topic, data
def __repr__(self) -> str:
return f"{self.__class__.__name__}(topic={self.topic})"
def start(self) -> None:
self.task = asyncio.create_task(self.listener_task())
def stop(self) -> None:
if self.task:
self.task.cancel()
if __name__ == "__main__":
console_source = ConsoleSource()
console_source.start()
while True:
print(console_source.receive())
time.sleep(1)
print("Listening to console input.")
count = 0
try:
while True:
print(console_source.receive())
time.sleep(1)
count += 1
print(count)
except KeyboardInterrupt:
print("Stopped")
sys.exit(0)

View File

@@ -3,7 +3,7 @@ from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any, Callable, Union
from heisskleber.config.config import BaseConf
from heisskleber.config import BaseConf
Serializable = Union[str, int, float]
@@ -23,12 +23,30 @@ class Sink(ABC):
pass
@abstractmethod
def send(self, data: dict[str, Any], topic: str) -> None:
def send(self, data: dict[str, Serializable], topic: str) -> None:
"""
Send data via the implemented output stream.
"""
pass
@abstractmethod
def __repr__(self) -> str:
pass
@abstractmethod
def start(self) -> None:
"""
Start any background processes and tasks.
"""
pass
@abstractmethod
def stop(self) -> None:
"""
Stop any background processes and tasks.
"""
pass
class Source(ABC):
"""
@@ -53,25 +71,21 @@ class Source(ABC):
"""
pass
class AsyncSubscriber(ABC):
"""
AsyncSubscriber interface
"""
@abstractmethod
def __repr__(self) -> str:
pass
@abstractmethod
def __init__(self, config: Any, topic: str | list[str]) -> None:
def start(self) -> None:
"""
Initialize the subscriber with a topic and a configuration object.
Start any background processes and tasks.
"""
pass
@abstractmethod
async def receive(self) -> tuple[str, dict[str, Serializable]]:
def stop(self) -> None:
"""
Blocking function to receive data from the implemented input stream.
Data is returned as a tuple of (topic, data).
Stop any background processes and tasks.
"""
pass
@@ -97,6 +111,24 @@ class AsyncSource(ABC):
"""
pass
@abstractmethod
def __repr__(self) -> str:
pass
@abstractmethod
def start(self) -> None:
"""
Start any background processes and tasks.
"""
pass
@abstractmethod
def stop(self) -> None:
"""
Stop any background processes and tasks.
"""
pass
class AsyncSink(ABC):
"""
@@ -118,3 +150,21 @@ class AsyncSink(ABC):
Send data via the implemented output stream.
"""
pass
@abstractmethod
def __repr__(self) -> str:
pass
@abstractmethod
def start(self) -> None:
"""
Start any background processes and tasks.
"""
pass
@abstractmethod
def stop(self) -> None:
"""
Stop any background processes and tasks.
"""
pass

View File

@@ -9,7 +9,7 @@ class MqttConf(BaseConf):
MQTT configuration class.
"""
broker: str = "localhost"
host: str = "localhost"
user: str = ""
password: str = ""
port: int = 1883

View File

@@ -34,8 +34,14 @@ class MqttBase:
def __init__(self, config: MqttConf) -> None:
self.config = config
self.client: mqtt_client | None = None
def start(self) -> None:
self.connect()
self.client.loop_start()
def stop(self) -> None:
if self.client:
self.client.loop_stop()
def connect(self) -> None:
self.client = mqtt_client()
@@ -52,7 +58,8 @@ class MqttBase:
# the default certification authority of the system is used.
self.client.tls_set(tls_version=ssl.PROTOCOL_TLS_CLIENT)
self.client.connect(self.config.broker, self.config.port)
self.client.connect(self.config.host, self.config.port)
self.client.loop_start()
@staticmethod
def _raise_if_thread_died() -> None:
@@ -63,7 +70,7 @@ class MqttBase:
# MQTT callbacks
def _on_connect(self, client, userdata, flags, return_code) -> None:
if return_code == 0:
print(f"MQTT node connected to {self.config.broker}:{self.config.port}")
print(f"MQTT node connected to {self.config.host}:{self.config.port}")
else:
print("Connection failed!")
if self.config.verbose:
@@ -84,4 +91,4 @@ class MqttBase:
print(f"Received message: {message.payload!s}, topic: {message.topic}, qos: {message.qos}")
def __del__(self) -> None:
self.client.loop_stop()
self.stop()

View File

@@ -1,9 +1,7 @@
from __future__ import annotations
from typing import Any
from heisskleber.core.packer import get_packer
from heisskleber.core.types import Sink
from heisskleber.core.types import Serializable, Sink
from .config import MqttConf
from .mqtt_base import MqttBase
@@ -21,14 +19,26 @@ class MqttPublisher(MqttBase, Sink):
super().__init__(config)
self.pack = get_packer(config.packstyle)
def send(self, data: dict[str, Any], topic: str) -> None:
def send(self, data: dict[str, Serializable], topic: str) -> None:
"""
Takes python dictionary, serializes it according to the packstyle
and sends it to the broker.
Publishing is asynchronous
"""
if not self.client.is_connected():
self.start()
self._raise_if_thread_died()
payload = self.pack(data)
self.client.publish(topic, payload, qos=self.config.qos, retain=self.config.retain)
def __repr__(self) -> str:
return f"{self.__class__.__name__}(broker={self.config.host}, port={self.config.port})"
def start(self) -> None:
super().start()
def stop(self) -> None:
super().stop()

View File

@@ -1,7 +1,4 @@
from __future__ import annotations
import asyncio
from asyncio import Queue, Task, create_task
from asyncio import Queue, Task, create_task, sleep
import aiomqtt
@@ -22,8 +19,8 @@ class AsyncMqttPublisher(AsyncSink):
def __init__(self, config: MqttConf) -> None:
self.config = config
self.pack = get_packer(config.packstyle)
self._send_queue: Queue[tuple[dict[str, Serializable], str]] = Queue(maxsize=config.max_saved_messages)
self._sender_task: Task[None] = create_task(self.send_work())
self._send_queue: Queue[tuple[dict[str, Serializable], str]] = Queue()
self._sender_task: Task[None] | None = None
async def send(self, data: dict[str, Serializable], topic: str) -> None:
"""
@@ -32,6 +29,8 @@ class AsyncMqttPublisher(AsyncSink):
Publishing is asynchronous
"""
if not self._sender_task:
self.start()
await self._send_queue.put((data, topic))
@@ -45,7 +44,7 @@ class AsyncMqttPublisher(AsyncSink):
while True:
try:
async with aiomqtt.Client(
hostname=self.config.broker,
hostname=self.config.host,
port=self.config.port,
username=self.config.user,
password=self.config.password,
@@ -57,4 +56,15 @@ class AsyncMqttPublisher(AsyncSink):
await client.publish(topic, payload)
except aiomqtt.MqttError:
print("Connection to MQTT broker failed. Retrying in 5 seconds")
await asyncio.sleep(5)
await sleep(5)
def __repr__(self) -> str:
return f"{self.__class__.__name__}(broker={self.config.host}, port={self.config.port})"
def start(self) -> None:
self._sender_task = create_task(self.send_work())
def stop(self) -> None:
if self._sender_task:
self._sender_task.cancel()
self._sender_task = None

View File

@@ -22,9 +22,8 @@ class MqttSubscriber(MqttBase, Source):
def __init__(self, config: MqttConf, topics: str | list[str]) -> None:
super().__init__(config)
self.topics = topics
self._message_queue: SimpleQueue[MQTTMessage] = SimpleQueue()
self.subscribe(topics)
self.client.on_message = self._on_message
self.unpack = get_unpacker(config.packstyle)
def subscribe(self, topics: str | list[str] | tuple[str]) -> None:
@@ -47,14 +46,29 @@ class MqttSubscriber(MqttBase, Source):
Messages are saved in a stack, if no message is available, this function blocks.
Returns:
tuple(topic: bytes, message: dict): the message received
tuple(topic: str, message: dict): the message received
"""
if not self.client:
self.start()
self._raise_if_thread_died()
mqtt_message = self._message_queue.get(block=True, timeout=self.config.timeout_s)
message_returned = self.unpack(mqtt_message.payload.decode())
return (mqtt_message.topic, message_returned)
def __repr__(self) -> str:
return f"{self.__class__.__name__}(broker={self.config.host}, port={self.config.port})"
def start(self) -> None:
super().start()
self.subscribe(self.topics)
self.client.on_message = self._on_message
self.is_connected = True
def stop(self) -> None:
super().stop()
# callback to add incoming messages onto stack
def _on_message(self, client, userdata, message) -> None:
self._message_queue.put(message)

View File

@@ -16,7 +16,7 @@ class AsyncMqttSubscriber(AsyncSource):
def __init__(self, config: MqttConf, topic: str | list[str]) -> None:
self.config: MqttConf = config
self.client = Client(
hostname=self.config.broker,
hostname=self.config.host,
port=self.config.port,
username=self.config.user,
password=self.config.password,
@@ -24,17 +24,32 @@ class AsyncMqttSubscriber(AsyncSource):
self.topics = topic
self.unpack = get_unpacker(self.config.packstyle)
self.message_queue: Queue[Message] = Queue(self.config.max_saved_messages)
self._listener_task: Task = create_task(self.create_listener())
self._listener_task: Task[None] | None = None
"""
Await the newest message in the queue and return Tuple
"""
def __repr__(self) -> str:
return f"{self.__class__.__name__}(broker={self.config.host}, port={self.config.port})"
def start(self) -> None:
self._listener_task = create_task(self.run())
def stop(self) -> None:
if self._listener_task:
self._listener_task.cancel()
self._listener_task = None
async def receive(self) -> tuple[str, dict[str, Serializable]]:
mqtt_message: Message = await self.message_queue.get()
"""
Await the newest message in the queue and return Tuple
"""
if not self._listener_task:
self.start()
mqtt_message = await self.message_queue.get()
return self._handle_message(mqtt_message)
async def create_listener(self):
async def run(self):
"""
Handle the connection to MQTT broker and run the message loop.
"""
while True:
try:
async with self.client:
@@ -45,11 +60,10 @@ class AsyncMqttSubscriber(AsyncSource):
print("Connection to MQTT failed. Retrying...")
await sleep(1)
"""
Listen to incoming messages asynchronously and put them into a queue
"""
async def _listen_mqtt_loop(self) -> None:
"""
Listen to incoming messages asynchronously and put them into a queue
"""
async with self.client.messages() as messages:
# async with self.client.filtered_messages(self.topics) as messages:
async for message in messages:

View File

@@ -1,12 +1,13 @@
import argparse
import sys
from typing import Union
from typing import Callable, Union
from heisskleber.config import load_config
from heisskleber.console.sink import ConsoleSink
from heisskleber.core.factories import _registered_sources
TopicType = Union[str, list[str]]
from heisskleber.mqtt import MqttSubscriber
from heisskleber.udp import UdpSubscriber
from heisskleber.zmq import ZmqSubscriber
def parse_args() -> argparse.Namespace:
@@ -33,14 +34,12 @@ def parse_args() -> argparse.Namespace:
"-H",
"--host",
type=str,
default="localhost",
help="Host or broker for MQTT, zmq and UDP.",
)
parser.add_argument(
"-P",
"--port",
type=int,
default=1883,
help="Port or serial interface for MQTT, zmq and UDP.",
)
parser.add_argument("-v", "--verbose", action="store_true")
@@ -49,7 +48,19 @@ def parse_args() -> argparse.Namespace:
return parser.parse_args()
def run() -> None:
def keyboardexit(func) -> Callable:
def wrapper(*args, **kwargs) -> Union[None, int]:
try:
return func(*args, **kwargs)
except KeyboardInterrupt:
print("Exiting...")
sys.exit(0)
return wrapper
@keyboardexit
def main() -> None:
args = parse_args()
sink = ConsoleSink(pretty=args.pretty, verbose=args.verbose)
@@ -58,36 +69,27 @@ def run() -> None:
try:
config = load_config(conf_cls(), args.type, read_commandline=False)
except FileNotFoundError:
print(f"Using default config for {args.type}.")
print(f"No config file found for {args.type}, using default values and user input.")
config = conf_cls()
if args.port:
config.port = args.port
if args.host:
if args.type == "mqtt":
config.broker = args.host
elif args.type == "zmq":
config.interface = args.host
elif args.type == "udp":
config.ip = args.host
if args.type == "zmq" and args.topic == "#":
args.topic = ""
source = sub_cls(config, args.topic)
if isinstance(source, (MqttSubscriber, UdpSubscriber)):
source.config.host = args.host or source.config.host
source.config.port = args.port or source.config.port
elif isinstance(source, ZmqSubscriber):
source.config.host = args.host or source.config.host
source.config.subscriber_port = args.port or source.config.subscriber_port
args.topic = "" if args.topic == "#" else args.topic
elif isinstance(source, UdpSubscriber):
source.config.port = args.port or source.config.port
source.start()
sink.start()
while True:
topic, data = source.receive()
sink.send(data, topic)
def main() -> None:
try:
run()
except KeyboardInterrupt:
print("Exiting...")
sys.exit(0)
if __name__ == "__main__":
main()

View File

@@ -1,5 +1,6 @@
from __future__ import annotations
import sys
from typing import Callable, Optional
import serial
@@ -11,6 +12,7 @@ from .config import SerialConf
class SerialPublisher(Sink):
serial_connection: serial.Serial
"""
Publisher for serial devices.
Can be used everywhere that a flucto style publishing connection is required.
@@ -30,19 +32,36 @@ class SerialPublisher(Sink):
):
self.config = config
self.pack = pack_func if pack_func else get_packer("serial")
self._connect()
self.is_connected = False
def start(self) -> None:
"""
Start the serial connection.
"""
try:
self.serial_connection = serial.Serial(
port=self.config.port,
baudrate=self.config.baudrate,
bytesize=self.config.bytesize,
parity=serial.PARITY_NONE,
stopbits=serial.STOPBITS_ONE,
)
except serial.SerialException:
print(f"Failed to connect to serial device at port {self.config.port}")
sys.exit(1)
def _connect(self) -> None:
self.serial: serial.Serial = serial.Serial(
port=self.config.port,
baudrate=self.config.baudrate,
bytesize=self.config.bytesize,
parity=serial.PARITY_NONE,
stopbits=serial.STOPBITS_ONE,
)
print(f"Successfully connected to serial device at port {self.config.port}")
self.is_connected = True
def send(self, message: dict[str, Serializable], topic: str) -> None:
def stop(self) -> None:
"""
Stop the serial connection.
"""
if hasattr(self, "serial_connection") and self.serial_connection.is_open:
self.serial_connection.flush()
self.serial_connection.close()
def send(self, data: dict[str, Serializable], topic: str) -> None:
"""
Takes python dictionary, serializes it according to the packstyle
and sends it to the broker.
@@ -52,16 +71,17 @@ class SerialPublisher(Sink):
message : dict
object to be serialized and sent via the serial connection. Usually a dict.
"""
payload = self.pack(message)
self.serial.write(payload.encode(self.config.encoding))
self.serial.flush()
if not self.is_connected:
self.start()
payload = self.pack(data)
self.serial_connection.write(payload.encode(self.config.encoding))
self.serial_connection.flush()
if self.config.verbose:
print(f"{topic}: {payload}")
def __repr__(self) -> str:
return f"SerialPublisher(port={self.config.port}, baudrate={self.config.baudrate}, bytezize={self.config.bytesize}, encoding={self.config.encoding})"
def __del__(self) -> None:
if not hasattr(self, "serial"):
return
if not self.serial.is_open:
return
self.serial.flush()
self.serial.close()
self.stop()

View File

@@ -1,7 +1,6 @@
from __future__ import annotations
import sys
from collections.abc import Generator
from typing import Callable, Optional
from typing import Callable
import serial
@@ -11,6 +10,7 @@ from .config import SerialConf
class SerialSubscriber(Source):
serial_connection: serial.Serial
"""
Subscriber for serial devices. Connects to a serial port and reads from it.
@@ -30,22 +30,38 @@ class SerialSubscriber(Source):
self,
config: SerialConf,
topic: str | None = None,
custom_unpack: Optional[Callable] = None, # noqa: UP007
custom_unpack: Callable | None = None,
):
self.config = config
self.topic = topic
self.unpack = custom_unpack if custom_unpack else lambda x: x # types: ignore
self._connect()
self.is_connected = False
def _connect(self):
self.serial: serial.Serial = serial.Serial(
port=self.config.port,
baudrate=self.config.baudrate,
bytesize=self.config.bytesize,
parity=serial.PARITY_NONE,
stopbits=serial.STOPBITS_ONE,
)
def start(self) -> None:
"""
Start the serial connection.
"""
try:
self.serial_connection = serial.Serial(
port=self.config.port,
baudrate=self.config.baudrate,
bytesize=self.config.bytesize,
parity=serial.PARITY_NONE,
stopbits=serial.STOPBITS_ONE,
)
except serial.SerialException:
print(f"Failed to connect to serial device at port {self.config.port}")
sys.exit(1)
print(f"Successfully connected to serial device at port {self.config.port}")
self.is_connected = True
def stop(self) -> None:
"""
Stop the serial connection.
"""
if hasattr(self, "serial_connection") and self.serial_connection.is_open:
self.serial_connection.flush()
self.serial_connection.close()
def receive(self) -> tuple[str, dict]:
"""
@@ -57,6 +73,9 @@ class SerialSubscriber(Source):
topic is a placeholder to adhere to the Subscriber interface
payload is a dictionary containing the data from the serial port
"""
if not self.is_connected:
self.start()
# message is a string
message = next(self.read_serial_port())
# payload is a dictionary
@@ -76,7 +95,7 @@ class SerialSubscriber(Source):
buffer = ""
while True:
try:
buffer = self.serial.readline().decode(self.config.encoding, "ignore")
buffer = self.serial_connection.readline().decode(self.config.encoding, "ignore")
yield buffer
except UnicodeError as e:
if self.config.verbose:
@@ -84,10 +103,8 @@ class SerialSubscriber(Source):
print(e)
continue
def __repr__(self) -> str:
return f"SerialPublisher(port={self.config.port}, baudrate={self.config.baudrate}, bytezize={self.config.bytesize}, encoding={self.config.encoding})"
def __del__(self) -> None:
if not hasattr(self, "serial"):
return
if not self.serial.is_open:
return
self.serial.flush()
self.serial.close()
self.stop()

View File

@@ -25,18 +25,31 @@ class Joint:
self.output_queue: asyncio.Queue[dict[str, Serializable]] = asyncio.Queue()
self.initialized = asyncio.Event()
self.initalize_task = asyncio.create_task(self.sync())
self.output_task = asyncio.create_task(self.output_work())
self.combined_dict: dict[str, Serializable] = {}
self.task: asyncio.Task[None] | None = None
"""
Main interaction coroutine: Get next value out of the queue.
"""
def __repr__(self) -> str:
return f"""Joint(resample_rate={self.conf.resample_rate},
sources={len(self.resamplers)} of type(s): {{r.__class__.__name__ for r in self.resamplers}})"""
def start(self) -> None:
self.task = asyncio.create_task(self.output_work())
def stop(self) -> None:
if self.task:
self.task.cancel()
async def receive(self) -> dict[str, Any]:
"""
Main interaction coroutine: Get next value out of the queue.
"""
if not self.task:
self.start()
output = await self.output_queue.get()
return output
async def sync(self) -> None:
"""Synchronize the resamplers by pulling data from each until the timestamp is aligned. Retains first matching data."""
print("Starting sync")
datas = await asyncio.gather(*[source.receive() for source in self.resamplers])
print("Got data")

View File

@@ -10,5 +10,6 @@ class UdpConf(BaseConf):
"""
port: int = 1234
ip: str = "127.0.0.1"
host: str = "127.0.0.1"
packer: str = "json"
max_queue_size: int = 1000

View File

@@ -1,4 +1,5 @@
import socket
import sys
from heisskleber.core.packer import get_packer
from heisskleber.core.types import Serializable, Sink
@@ -8,15 +9,30 @@ from heisskleber.udp.config import UdpConf
class UdpPublisher(Sink):
def __init__(self, config: UdpConf) -> None:
self.config = config
self.ip = self.config.ip
self.ip = self.config.host
self.port = self.config.port
self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
self.pack = get_packer(self.config.packer)
self.is_connected = False
def send(self, message: dict[str, Serializable], topic: str) -> None:
message["topic"] = topic
payload = self.pack(message).encode("utf-8")
def start(self) -> None:
try:
self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
except OSError as e:
print(f"failed to create socket: {e}")
sys.exit(-1)
else:
self.is_connected = True
def stop(self) -> None:
self.socket.close()
self.is_connected = True
def send(self, data: dict[str, Serializable], topic: str) -> None:
if not self.is_connected:
self.start()
data["topic"] = topic
payload = self.pack(data).encode("utf-8")
self.socket.sendto(payload, (self.ip, self.port))
def __del__(self) -> None:
self.socket.close()
def __repr__(self) -> str:
return f"{self.__class__.__name__}(host={self.config.host}, port={self.config.port})"

View File

@@ -1,6 +1,7 @@
import socket
import sys
import threading
from queue import SimpleQueue
from queue import Queue
from heisskleber.core.packer import get_unpacker
from heisskleber.core.types import Serializable, Source
@@ -11,15 +12,30 @@ class UdpSubscriber(Source):
def __init__(self, config: UdpConf, topic: str | None = None):
self.config = config
self.topic = topic
self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
self.socket.bind((self.config.ip, self.config.port))
self.unpacker = get_unpacker(self.config.packer)
self._queue: SimpleQueue[tuple[str, dict[str, Serializable]]] = SimpleQueue()
self._queue: Queue[tuple[str, dict[str, Serializable]]] = Queue(maxsize=self.config.max_queue_size)
self._running = threading.Event()
def start(self) -> None:
try:
self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
except OSError as e:
print(f"failed to create socket: {e}")
sys.exit(-1)
self.socket.bind((self.config.host, self.config.port))
self._running.set()
self._thread: threading.Thread | None = None
self._thread = threading.Thread(target=self._loop, daemon=True)
self._thread.start()
def stop(self) -> None:
self._running.clear()
# if self._thread is not None:
# self._thread.join()
self.socket.close()
def receive(self) -> tuple[str, dict[str, Serializable]]:
if not self._running.is_set():
self.start()
return self._queue.get()
def _loop(self) -> None:
@@ -33,15 +49,5 @@ class UdpSubscriber(Source):
error_message = f"Error in UDP listener loop: {e}"
print(error_message)
def start_loop(self) -> None:
self._thread = threading.Thread(target=self._loop, daemon=True)
self._thread.start()
def stop_loop(self) -> None:
self._running.clear()
if self._thread is not None:
self._thread.join()
self.socket.close()
def __del__(self) -> None:
self.stop_loop()
def __repr__(self) -> str:
return f"{self.__class__.__name__}(host={self.config.host}, port={self.config.port})"

View File

@@ -6,15 +6,15 @@ from heisskleber.config import BaseConf
@dataclass
class ZmqConf(BaseConf):
protocol: str = "tcp"
interface: str = "127.0.0.1"
host: str = "127.0.0.1"
publisher_port: int = 5555
subscriber_port: int = 5556
packstyle: str = "json"
@property
def publisher_address(self) -> str:
return f"{self.protocol}://{self.interface}:{self.publisher_port}"
return f"{self.protocol}://{self.host}:{self.publisher_port}"
@property
def subscriber_address(self) -> str:
return f"{self.protocol}://{self.interface}:{self.subscriber_port}"
return f"{self.protocol}://{self.host}:{self.subscriber_port}"

View File

@@ -1,4 +1,5 @@
import sys
from typing import Callable
import zmq
import zmq.asyncio
@@ -10,45 +11,46 @@ from .config import ZmqConf
class ZmqPublisher(Sink):
"""
Publisher that sends messages to a ZMQ PUB socket.
Attributes:
-----------
pack : Callable
The packer function to use for serializing the data.
Methods:
--------
send(data : dict, topic : str):
Send the data with the given topic.
start():
Connect to the socket.
stop():
Close the socket.
"""
def __init__(self, config: ZmqConf):
self.config = config
self.context = zmq.Context.instance()
self.socket = self.context.socket(zmq.PUB)
self.pack = get_packer(config.packstyle)
self.connect()
def connect(self) -> None:
try:
if self.config.verbose:
print(f"connecting to {self.config.publisher_address}")
self.socket.connect(self.config.publisher_address)
except Exception as e:
print(f"failed to bind to zeromq socket: {e}")
sys.exit(-1)
self.is_connected = False
def send(self, data: dict[str, Serializable], topic: str) -> None:
"""
Take the data as a dict, serialize it with the given packer and send it to the zmq socket.
"""
if not self.is_connected:
self.start()
payload = self.pack(data)
if self.config.verbose:
print(f"sending message {payload} to topic {topic}")
self.socket.send_multipart([topic.encode(), payload.encode()])
def __del__(self):
self.socket.close()
class ZmqAsyncPublisher(AsyncSink):
def __init__(self, config: ZmqConf):
self.config = config
self.context = zmq.asyncio.Context.instance()
self.socket: zmq.asyncio.Socket = self.context.socket(zmq.PUB)
self.pack = get_packer(config.packstyle)
self.connect()
def connect(self) -> None:
def start(self) -> None:
"""Connect to the zmq socket."""
try:
if self.config.verbose:
print(f"connecting to {self.config.publisher_address}")
@@ -56,12 +58,72 @@ class ZmqAsyncPublisher(AsyncSink):
except Exception as e:
print(f"failed to bind to zeromq socket: {e}")
sys.exit(-1)
else:
self.is_connected = True
def stop(self) -> None:
self.socket.close()
self.is_connected = False
def __repr__(self) -> str:
return f"{self.__class__.__name__}(host={self.config.publisher_address}, port={self.config.publisher_port})"
class ZmqAsyncPublisher(AsyncSink):
"""
Async publisher that sends messages to a ZMQ PUB socket.
Attributes:
-----------
pack : Callable
The packer function to use for serializing the data.
Methods:
--------
send(data : dict, topic : str):
Send the data with the given topic.
start():
Connect to the socket.
stop():
Close the socket.
"""
def __init__(self, config: ZmqConf):
self.config = config
self.context = zmq.asyncio.Context.instance()
self.socket: zmq.asyncio.Socket = self.context.socket(zmq.PUB)
self.pack: Callable = get_packer(config.packstyle)
self.is_connected = False
async def send(self, data: dict[str, Serializable], topic: str) -> None:
"""
Take the data as a dict, serialize it with the given packer and send it to the zmq socket.
"""
if not self.is_connected:
self.start()
payload = self.pack(data)
if self.config.verbose:
print(f"sending message {payload} to topic {topic}")
await self.socket.send_multipart([topic.encode(), payload.encode()])
def __del__(self):
def start(self) -> None:
"""Connect to the zmq socket."""
try:
if self.config.verbose:
print(f"connecting to {self.config.publisher_address}")
self.socket.connect(self.config.publisher_address)
except Exception as e:
print(f"failed to bind to zeromq socket: {e}")
sys.exit(-1)
else:
self.is_connected = True
def stop(self) -> None:
"""Close the zmq socket."""
self.socket.close()
self.is_connected = False
def __repr__(self) -> str:
return f"{self.__class__.__name__}(host={self.config.publisher_address}, port={self.config.publisher_port})"

View File

@@ -12,34 +12,43 @@ from .config import ZmqConf
class ZmqSubscriber(Source):
def __init__(self, config: ZmqConf, topic: str):
self.config = config
"""
Source that subscribes to one or many topics from a zmq broker and receives messages via the receive() function.
Attributes:
-----------
unpack : Callable
The unpacker function to use for deserializing the data.
Methods:
--------
receive() -> tuple[str, dict]:
Send the data with the given topic.
start():
Connect to the socket.
stop():
Close the socket.
"""
def __init__(self, config: ZmqConf, topic: str | list[str]):
"""
Constructs new ZmqAsyncSubscriber instance.
Parameters:
-----------
config : ZmqConf
The configuration dataclass object for the zmq connection.
topic : str
The topic or list of topics to subscribe to.
"""
self.config = config
self.topic = topic
self.context = zmq.Context.instance()
self.socket = self.context.socket(zmq.SUB)
self.connect()
self.subscribe(topic)
self.unpack = get_unpacker(config.packstyle)
def connect(self):
try:
# print(f"Connecting to { self.config.consumer_connection }")
self.socket.connect(self.config.subscriber_address)
except Exception as e:
print(f"failed to bind to zeromq socket: {e}")
sys.exit(-1)
def _subscribe_single_topic(self, topic: str):
self.socket.setsockopt(zmq.SUBSCRIBE, topic.encode())
def subscribe(self, topic: str | list[str] | tuple[str]):
# Accepts single topic or list of topics
if isinstance(topic, (list, tuple)):
for t in topic:
self._subscribe_single_topic(t)
else:
self._subscribe_single_topic(topic)
self.is_connected = False
def receive(self) -> tuple[str, dict]:
"""
@@ -48,34 +57,26 @@ class ZmqSubscriber(Source):
Returns:
tuple(topic: str, message: dict): the message received
"""
if not self.is_connected:
self.start()
(topic, payload) = self.socket.recv_multipart()
message = self.unpack(payload.decode())
topic = topic.decode()
return (topic, message)
def __del__(self):
self.socket.close()
class ZmqAsyncSubscriber(AsyncSource):
def __init__(self, config: ZmqConf, topic: str):
self.config = config
self.context = zmq.asyncio.Context.instance()
self.socket: zmq.asyncio.Socket = self.context.socket(zmq.SUB)
self.connect()
self.subscribe(topic)
self.unpack = get_unpacker(config.packstyle)
def connect(self):
def start(self):
try:
self.socket.connect(self.config.subscriber_address)
self.subscribe(self.topic)
except Exception as e:
print(f"failed to bind to zeromq socket: {e}")
sys.exit(-1)
else:
self.is_connected = True
def _subscribe_single_topic(self, topic: str):
self.socket.setsockopt(zmq.SUBSCRIBE, topic.encode())
def stop(self):
self.socket.close()
self.is_connected = False
def subscribe(self, topic: str | list[str] | tuple[str]):
# Accepts single topic or list of topics
@@ -85,6 +86,52 @@ class ZmqAsyncSubscriber(AsyncSource):
else:
self._subscribe_single_topic(topic)
def _subscribe_single_topic(self, topic: str):
self.socket.setsockopt(zmq.SUBSCRIBE, topic.encode())
def __repr__(self) -> str:
return f"{self.__class__.__name__}(host={self.config.subscriber_address}, port={self.config.subscriber_port})"
class ZmqAsyncSubscriber(AsyncSource):
"""
Async source that subscribes to one or many topics from a zmq broker and receives messages via the receive() function.
Attributes:
-----------
unpack : Callable
The unpacker function to use for deserializing the data.
Methods:
--------
receive() -> tuple[str, dict]:
Send the data with the given topic.
start():
Connect to the socket.
stop():
Close the socket.
"""
def __init__(self, config: ZmqConf, topic: str | list[str]):
"""
Constructs new ZmqAsyncSubscriber instance.
Parameters:
-----------
config : ZmqConf
The configuration dataclass object for the zmq connection.
topic : str
The topic or list of topics to subscribe to.
"""
self.config = config
self.topic = topic
self.context = zmq.asyncio.Context.instance()
self.socket: zmq.asyncio.Socket = self.context.socket(zmq.SUB)
self.unpack = get_unpacker(config.packstyle)
self.is_connected = True
async def receive(self) -> tuple[str, dict]:
"""
reads a message from the zmq bus and returns it
@@ -92,10 +139,43 @@ class ZmqAsyncSubscriber(AsyncSource):
Returns:
tuple(topic: str, message: dict): the message received
"""
if not self.is_connected:
self.start()
(topic, payload) = await self.socket.recv_multipart()
message = self.unpack(payload.decode())
topic = topic.decode()
return (topic, message)
def __del__(self):
def start(self):
"""Connect to the zmq socket."""
try:
self.socket.connect(self.config.subscriber_address)
except Exception as e:
print(f"failed to bind to zeromq socket: {e}")
sys.exit(-1)
else:
self.is_connected = True
self.subscribe(self.topic)
def stop(self):
"""Close the zmq socket."""
self.socket.close()
self.is_connected = False
def subscribe(self, topic: str | list[str] | tuple[str]):
"""
Subscribes to the given topic(s) on the zmq socket.
Accepts single topic or list of topics.
"""
if isinstance(topic, (list, tuple)):
for t in topic:
self._subscribe_single_topic(t)
else:
self._subscribe_single_topic(topic)
def _subscribe_single_topic(self, topic: str):
self.socket.setsockopt(zmq.SUBSCRIBE, topic.encode())
def __repr__(self) -> str:
return f"{self.__class__.__name__}(host={self.config.subscriber_address}, port={self.config.subscriber_port})"

View File

@@ -1,8 +1,8 @@
from heisskleber.udp import UdpSubscriber, UdpConf
from heisskleber.udp import UdpConf, UdpSubscriber
def main() -> None:
conf = UdpConf(ip="192.168.137.1", port=6600)
conf = UdpConf(host="192.168.137.1", port=6600)
subscriber = UdpSubscriber(conf)
while True:
@@ -12,4 +12,3 @@ def main() -> None:
if __name__ == "__main__":
main()

View File

@@ -10,7 +10,7 @@ async def main():
topic1 = "topic1"
topic2 = "topic2"
config = MqttConf(broker="localhost", port=1883, user="", password="") # not a real password
config = MqttConf(host="localhost", port=1883, user="", password="") # not a real password
sub1 = AsyncMqttSubscriber(config, topic1)
sub2 = AsyncMqttSubscriber(config, topic2)

View File

@@ -7,7 +7,7 @@ from heisskleber.stream import Joint, Resampler, ResamplerConf
async def main():
topics = ["topic0", "topic1", "topic2", "topic3"]
config = MqttConf(broker="localhost", port=1883, user="", password="") # not a real password
config = MqttConf(host="localhost", port=1883, user="", password="") # not a real password
subs = [AsyncMqttSubscriber(config, topic=topic) for topic in topics]
resampler_config = ResamplerConf(resample_rate=1000)

View File

@@ -4,7 +4,7 @@ from heisskleber.mqtt import AsyncMqttSubscriber, MqttConf
async def main():
conf = MqttConf(broker="localhost", port=1883, user="", password="")
conf = MqttConf(host="localhost", port=1883, user="", password="")
sub = AsyncMqttSubscriber(conf, topic="#")
# async for topic, message in sub:
# print(message)

View File

@@ -20,7 +20,7 @@ async def send_every_n_miliseconds(frequency, value, pub, topic):
async def main2():
config = MqttConf(broker="localhost", port=1883, user="", password="")
config = MqttConf(host="localhost", port=1883, user="", password="")
pubs = [AsyncMqttPublisher(config) for i in range(5)]
tasks = []

View File

@@ -5,7 +5,7 @@ from heisskleber.stream import Resampler, ResamplerConf
async def main():
conf = MqttConf(broker="localhost", port=1883, user="", password="")
conf = MqttConf(host="localhost", port=1883, user="", password="")
sub = AsyncMqttSubscriber(conf, topic="#")
resampler = Resampler(ResamplerConf(), sub)

View File

@@ -4,7 +4,7 @@ from heisskleber.mqtt import AsyncMqttSubscriber, MqttConf
async def main():
config = MqttConf(broker="localhost", port=1883, user="", password="")
config = MqttConf(host="localhost", port=1883, user="", password="")
sub = AsyncMqttSubscriber(config, topic="#")

View File

@@ -13,7 +13,7 @@ def main():
# topic2 = "topic2"
config = MqttConf(
broker="localhost", port=1883, user="", password=""
host="localhost", port=1883, user="", password=""
) # , not a real password port=1883, user="", password="")
sub1 = MqttSubscriber(config, topic1)
# sub2 = MqttSubscriber(config, topic2)

View File

@@ -1,4 +1,4 @@
protocol : "tcp" # ipc protocol
interface: "127.0.0.1" # the interface to bind to
publisher_port : 5555 # port used by primary producers
subscriber_port: 5556 # port used by primary consumers
protocol: "tcp" # ipc protocol
host: "127.0.0.1" # the interface to bind to
publisher_port: 5555 # port used by primary producers
subscriber_port: 5556 # port used by primary consumers

View File

@@ -39,6 +39,16 @@ def test_console_sink_pretty_verbose(capsys) -> None:
assert captured.out == 'test:\t{\n "key": 3\n}\n'
def test_console_repr() -> None:
sink = ConsoleSink()
assert repr(sink) == "ConsoleSink(pretty=False, verbose=False)"
def test_async_console_repr() -> None:
sink = AsyncConsoleSink()
assert repr(sink) == "AsyncConsoleSink(pretty=False, verbose=False)"
@pytest.mark.asyncio
async def test_async_console_sink(capsys) -> None:
sink = AsyncConsoleSink()

View File

@@ -14,7 +14,7 @@ from heisskleber.mqtt.subscriber import MqttSubscriber
@pytest.fixture
def mock_mqtt_conf() -> MqttConf:
return MqttConf(
broker="localhost",
host="localhost",
port=1883,
user="user",
password="passwd", # noqa: S106, this is a test password
@@ -40,12 +40,14 @@ def mock_queue():
def test_mqtt_base_intialization(mock_mqtt_client, mock_mqtt_conf):
"""Test that the intialization of the mqtt client is as expected."""
base = MqttBase(config=mock_mqtt_conf)
base.start()
mock_mqtt_client.assert_called_once()
mock_mqtt_client.return_value.loop_start.assert_called_once()
mock_client_instance = mock_mqtt_client.return_value
mock_client_instance.username_pw_set.assert_called_with(mock_mqtt_conf.user, mock_mqtt_conf.password)
mock_client_instance.connect.assert_called_with(mock_mqtt_conf.broker, mock_mqtt_conf.port)
mock_client_instance.connect.assert_called_with(mock_mqtt_conf.host, mock_mqtt_conf.port)
assert base.client
assert base.client.on_connect == base._on_connect
assert base.client.on_disconnect == base._on_disconnect
assert base.client.on_publish == base._on_publish
@@ -56,7 +58,7 @@ def test_mqtt_base_on_connect(mock_mqtt_client, mock_mqtt_conf, capsys):
base = MqttBase(config=mock_mqtt_conf)
base._on_connect(None, None, {}, 0)
captured = capsys.readouterr()
assert f"MQTT node connected to {mock_mqtt_conf.broker}:{mock_mqtt_conf.port}" in captured.out
assert f"MQTT node connected to {mock_mqtt_conf.host}:{mock_mqtt_conf.port}" in captured.out
def test_mqtt_base_on_disconnect_with_error(mock_mqtt_client, mock_mqtt_conf, capsys):
@@ -71,7 +73,8 @@ def test_mqtt_base_on_disconnect_with_error(mock_mqtt_client, mock_mqtt_conf, ca
def test_mqtt_subscribes_single_topic(mock_mqtt_client, mock_mqtt_conf):
"""Test that the mqtt client subscribes to a single topic."""
_ = MqttSubscriber(topics="singleTopic", config=mock_mqtt_conf)
sub = MqttSubscriber(topics="singleTopic", config=mock_mqtt_conf)
sub.start()
actual_calls = mock_mqtt_client.return_value.subscribe.call_args_list
assert actual_calls == [call("singleTopic", mock_mqtt_conf.qos)]
@@ -82,7 +85,8 @@ def test_mqtt_subscribes_multiple_topics(mock_mqtt_client, mock_mqtt_conf):
I would love to do this via parametrization, but the call argument is built differently for single size lists and longer lists.
"""
_ = MqttSubscriber(topics=["multiple1", "multiple2"], config=mock_mqtt_conf)
sub = MqttSubscriber(topics=["multiple1", "multiple2"], config=mock_mqtt_conf)
sub.start()
actual_calls = mock_mqtt_client.return_value.subscribe.call_args_list
assert actual_calls == [
@@ -92,7 +96,8 @@ def test_mqtt_subscribes_multiple_topics(mock_mqtt_client, mock_mqtt_conf):
def test_mqtt_subscribes_multiple_topics_tuple(mock_mqtt_client, mock_mqtt_conf):
"""Test that the mqtt client subscribes to multiple topics passed as tuple."""
_ = MqttSubscriber(topics=("multiple1", "multiple2"), config=mock_mqtt_conf)
sub = MqttSubscriber(topics=("multiple1", "multiple2"), config=mock_mqtt_conf)
sub.start()
actual_calls = mock_mqtt_client.return_value.subscribe.call_args_list
assert actual_calls == [

View File

@@ -29,10 +29,11 @@ def mock_serial_device_publisher():
def test_serial_subscriber_initialization(mock_serial_device_subscriber, serial_conf):
"""Test that the SerialSubscriber class initializes correctly.
Mocks the serial.Serial class to avoid opening a serial port."""
_ = SerialSubscriber(
sub = SerialSubscriber(
config=serial_conf,
topic="",
)
sub.start()
mock_serial_device_subscriber.assert_called_with(
port=serial_conf.port,
baudrate=serial_conf.baudrate,
@@ -45,6 +46,7 @@ def test_serial_subscriber_initialization(mock_serial_device_subscriber, serial_
def test_serial_subscriber_receive(mock_serial_device_subscriber, serial_conf):
"""Test that the SerialSubscriber class calls readline and unpack as expected."""
subscriber = SerialSubscriber(config=serial_conf, topic="")
subscriber.start()
# Set up the readline return value
mock_serial_instance = mock_serial_device_subscriber.return_value
@@ -69,6 +71,7 @@ def test_serial_subscriber_converts_bytes_to_str():
"""Test that the SerialSubscriber class converts bytes to str as expected."""
with patch("heisskleber.serial.subscriber.serial.Serial") as mock_serial:
subscriber = SerialSubscriber(config=SerialConf(), topic="", custom_unpack=lambda x: x)
subscriber.start()
# Set the readline method to raise UnicodeError
mock_serial_instance = mock_serial.return_value
@@ -86,6 +89,7 @@ def test_serial_publisher_initialization(mock_serial_device_publisher, serial_co
"""Test that the SerialPublisher class initializes correctly.
Mocks the serial.Serial class to avoid opening a serial port."""
publisher = SerialPublisher(config=serial_conf)
publisher.start()
mock_serial_device_publisher.assert_called_with(
port=serial_conf.port,
baudrate=serial_conf.baudrate,
@@ -93,7 +97,7 @@ def test_serial_publisher_initialization(mock_serial_device_publisher, serial_co
parity=serial.PARITY_NONE,
stopbits=serial.STOPBITS_ONE,
)
assert publisher.serial
assert publisher.serial_connection
def test_serial_publisher_send(mock_serial_device_publisher, serial_conf):

View File

@@ -67,11 +67,12 @@ async def test_resampler_multiple_modes(mock_subscriber):
]
)
config = ResamplerConf(resample_rate=1000) # Fill in your MQTT configuration
config = ResamplerConf(resample_rate=1000)
resampler = Resampler(config, mock_subscriber)
# Test the resample method
resampled_data = [await resampler.receive() for _ in range(3)]
resampler.stop()
assert resampled_data[0] == {"epoch": 0.0, "data": 1.5}
assert resampled_data[1] == {"epoch": 1.0, "data": 3.5}
@@ -89,11 +90,12 @@ async def test_resampler_upsampling(mock_subscriber):
]
)
config = ResamplerConf(resample_rate=250) # Fill in your MQTT configuration
config = ResamplerConf(resample_rate=250)
resampler = Resampler(config, mock_subscriber)
# Test the resample method
resampled_data = [await resampler.receive() for _ in range(7)]
resampler.stop()
assert resampled_data[0] == {"epoch": 0.0, "data": 1.0}
assert resampled_data[1] == {"epoch": 0.25, "data": 1.25}

View File

@@ -17,19 +17,22 @@ def mock_socket():
@pytest.fixture
def mock_conf():
return UdpConf(ip="127.0.0.1", port=12345, packer="json")
return UdpConf(host="127.0.0.1", port=12345, packer="json")
def test_connects_to_socket(mock_socket, mock_conf) -> None:
_ = UdpPublisher(mock_conf)
pub = UdpPublisher(mock_conf)
pub.start()
# constructor was called
mock_socket.assert_called_with(socket.AF_INET, socket.SOCK_DGRAM)
pub.stop()
def test_closes_socket(mock_socket, mock_conf) -> None:
pub = UdpPublisher(mock_conf)
del pub
pub.start()
pub.stop()
# instace was closed
mock_socket.return_value.close.assert_called()
@@ -45,8 +48,9 @@ def test_packs_and_sends_message(mock_socket, mock_conf) -> None:
mock_socket.return_value.sendto.assert_called_with(
b'{"key": "val", "intkey": 1, "floatkey": 1.0, "topic": "test"}',
(str(mock_conf.ip), mock_conf.port),
(str(mock_conf.host), mock_conf.port),
)
pub.stop()
def test_subscriber_receives_message_from_queue(mock_conf) -> None:
@@ -59,13 +63,16 @@ def test_subscriber_receives_message_from_queue(mock_conf) -> None:
topic, data = sub.receive()
assert test_topic == topic
assert test_data == data
sub.stop()
@pytest.fixture
def udp_sub(mock_conf):
sub = UdpSubscriber(mock_conf)
sub.start_loop()
sub.config.port = 12346 # explicitly set port to avoid conflicts
sub.start()
yield sub
sub.stop()
def test_sends_message_between_pub_and_sub(udp_sub, mock_conf):

View File

@@ -33,9 +33,9 @@ def start_broker():
def test_config_parses_correctly():
conf = ZmqConf(protocol="tcp", interface="localhost", publisher_port=5555, subscriber_port=5556)
conf = ZmqConf(protocol="tcp", host="localhost", publisher_port=5555, subscriber_port=5556)
assert conf.protocol == "tcp"
assert conf.interface == "localhost"
assert conf.host == "localhost"
assert conf.publisher_port == 5555
assert conf.subscriber_port == 5556
@@ -44,13 +44,13 @@ def test_config_parses_correctly():
def test_instantiate_subscriber():
conf = ZmqConf(protocol="tcp", interface="localhost", publisher_port=5555, subscriber_port=5556)
conf = ZmqConf(protocol="tcp", host="localhost", publisher_port=5555, subscriber_port=5556)
sub = ZmqSubscriber(conf, "test")
assert sub.config == conf
def test_instantiate_publisher():
conf = ZmqConf(protocol="tcp", interface="localhost", publisher_port=5555, subscriber_port=5556)
conf = ZmqConf(protocol="tcp", host="localhost", publisher_port=5555, subscriber_port=5556)
pub = ZmqPublisher(conf)
assert pub.config == conf
@@ -59,7 +59,9 @@ def test_send_receive(start_broker):
print("test_send_receive")
topic = "test"
source = get_source("zmq", topic)
source.start()
sink = get_sink("zmq")
sink.start()
time.sleep(1) # this is crucial, otherwise the source might hang
for i in range(10):
message = {"m": i}

View File

@@ -33,13 +33,13 @@ def start_broker() -> Generator[Process, None, None]:
def test_instantiate_subscriber() -> None:
conf = ZmqConf(protocol="tcp", interface="localhost", publisher_port=5555, subscriber_port=5556)
conf = ZmqConf(protocol="tcp", host="localhost", publisher_port=5555, subscriber_port=5556)
sub = ZmqAsyncSubscriber(conf, "test")
assert sub.config == conf
def test_instantiate_publisher() -> None:
conf = ZmqConf(protocol="tcp", interface="localhost", publisher_port=5555, subscriber_port=5556)
conf = ZmqConf(protocol="tcp", host="localhost", publisher_port=5555, subscriber_port=5556)
pub = ZmqPublisher(conf)
assert pub.config == conf
@@ -48,9 +48,11 @@ def test_instantiate_publisher() -> None:
async def test_send_receive(start_broker) -> None:
print("test_send_receive")
topic = "test"
conf = ZmqConf(protocol="tcp", interface="localhost", publisher_port=5555, subscriber_port=5556)
conf = ZmqConf(protocol="tcp", host="localhost", publisher_port=5555, subscriber_port=5556)
source = ZmqAsyncSubscriber(conf, topic)
sink = ZmqAsyncPublisher(conf)
source.start()
sink.start()
time.sleep(1) # this is crucial, otherwise the source might hang
for i in range(10):
message = {"m": i}