Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat!: raise TypeError if config field does not match declared type #7

Merged
merged 5 commits into from
Feb 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions .github/actions/setup-uv/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,6 @@ runs:
using: 'composite'
steps:
- name: Install uv
uses: astral-sh/setup-uv@v4
uses: astral-sh/setup-uv@v5
with:
version: "0.5.10"
enable-cache: true
cache-dependency-glob: "**/pyproject.toml"
python-version: ${{ matrix.python-version }}
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
uv.lock
tests/test_serialization.json

WadaSNR/
.idea/
Expand Down
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.8.3
rev: v0.9.3
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
Expand All @@ -13,7 +13,7 @@ repos:
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.13.0
rev: v1.14.1
hooks:
- id: mypy
args: [--strict]
Expand Down
117 changes: 102 additions & 15 deletions coqpit/coqpit.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,17 @@

import argparse
import contextlib
import functools
import json
import operator
import typing
import warnings
from collections.abc import Callable, ItemsView, Iterable, Iterator, MutableMapping
from dataclasses import MISSING as _MISSING
from dataclasses import Field, asdict, dataclass, fields, is_dataclass, replace
from pathlib import Path
from pprint import pprint
from types import GenericAlias, UnionType
from types import UnionType
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeAlias, TypeGuard, TypeVar, Union, overload

from typing_extensions import Self, TypeIs
Expand Down Expand Up @@ -61,6 +63,27 @@ def _is_list(field_type: FieldType) -> TypeGuard[type]:
return field_type is list or typing.get_origin(field_type) is list


def _parse_list_union(field_type: FieldType) -> type | None:
"""Check if the input type matches `_T | list[_T]`.

Args:
field_type: input type.

Returns:
bool: _T, if input type matches `_T | list[_T]`, else None
"""
if not _is_union(field_type):
return None

def _get_base_type(field_type: type) -> type:
return typing.get_args(field_type)[0] if _is_list(field_type) else field_type

args = typing.get_args(field_type)
is_list = [_is_list(arg) for arg in args]
base_types = [_get_base_type(arg) for arg in args]
return base_types[0] if len(args) == 2 and sum(is_list) == 1 and base_types[0] == base_types[1] else None # noqa: PLR2004


def _is_dict(field_type: FieldType) -> TypeGuard[type]:
"""Check if the input type is `dict`.

Expand Down Expand Up @@ -142,13 +165,12 @@ def _drop_none_type(field_type: FieldType) -> FieldType:
"""
if not _is_union(field_type):
return field_type
origin = typing.get_origin(field_type)
args = list(typing.get_args(field_type))
if type(None) in args:
args.remove(type(None))
if len(args) == 1:
return typing.cast(type, args[0])
return typing.cast("UnionType", GenericAlias(origin, args))
return typing.cast(UnionType, functools.reduce(lambda a, b: a | b, args))


def _serialize(x: Any) -> Any:
Expand Down Expand Up @@ -182,6 +204,9 @@ def _deserialize_dict(x: dict[Any, Any]) -> dict[Any, Any]:
Returns:
Dict: deserialized dictionary.
"""
if not isinstance(x, dict):
msg = f"Value `{x}` is not a dictionary"
raise TypeError(msg)
out_dict: dict[Any, Any] = {}
for k, v in x.items():
if v is None: # if {'key':None}
Expand All @@ -204,6 +229,9 @@ def _deserialize_list(x: list[Any], field_type: FieldType) -> list[Any]:
Returns:
[List]: deserialized list.
"""
if not isinstance(x, list):
msg = f"Value `{x}` does not match field type `{field_type}`"
raise TypeError(msg)
field_args = typing.get_args(field_type)
if len(field_args) == 0:
return x
Expand Down Expand Up @@ -232,7 +260,7 @@ def _deserialize_union(x: Any, field_type: UnionType) -> Any:
try:
x = _deserialize(x, arg)
break
except ValueError:
except (TypeError, ValueError):
pass
return x

