From 51c26cd44a7ab24749027e7de66585d2a5d9ed49 Mon Sep 17 00:00:00 2001 From: Philip Couling Date: Tue, 13 Feb 2024 18:56:37 +0000 Subject: [PATCH] Fixed inheritance (#6) --- dataclass_click/dataclass_click.py | 3 +-- tests/test_end_to_end.py | 20 ++++++++++++++++++++ 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/dataclass_click/dataclass_click.py b/dataclass_click/dataclass_click.py index 4c1cd02..366fbd2 100644 --- a/dataclass_click/dataclass_click.py +++ b/dataclass_click/dataclass_click.py @@ -9,7 +9,6 @@ import dataclasses import functools -import inspect import operator import types import typing @@ -225,7 +224,7 @@ def _collect_click_annotations(arg_class: typing.Type[Arg]) -> dict[str, _Delaye :param arg_class: Dataclass to analyze :return: A dictionary _DelayedCall keyed by attribute names""" annotations: dict[str, _DelayedCall] = {} - for key, value in inspect.get_annotations(arg_class).items(): + for key, value in typing.get_type_hints(arg_class, include_extras=True).items(): if typing.get_origin(value) is typing.Annotated: for annotation in typing.get_args(value): if isinstance(annotation, _DelayedCall): diff --git a/tests/test_end_to_end.py b/tests/test_end_to_end.py index adee171..1be7755 100644 --- a/tests/test_end_to_end.py +++ b/tests/test_end_to_end.py @@ -233,3 +233,23 @@ def main(*args, **kwargs): results: list[CallRecord] = [] quick_run(main) assert results == [((), {"foo": Config(bar=None)})] + + +def test_inheritance(): + + @dataclass() + class Parent: + foo: Annotated[int | None, option()] + + @dataclass + class Config(Parent): + bar: Annotated[int | None, option()] + + @click.command() + @dataclass_click(Config) + def main(*args, **kwargs): + results.append((args, kwargs)) + + results: list[CallRecord] = [] + quick_run(main, "--foo", "10", "--bar", "20") + assert results == [((Config(foo=10, bar=20), ), {})]