mirror of
https://github.com/OMGeeky/flucto-heisskleber.git
synced 2025-12-26 16:07:50 +01:00
Fix receiver aiter and anext interface. (#215)
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
"""Asynchronous data source interface."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import AsyncGenerator
|
||||
from types import TracebackType
|
||||
from typing import Any, Generic
|
||||
|
||||
@@ -56,20 +55,32 @@ class Receiver(ABC, Generic[T_co]):
|
||||
def __repr__(self) -> str:
|
||||
"""A string reprensatiion of the source."""
|
||||
|
||||
async def __aiter__(self) -> AsyncGenerator[tuple[T_co, dict[str, Any]], None]:
|
||||
"""Implement async iteration over the source's data stream.
|
||||
def __aiter__(self) -> "Receiver[T_co]":
|
||||
"""Return self as the async iterator object.
|
||||
|
||||
Yields:
|
||||
tuple[T_co, dict[str, Any]]: Each data item and its associated metadata
|
||||
as returned by receive().
|
||||
This method enables the receiver to be used in async for loops.
|
||||
|
||||
Returns:
|
||||
Self as the async iterator.
|
||||
"""
|
||||
return self
|
||||
|
||||
async def __anext__(self) -> tuple[T_co, dict[str, Any]]:
|
||||
"""Get the next item from the receiver's data stream.
|
||||
|
||||
This method implements async iteration by calling the receive() method
|
||||
and returning its result.
|
||||
|
||||
Returns:
|
||||
tuple[T_co, dict[str, Any]]: The next data item and its associated
|
||||
metadata as returned by receive().
|
||||
|
||||
Raises:
|
||||
Any exceptions that might occur during receive().
|
||||
|
||||
StopAsyncIteration: When the receiver has no more data.
|
||||
Any other exceptions that might occur during receive().
|
||||
"""
|
||||
while True:
|
||||
data, meta = await self.receive()
|
||||
yield data, meta
|
||||
data, meta = await self.receive()
|
||||
return data, meta
|
||||
|
||||
async def __aenter__(self) -> "Receiver[T_co]":
|
||||
"""Initialize the source for use in an async context manager.
|
||||
|
||||
50
tests/core/test_receiver.py
Normal file
50
tests/core/test_receiver.py
Normal file
@@ -0,0 +1,50 @@
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from heisskleber import Receiver
|
||||
|
||||
|
||||
class MockReceiver(Receiver):
|
||||
def __init__(self) -> None:
|
||||
self.n_called = 0
|
||||
|
||||
async def receive(self) -> tuple[bool, dict[str, Any]]:
|
||||
self.n_called += 1
|
||||
return True, {"msg": "Called MockReceiver", "count": self.n_called}
|
||||
|
||||
async def start(self) -> None:
|
||||
return
|
||||
|
||||
async def stop(self) -> None:
|
||||
return
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "MockReceiver"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mock_receiver_can_be_iterated_over() -> None:
|
||||
count = 1
|
||||
|
||||
async for data, meta in MockReceiver():
|
||||
assert data
|
||||
assert "msg" in meta
|
||||
assert meta["count"] == count
|
||||
count += 1
|
||||
if count == 3:
|
||||
break
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mock_receiver_call_anext() -> None:
|
||||
receiver = MockReceiver()
|
||||
|
||||
data, meta = await anext(receiver)
|
||||
|
||||
assert data
|
||||
assert meta["count"] == 1
|
||||
|
||||
data, meta = await anext(receiver)
|
||||
|
||||
assert meta["count"] == 2
|
||||
Reference in New Issue
Block a user