diff --git a/dataclass_click/dataclass_click.py b/dataclass_click/dataclass_click.py index 9e24f96..4c1cd02 100644 --- a/dataclass_click/dataclass_click.py +++ b/dataclass_click/dataclass_click.py @@ -79,7 +79,7 @@ def dataclass_click( arg_class can be any class type as long as annotations can be extracted with inspect. Either the arg_class constructor must accept kwarg arguments to match annotated field names (default for a @dataclass), or a factory - function (callable object) must be passed that accepts those kwarguments and returns an object of arg_class. + function (callable object) must be passed that accepts those kwargs and returns an object of arg_class. Note that newer annotation types such as PEP 655 ``Required[]`` and ``NotRequired[]`` annotations not well-supported: ``Annotated`` must be the outermost annotation and other such annotations like ``Required`` and @@ -170,7 +170,7 @@ def _patch_click_types( for key, annotation in annotations.items(): hint: typing.Type[Any] _, hint = _strip_optional(type_hints[key]) - if "type" not in annotation.kwargs and not annotation.kwargs.get("is_flag", False): + if "type" not in annotation.kwargs and not annotation.kwargs.get("is_flag", False): if hint in complete_type_inferences: annotation.kwargs["type"] = complete_type_inferences[hint] else: @@ -285,4 +285,4 @@ def register_type_inference( These will be added to the resulting decorator in the order the attribute appears on the dataclass -Arguments are almost identical to click.option(), but do not include a name to give to the python argument""" \ No newline at end of file +Arguments are almost identical to click.option(), but do not include a name to give to the python argument""" diff --git a/pyproject.toml b/pyproject.toml index 9db9842..57b753e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,7 +37,11 @@ source = [ based_on_style = "pep8" column_limit = 120 split_before_first_argument = true +each_dict_entry_on_separate_line = false [tool.mypy] -packages = ["dataclass_click"] +packages = [ + "dataclass_click", + "tests", +] check_untyped_defs = true diff --git a/tests/test_end_to_end.py b/tests/test_end_to_end.py index e7c36ad..adee171 100644 --- a/tests/test_end_to_end.py +++ b/tests/test_end_to_end.py @@ -1,11 +1,14 @@ from dataclasses import dataclass +from decimal import Decimal from typing import Annotated, Any -import io + import click import pytest from click.testing import CliRunner -from dataclass_click import dataclass_click, option +from dataclass_click import _dataclass_click, dataclass_click, option, register_type_inference + +CallRecord = tuple[tuple[Any, ...], dict[str, Any]] def quick_run(command, *args: str, expect_exit_code: int = 0) -> None: @@ -29,7 +32,7 @@ class Config: def main(*args, **kwargs): results.append((args, kwargs)) - results = [] + results: list[CallRecord] = [] quick_run(main, "--foo", "a", "--bar", "b", "--baz", "c") assert results == [((Config(foo="a"), ), {"bar": "b", "baz": "c"})] @@ -38,14 +41,14 @@ def test_types_can_be_inferred(inferrable_type, example_value_for_inferrable_typ @dataclass class Config: - foo: Annotated[inferrable_type, option("--foo")] + foo: Annotated[inferrable_type, option("--foo")] # type: ignore @click.command() @dataclass_click(Config) def main(*args, **kwargs): results.append((args, kwargs)) - results = [] + results: list[CallRecord] = [] if hasattr(example_value_for_inferrable_type, "isoformat"): str_value = example_value_for_inferrable_type.isoformat() else: @@ -104,7 +107,7 @@ class Config: def main(*args, **kwargs): results.append((args, kwargs)) - results = [] + results: list[CallRecord] = [] quick_run(main, "--foo", "10") assert results == [((Config(foo=10), ), {})] @@ -121,18 +124,20 @@ class Config: def main(*args, **kwargs): results.append((args, kwargs)) - results = [] + results: list[CallRecord] = [] quick_run(main, "--foo", "10") assert results == [((Config(baz=10), ), {})] -@pytest.mark.parametrize(["args", "expect"], [ - ({}, 2), - ({"required": True}, 2), - ({"required": False}, 0), - ({"default": 10}, 0), - ({"default": 10, "required": False}, 0) -], ids=["neither", "required-true", "required-false", "default", "both"]) +@pytest.mark.parametrize( + ["args", "expect"], [ + ({}, 2), + ({"required": True}, 2), + ({"required": False}, 0), + ({"default": 10}, 0), + ({"default": 10, "required": False}, 0), + ], + ids=["neither", "required-true", "required-false", "default", "both"]) def test_inferred_required(args: dict[str, Any], expect: int): @dataclass @@ -147,14 +152,13 @@ def main(*args, **kwargs): quick_run(main, expect_exit_code=expect) -@pytest.mark.parametrize(["args", "expect"], [ - ({}, 0), - ({"required": True}, 2), - ({"required": False}, 0), - ({"default": 10}, 0), - ({"default": 10, "required": False}, 0) -], ids=["neither", "required-true", "required-false", "default", "both"]) -def test_inferred_required(args: dict[str, Any], expect: int): +@pytest.mark.parametrize( + ["args", "expect"], [ + ({}, 0), ({"required": True}, 2), ({"required": False}, 0), ({"default": 10}, 0), + ({"default": 10, "required": False}, 0) + ], + ids=["neither", "required-true", "required-false", "default", "both"]) +def test_inferred_not_required(args: dict[str, Any], expect: int): @dataclass class Config: @@ -168,7 +172,38 @@ def main(*args, **kwargs): quick_run(main, expect_exit_code=expect) +class DecimalParamType(click.ParamType): + + def convert(self, value: Any, param: click.Parameter | None, ctx: click.Context | None) -> Decimal: + if isinstance(value, Decimal): + return value + return Decimal(value) + + +def test_patch_type_inference(monkeypatch): + monkeypatch.setattr(_dataclass_click, "_TYPE_INFERENCE", _dataclass_click._TYPE_INFERENCE.copy()) + + @dataclass + class Config: + imply_required: Annotated[Decimal, option()] + + with pytest.raises(TypeError): + + @click.command() + @dataclass_click(Config) + def main(*args, **kwargs): + pass + + register_type_inference(Decimal, DecimalParamType()) + + @click.command() + @dataclass_click(Config) + def main_2(*args, **kwargs): + pass + + def test_dataclass_can_be_used_twice(): + @dataclass class Config: imply_required: Annotated[int, option()] @@ -181,4 +216,20 @@ def main_1(*args, **kwargs): @click.command() @dataclass_click(Config) def main_2(*args, **kwargs): - pass \ No newline at end of file + pass + + +def test_keyword_name(): + + @dataclass + class Config: + bar: Annotated[int | None, option()] + + @click.command() + @dataclass_click(Config, kw_name="foo") + def main(*args, **kwargs): + results.append((args, kwargs)) + + results: list[CallRecord] = [] + quick_run(main) + assert results == [((), {"foo": Config(bar=None)})]