From e95cbda2070a73905ea008817e9a605eafee7958 Mon Sep 17 00:00:00 2001 From: Felix Weiler Date: Tue, 6 May 2025 13:37:39 +0200 Subject: [PATCH] Fix receiver aiter and anext interface. (#215) --- src/heisskleber/core/receiver.py | 33 ++++++++++++++------- tests/core/test_receiver.py | 50 ++++++++++++++++++++++++++++++++ 2 files changed, 72 insertions(+), 11 deletions(-) create mode 100644 tests/core/test_receiver.py diff --git a/src/heisskleber/core/receiver.py b/src/heisskleber/core/receiver.py index 5d821bc..3984fdc 100644 --- a/src/heisskleber/core/receiver.py +++ b/src/heisskleber/core/receiver.py @@ -1,7 +1,6 @@ """Asynchronous data source interface.""" from abc import ABC, abstractmethod -from collections.abc import AsyncGenerator from types import TracebackType from typing import Any, Generic @@ -56,20 +55,32 @@ class Receiver(ABC, Generic[T_co]): def __repr__(self) -> str: """A string reprensatiion of the source.""" - async def __aiter__(self) -> AsyncGenerator[tuple[T_co, dict[str, Any]], None]: - """Implement async iteration over the source's data stream. + def __aiter__(self) -> "Receiver[T_co]": + """Return self as the async iterator object. - Yields: - tuple[T_co, dict[str, Any]]: Each data item and its associated metadata - as returned by receive(). + This method enables the receiver to be used in async for loops. + + Returns: + Self as the async iterator. + """ + return self + + async def __anext__(self) -> tuple[T_co, dict[str, Any]]: + """Get the next item from the receiver's data stream. + + This method implements async iteration by calling the receive() method + and returning its result. + + Returns: + tuple[T_co, dict[str, Any]]: The next data item and its associated + metadata as returned by receive(). Raises: - Any exceptions that might occur during receive(). - + StopAsyncIteration: When the receiver has no more data. + Any other exceptions that might occur during receive(). """ - while True: - data, meta = await self.receive() - yield data, meta + data, meta = await self.receive() + return data, meta async def __aenter__(self) -> "Receiver[T_co]": """Initialize the source for use in an async context manager. diff --git a/tests/core/test_receiver.py b/tests/core/test_receiver.py new file mode 100644 index 0000000..b8e7aac --- /dev/null +++ b/tests/core/test_receiver.py @@ -0,0 +1,50 @@ +from typing import Any + +import pytest + +from heisskleber import Receiver + + +class MockReceiver(Receiver): + def __init__(self) -> None: + self.n_called = 0 + + async def receive(self) -> tuple[bool, dict[str, Any]]: + self.n_called += 1 + return True, {"msg": "Called MockReceiver", "count": self.n_called} + + async def start(self) -> None: + return + + async def stop(self) -> None: + return + + def __repr__(self) -> str: + return "MockReceiver" + + +@pytest.mark.asyncio +async def test_mock_receiver_can_be_iterated_over() -> None: + count = 1 + + async for data, meta in MockReceiver(): + assert data + assert "msg" in meta + assert meta["count"] == count + count += 1 + if count == 3: + break + + +@pytest.mark.asyncio +async def test_mock_receiver_call_anext() -> None: + receiver = MockReceiver() + + data, meta = await anext(receiver) + + assert data + assert meta["count"] == 1 + + data, meta = await anext(receiver) + + assert meta["count"] == 2