Correctly check literal values in configuration. (#205)

This commit is contained in:
Felix Weiler
2025-02-15 16:29:40 +01:00
committed by GitHub
parent 50e2fda5f1
commit 7971d09147
3 changed files with 54 additions and 8 deletions

View File

@@ -3,7 +3,7 @@
import logging
from dataclasses import dataclass, fields
from pathlib import Path
from typing import Any, TextIO, TypeVar, Union
from typing import Any, Literal, TextIO, TypeVar, Union, get_args, get_origin
import yaml # type: ignore[import-untyped]
@@ -48,6 +48,16 @@ def _parser(path: Path) -> dict[str, Any]:
raise ValueError
def _check_type(value: Any, expected_type: Any) -> bool:
origin = get_origin(expected_type)
if origin is Literal: # Explicitly check literal
if value not in get_args(expected_type):
logger.exception("%s is not part of %s", value, get_args(expected_type))
raise TypeError
return True
return isinstance(value, expected_type)
@dataclass
class BaseConf:
"""Default configuration class for generic configuration info."""
@@ -58,13 +68,11 @@ class BaseConf:
value = getattr(self, field.name)
if value is None: # Allow optional fields
continue
if not isinstance(value, field.type): # Failed field comparison
raise TypeError
if ( # Failed Union comparison
hasattr(field.type, "__origin__")
and field.type.__origin__ is Union
and not any(isinstance(value, t) for t in field.type.__args__)
):
if hasattr(field.type, "__origin__") and field.type.__origin__ is Union:
if not any(_check_type(value, t) for t in field.type.__args__):
raise TypeError
continue
if not _check_type(value, field.type): # Failed field comparison
raise TypeError
@classmethod

View File

@@ -1,8 +1,11 @@
import logging
from dataclasses import dataclass
from typing import Literal
import pytest
from heisskleber.core import BaseConf
from heisskleber.core.config import _check_type
@dataclass
@@ -53,3 +56,29 @@ def test_conf_from_file() -> None:
assert test_conf.name == "Frodo"
assert test_conf.age == 30
assert test_conf.speed == 0.5
@dataclass
class ConfigWithLiteral(BaseConf):
direction: Literal["U", "D"] = "U"
def test_parses_literal() -> None:
test_dict = {"direction": "U"}
expected = ConfigWithLiteral(direction="U")
config = ConfigWithLiteral.from_dict(test_dict)
assert config == expected
def test_logs_literal_error(caplog: pytest.LogCaptureFixture) -> None:
caplog.set_level(logging.ERROR)
expected = Literal[1, 2, 3]
with pytest.raises(TypeError):
_check_type(4, expected)
assert len(caplog.records) == 1
assert "4 is not part of (1, 2, 3)" in caplog.text

View File

@@ -0,0 +1,9 @@
from heisskleber.serial import SerialConf
def test_serial_config() -> None:
config_dict = {"port": "/test/serial", "baudrate": 5000, "parity": "N", "stopbits": 1}
config = SerialConf.from_dict(config_dict)
assert config == SerialConf(port="/test/serial", baudrate=5000, parity="N", stopbits=1)