Skip to content

Commit

Permalink
reformat func for further merging with pt version (deepmodeling#2946)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
CaRoLZhangxy and pre-commit-ci[bot] authored Oct 25, 2023
1 parent 6d973ef commit df04bd7
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 58 deletions.
101 changes: 45 additions & 56 deletions deepmd/utils/data_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,28 +353,15 @@ def set_sys_probs(self, sys_probs=None, auto_prob_style: str = "prob_sys_size"):
elif auto_prob_style == "prob_sys_size":
probs = self.prob_nbatches
elif auto_prob_style[:14] == "prob_sys_size;":
probs = self._prob_sys_size_ext(auto_prob_style)
probs = prob_sys_size_ext(
auto_prob_style, self.get_nsystems(), self.nbatches
)
else:
raise RuntimeError("Unknown auto prob style: " + auto_prob_style)
else:
probs = self._process_sys_probs(sys_probs)
probs = process_sys_probs(sys_probs, self.nbatches)
self.sys_probs = probs

def _get_sys_probs(self, sys_probs, auto_prob_style): # depreciated
if sys_probs is None:
if auto_prob_style == "prob_uniform":
prob_v = 1.0 / float(self.nsystems)
prob = [prob_v for ii in range(self.nsystems)]
elif auto_prob_style == "prob_sys_size":
prob = self.prob_nbatches
elif auto_prob_style[:14] == "prob_sys_size;":
prob = self._prob_sys_size_ext(auto_prob_style)
else:
raise RuntimeError("unknown style " + auto_prob_style)
else:
prob = self._process_sys_probs(sys_probs)
return prob

def get_batch(self, sys_idx: Optional[int] = None) -> dict:
# batch generation style altered by Ziyao Li:
# one should specify the "sys_prob" and "auto_prob_style" params
Expand Down Expand Up @@ -623,42 +610,44 @@ def _check_type_map_consistency(self, type_map_list):
ret = ii
return ret

def _process_sys_probs(self, sys_probs):
sys_probs = np.array(sys_probs)
type_filter = sys_probs >= 0
assigned_sum_prob = np.sum(type_filter * sys_probs)
# 1e-8 is to handle floating point error; See #1917
assert (
assigned_sum_prob <= 1.0 + 1e-8
), "the sum of assigned probability should be less than 1"
rest_sum_prob = 1.0 - assigned_sum_prob
if not np.isclose(rest_sum_prob, 0):
rest_nbatch = (1 - type_filter) * self.nbatches
rest_prob = rest_sum_prob * rest_nbatch / np.sum(rest_nbatch)
ret_prob = rest_prob + type_filter * sys_probs
else:
ret_prob = sys_probs
assert np.isclose(np.sum(ret_prob), 1), "sum of probs should be 1"
return ret_prob

def _prob_sys_size_ext(self, keywords):
block_str = keywords.split(";")[1:]
block_stt = []
block_end = []
block_weights = []
for ii in block_str:
stt = int(ii.split(":")[0])
end = int(ii.split(":")[1])
weight = float(ii.split(":")[2])
assert weight >= 0, "the weight of a block should be no less than 0"
block_stt.append(stt)
block_end.append(end)
block_weights.append(weight)
nblocks = len(block_str)
block_probs = np.array(block_weights) / np.sum(block_weights)
sys_probs = np.zeros([self.get_nsystems()])
for ii in range(nblocks):
nbatch_block = self.nbatches[block_stt[ii] : block_end[ii]]
tmp_prob = [float(i) for i in nbatch_block] / np.sum(nbatch_block)
sys_probs[block_stt[ii] : block_end[ii]] = tmp_prob * block_probs[ii]
return sys_probs

def process_sys_probs(sys_probs, nbatch):
sys_probs = np.array(sys_probs)
type_filter = sys_probs >= 0
assigned_sum_prob = np.sum(type_filter * sys_probs)
# 1e-8 is to handle floating point error; See #1917
assert (
assigned_sum_prob <= 1.0 + 1e-8
), "the sum of assigned probability should be less than 1"
rest_sum_prob = 1.0 - assigned_sum_prob
if not np.isclose(rest_sum_prob, 0):
rest_nbatch = (1 - type_filter) * nbatch
rest_prob = rest_sum_prob * rest_nbatch / np.sum(rest_nbatch)
ret_prob = rest_prob + type_filter * sys_probs
else:
ret_prob = sys_probs
assert np.isclose(np.sum(ret_prob), 1), "sum of probs should be 1"
return ret_prob


def prob_sys_size_ext(keywords, nsystems, nbatch):
block_str = keywords.split(";")[1:]
block_stt = []
block_end = []
block_weights = []
for ii in block_str:
stt = int(ii.split(":")[0])
end = int(ii.split(":")[1])
weight = float(ii.split(":")[2])
assert weight >= 0, "the weight of a block should be no less than 0"
block_stt.append(stt)
block_end.append(end)
block_weights.append(weight)
nblocks = len(block_str)
block_probs = np.array(block_weights) / np.sum(block_weights)
sys_probs = np.zeros([nsystems])
for ii in range(nblocks):
nbatch_block = nbatch[block_stt[ii] : block_end[ii]]
tmp_prob = [float(i) for i in nbatch_block] / np.sum(nbatch_block)
sys_probs[block_stt[ii] : block_end[ii]] = tmp_prob * block_probs[ii]
return sys_probs
9 changes: 7 additions & 2 deletions source/tests/test_deepmd_data_sys.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
)
from deepmd.utils.data_system import (
DeepmdDataSystem,
prob_sys_size_ext,
)

if GLOBAL_NP_FLOAT_PRECISION == np.float32:
Expand Down Expand Up @@ -310,7 +311,9 @@ def test_prob_sys_size_1(self):
batch_size = 1
test_size = 1
ds = DeepmdDataSystem(self.sys_name, batch_size, test_size, 2.0)
prob = ds._prob_sys_size_ext("prob_sys_size; 0:2:2; 2:4:8")
prob = prob_sys_size_ext(
"prob_sys_size; 0:2:2; 2:4:8", ds.get_nsystems(), ds.get_nbatches()
)
self.assertAlmostEqual(np.sum(prob), 1)
self.assertAlmostEqual(np.sum(prob[0:2]), 0.2)
self.assertAlmostEqual(np.sum(prob[2:4]), 0.8)
Expand All @@ -332,7 +335,9 @@ def test_prob_sys_size_2(self):
batch_size = 1
test_size = 1
ds = DeepmdDataSystem(self.sys_name, batch_size, test_size, 2.0)
prob = ds._prob_sys_size_ext("prob_sys_size; 1:2:0.4; 2:4:1.6")
prob = prob_sys_size_ext(
"prob_sys_size; 1:2:0.4; 2:4:1.6", ds.get_nsystems(), ds.get_nbatches()
)
self.assertAlmostEqual(np.sum(prob), 1)
self.assertAlmostEqual(np.sum(prob[1:2]), 0.2)
self.assertAlmostEqual(np.sum(prob[2:4]), 0.8)
Expand Down

0 comments on commit df04bd7

Please sign in to comment.