Skip to content

Commit

Permalink
Merge pull request #7 from idiap/union
Browse files Browse the repository at this point in the history
feat!: raise TypeError if config field does not match declared type
  • Loading branch information
eginhard authored Feb 4, 2025
2 parents b41b0e1 + 60f2d26 commit 0b45b59
Show file tree
Hide file tree
Showing 10 changed files with 600 additions and 65 deletions.
5 changes: 1 addition & 4 deletions .github/actions/setup-uv/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,6 @@ runs:
using: 'composite'
steps:
- name: Install uv
uses: astral-sh/setup-uv@v4
uses: astral-sh/setup-uv@v5
with:
version: "0.5.10"
enable-cache: true
cache-dependency-glob: "**/pyproject.toml"
python-version: ${{ matrix.python-version }}
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
uv.lock
tests/test_serialization.json

WadaSNR/
.idea/
Expand Down
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.8.3
rev: v0.9.3
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
Expand All @@ -13,7 +13,7 @@ repos:
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.13.0
rev: v1.14.1
hooks:
- id: mypy
args: [--strict]
Expand Down
117 changes: 102 additions & 15 deletions coqpit/coqpit.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,17 @@

import argparse
import contextlib
import functools
import json
import operator
import typing
import warnings
from collections.abc import Callable, ItemsView, Iterable, Iterator, MutableMapping
from dataclasses import MISSING as _MISSING
from dataclasses import Field, asdict, dataclass, fields, is_dataclass, replace
from pathlib import Path
from pprint import pprint
from types import GenericAlias, UnionType
from types import UnionType
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeAlias, TypeGuard, TypeVar, Union, overload

from typing_extensions import Self, TypeIs
Expand Down Expand Up @@ -61,6 +63,27 @@ def _is_list(field_type: FieldType) -> TypeGuard[type]:
return field_type is list or typing.get_origin(field_type) is list


def _parse_list_union(field_type: FieldType) -> type | None:
"""Check if the input type matches `_T | list[_T]`.
Args:
field_type: input type.
Returns:
bool: _T, if input type matches `_T | list[_T]`, else None
"""
if not _is_union(field_type):
return None

def _get_base_type(field_type: type) -> type:
return typing.get_args(field_type)[0] if _is_list(field_type) else field_type

args = typing.get_args(field_type)
is_list = [_is_list(arg) for arg in args]
base_types = [_get_base_type(arg) for arg in args]
return base_types[0] if len(args) == 2 and sum(is_list) == 1 and base_types[0] == base_types[1] else None # noqa: PLR2004


def _is_dict(field_type: FieldType) -> TypeGuard[type]:
"""Check if the input type is `dict`.
Expand Down Expand Up @@ -142,13 +165,12 @@ def _drop_none_type(field_type: FieldType) -> FieldType:
"""
if not _is_union(field_type):
return field_type
origin = typing.get_origin(field_type)
args = list(typing.get_args(field_type))
if type(None) in args:
args.remove(type(None))
if len(args) == 1:
return typing.cast(type, args[0])
return typing.cast("UnionType", GenericAlias(origin, args))
return typing.cast(UnionType, functools.reduce(lambda a, b: a | b, args))


def _serialize(x: Any) -> Any:
Expand Down Expand Up @@ -182,6 +204,9 @@ def _deserialize_dict(x: dict[Any, Any]) -> dict[Any, Any]:
Returns:
Dict: deserialized dictionary.
"""
if not isinstance(x, dict):
msg = f"Value `{x}` is not a dictionary"
raise TypeError(msg)
out_dict: dict[Any, Any] = {}
for k, v in x.items():
if v is None: # if {'key':None}
Expand All @@ -204,6 +229,9 @@ def _deserialize_list(x: list[Any], field_type: FieldType) -> list[Any]:
Returns:
[List]: deserialized list.
"""
if not isinstance(x, list):
msg = f"Value `{x}` does not match field type `{field_type}`"
raise TypeError(msg)
field_args = typing.get_args(field_type)
if len(field_args) == 0:
return x
Expand Down Expand Up @@ -232,7 +260,7 @@ def _deserialize_union(x: Any, field_type: UnionType) -> Any:
try:
x = _deserialize(x, arg)
break
except ValueError:
except (TypeError, ValueError):
pass
return x

