Fix receiver aiter and anext interface. (#215)

This commit is contained in:
Felix Weiler
2025-05-06 13:37:39 +02:00
committed by GitHub
parent 317a09972f
commit e95cbda207
2 changed files with 72 additions and 11 deletions

View File

@@ -1,7 +1,6 @@
"""Asynchronous data source interface.""" """Asynchronous data source interface."""
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import AsyncGenerator
from types import TracebackType from types import TracebackType
from typing import Any, Generic from typing import Any, Generic
@@ -56,20 +55,32 @@ class Receiver(ABC, Generic[T_co]):
def __repr__(self) -> str: def __repr__(self) -> str:
"""A string reprensatiion of the source.""" """A string reprensatiion of the source."""
async def __aiter__(self) -> AsyncGenerator[tuple[T_co, dict[str, Any]], None]: def __aiter__(self) -> "Receiver[T_co]":
"""Implement async iteration over the source's data stream. """Return self as the async iterator object.
Yields: This method enables the receiver to be used in async for loops.
tuple[T_co, dict[str, Any]]: Each data item and its associated metadata
as returned by receive(). Returns:
Self as the async iterator.
"""
return self
async def __anext__(self) -> tuple[T_co, dict[str, Any]]:
"""Get the next item from the receiver's data stream.
This method implements async iteration by calling the receive() method
and returning its result.
Returns:
tuple[T_co, dict[str, Any]]: The next data item and its associated
metadata as returned by receive().
Raises: Raises:
Any exceptions that might occur during receive(). StopAsyncIteration: When the receiver has no more data.
Any other exceptions that might occur during receive().
""" """
while True: data, meta = await self.receive()
data, meta = await self.receive() return data, meta
yield data, meta
async def __aenter__(self) -> "Receiver[T_co]": async def __aenter__(self) -> "Receiver[T_co]":
"""Initialize the source for use in an async context manager. """Initialize the source for use in an async context manager.

View File

@@ -0,0 +1,50 @@
from typing import Any
import pytest
from heisskleber import Receiver
class MockReceiver(Receiver):
def __init__(self) -> None:
self.n_called = 0
async def receive(self) -> tuple[bool, dict[str, Any]]:
self.n_called += 1
return True, {"msg": "Called MockReceiver", "count": self.n_called}
async def start(self) -> None:
return
async def stop(self) -> None:
return
def __repr__(self) -> str:
return "MockReceiver"
@pytest.mark.asyncio
async def test_mock_receiver_can_be_iterated_over() -> None:
count = 1
async for data, meta in MockReceiver():
assert data
assert "msg" in meta
assert meta["count"] == count
count += 1
if count == 3:
break
@pytest.mark.asyncio
async def test_mock_receiver_call_anext() -> None:
receiver = MockReceiver()
data, meta = await anext(receiver)
assert data
assert meta["count"] == 1
data, meta = await anext(receiver)
assert meta["count"] == 2