Expand All @@ -252,18 +280,30 @@ def _deserialize_primitive_types(
Returns:
Union[int, float, str, bool]: deserialized value.
"""
if isinstance(x, str | bool):
base_type = _drop_none_type(field_type)
if base_type is not float and base_type is not int and base_type is not str and base_type is not bool:
raise TypeError
base_type = typing.cast(type[int | float | str | bool], base_type)

type_mismatch = f"Value `{x}` does not match field type `{field_type}`"
if x is None and type(None) in typing.get_args(field_type):
return None
if isinstance(x, str):
if base_type is not str:
raise TypeError(type_mismatch)
return x
if isinstance(x, bool):
if base_type is not bool:
raise TypeError(type_mismatch)
return x
if isinstance(x, int | float):
base_type = _drop_none_type(field_type)
if base_type is not float and base_type is not int and base_type is not str and base_type is not bool:
raise TypeError
base_type = typing.cast(type[int | float | str | bool], base_type)
if x == float("inf") or x == float("-inf"):
# if value type is inf return regardless.
return x
if base_type is not int and base_type is not float:
raise TypeError(type_mismatch)
return base_type(x)
return None
raise TypeError(type_mismatch)


def _deserialize_path(x: Any, field_type: FieldType) -> Path | None:
Expand Down Expand Up @@ -299,8 +339,8 @@ def _deserialize(x: Any, field_type: FieldType) -> Any:
return _deserialize_path(x, field_type)
if _is_primitive_type(_drop_none_type(field_type)):
return _deserialize_primitive_types(x, field_type)
msg = f" [!] '{type(x)}' value type of '{x}' does not match '{field_type}' field type."
raise ValueError(msg)
msg = f"Type '{type(x)}' of value '{x}' does not match declared '{field_type}' field type."
raise TypeError(msg)


CoqpitType: TypeAlias = MutableMapping[str, "CoqpitNestedValue"]
Expand Down Expand Up @@ -433,7 +473,18 @@ def deserialize(self, data: dict[str, Any]) -> Self:
if value == MISSING:
msg = f"deserialized with unknown value for {field.name} in {self.__class__.__name__}"
raise ValueError(msg)
value = _deserialize(value, field.type)
try:
value = _deserialize(value, field.type)
except TypeError:
warnings.warn(
(
f"Type mismatch in {type(self).__name__}\n"
f"Failed to deserialize field: {field.name} ({field.type}) = {value}\n"
f"Replaced it with field's default value: {_default_value(field)}"
),
stacklevel=2,
)
value = _default_value(field)
init_kwargs[field.name] = value
for k, v in init_kwargs.items():
setattr(self, k, v)
Expand Down Expand Up @@ -516,6 +567,7 @@ def _add_argument( # noqa: C901, PLR0913, PLR0912, PLR0915
not has_default
and not _is_primitive_type(_drop_none_type(field_type))
and not _is_list(_drop_none_type(field_type))
and _parse_list_union(_drop_none_type(field_type)) is None
):
# aggregate types (fields with a Coqpit subclass as type) are not
# supported without None
Expand All @@ -531,7 +583,6 @@ def _add_argument( # noqa: C901, PLR0913, PLR0912, PLR0915
type=json.loads,
)
elif _is_list(_drop_none_type(field_type)):
# TODO: We need a more clear help msg for lists.
field_args = typing.get_args(_drop_none_type(field_type))
if len(field_args) > 1 and not relaxed_parser:
msg = "Coqpit does not support multi-type hinted 'List'"
Expand Down Expand Up @@ -571,7 +622,43 @@ def _add_argument( # noqa: C901, PLR0913, PLR0912, PLR0915
fv,
field_default_factory,
field_help="",
help_prefix=f"{help_prefix} - ",
help_prefix=f"{help_prefix} (item {idx})",
arg_prefix=f"{arg_prefix}",
relaxed_parser=relaxed_parser,
)
# Fields matching: _T | list[_T] ( | None)
elif (list_field_type := _parse_list_union(_drop_none_type(field_type))) is not None:
if not has_default or field_default_factory is list:
if not _is_primitive_type(list_field_type) and not relaxed_parser:
msg = " [!] Empty list with non primitive inner type is currently not supported."
raise NotImplementedError(msg)

# If the list's default value is None, the user can specify the entire list by passing multiple parameters
parser.add_argument(
f"--{arg_prefix}",
nargs="*",
type=list_field_type,
help=f"Coqpit Field: {help_prefix}",
)
# If a default value is defined, just enable editing the values from argparse
# TODO: allow inserting a new value/obj to the end of the list.
elif not isinstance(default, list):
parser.add_argument(
f"--{arg_prefix}",
default=default,
type=list_field_type,
help=f"Coqpit Field: {help_prefix}",
)
else:
for idx, fv in enumerate(default):
parser = _add_argument(
parser,
str(idx),
list_field_type,
fv,
field_default_factory,
field_help="",
help_prefix=f"{help_prefix} (item {idx})",
arg_prefix=f"{arg_prefix}",
relaxed_parser=relaxed_parser,
)
Expand Down
10 changes: 7 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "hatchling.build"

[project]
name = "coqpit-config"
version = "0.1.2"
version = "0.2.0"
description = "Simple (maybe too simple), light-weight config management through python data-classes."
readme = "README.md"
requires-python = ">=3.10"
Expand Down Expand Up @@ -35,16 +35,19 @@ dependencies = [
[dependency-groups]
dev = [
"coverage>=7",
"mypy>=1.13.0",
"mypy>=1.14.1",
"pre-commit>=4",
"pytest>=8",
"ruff==0.8.3",
"ruff==0.9.3",
]

[project.urls]
Repository = "https://github.com/idiap/coqui-ai-coqpit"
Issues = "https://github.com/idiap/coqui-ai-coqpit/issues"

[tool.uv]
required-version = ">=0.5.0"

[tool.hatch.build]
exclude = [
"/.github",
Expand Down Expand Up @@ -74,6 +77,7 @@ convention = "google"
"tests/**" = [
"D",
"FA100",
"FBT001",
"PLR2004",
"S101",
"SLF001",
Expand Down
14 changes: 13 additions & 1 deletion tests/test_parse_argparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,18 @@ class SimpleConfig(Coqpit):
default_factory=lambda: [SimplerConfig(val_a=100), SimplerConfig(val_a=999)],
metadata={"help": "list of SimplerConfig"},
)
int_list: list[int] = field(default_factory=lambda: [1, 2, 3], metadata={"help": "int"})
int_list: list[int] = field(default_factory=lambda: [1, 2, 3], metadata={"help": "int list"})
str_list: list[str] = field(default_factory=lambda: ["veni", "vidi", "vici"], metadata={"help": "str"})
empty_int_list: list[int] | None = field(default=None, metadata={"help": "int list without default value"})
empty_str_list: list[str] | None = field(default=None, metadata={"help": "str list without default value"})
list_with_default_factory: list[str] = field(
default_factory=list,
metadata={"help": "str list with default factory"},
)
int_or_list: int | list[int] = field(default_factory=lambda: [1, 2, 3])
float_or_list: float | list[float] = field(default=0.1)
str_or_list: str | list[str] | None = field(default=None)
bool_or_list: bool | list[bool] | None = field(default=None)

# TODO: not supported yet
# mylist_without_default: List[SimplerConfig] = field(default=None) noqa: ERA001
Expand All @@ -51,6 +55,10 @@ def test_parse_argparse() -> None:
args.extend(["--coqpit.list_with_default_factory", "blah"])
args.extend(["--coqpit.str_list.0", "neci"])
args.extend(["--coqpit.int_list.1", "4"])
args.extend(["--coqpit.int_or_list.0", "5"])
args.extend(["--coqpit.float_or_list", "3.4"])
args.extend(["--coqpit.str_or_list", "a", "b"])
args.extend(["--coqpit.bool_or_list", "true"])

# initial config
config = SimpleConfig()
Expand All @@ -68,6 +76,10 @@ def test_parse_argparse() -> None:
str_list=["neci", "vidi", "vici"],
int_list=[1, 4, 3],
list_with_default_factory=["blah"],
int_or_list=[5, 2, 3],
float_or_list=3.4,
str_or_list=["a", "b"],
bool_or_list=[True],
)

# create and init argparser with Coqpit
Expand Down
24 changes: 0 additions & 24 deletions tests/test_serialization.json

This file was deleted.

Loading