From d93408596907c96e2909054498db07f6a078ad6c Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Wed, 12 Jun 2024 15:30:10 -0400 Subject: [PATCH] feat: allow repeat with dtype=dict (#67) This will be helpful to normalize `model_dict`, `loss_dict`, etc, in DeePMD-kit. ## Summary by CodeRabbit - **New Features** - Enhanced handling of data types based on the `repeat` flag, supporting both lists of dicts and dicts of dicts. - Improved documentation generation reflecting updated data type handling. - **Bug Fixes** - Corrected traversal logic for values to ensure proper handling of dictionaries and lists based on the `repeat` flag. - **Tests** - Updated and added new test cases to cover changes in data type handling and traversal logic. --------- Signed-off-by: Jinzhe Zeng Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> --- dargs/dargs.py | 45 ++++++++++++++++++++++++++++++++++------ dargs/sphinx.py | 15 ++++++++++++-- tests/test_checker.py | 37 +++++++++++++++++++++++++++++++-- tests/test_docgen.py | 30 ++++++++++++++++++++++++++- tests/test_normalizer.py | 12 ++++++++++- 5 files changed, 127 insertions(+), 12 deletions(-) diff --git a/dargs/dargs.py b/dargs/dargs.py index 2529560..b9a1c9b 100644 --- a/dargs/dargs.py +++ b/dargs/dargs.py @@ -109,7 +109,7 @@ class Argument: If given, `dtype` is assumed to be dict, and its items are determined by the `Variant`s in the given list and the value of their flag keys. repeat: bool, optional - If true, `dtype` is assume to be list of dict and each dict consists + If true, `dtype` is assume to be list of dict or dict of dict, and each dict consists of sub fields and sub variants described above. Defaults to false. optional: bool, optional If true, consider the current argument to be optional in checking. @@ -235,7 +235,17 @@ def _reorg_dtype( } # check conner cases if self.sub_fields or self.sub_variants: - dtype.add(list if self.repeat else dict) + if not self.repeat: + dtype.add(dict) + else: + # convert dtypes to unsubscripted types + unsubscripted_dtype = { + get_origin(dt) if get_origin(dt) is not None else dt for dt in dtype + } + if dict not in unsubscripted_dtype: + # only add list (compatible with old behaviors) if no dict in dtype + dtype.add(list) + if ( self.optional and self.default is not _Flags.NONE @@ -347,11 +357,11 @@ def traverse_value( # in the condition where there is no leading key if path is None: path = [] - if isinstance(value, dict): + if not self.repeat and isinstance(value, dict): self._traverse_sub( value, key_hook, value_hook, sub_hook, variant_hook, path ) - if isinstance(value, list) and self.repeat: + elif self.repeat and isinstance(value, list): for idx, item in enumerate(value): self._traverse_sub( item, @@ -361,6 +371,16 @@ def traverse_value( variant_hook, [*path, str(idx)], ) + elif self.repeat and isinstance(value, dict): + for kk, item in value.items(): + self._traverse_sub( + item, + key_hook, + value_hook, + sub_hook, + variant_hook, + [*path, kk], + ) def _traverse_sub( self, @@ -653,9 +673,22 @@ def gen_doc_body(self, path: list[str] | None = None, **kwargs) -> str: body_list.append(self.doc + "\n") if not self.fold_subdoc: if self.repeat: + unsubscripted_dtype = { + get_origin(dt) if get_origin(dt) is not None else dt + for dt in self.dtype + } + allowed_types = [] + allowed_element = [] + if list in unsubscripted_dtype or dict in unsubscripted_dtype: + if list in unsubscripted_dtype: + allowed_types.append("list") + allowed_element.append("element") + if dict in unsubscripted_dtype: + allowed_types.append("dict") + allowed_element.append("key-value pair") body_list.append( - "This argument takes a list with " - "each element containing the following: \n" + f"This argument takes a {' or '.join(allowed_types)} with " + f"each {' or '.join(allowed_element)} containing the following: \n" ) if self.sub_fields: # body_list.append("") # genetate a blank line diff --git a/dargs/sphinx.py b/dargs/sphinx.py index affb1fa..ac86f3e 100644 --- a/dargs/sphinx.py +++ b/dargs/sphinx.py @@ -190,7 +190,7 @@ def _test_argument() -> Argument: doc=doc_test, sub_fields=[ Argument( - "test_repeat", + "test_repeat_list", dtype=list, repeat=True, doc=doc_test, @@ -199,7 +199,18 @@ def _test_argument() -> Argument: "test_repeat_item", dtype=bool, doc=doc_test ), ], - ) + ), + Argument( + "test_repeat_dict", + dtype=dict, + repeat=True, + doc=doc_test, + sub_fields=[ + Argument( + "test_repeat_item", dtype=bool, doc=doc_test + ), + ], + ), ], ), ], diff --git a/tests/test_checker.py b/tests/test_checker.py index 62fea21..92926be 100644 --- a/tests/test_checker.py +++ b/tests/test_checker.py @@ -100,9 +100,9 @@ def test_sub_fields(self): with self.assertRaises(ValueError): Argument("base", dict, [Argument("sub1", int), Argument("sub1", int)]) - def test_sub_repeat(self): + def test_sub_repeat_list(self): ca = Argument( - "base", dict, [Argument("sub1", int), Argument("sub2", str)], repeat=True + "base", list, [Argument("sub1", int), Argument("sub2", str)], repeat=True ) test_dict1 = { "base": [{"sub1": 10, "sub2": "hello"}, {"sub1": 11, "sub2": "world"}] @@ -124,6 +124,39 @@ def test_sub_repeat(self): with self.assertRaises(ArgumentTypeError): ca.check(err_dict2) + def test_sub_repeat_dict(self): + ca = Argument( + "base", dict, [Argument("sub1", int), Argument("sub2", str)], repeat=True + ) + test_dict1 = { + "base": { + "item1": {"sub1": 10, "sub2": "hello"}, + "item2": {"sub1": 11, "sub2": "world"}, + } + } + ca.check(test_dict1) + ca.check_value(test_dict1["base"]) + err_dict1 = { + "base": { + "item1": {"sub1": 10, "sub2": "hello"}, + "item2": {"sub1": 11, "sub3": "world"}, + } + } + with self.assertRaises(ArgumentKeyError): + ca.check(err_dict1) + err_dict1["base"]["item2"]["sub2"] = "world too" + ca.check(err_dict1) # now should pass + with self.assertRaises(ArgumentKeyError): + ca.check(err_dict1, strict=True) # but should fail when strict + err_dict2 = { + "base": { + "item1": {"sub1": 10, "sub2": "hello"}, + "item2": {"sub1": 11, "sub2": None}, + } + } + with self.assertRaises(ArgumentTypeError): + ca.check(err_dict2) + def test_sub_variants(self): ca = Argument( "base", diff --git a/tests/test_docgen.py b/tests/test_docgen.py index edcfabe..0e448e3 100644 --- a/tests/test_docgen.py +++ b/tests/test_docgen.py @@ -41,7 +41,7 @@ def test_sub_fields(self): jsonstr = json.dumps(ca, cls=ArgumentEncoder) # print("\n\n"+docstr) - def test_sub_repeat(self): + def test_sub_repeat_list(self): ca = Argument( "base", list, @@ -70,6 +70,34 @@ def test_sub_repeat(self): jsonstr = json.dumps(ca, cls=ArgumentEncoder) # print("\n\n"+docstr) + def test_sub_repeat_dict(self): + ca = Argument( + "base", + dict, + [ + Argument("sub1", int, doc="sub doc." * 5), + Argument( + "sub2", + [None, str, dict], + [ + Argument("subsub1", int, doc="subsub doc." * 5, optional=True), + Argument( + "subsub2", + dict, + [Argument("subsubsub1", int, doc="subsubsub doc." * 5)], + doc="subsub doc." * 5, + repeat=True, + ), + ], + doc="sub doc." * 5, + ), + ], + doc="Base doc. " * 10, + repeat=True, + ) + docstr = ca.gen_doc() + jsonstr = json.dumps(ca, cls=ArgumentEncoder) + def test_sub_variants(self): ca = Argument( "base", diff --git a/tests/test_normalizer.py b/tests/test_normalizer.py index 2282d9f..8a03571 100644 --- a/tests/test_normalizer.py +++ b/tests/test_normalizer.py @@ -93,6 +93,13 @@ def test_complicated(self): repeat=True, alias=["sub2a"], ), + Argument( + "sub2_dict", + dict, + [Argument("ss1", int, optional=True, default=21, alias=["ss1a"])], + repeat=True, + alias=["sub2a_dict"], + ), ], [ Variant( @@ -145,11 +152,12 @@ def test_complicated(self): ) ], ) - beg1 = {"base": {"sub2": [{}, {}]}} + beg1 = {"base": {"sub2": [{}, {}], "sub2_dict": {"item1": {}, "item2": {}}}} ref1 = { "base": { "sub1": 1, "sub2": [{"ss1": 21}, {"ss1": 21}], + "sub2_dict": {"item1": {"ss1": 21}, "item2": {"ss1": 21}}, "vnt_flag": "type1", "shared": -1, "vnt1": 111, @@ -161,6 +169,7 @@ def test_complicated(self): "base": { "sub1a": 2, "sub2a": [{"ss1a": 22}, {"_comment1": None}], + "sub2a_dict": {"item1": {"ss1a": 22}, "item2": {"_comment1": None}}, "vnt_flag": "type3", "sharedb": -3, "vnt2a": 223, @@ -172,6 +181,7 @@ def test_complicated(self): "base": { "sub1": 2, "sub2": [{"ss1": 22}, {"ss1": 21}], + "sub2_dict": {"item1": {"ss1": 22}, "item2": {"ss1": 21}}, "vnt_flag": "type2", "shared": -3, "vnt2": 223,