Skip to content

Commit

Permalink
feat: allow repeat with dtype=dict (#67)
Browse files Browse the repository at this point in the history
This will be helpful to normalize `model_dict`, `loss_dict`, etc, in
DeePMD-kit.

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

- **New Features**
- Enhanced handling of data types based on the `repeat` flag, supporting
both lists of dicts and dicts of dicts.
- Improved documentation generation reflecting updated data type
handling.

- **Bug Fixes**
- Corrected traversal logic for values to ensure proper handling of
dictionaries and lists based on the `repeat` flag.

- **Tests**
- Updated and added new test cases to cover changes in data type
handling and traversal logic.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
  • Loading branch information
njzjz and coderabbitai[bot] authored Jun 12, 2024
1 parent a02640c commit d934085
Show file tree
Hide file tree
Showing 5 changed files with 127 additions and 12 deletions.
45 changes: 39 additions & 6 deletions dargs/dargs.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ class Argument:
If given, `dtype` is assumed to be dict, and its items are determined
by the `Variant`s in the given list and the value of their flag keys.
repeat: bool, optional
If true, `dtype` is assume to be list of dict and each dict consists
If true, `dtype` is assume to be list of dict or dict of dict, and each dict consists
of sub fields and sub variants described above. Defaults to false.
optional: bool, optional
If true, consider the current argument to be optional in checking.
Expand Down Expand Up @@ -235,7 +235,17 @@ def _reorg_dtype(
}
# check conner cases
if self.sub_fields or self.sub_variants:
dtype.add(list if self.repeat else dict)
if not self.repeat:
dtype.add(dict)
else:
# convert dtypes to unsubscripted types
unsubscripted_dtype = {
get_origin(dt) if get_origin(dt) is not None else dt for dt in dtype
}
if dict not in unsubscripted_dtype:
# only add list (compatible with old behaviors) if no dict in dtype
dtype.add(list)

if (
self.optional
and self.default is not _Flags.NONE
Expand Down Expand Up @@ -347,11 +357,11 @@ def traverse_value(
# in the condition where there is no leading key
if path is None:
path = []
if isinstance(value, dict):
if not self.repeat and isinstance(value, dict):
self._traverse_sub(
value, key_hook, value_hook, sub_hook, variant_hook, path
)
if isinstance(value, list) and self.repeat:
elif self.repeat and isinstance(value, list):
for idx, item in enumerate(value):
self._traverse_sub(
item,
Expand All @@ -361,6 +371,16 @@ def traverse_value(
variant_hook,
[*path, str(idx)],
)
elif self.repeat and isinstance(value, dict):
for kk, item in value.items():
self._traverse_sub(
item,
key_hook,
value_hook,
sub_hook,
variant_hook,
[*path, kk],
)

def _traverse_sub(
self,
Expand Down Expand Up @@ -653,9 +673,22 @@ def gen_doc_body(self, path: list[str] | None = None, **kwargs) -> str:
body_list.append(self.doc + "\n")
if not self.fold_subdoc:
if self.repeat:
unsubscripted_dtype = {
get_origin(dt) if get_origin(dt) is not None else dt
for dt in self.dtype
}
allowed_types = []
allowed_element = []
if list in unsubscripted_dtype or dict in unsubscripted_dtype:
if list in unsubscripted_dtype:
allowed_types.append("list")
allowed_element.append("element")
if dict in unsubscripted_dtype:
allowed_types.append("dict")
allowed_element.append("key-value pair")
body_list.append(
"This argument takes a list with "
"each element containing the following: \n"
f"This argument takes a {' or '.join(allowed_types)} with "
f"each {' or '.join(allowed_element)} containing the following: \n"
)
if self.sub_fields:
# body_list.append("") # genetate a blank line
Expand Down
15 changes: 13 additions & 2 deletions dargs/sphinx.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def _test_argument() -> Argument:
doc=doc_test,
sub_fields=[
Argument(
"test_repeat",
"test_repeat_list",
dtype=list,
repeat=True,
doc=doc_test,
Expand All @@ -199,7 +199,18 @@ def _test_argument() -> Argument:
"test_repeat_item", dtype=bool, doc=doc_test
),
],
)
),
Argument(
"test_repeat_dict",
dtype=dict,
repeat=True,
doc=doc_test,
sub_fields=[
Argument(
"test_repeat_item", dtype=bool, doc=doc_test
),
],
),
],
),
],
Expand Down
37 changes: 35 additions & 2 deletions tests/test_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,9 @@ def test_sub_fields(self):
with self.assertRaises(ValueError):
Argument("base", dict, [Argument("sub1", int), Argument("sub1", int)])