Expand All @@ -252,18 +280,30 @@ def _deserialize_primitive_types(
Returns:
Union[int, float, str, bool]: deserialized value.
"""
if isinstance(x, str | bool):
base_type = _drop_none_type(field_type)
if base_type is not float and base_type is not int and base_type is not str and base_type is not bool:
raise TypeError
base_type = typing.cast(type[int | float | str | bool], base_type)

type_mismatch = f"Value `{x}` does not match field type `{field_type}`"
if x is None and type(None) in typing.get_args(field_type):
return None
if isinstance(x, str):
if base_type is not str:
raise TypeError(type_mismatch)
return x
if isinstance(x, bool):
if base_type is not bool:
raise TypeError(type_mismatch)
return x
if isinstance(x, int | float):
base_type = _drop_none_type(field_type)
if base_type is not float and base_type is not int and base_type is not str and base_type is not bool:
raise TypeError
base_type = typing.cast(type[int | float | str | bool], base_type)
if x == float("inf") or x == float("-inf"):
# if value type is inf return regardless.
return x
if base_type is not int and base_type is not float:
raise TypeError(type_mismatch)
return base_type(x)
return None
raise TypeError(type_mismatch)


def _deserialize_path(x: Any, field_type: FieldType) -> Path | None:
Expand Down Expand Up @@ -299,8 +339,8 @@ def _deserialize(x: Any, field_type: FieldType) -> Any:
return _deserialize_path(x, field_type)
if _is_primitive_type(_drop_none_type(field_type)):
return _deserialize_primitive_types(x, field_type)
msg = f" [!] '{type(x)}' value type of '{x}' does not match '{field_type}' field type."
raise ValueError(msg)
msg = f"Type '{type(x)}' of value '{x}' does not match declared '{field_type}' field type."
raise TypeError(msg)


CoqpitType: TypeAlias = MutableMapping[str, "CoqpitNestedValue"]
Expand Down Expand Up @@ -433,7 +473,18 @@ def deserialize(self, data: dict[str, Any]) -> Self:
if value == MISSING:
msg = f"deserialized with unknown value for {field.name} in {self.__class__.__name__}"
raise ValueError(msg)
value = _deserialize(value, field.type)
try:
value = _deserialize(value, field.type)
except TypeError:
warnings.warn(
(
f"Type mismatch in {type(self).__name__}\n"
f"Failed to deserialize field: {field.name} ({field.type}) = {value}\n"
f"Replaced it with field's default value: {_default_value(field)}"
),
stacklevel=2,
)
value = _default_value(field)
init_kwargs[field.name] = value
for k, v in init_kwargs.items():
setattr(self, k, v)
Expand Down Expand Up @@ -516,6 +567,7 @@ def _add_argument( # noqa: C901, PLR0913, PLR0912, PLR0915
not has_default
and not _is_primitive_type(_drop_none_type(field_type))
and not _is_list(_drop_none_type(field_type))
and _parse_list_union(_drop_none_type(field_type)) is None
):
# aggregate types (fields with a Coqpit subclass as type) are not
# supported without None
Expand All @@ -531,7 +583,6 @@ def _add_argument( # noqa: C901, PLR0913, PLR0912, PLR0915
type=json.loads,
)
elif _is_list(_drop_none_type(field_type)):
# TODO: We need a more clear help msg for lists.
field_args = typing.get_args(_drop_none_type(field_type))
if len(field_args) > 1 and not relaxed_parser:
msg = "Coqpit does not support multi-type hinted 'List'"
Expand Down Expand Up @@ -571,7 +622,43 @@ def _add_argument( # noqa: C901, PLR0913, PLR0912, PLR0915
fv,
field_default_factory,
field_help="",
help_prefix=f"{help_prefix} - ",
help_prefix=f"{help_prefix} (item {idx})",
arg_prefix=f"{arg_prefix}",
relaxed_parser=relaxed_parser,
)
# Fields matching: _T | list[_T] ( | None)
elif (list_field_type := _parse_list_union(_drop_none_type(field_type))) is not None:
if not has_default or field_default_factory is list:
if not _is_primitive_type(list_field_type) and not relaxed_parser:
msg = " [!] Empty list with non primitive inner type is currently not supported."
raise NotImplementedError(msg)

# If the list's default value is None, the user can specify the entire list by passing multiple parameters
parser.add_argument(
f"--{arg_prefix}",
nargs="*",
type=list_field_type,
help=f"Coqpit Field: {help_prefix}",
)
# If a default value is defined, just enable editing the values from argparse
# TODO: allow inserting a new value/obj to the end of the list.
elif not isinstance(default, list):
parser.add_argument(
f"--{arg_prefix}",
default=default,
type=list_field_type,
help=f"Coqpit Field: {help_prefix}",
)
else:
for idx, fv in enumerate(default):
parser = _add_argument(
parser,
str(idx),
list_field_type,
fv,
field_default_factory,
field_help="",
help_prefix=f"{help_prefix} (item {idx})",
arg_prefix=f"{arg_prefix}",
relaxed_parser=relaxed_parser,
)
Expand Down
10 changes: 7 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "hatchling.build"

[project]
name = "coqpit-config"
version = "0.1.2"
version = "0.2.0"
description = "Simple (maybe too simple), light-weight config management through python data-classes."
readme = "README.md"
requires-python = ">=3.10"
Expand Down Expand Up @@ -35,16 +35,19 @@ dependencies = [
[dependency-groups]
dev = [
"coverage>=7",
"mypy>=1.13.0",
"mypy>=1.14.1",
"pre-commit>=4",
"pytest>=8",
"ruff==0.8.3",
"ruff==0.9.3",
]

[project.urls]
Repository = "https://github.com/idiap/coqui-ai-coqpit"
Issues = "https://github.com/idiap/coqui-ai-coqpit/issues"

[tool.uv]
required-version = ">=0.5.0"

[tool.hatch.build]
exclude = [
"/.github",
Expand Down Expand Up @@ -74,6 +77,7 @@ convention = "google"
"tests/**" = [
"D",
"FA100",
"FBT001",
"PLR2004",
"S101",
"SLF001",
Expand Down
14 changes: 13 additions & 1 deletion tests/test_parse_argparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,18 @@ class SimpleConfig(Coqpit):
default_factory=lambda: [SimplerConfig(val_a=100), SimplerConfig(val_a=999)],
metadata={"help": "list of SimplerConfig"},
)
int_list: list[int] = field(default_factory=lambda: [1, 2, 3], metadata={"help": "int"})
int_list: list[int] = field(default_factory=lambda: [1, 2, 3], metadata={"help": "int list"})
str_list: list[str] = field(default_factory=lambda: ["veni", "vidi", "vici"], metadata={"help": "str"})
empty_int_list: list[int] | None = field(default=None, metadata={"help": "int list without default value"})
empty_str_list: list[str] | None = field(default=None, metadata={"help": "str list without default value"})
list_with_default_factory: list[str] = field(
default_factory=list,
metadata={"help": "str list with default factory"},
)
int_or_list: int | list[int] = field(default_factory=lambda: [1, 2, 3])
float_or_list: float | list[float] = field(default=0.1)
str_or_list: str | list[str] | None = field(default=None)
bool_or_list: bool | list[bool] | None = field(default=None)

# TODO: not supported yet
# mylist_without_default: List[SimplerConfig] = field(default=None) noqa: ERA001
Expand All @@ -51,6 +55,10 @@ def test_parse_argparse() -> None:
args.extend(["--coqpit.list_with_default_factory", "blah"])
args.extend(["--coqpit.str_list.0", "neci"])
args.extend(["--coqpit.int_list.1", "4"])
args.extend(["--coqpit.int_or_list.0", "5"])
args.extend(["--coqpit.float_or_list", "3.4"])
args.extend(["--coqpit.str_or_list", "a", "b"])
args.extend(["--coqpit.bool_or_list", "true"])

# initial config
config = SimpleConfig()
Expand All @@ -68,6 +76,10 @@ def test_parse_argparse() -> None:
str_list=["neci", "vidi", "vici"],
int_list=[1, 4, 3],
list_with_default_factory=["blah"],
int_or_list=[5, 2, 3],
float_or_list=3.4,
str_or_list=["a", "b"],
bool_or_list=[True],
)

# create and init argparser with Coqpit
Expand Down
24 changes: 0 additions & 24 deletions tests/test_serialization.json

This file was deleted.

Loading

0 comments on commit 0b45b59

Please sign in to comment.