From bcbca9ca2e45f04728cbc066d30f855b2c90662c Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Wed, 25 Oct 2023 11:41:31 -0400 Subject: [PATCH] do not add type of default if it is already in type (#40) --- dargs/dargs.py | 19 ++++++++++++++++++- dargs/sphinx.py | 8 +++++++- tests/test_checker.py | 3 +++ 3 files changed, 28 insertions(+), 2 deletions(-) diff --git a/dargs/dargs.py b/dargs/dargs.py index dc70b0d..6a7c87c 100644 --- a/dargs/dargs.py +++ b/dargs/dargs.py @@ -224,7 +224,11 @@ def _reorg_dtype(self): # check conner cases if self.sub_fields or self.sub_variants: self.dtype.add(list if self.repeat else dict) - if self.optional and self.default is not _Flags.NONE: + if ( + self.optional + and self.default is not _Flags.NONE + and all([not isinstance_annotation(self.default, tt) for tt in self.dtype]) + ): self.dtype.add(type(self.default)) # and make it compatible with `isinstance` self.dtype = tuple(self.dtype) @@ -968,6 +972,19 @@ def trim_by_pattern( argdict.pop(key) +def isinstance_annotation(value, dtype) -> bool: + """Same as isinstance(), but supports arbitrary type annotations.""" + try: + typeguard.check_type( + value, + dtype, + collection_check_strategy=typeguard.CollectionCheckStrategy.ALL_ITEMS, + ) + except typeguard.TypeCheckError as e: + return False + return True + + class ArgumentEncoder(json.JSONEncoder): """Extended JSON Encoder to encode Argument object: diff --git a/dargs/sphinx.py b/dargs/sphinx.py index 80a80e3..5cc9e7f 100644 --- a/dargs/sphinx.py +++ b/dargs/sphinx.py @@ -192,5 +192,11 @@ def _test_arguments() -> List[Argument]: return [ Argument(name="test1", dtype=int, doc="Argument 1"), Argument(name="test2", dtype=[float, None], doc="Argument 2"), - Argument(name="test3", dtype=List[str], doc="Argument 3"), + Argument( + name="test3", + dtype=List[str], + default=["test"], + optional=True, + doc="Argument 3", + ), ] diff --git a/tests/test_checker.py b/tests/test_checker.py index 088a8b0..62be7ec 100644 --- a/tests/test_checker.py +++ b/tests/test_checker.py @@ -31,6 +31,9 @@ def test_name_type(self): # list[int] ca = Argument("key1", List[float]) ca.check({"key1": [1, 2.0, 3]}) + with self.assertRaises(ArgumentTypeError): + ca.check({"key1": [1, 2.0, "3"]}) + ca = Argument("key1", List[float], default=[], optional=True) with self.assertRaises(ArgumentTypeError): ca.check({"key1": [1, 2.0, "3"]}) # optional case