Skip to content

Commit

Permalink
Merge pull request #239 from alexhernandezgarcia/speedup_neutral_char…
Browse files Browse the repository at this point in the history
…ge_check

Speedup neutral charge check
  • Loading branch information
carriepl authored Oct 6, 2023
2 parents bf4938e + 1f43fe1 commit f2f229b
Showing 1 changed file with 23 additions and 11 deletions.
34 changes: 23 additions & 11 deletions gflownet/envs/crystals/composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,16 +617,28 @@ def _can_produce_neutral_charge(self, state: Optional[List[int]] = None) -> bool
for i, num in enumerate(state)
if num > 0
]
sum_diff_elem = []

for n, c in nums_charges:
charge_sums = []
for c_i in itertools.product(c, repeat=n):
charge_sums.append(sum(c_i))
sum_diff_elem.append(np.unique(charge_sums))

poss_charge_sum = [
sum(combo) == 0 for combo in itertools.product(*sum_diff_elem)
]
# Process all atoms one by one, gradually accumulating a set of all possible
# charge totals so far
poss_charge_sum = set([0])
while len(nums_charges) > 0:
num, charges = nums_charges[0]

# Compute all possible charge totals that can be obtained by combining
# all the previous charge totals with all the possible charges for the
# current atom
new_poss_charge_sum = set()
for old_charge_sum in poss_charge_sum:
for element_charge in charges:
new_poss_charge_sum.add(old_charge_sum + element_charge)
poss_charge_sum = new_poss_charge_sum

# Remove the atom that was processed from nums_charges
if num == 1:
# Remove element from nums_charges
del nums_charges[0]
else:
# Remove one atom from this element
nums_charges[0] = (num - 1, charges)

return any(poss_charge_sum)
return 0 in poss_charge_sum

0 comments on commit f2f229b

Please sign in to comment.