From df04bd7d1d0566d26420681ec9a7fa400b42087d Mon Sep 17 00:00:00 2001 From: Lysithea <52808607+zxysbsbzxy@users.noreply.github.com> Date: Wed, 25 Oct 2023 11:03:01 +0800 Subject: [PATCH] reformat func for further merging with pt version (#2946) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- deepmd/utils/data_system.py | 101 ++++++++++++--------------- source/tests/test_deepmd_data_sys.py | 9 ++- 2 files changed, 52 insertions(+), 58 deletions(-) diff --git a/deepmd/utils/data_system.py b/deepmd/utils/data_system.py index 0071da755c..69a6cbe112 100644 --- a/deepmd/utils/data_system.py +++ b/deepmd/utils/data_system.py @@ -353,28 +353,15 @@ def set_sys_probs(self, sys_probs=None, auto_prob_style: str = "prob_sys_size"): elif auto_prob_style == "prob_sys_size": probs = self.prob_nbatches elif auto_prob_style[:14] == "prob_sys_size;": - probs = self._prob_sys_size_ext(auto_prob_style) + probs = prob_sys_size_ext( + auto_prob_style, self.get_nsystems(), self.nbatches + ) else: raise RuntimeError("Unknown auto prob style: " + auto_prob_style) else: - probs = self._process_sys_probs(sys_probs) + probs = process_sys_probs(sys_probs, self.nbatches) self.sys_probs = probs - def _get_sys_probs(self, sys_probs, auto_prob_style): # depreciated - if sys_probs is None: - if auto_prob_style == "prob_uniform": - prob_v = 1.0 / float(self.nsystems) - prob = [prob_v for ii in range(self.nsystems)] - elif auto_prob_style == "prob_sys_size": - prob = self.prob_nbatches - elif auto_prob_style[:14] == "prob_sys_size;": - prob = self._prob_sys_size_ext(auto_prob_style) - else: - raise RuntimeError("unknown style " + auto_prob_style) - else: - prob = self._process_sys_probs(sys_probs) - return prob - def get_batch(self, sys_idx: Optional[int] = None) -> dict: # batch generation style altered by Ziyao Li: # one should specify the "sys_prob" and "auto_prob_style" params @@ -623,42 +610,44 @@ def _check_type_map_consistency(self, type_map_list): ret = ii return ret - def _process_sys_probs(self, sys_probs): - sys_probs = np.array(sys_probs) - type_filter = sys_probs >= 0 - assigned_sum_prob = np.sum(type_filter * sys_probs) - # 1e-8 is to handle floating point error; See #1917 - assert ( - assigned_sum_prob <= 1.0 + 1e-8 - ), "the sum of assigned probability should be less than 1" - rest_sum_prob = 1.0 - assigned_sum_prob - if not np.isclose(rest_sum_prob, 0): - rest_nbatch = (1 - type_filter) * self.nbatches - rest_prob = rest_sum_prob * rest_nbatch / np.sum(rest_nbatch) - ret_prob = rest_prob + type_filter * sys_probs - else: - ret_prob = sys_probs - assert np.isclose(np.sum(ret_prob), 1), "sum of probs should be 1" - return ret_prob - - def _prob_sys_size_ext(self, keywords): - block_str = keywords.split(";")[1:] - block_stt = [] - block_end = [] - block_weights = [] - for ii in block_str: - stt = int(ii.split(":")[0]) - end = int(ii.split(":")[1]) - weight = float(ii.split(":")[2]) - assert weight >= 0, "the weight of a block should be no less than 0" - block_stt.append(stt) - block_end.append(end) - block_weights.append(weight) - nblocks = len(block_str) - block_probs = np.array(block_weights) / np.sum(block_weights) - sys_probs = np.zeros([self.get_nsystems()]) - for ii in range(nblocks): - nbatch_block = self.nbatches[block_stt[ii] : block_end[ii]] - tmp_prob = [float(i) for i in nbatch_block] / np.sum(nbatch_block) - sys_probs[block_stt[ii] : block_end[ii]] = tmp_prob * block_probs[ii] - return sys_probs + +def process_sys_probs(sys_probs, nbatch): + sys_probs = np.array(sys_probs) + type_filter = sys_probs >= 0 + assigned_sum_prob = np.sum(type_filter * sys_probs) + # 1e-8 is to handle floating point error; See #1917 + assert ( + assigned_sum_prob <= 1.0 + 1e-8 + ), "the sum of assigned probability should be less than 1" + rest_sum_prob = 1.0 - assigned_sum_prob + if not np.isclose(rest_sum_prob, 0): + rest_nbatch = (1 - type_filter) * nbatch + rest_prob = rest_sum_prob * rest_nbatch / np.sum(rest_nbatch) + ret_prob = rest_prob + type_filter * sys_probs + else: + ret_prob = sys_probs + assert np.isclose(np.sum(ret_prob), 1), "sum of probs should be 1" + return ret_prob + + +def prob_sys_size_ext(keywords, nsystems, nbatch): + block_str = keywords.split(";")[1:] + block_stt = [] + block_end = [] + block_weights = [] + for ii in block_str: + stt = int(ii.split(":")[0]) + end = int(ii.split(":")[1]) + weight = float(ii.split(":")[2]) + assert weight >= 0, "the weight of a block should be no less than 0" + block_stt.append(stt) + block_end.append(end) + block_weights.append(weight) + nblocks = len(block_str) + block_probs = np.array(block_weights) / np.sum(block_weights) + sys_probs = np.zeros([nsystems]) + for ii in range(nblocks): + nbatch_block = nbatch[block_stt[ii] : block_end[ii]] + tmp_prob = [float(i) for i in nbatch_block] / np.sum(nbatch_block) + sys_probs[block_stt[ii] : block_end[ii]] = tmp_prob * block_probs[ii] + return sys_probs diff --git a/source/tests/test_deepmd_data_sys.py b/source/tests/test_deepmd_data_sys.py index 54b75cbed2..abfa7d7e48 100644 --- a/source/tests/test_deepmd_data_sys.py +++ b/source/tests/test_deepmd_data_sys.py @@ -13,6 +13,7 @@ ) from deepmd.utils.data_system import ( DeepmdDataSystem, + prob_sys_size_ext, ) if GLOBAL_NP_FLOAT_PRECISION == np.float32: @@ -310,7 +311,9 @@ def test_prob_sys_size_1(self): batch_size = 1 test_size = 1 ds = DeepmdDataSystem(self.sys_name, batch_size, test_size, 2.0) - prob = ds._prob_sys_size_ext("prob_sys_size; 0:2:2; 2:4:8") + prob = prob_sys_size_ext( + "prob_sys_size; 0:2:2; 2:4:8", ds.get_nsystems(), ds.get_nbatches() + ) self.assertAlmostEqual(np.sum(prob), 1) self.assertAlmostEqual(np.sum(prob[0:2]), 0.2) self.assertAlmostEqual(np.sum(prob[2:4]), 0.8) @@ -332,7 +335,9 @@ def test_prob_sys_size_2(self): batch_size = 1 test_size = 1 ds = DeepmdDataSystem(self.sys_name, batch_size, test_size, 2.0) - prob = ds._prob_sys_size_ext("prob_sys_size; 1:2:0.4; 2:4:1.6") + prob = prob_sys_size_ext( + "prob_sys_size; 1:2:0.4; 2:4:1.6", ds.get_nsystems(), ds.get_nbatches() + ) self.assertAlmostEqual(np.sum(prob), 1) self.assertAlmostEqual(np.sum(prob[1:2]), 0.2) self.assertAlmostEqual(np.sum(prob[2:4]), 0.8)