Skip to content

Commit

Permalink
feat!: raise TypeError if config field does not match declared type
Browse files Browse the repository at this point in the history
This makes parsing stricter and could result in errors in some existing
configs. However, it allows for more precise deserialization, especially in case
of union types.
  • Loading branch information
eginhard committed Jan 11, 2025
1 parent b41b0e1 commit b549b44
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 24 deletions.
50 changes: 40 additions & 10 deletions coqpit/coqpit.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import json
import operator
import typing
import warnings
from collections.abc import Callable, ItemsView, Iterable, Iterator, MutableMapping
from dataclasses import MISSING as _MISSING
from dataclasses import Field, asdict, dataclass, fields, is_dataclass, replace
Expand Down Expand Up @@ -182,6 +183,9 @@ def _deserialize_dict(x: dict[Any, Any]) -> dict[Any, Any]:
Returns:
Dict: deserialized dictionary.
"""
if not isinstance(x, dict):
msg = f"Value `{x}` is not a dictionary"
raise TypeError(msg)
out_dict: dict[Any, Any] = {}
for k, v in x.items():
if v is None: # if {'key':None}
Expand All @@ -204,6 +208,9 @@ def _deserialize_list(x: list[Any], field_type: FieldType) -> list[Any]:
Returns:
[List]: deserialized list.
"""
if not isinstance(x, list):
msg = f"Value `{x}` does not match field type `{field_type}`"
raise TypeError(msg)
field_args = typing.get_args(field_type)
if len(field_args) == 0:
return x
Expand Down Expand Up @@ -232,7 +239,7 @@ def _deserialize_union(x: Any, field_type: UnionType) -> Any:
try:
x = _deserialize(x, arg)
break
except ValueError:
except (TypeError, ValueError):
pass
return x

