mirror of
https://github.com/OMGeeky/flucto-heisskleber.git
synced 2025-12-26 16:07:50 +01:00
Correctly check literal values in configuration. (#205)
This commit is contained in:
@@ -3,7 +3,7 @@
|
||||
import logging
|
||||
from dataclasses import dataclass, fields
|
||||
from pathlib import Path
|
||||
from typing import Any, TextIO, TypeVar, Union
|
||||
from typing import Any, Literal, TextIO, TypeVar, Union, get_args, get_origin
|
||||
|
||||
import yaml # type: ignore[import-untyped]
|
||||
|
||||
@@ -48,6 +48,16 @@ def _parser(path: Path) -> dict[str, Any]:
|
||||
raise ValueError
|
||||
|
||||
|
||||
def _check_type(value: Any, expected_type: Any) -> bool:
|
||||
origin = get_origin(expected_type)
|
||||
if origin is Literal: # Explicitly check literal
|
||||
if value not in get_args(expected_type):
|
||||
logger.exception("%s is not part of %s", value, get_args(expected_type))
|
||||
raise TypeError
|
||||
return True
|
||||
return isinstance(value, expected_type)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseConf:
|
||||
"""Default configuration class for generic configuration info."""
|
||||
@@ -58,13 +68,11 @@ class BaseConf:
|
||||
value = getattr(self, field.name)
|
||||
if value is None: # Allow optional fields
|
||||
continue
|
||||
if not isinstance(value, field.type): # Failed field comparison
|
||||
raise TypeError
|
||||
if ( # Failed Union comparison
|
||||
hasattr(field.type, "__origin__")
|
||||
and field.type.__origin__ is Union
|
||||
and not any(isinstance(value, t) for t in field.type.__args__)
|
||||
):
|
||||
if hasattr(field.type, "__origin__") and field.type.__origin__ is Union:
|
||||
if not any(_check_type(value, t) for t in field.type.__args__):
|
||||
raise TypeError
|
||||
continue
|
||||
if not _check_type(value, field.type): # Failed field comparison
|
||||
raise TypeError
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal
|
||||
|
||||
import pytest
|
||||
|
||||
from heisskleber.core import BaseConf
|
||||
from heisskleber.core.config import _check_type
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -53,3 +56,29 @@ def test_conf_from_file() -> None:
|
||||
assert test_conf.name == "Frodo"
|
||||
assert test_conf.age == 30
|
||||
assert test_conf.speed == 0.5
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConfigWithLiteral(BaseConf):
|
||||
direction: Literal["U", "D"] = "U"
|
||||
|
||||
|
||||
def test_parses_literal() -> None:
|
||||
test_dict = {"direction": "U"}
|
||||
expected = ConfigWithLiteral(direction="U")
|
||||
|
||||
config = ConfigWithLiteral.from_dict(test_dict)
|
||||
|
||||
assert config == expected
|
||||
|
||||
|
||||
def test_logs_literal_error(caplog: pytest.LogCaptureFixture) -> None:
|
||||
caplog.set_level(logging.ERROR)
|
||||
|
||||
expected = Literal[1, 2, 3]
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
_check_type(4, expected)
|
||||
|
||||
assert len(caplog.records) == 1
|
||||
assert "4 is not part of (1, 2, 3)" in caplog.text
|
||||
|
||||
9
tests/serial/test_serial_config.py
Normal file
9
tests/serial/test_serial_config.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from heisskleber.serial import SerialConf
|
||||
|
||||
|
||||
def test_serial_config() -> None:
|
||||
config_dict = {"port": "/test/serial", "baudrate": 5000, "parity": "N", "stopbits": 1}
|
||||
|
||||
config = SerialConf.from_dict(config_dict)
|
||||
|
||||
assert config == SerialConf(port="/test/serial", baudrate=5000, parity="N", stopbits=1)
|
||||
Reference in New Issue
Block a user