Skip to content

Commit

Permalink
Merge pull request #1304 from pyiron/fix_read_dict_from_hdf
Browse files Browse the repository at this point in the history
Fix read_dict_from_hdf() when using overlapping paths
  • Loading branch information
jan-janssen authored Feb 2, 2024
2 parents a179cc0 + 61774f5 commit 5162288
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 16 deletions.
39 changes: 27 additions & 12 deletions pyiron_base/storage/helper_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,24 @@ def get_groups_hdf(hdf, h5_path):
except KeyError:
return []

def merge_dict(main_dict, add_dict):
"""
Merge two dictionaries recursively
Args:
main_dict (dict): The primary dictionary, the secondary dictionary is merged into
add_dict (dict): The secondary dictionary which is merged in the primary dictionary
Returns:
dict: The merged dictionary with all keys
"""
for k, v in add_dict.items():
if k in main_dict.keys() and isinstance(v, dict):
main_dict[k] = merge_dict(main_dict=main_dict[k], add_dict=v)
else:
main_dict[k] = v
return main_dict

if recursive and len(group_paths) > 0:
raise ValueError(
"Loading subgroups can either be defined by the group paths ",
Expand All @@ -126,20 +144,17 @@ def get_groups_hdf(hdf, h5_path):
for g in get_groups_hdf(hdf=store, h5_path=h5_path)
]
for group_path in group_paths:
read_dict = resolve_nested_dict(
group_path=group_path,
data_dict=get_dict_from_nodes(
store=store,
h5_path=get_h5_path(h5_path=h5_path, name=group_path),
slash=slash,
output_dict = merge_dict(
main_dict=output_dict,
add_dict=resolve_nested_dict(
group_path=group_path,
data_dict=get_dict_from_nodes(
store=store,
h5_path=get_h5_path(h5_path=h5_path, name=group_path),
slash=slash,
),
),
)
for k, v in read_dict.items():
if k in output_dict.keys() and isinstance(v, dict):
for sk, vs in v.items():
output_dict[k][sk] = vs
else:
output_dict[k] = v
return output_dict


Expand Down
17 changes: 13 additions & 4 deletions tests/storage/test_helper_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class TestWriteHdfIO(TestCase):
def setUp(self):
self.file_name = "test_write_hdf5.h5"
self.h5_path = "data_hierarchical"
self.data_hierarchical = {"a": [1, 2], "b": 3, "c": {"d": 4, "e": 5}}
self.data_hierarchical = {"a": [1, 2], "b": 3, "c": {"d": 4, "e": {"f": 5}}}
_write_hdf(hdf_filehandle=self.file_name, data=self.data_hierarchical, h5_path=self.h5_path)

def tearDown(self):
Expand All @@ -37,15 +37,23 @@ def test_read_hierarchical(self):
def test_read_dict_hierarchical(self):
self.assertEqual({'key_b': 3}, read_dict_from_hdf(file_name=self.file_name, h5_path=self.h5_path))
self.assertEqual(
{'key_a': {'idx_0': 1, 'idx_1': 2}, 'key_b': 3, 'key_c': {'key_d': 4, 'key_e': 5}},
{'key_a': {'idx_0': 1, 'idx_1': 2}, 'key_b': 3, 'key_c': {'key_d': 4}},
read_dict_from_hdf(
file_name=self.file_name,
h5_path=self.h5_path,
group_paths=["key_a", "key_c"],
)
)
self.assertEqual(
{'key_a': {'idx_0': 1, 'idx_1': 2}, 'key_b': 3, 'key_c': {'key_d': 4, 'key_e': 5}},
{'key_a': {'idx_0': 1, 'idx_1': 2}, 'key_b': 3, 'key_c': {'key_d': 4, 'key_e': {'key_f': 5}}},
read_dict_from_hdf(
file_name=self.file_name,
h5_path=self.h5_path,
group_paths=["key_a", "key_c", "key_c/key_e"],
)
)
self.assertEqual(
{'key_a': {'idx_0': 1, 'idx_1': 2}, 'key_b': 3, 'key_c': {'key_d': 4, 'key_e': {'key_f': 5}}},
read_dict_from_hdf(
file_name=self.file_name,
h5_path=self.h5_path,
Expand All @@ -66,7 +74,8 @@ def test_hdf5_structure(self):
{'data_hierarchical/key_b': {'TITLE': 'int'}},
{'data_hierarchical/key_c': {'TITLE': 'dict'}},
{'data_hierarchical/key_c/key_d': {'TITLE': 'int'}},
{'data_hierarchical/key_c/key_e': {'TITLE': 'int'}}
{'data_hierarchical/key_c/key_e': {'TITLE': 'dict'}},
{'data_hierarchical/key_c/key_e/key_f': {'TITLE': 'int'}}
])

def test_list_groups(self):
Expand Down

0 comments on commit 5162288

Please sign in to comment.