Skip to content

Commit

Permalink
Added extra unit tests and type hinting (#3)
Browse files Browse the repository at this point in the history
  • Loading branch information
couling authored Feb 13, 2024
1 parent 18c6668 commit 562e548
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 27 deletions.
6 changes: 3 additions & 3 deletions dataclass_click/dataclass_click.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def dataclass_click(
arg_class can be any class type as long as annotations can be extracted with inspect. Either the arg_class
constructor must accept kwarg arguments to match annotated field names (default for a @dataclass), or a factory
function (callable object) must be passed that accepts those kwarguments and returns an object of arg_class.
function (callable object) must be passed that accepts those kwargs and returns an object of arg_class.
Note that newer annotation types such as PEP 655 ``Required[]`` and ``NotRequired[]`` annotations not
well-supported: ``Annotated`` must be the outermost annotation and other such annotations like ``Required`` and
Expand Down Expand Up @@ -170,7 +170,7 @@ def _patch_click_types(
for key, annotation in annotations.items():
hint: typing.Type[Any]
_, hint = _strip_optional(type_hints[key])
if "type" not in annotation.kwargs and not annotation.kwargs.get("is_flag", False):
if "type" not in annotation.kwargs and not annotation.kwargs.get("is_flag", False):
if hint in complete_type_inferences:
annotation.kwargs["type"] = complete_type_inferences[hint]
else:
Expand Down Expand Up @@ -285,4 +285,4 @@ def register_type_inference(
These will be added to the resulting decorator in the order the attribute appears on the dataclass
Arguments are almost identical to click.option(), but do not include a name to give to the python argument"""
Arguments are almost identical to click.option(), but do not include a name to give to the python argument"""
6 changes: 5 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,11 @@ source = [
based_on_style = "pep8"
column_limit = 120
split_before_first_argument = true
each_dict_entry_on_separate_line = false

[tool.mypy]
packages = ["dataclass_click"]
packages = [
"dataclass_click",
"tests",
]
check_untyped_defs = true
97 changes: 74 additions & 23 deletions tests/test_end_to_end.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from dataclasses import dataclass
from decimal import Decimal
from typing import Annotated, Any
import io

import click
import pytest
from click.testing import CliRunner

from dataclass_click import dataclass_click, option
from dataclass_click import _dataclass_click, dataclass_click, option, register_type_inference

CallRecord = tuple[tuple[Any, ...], dict[str, Any]]


def quick_run(command, *args: str, expect_exit_code: int = 0) -> None:
Expand All @@ -29,7 +32,7 @@ class Config:
def main(*args, **kwargs):
results.append((args, kwargs))

results = []
results: list[CallRecord] = []
quick_run(main, "--foo", "a", "--bar", "b", "--baz", "c")
assert results == [((Config(foo="a"), ), {"bar": "b", "baz": "c"})]

Expand All @@ -38,14 +41,14 @@ def test_types_can_be_inferred(inferrable_type, example_value_for_inferrable_typ

@dataclass
class Config:
foo: Annotated[inferrable_type, option("--foo")]
foo: Annotated[inferrable_type, option("--foo")] # type: ignore

@click.command()
@dataclass_click(Config)
def main(*args, **kwargs):
results.append((args, kwargs))

results = []
results: list[CallRecord] = []
if hasattr(example_value_for_inferrable_type, "isoformat"):
str_value = example_value_for_inferrable_type.isoformat()
else:
Expand Down Expand Up @@ -104,7 +107,7 @@ class Config:
def main(*args, **kwargs):
results.append((args, kwargs))

results = []
results: list[CallRecord] = []
quick_run(main, "--foo", "10")
assert results == [((Config(foo=10), ), {})]

Expand All @@ -121,18 +124,20 @@ class Config:
def main(*args, **kwargs):
results.append((args, kwargs))

results = []
results: list[CallRecord] = []
quick_run(main, "--foo", "10")
assert results == [((Config(baz=10), ), {})]


@pytest.mark.parametrize(["args", "expect"], [
({}, 2),
({"required": True}, 2),
({"required": False}, 0),
({"default": 10}, 0),
({"default": 10, "required": False}, 0)
], ids=["neither", "required-true", "required-false", "default", "both"])
@pytest.mark.parametrize(
["args", "expect"], [
({}, 2),
({"required": True}, 2),
({"required": False}, 0),
({"default": 10}, 0),
({"default": 10, "required": False}, 0),
],
ids=["neither", "required-true", "required-false", "default", "both"])
def test_inferred_required(args: dict[str, Any], expect: int):

@dataclass
Expand All @@ -147,14 +152,13 @@ def main(*args, **kwargs):
quick_run(main, expect_exit_code=expect)


@pytest.mark.parametrize(["args", "expect"], [
({}, 0),
({"required": True}, 2),
({"required": False}, 0),
({"default": 10}, 0),
({"default": 10, "required": False}, 0)
], ids=["neither", "required-true", "required-false", "default", "both"])
def test_inferred_required(args: dict[str, Any], expect: int):
@pytest.mark.parametrize(
["args", "expect"], [
({}, 0), ({"required": True}, 2), ({"required": False}, 0), ({"default": 10}, 0),
({"default": 10, "required": False}, 0)
],
ids=["neither", "required-true", "required-false", "default", "both"])
def test_inferred_not_required(args: dict[str, Any], expect: int):

@dataclass
class Config:
Expand All @@ -168,7 +172,38 @@ def main(*args, **kwargs):
quick_run(main, expect_exit_code=expect)


class DecimalParamType(click.ParamType):

def convert(self, value: Any, param: click.Parameter | None, ctx: click.Context | None) -> Decimal:
if isinstance(value, Decimal):
return value
return Decimal(value)


def test_patch_type_inference(monkeypatch):
monkeypatch.setattr(_dataclass_click, "_TYPE_INFERENCE", _dataclass_click._TYPE_INFERENCE.copy())

@dataclass
class Config:
imply_required: Annotated[Decimal, option()]

with pytest.raises(TypeError):

@click.command()
@dataclass_click(Config)
def main(*args, **kwargs):
pass

register_type_inference(Decimal, DecimalParamType())

@click.command()
@dataclass_click(Config)
def main_2(*args, **kwargs):
pass


def test_dataclass_can_be_used_twice():

@dataclass
class Config:
imply_required: Annotated[int, option()]
Expand All @@ -181,4 +216,20 @@ def main_1(*args, **kwargs):
@click.command()
@dataclass_click(Config)
def main_2(*args, **kwargs):
pass
pass


def test_keyword_name():

@dataclass
class Config:
bar: Annotated[int | None, option()]

@click.command()
@dataclass_click(Config, kw_name="foo")
def main(*args, **kwargs):
results.append((args, kwargs))

results: list[CallRecord] = []
quick_run(main)
assert results == [((), {"foo": Config(bar=None)})]

0 comments on commit 562e548

Please sign in to comment.