Expand All @@ -252,18 +259,30 @@ def _deserialize_primitive_types(
Returns:
Union[int, float, str, bool]: deserialized value.
"""
if isinstance(x, str | bool):
base_type = _drop_none_type(field_type)
if base_type is not float and base_type is not int and base_type is not str and base_type is not bool:
raise TypeError
base_type = typing.cast(type[int | float | str | bool], base_type)

type_mismatch = f"Value `{x}` does not match field type `{field_type}`"
if x is None and type(None) in typing.get_args(field_type):
return None
if isinstance(x, str):
if base_type is not str:
raise TypeError(type_mismatch)
return x
if isinstance(x, bool):
if base_type is not bool:
raise TypeError(type_mismatch)
return x
if isinstance(x, int | float):
base_type = _drop_none_type(field_type)
if base_type is not float and base_type is not int and base_type is not str and base_type is not bool:
raise TypeError
base_type = typing.cast(type[int | float | str | bool], base_type)
if x == float("inf") or x == float("-inf"):
# if value type is inf return regardless.
return x
if base_type is not int and base_type is not float:
raise TypeError(type_mismatch)
return base_type(x)
return None
raise TypeError(type_mismatch)


def _deserialize_path(x: Any, field_type: FieldType) -> Path | None:
Expand Down Expand Up @@ -299,8 +318,8 @@ def _deserialize(x: Any, field_type: FieldType) -> Any:
return _deserialize_path(x, field_type)
if _is_primitive_type(_drop_none_type(field_type)):
return _deserialize_primitive_types(x, field_type)
msg = f" [!] '{type(x)}' value type of '{x}' does not match '{field_type}' field type."
raise ValueError(msg)
msg = f"Type '{type(x)}' of value '{x}' does not match declared '{field_type}' field type."
raise TypeError(msg)


CoqpitType: TypeAlias = MutableMapping[str, "CoqpitNestedValue"]
Expand Down Expand Up @@ -433,7 +452,18 @@ def deserialize(self, data: dict[str, Any]) -> Self:
if value == MISSING:
msg = f"deserialized with unknown value for {field.name} in {self.__class__.__name__}"
raise ValueError(msg)
value = _deserialize(value, field.type)
try:
value = _deserialize(value, field.type)
except TypeError:
warnings.warn(
(
f"Type mismatch in {type(self).__name__}\n"
f"Failed to deserialize field: {field.name} ({field.type}) = {value}\n"
f"Replaced it with field's default value: {_default_value(field)}"
),
stacklevel=2,
)
value = _default_value(field)
init_kwargs[field.name] = value
for k, v in init_kwargs.items():
setattr(self, k, v)
Expand Down
80 changes: 66 additions & 14 deletions tests/test_serialization.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from dataclasses import dataclass, field
from pathlib import Path
from types import UnionType
from typing import Any

import pytest

from coqpit.coqpit import Coqpit, _deserialize_list, _deserialize_primitive_types
from coqpit.coqpit import Coqpit, FieldType, _deserialize_list, _deserialize_primitive_types, _deserialize_union


@dataclass
Expand Down Expand Up @@ -65,36 +67,86 @@ def test_serialization() -> None:
def test_deserialize_list() -> None:
assert _deserialize_list([1, 2, 3], list) == [1, 2, 3]
assert _deserialize_list([1, 2, 3], list[int]) == [1, 2, 3]
assert _deserialize_list([[1, 2, 3]], list[list[int]]) == [[1, 2, 3]]
assert _deserialize_list([1.0, 2.0, 3.0], list[float]) == [1.0, 2.0, 3.0]
assert _deserialize_list([1, 2, 3], list[float]) == [1.0, 2.0, 3.0]
assert _deserialize_list([1, 2, 3], list[str]) == ["1", "2", "3"]
assert _deserialize_list(["1", "2", "3"], list[str]) == ["1", "2", "3"]

with pytest.raises(TypeError, match="does not match field type"):
_deserialize_list([1, 2, 3], list[list[int]])

def test_deserialize_primitive_type() -> None:
cases = (

@pytest.mark.parametrize(
("value", "field_type", "expected"),
[
(True, bool, True),
(False, bool, False),
("a", str, "a"),
("3", str, "3"),
(3, int, 3),
(3, float, 3.0),
(3, str, "3"),
(3.0, str, "3.0"),
(3, bool, True),
("a", str | None, "a"),
("3", str | None, "3"),
(3, int | None, 3),
(3, float | None, 3.0),
(None, str | None, None),
(None, int | None, None),
(None, float | None, None),
(None, str | None, None),
(None, bool | None, None),
(float("inf"), float, float("inf")),
(float("inf"), int, float("inf")),
(float("-inf"), float, float("-inf")),
(float("-inf"), int, float("-inf")),
)
for value, field_type, expected in cases:
assert _deserialize_primitive_types(value, field_type) == expected

with pytest.raises(TypeError):
_deserialize_primitive_types(3, Coqpit)
],
)
def test_deserialize_primitive_type(
value: str | bool | float | None,
field_type: FieldType,
expected: str | bool | float | None,
) -> None:
assert _deserialize_primitive_types(value, field_type) == expected


@pytest.mark.parametrize(
("value", "field_type"),
[
(3, str),
(3, str | None),
(3.0, str),
(3, bool),
("1", int),
("2.0", float),
("True", bool),
("True", bool | None),
("", bool | None),
([1, 2], str),
([1, 2, 3], int),
],
)
def test_deserialize_primitive_type_mismatch(
value: str | bool | float | None,
field_type: FieldType,
) -> None:
with pytest.raises(TypeError, match="does not match field type"):
_deserialize_primitive_types(value, field_type)


@pytest.mark.parametrize(
("value", "field_type", "expected"),
[
("a", int | str, "a"),
("a", str | int, "a"),
(1, int | str, 1),
(1, str | int, 1),
(1, str | int | list[int], 1),
([1, 2], str | int | list[int], [1, 2]),
([1, 2], list[int] | int | str, [1, 2]),
([1, 2], dict | list, [1, 2]),
(["a", "b"], list[str] | list[list[str]], ["a", "b"]),
(["a", "b"], list[list[str]] | list[str], ["a", "b"]),
([["a", "b"]], list[str] | list[list[str]], [["a", "b"]]),
([["a", "b"]], list[list[str]] | list[str], [["a", "b"]]),
],
)
def test_deserialize_union(value: Any, field_type: UnionType, expected: Any) -> None:
assert _deserialize_union(value, field_type) == expected

0 comments on commit b549b44

Please sign in to comment.