def test_sub_repeat(self):
def test_sub_repeat_list(self):
ca = Argument(
"base", dict, [Argument("sub1", int), Argument("sub2", str)], repeat=True
"base", list, [Argument("sub1", int), Argument("sub2", str)], repeat=True
)
test_dict1 = {
"base": [{"sub1": 10, "sub2": "hello"}, {"sub1": 11, "sub2": "world"}]
Expand All @@ -124,6 +124,39 @@ def test_sub_repeat(self):
with self.assertRaises(ArgumentTypeError):
ca.check(err_dict2)

def test_sub_repeat_dict(self):
ca = Argument(
"base", dict, [Argument("sub1", int), Argument("sub2", str)], repeat=True
)
test_dict1 = {
"base": {
"item1": {"sub1": 10, "sub2": "hello"},
"item2": {"sub1": 11, "sub2": "world"},
}
}
ca.check(test_dict1)
ca.check_value(test_dict1["base"])
err_dict1 = {
"base": {
"item1": {"sub1": 10, "sub2": "hello"},
"item2": {"sub1": 11, "sub3": "world"},
}
}
with self.assertRaises(ArgumentKeyError):
ca.check(err_dict1)
err_dict1["base"]["item2"]["sub2"] = "world too"
ca.check(err_dict1) # now should pass
with self.assertRaises(ArgumentKeyError):
ca.check(err_dict1, strict=True) # but should fail when strict
err_dict2 = {
"base": {
"item1": {"sub1": 10, "sub2": "hello"},
"item2": {"sub1": 11, "sub2": None},
}
}
with self.assertRaises(ArgumentTypeError):
ca.check(err_dict2)

def test_sub_variants(self):
ca = Argument(
"base",
Expand Down
30 changes: 29 additions & 1 deletion tests/test_docgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def test_sub_fields(self):
jsonstr = json.dumps(ca, cls=ArgumentEncoder)
# print("\n\n"+docstr)

def test_sub_repeat(self):
def test_sub_repeat_list(self):
ca = Argument(
"base",
list,
Expand Down Expand Up @@ -70,6 +70,34 @@ def test_sub_repeat(self):
jsonstr = json.dumps(ca, cls=ArgumentEncoder)
# print("\n\n"+docstr)

def test_sub_repeat_dict(self):
ca = Argument(
"base",
dict,
[
Argument("sub1", int, doc="sub doc." * 5),
Argument(
"sub2",
[None, str, dict],
[
Argument("subsub1", int, doc="subsub doc." * 5, optional=True),
Argument(
"subsub2",
dict,
[Argument("subsubsub1", int, doc="subsubsub doc." * 5)],
doc="subsub doc." * 5,
repeat=True,
),
],
doc="sub doc." * 5,
),
],
doc="Base doc. " * 10,
repeat=True,
)
docstr = ca.gen_doc()
jsonstr = json.dumps(ca, cls=ArgumentEncoder)

def test_sub_variants(self):
ca = Argument(
"base",
Expand Down
12 changes: 11 additions & 1 deletion tests/test_normalizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,13 @@ def test_complicated(self):
repeat=True,
alias=["sub2a"],
),
Argument(
"sub2_dict",
dict,
[Argument("ss1", int, optional=True, default=21, alias=["ss1a"])],
repeat=True,
alias=["sub2a_dict"],
),
],
[
Variant(
Expand Down Expand Up @@ -145,11 +152,12 @@ def test_complicated(self):
)
],
)
beg1 = {"base": {"sub2": [{}, {}]}}
beg1 = {"base": {"sub2": [{}, {}], "sub2_dict": {"item1": {}, "item2": {}}}}
ref1 = {
"base": {
"sub1": 1,
"sub2": [{"ss1": 21}, {"ss1": 21}],
"sub2_dict": {"item1": {"ss1": 21}, "item2": {"ss1": 21}},
"vnt_flag": "type1",
"shared": -1,
"vnt1": 111,
Expand All @@ -161,6 +169,7 @@ def test_complicated(self):
"base": {
"sub1a": 2,
"sub2a": [{"ss1a": 22}, {"_comment1": None}],
"sub2a_dict": {"item1": {"ss1a": 22}, "item2": {"_comment1": None}},
"vnt_flag": "type3",
"sharedb": -3,
"vnt2a": 223,
Expand All @@ -172,6 +181,7 @@ def test_complicated(self):
"base": {
"sub1": 2,
"sub2": [{"ss1": 22}, {"ss1": 21}],
"sub2_dict": {"item1": {"ss1": 22}, "item2": {"ss1": 21}},
"vnt_flag": "type2",
"shared": -3,
"vnt2": 223,
Expand Down

0 comments on commit d934085

Please sign in to comment.