diff --git a/src/heisskleber/core/config.py b/src/heisskleber/core/config.py index 04ac42c..88257a0 100644 --- a/src/heisskleber/core/config.py +++ b/src/heisskleber/core/config.py @@ -3,7 +3,7 @@ import logging from dataclasses import dataclass, fields from pathlib import Path -from typing import Any, TextIO, TypeVar, Union +from typing import Any, Literal, TextIO, TypeVar, Union, get_args, get_origin import yaml # type: ignore[import-untyped] @@ -48,6 +48,16 @@ def _parser(path: Path) -> dict[str, Any]: raise ValueError +def _check_type(value: Any, expected_type: Any) -> bool: + origin = get_origin(expected_type) + if origin is Literal: # Explicitly check literal + if value not in get_args(expected_type): + logger.exception("%s is not part of %s", value, get_args(expected_type)) + raise TypeError + return True + return isinstance(value, expected_type) + + @dataclass class BaseConf: """Default configuration class for generic configuration info.""" @@ -58,13 +68,11 @@ class BaseConf: value = getattr(self, field.name) if value is None: # Allow optional fields continue - if not isinstance(value, field.type): # Failed field comparison - raise TypeError - if ( # Failed Union comparison - hasattr(field.type, "__origin__") - and field.type.__origin__ is Union - and not any(isinstance(value, t) for t in field.type.__args__) - ): + if hasattr(field.type, "__origin__") and field.type.__origin__ is Union: + if not any(_check_type(value, t) for t in field.type.__args__): + raise TypeError + continue + if not _check_type(value, field.type): # Failed field comparison raise TypeError @classmethod diff --git a/tests/core/test_config.py b/tests/core/test_config.py index e5041a6..b88451b 100644 --- a/tests/core/test_config.py +++ b/tests/core/test_config.py @@ -1,8 +1,11 @@ +import logging from dataclasses import dataclass +from typing import Literal import pytest from heisskleber.core import BaseConf +from heisskleber.core.config import _check_type @dataclass @@ -53,3 +56,29 @@ def test_conf_from_file() -> None: assert test_conf.name == "Frodo" assert test_conf.age == 30 assert test_conf.speed == 0.5 + + +@dataclass +class ConfigWithLiteral(BaseConf): + direction: Literal["U", "D"] = "U" + + +def test_parses_literal() -> None: + test_dict = {"direction": "U"} + expected = ConfigWithLiteral(direction="U") + + config = ConfigWithLiteral.from_dict(test_dict) + + assert config == expected + + +def test_logs_literal_error(caplog: pytest.LogCaptureFixture) -> None: + caplog.set_level(logging.ERROR) + + expected = Literal[1, 2, 3] + + with pytest.raises(TypeError): + _check_type(4, expected) + + assert len(caplog.records) == 1 + assert "4 is not part of (1, 2, 3)" in caplog.text diff --git a/tests/serial/test_serial_config.py b/tests/serial/test_serial_config.py new file mode 100644 index 0000000..3fb6b6e --- /dev/null +++ b/tests/serial/test_serial_config.py @@ -0,0 +1,9 @@ +from heisskleber.serial import SerialConf + + +def test_serial_config() -> None: + config_dict = {"port": "/test/serial", "baudrate": 5000, "parity": "N", "stopbits": 1} + + config = SerialConf.from_dict(config_dict) + + assert config == SerialConf(port="/test/serial", baudrate=5000, parity="N", stopbits=1)