From e2e60dcba083549bae5cf70975e03c8171b27409 Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Sat, 23 Sep 2023 19:04:42 -0400 Subject: [PATCH 1/9] Space group now accepts an iterable of valid space groups to restrict the set of space groups; plus other changes and new tests motivated by this --- config/env/crystals/spacegroup.yaml | 3 + gflownet/envs/crystals/spacegroup.py | 156 ++++++++++++++++++++----- tests/gflownet/envs/test_spacegroup.py | 80 ++++++++++++- 3 files changed, 205 insertions(+), 34 deletions(-) diff --git a/config/env/crystals/spacegroup.yaml b/config/env/crystals/spacegroup.yaml index 0766ddee7..76da40e9f 100644 --- a/config/env/crystals/spacegroup.yaml +++ b/config/env/crystals/spacegroup.yaml @@ -4,6 +4,9 @@ defaults: _target_: gflownet.envs.crystals.spacegroup.SpaceGroup id: spacegroup + +# Subset of space groups +space_groups_subset: null # Stoichiometry n_atoms: null diff --git a/gflownet/envs/crystals/spacegroup.py b/gflownet/envs/crystals/spacegroup.py index 5e7b25810..8de313991 100644 --- a/gflownet/envs/crystals/spacegroup.py +++ b/gflownet/envs/crystals/spacegroup.py @@ -2,8 +2,10 @@ Classes to represent crystal environments """ import itertools +from copy import deepcopy +from enum import Enum from pathlib import Path -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, Iterable, List, Optional, Tuple, Union import numpy as np import torch @@ -43,6 +45,19 @@ def _get_space_groups(): return SPACE_GROUPS +class Prop(Enum): + """ + Enumeration of the 3 properties of the SpaceGroup Environment: + - Crystal lattice system + - Point symmetry + - Space group + """ + + CLS = 0 + PS = 1 + SG = 2 + + class SpaceGroup(GFlowNetEnv): """ SpaceGroup environment for ionic conductivity. @@ -73,7 +88,12 @@ class SpaceGroup(GFlowNetEnv): the order of selection of properties. """ - def __init__(self, n_atoms: Optional[List[int]] = None, **kwargs): + def __init__( + self, + space_groups_subset: Optional[Iterable] = None, + n_atoms: Optional[List[int]] = None, + **kwargs, + ): """ Args ---- @@ -86,14 +106,18 @@ def __init__(self, n_atoms: Optional[List[int]] = None, **kwargs): self.crystal_lattice_systems = _get_crystal_lattice_systems() self.point_symmetries = _get_point_symmetries() self.space_groups = _get_space_groups() - self.n_crystal_lattice_systems = len(self.crystal_lattice_systems) - self.n_point_symmetries = len(self.point_symmetries) - self.n_space_groups = len(self.space_groups) + self._restrict_space_groups(space_groups_subset) # Set dictionary of compatibility with number of atoms self.set_n_atoms_compatibility_dict(n_atoms) # Indices in the state representation: crystal-lattice system (cls), point # symmetry (ps) and space group (sg) self.cls_idx, self.ps_idx, self.sg_idx = 0, 1, 2 + # Dictionary of all properties + self.properties = { + Prop.CLS: self.crystal_lattice_systems, + Prop.PS: self.point_symmetries, + Prop.SG: self.space_groups, + } # Indices of state types (see self.get_state_type) self.state_type_indices = [0, 1, 2, 3] # End-of-sequence action @@ -117,20 +141,13 @@ def get_action_space(self): state (see self.state_type_indices). """ actions = [] - for prop, n_idx in zip( - [self.cls_idx, self.ps_idx, self.sg_idx], - [ - self.n_crystal_lattice_systems, - self.n_point_symmetries, - self.n_space_groups, - ], - ): + for prop, indices in self.properties.items(): for s_from_type in self.state_type_indices: - if prop == self.cls_idx and s_from_type in [1, 3]: + if prop == Prop.CLS and s_from_type in [1, 3]: continue - if prop == self.ps_idx and s_from_type in [2, 3]: + if prop == Prop.PS and s_from_type in [2, 3]: continue - actions_prop = [(prop, idx + 1, s_from_type) for idx in range(n_idx)] + actions_prop = [(prop.value, idx, s_from_type) for idx in indices] actions += actions_prop actions += [self.eos] return actions @@ -162,14 +179,14 @@ def get_mask_invalid_actions_forward( # composition-compatibility constraints if cls_idx == 0 and ps_idx == 0: crystal_lattice_systems = [ - (self.cls_idx, idx + 1, state_type) - for idx in range(self.n_crystal_lattice_systems) - if self._is_compatible(cls_idx=idx + 1) + (self.cls_idx, idx, state_type) + for idx in self.crystal_lattice_systems + if self._is_compatible(cls_idx=idx) ] point_symmetries = [ - (self.ps_idx, idx + 1, state_type) - for idx in range(self.n_point_symmetries) - if self._is_compatible(ps_idx=idx + 1) + (self.ps_idx, idx, state_type) + for idx in self.point_symmetries + if self._is_compatible(ps_idx=idx) ] # Constraints after having selected crystal-lattice system if cls_idx != 0: @@ -188,9 +205,9 @@ def get_mask_invalid_actions_forward( ] else: space_groups_cls = [ - (self.sg_idx, idx + 1, state_type) - for idx in range(self.n_space_groups) - if self.n_atoms_compatibility_dict[idx + 1] + (self.sg_idx, idx, state_type) + for idx in self.space_groups + if self.n_atoms_compatibility_dict[idx] ] # Constraints after having selected point symmetry if ps_idx != 0: @@ -209,9 +226,9 @@ def get_mask_invalid_actions_forward( ] else: space_groups_ps = [ - (self.sg_idx, idx + 1, state_type) - for idx in range(self.n_space_groups) - if self.n_atoms_compatibility_dict[idx + 1] + (self.sg_idx, idx, state_type) + for idx in self.space_groups + if self.n_atoms_compatibility_dict[idx] ] # Merge space_groups constraints and determine valid space group actions space_groups = list(set(space_groups_cls).intersection(set(space_groups_ps))) @@ -656,11 +673,92 @@ def build_n_atoms_compatibility_dict(n_atoms: List[int], space_groups: List[int] assert all([sg > 0 and sg <= 230 for sg in space_groups]) return {sg: Group(sg).check_compatible(n_atoms)[0] for sg in space_groups} + def _restrict_space_groups(self, sg_subset: Optional[Iterable] = None): + """ + Updates the dictionaries: + - self.space_groups + - self.crystal_lattice_systems + - self.point_symmetries + by eliminating the space groups that are not in the subset sg_subset passed as + an argument. + """ + if sg_subset is None: + return + sg_subset = set(sg_subset) + + # Update self.space_groups + self.space_groups = deepcopy(self.space_groups) + sg_to_remove = [sg for sg in self.space_groups if sg not in sg_subset] + for sg in sg_to_remove: + del self.space_groups[sg] + + # Update self.crystal_lattice_systems based on space groups + self.crystal_lattice_systems = deepcopy(self.crystal_lattice_systems) + cls_to_remove = [] + for cls in self.crystal_lattice_systems: + cls_space_groups = sg_subset.intersection( + set(self.crystal_lattice_systems[cls]["space_groups"]) + ) + if len(cls_space_groups) == 0: + cls_to_remove.append(cls) + else: + self.crystal_lattice_systems[cls]["space_groups"] = list( + cls_space_groups + ) + for cls in cls_to_remove: + del self.crystal_lattice_systems[cls] + + # Update self.point_symmetries based on space groups + self.point_symmetries = deepcopy(self.point_symmetries) + ps_to_remove = [] + for ps in self.point_symmetries: + ps_space_groups = sg_subset.intersection( + set(self.point_symmetries[ps]["space_groups"]) + ) + if len(ps_space_groups) == 0: + ps_to_remove.append(ps) + else: + self.point_symmetries[ps]["space_groups"] = list(ps_space_groups) + for ps in ps_to_remove: + del self.point_symmetries[ps] + + # Update self.crystal_lattice_systems based on point symmetries + cls_to_remove = [] + point_symmetries = set(self.point_symmetries) + for cls in self.crystal_lattice_systems: + cls_point_symmetries = point_symmetries.intersection( + set(self.crystal_lattice_systems[cls]["point_symmetries"]) + ) + if len(cls_point_symmetries) == 0: + cls_to_remove.append(cls) + else: + self.crystal_lattice_systems[cls]["point_symmetries"] = list( + cls_point_symmetries + ) + for cls in cls_to_remove: + del self.crystal_lattice_systems[cls] + + # Update self.point_symmetries based on point symmetries + ps_to_remove = [] + crystal_lattice_systems = set(self.crystal_lattice_systems) + for ps in self.point_symmetries: + ps_crystal_lattice_systems = crystal_lattice_systems.intersection( + set(self.point_symmetries[ps]["crystal_lattice_systems"]) + ) + if len(ps_crystal_lattice_systems) == 0: + ps_to_remove.append(ps) + else: + self.point_symmetries[ps]["crystal_lattice_systems"] = list( + ps_crystal_lattice_systems + ) + for ps in ps_to_remove: + del self.point_symmetries[ps] + def get_all_terminating_states( self, apply_stoichiometry_constraints: Optional[bool] = True ) -> List[List]: all_x = [] - for sg in range(1, self.n_space_groups + 1): + for sg in self.space_groups: if ( apply_stoichiometry_constraints and self.n_atoms_compatibility_dict[sg] is False diff --git a/tests/gflownet/envs/test_spacegroup.py b/tests/gflownet/envs/test_spacegroup.py index a7de2e237..50e82d61c 100644 --- a/tests/gflownet/envs/test_spacegroup.py +++ b/tests/gflownet/envs/test_spacegroup.py @@ -8,6 +8,7 @@ from gflownet.envs.crystals.spacegroup import SpaceGroup N_ATOMS = [3, 7, 9] +SG_SUBSET = [1, 17, 39, 123, 230] @pytest.fixture @@ -20,12 +21,32 @@ def env_with_composition(): return SpaceGroup(n_atoms=N_ATOMS) +@pytest.fixture +def env_with_restricted_spacegroups(): + return SpaceGroup(space_groups_subset=SG_SUBSET) + + def test__environment__initializes_properly(): env = SpaceGroup() assert env.source == [0] * 3 assert env.state == [0] * 3 +def test__environment__space_groups_subset__initializes_properly(): + env_sg_subset = SpaceGroup(space_groups_subset=[1, 2]) + assert env_sg_subset.source == [0] * 3 + assert env_sg_subset.state == [0] * 3 + assert len(env_sg_subset.space_groups) == 2 + env_sg_subset = SpaceGroup(space_groups_subset=range(1, 15 + 1)) + assert env_sg_subset.source == [0] * 3 + assert env_sg_subset.state == [0] * 3 + assert len(env_sg_subset.space_groups) == 15 + env_sg_subset = SpaceGroup(space_groups_subset=SG_SUBSET) + assert env_sg_subset.source == [0] * 3 + assert env_sg_subset.state == [0] * 3 + assert len(env_sg_subset.space_groups) == len(SG_SUBSET) + + def test__environment__action_space_has_eos(): env = SpaceGroup() assert env.eos in env.action_space @@ -88,6 +109,42 @@ def test__action_space__contains_expected(env, action, expected): assert (action in env.action_space) == expected +@pytest.mark.parametrize( + "action, expected", + [ + ( + (2, 1, 0), + True, + ), + ( + (2, 17, 0), + True, + ), + ( + (2, 39, 0), + True, + ), + ( + (2, 123, 0), + True, + ), + ( + (2, 230, 0), + True, + ), + ( + (2, 2, 0), + False, + ), + ], +) +def test__action_space__env_with_restricted_spacegroups__contains_expected( + env_with_restricted_spacegroups, action, expected +): + env = env_with_restricted_spacegroups + assert (action in env.action_space) == expected + + @pytest.mark.parametrize( "state, action, expected", [ @@ -231,7 +288,7 @@ def test__get_mask_invalid_actions_forward__incompatible_sg_are_invalid( env_with_composition.set_state(state=state, done=False) mask_f = env_with_composition.get_mask_invalid_actions_forward() state_type = env_with_composition.get_state_type(state) - for sg in range(1, env_with_composition.n_space_groups + 1): + for sg in env_with_composition.space_groups: sg_pyxtal = Group(sg) is_compatible = sg_pyxtal.check_compatible(N_ATOMS)[0] action = (env_with_composition.sg_idx, sg, state_type) @@ -240,10 +297,10 @@ def test__get_mask_invalid_actions_forward__incompatible_sg_are_invalid( def test__states_are_compatible_with_pymatgen(env): - for idx in range(env.n_space_groups): + for idx in env.space_groups: env = env.reset() - env.step((2, idx + 1, 0)) - sg_int = pmgg.sg_symbol_from_int_number(idx + 1) + env.step((2, idx, 0)) + sg_int = pmgg.sg_symbol_from_int_number(idx) sg = pmgg.SpaceGroup(sg_int) assert sg.int_number == env.state[env.sg_idx] assert sg.crystal_system == env.crystal_system @@ -280,5 +337,18 @@ def test__special_cases_composition_compatibility(n_atoms, cls_idx, ps_idx): assert valid is False -def test__all_env_common(env): +def test__all_common__env(env): + print("\n\nCommon tests for SpaceGroup without composition restrictions\n") return common.test__all_env_common(env) + + +def test__all_common__env_with_composition(env_with_composition): + print( + f"\n\nCommon tests for SpaceGroup with restrictions from composition {N_ATOMS}\n" + ) + return common.test__all_env_common(env_with_composition) + + +def test__all_common__env_with_restricted_spacegroups(env_with_restricted_spacegroups): + print(f"\n\nCommon tests for SpaceGroup with restricted space groups {SG_SUBSET}") + return common.test__all_env_common(env_with_restricted_spacegroups) From 75aa91e788ab5d58f308d26b85a9c266366418a1 Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Wed, 18 Oct 2023 09:45:47 -0400 Subject: [PATCH 2/9] Simplify code to update space groups based on subset --- gflownet/envs/crystals/spacegroup.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/gflownet/envs/crystals/spacegroup.py b/gflownet/envs/crystals/spacegroup.py index 8de313991..95e5f781c 100644 --- a/gflownet/envs/crystals/spacegroup.py +++ b/gflownet/envs/crystals/spacegroup.py @@ -687,10 +687,7 @@ def _restrict_space_groups(self, sg_subset: Optional[Iterable] = None): sg_subset = set(sg_subset) # Update self.space_groups - self.space_groups = deepcopy(self.space_groups) - sg_to_remove = [sg for sg in self.space_groups if sg not in sg_subset] - for sg in sg_to_remove: - del self.space_groups[sg] + self.space_groups = [sg for sg in self.space_groups if sg in sg_subset] # Update self.crystal_lattice_systems based on space groups self.crystal_lattice_systems = deepcopy(self.crystal_lattice_systems) From 4cd042a6981f6c69ac5eacb26ebaa26164e72037 Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Wed, 18 Oct 2023 09:55:29 -0400 Subject: [PATCH 3/9] Docstring for attribute space_groups_subset --- gflownet/envs/crystals/spacegroup.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/gflownet/envs/crystals/spacegroup.py b/gflownet/envs/crystals/spacegroup.py index 95e5f781c..cd4839cff 100644 --- a/gflownet/envs/crystals/spacegroup.py +++ b/gflownet/envs/crystals/spacegroup.py @@ -97,6 +97,11 @@ def __init__( """ Args ---- + space_groups_subset : iterable + A subset of space group (international) numbers to which to restrict the + state space. If None (default), the entire set of 230 space groups is + considered. + n_atoms : list of int (optional) A list with the number of atoms per element, used to compute constraints on the space group. 0's are removed from the list. If None, composition/space From 0a10e03aefc87468583b41dbd4ff48fe0a70f361 Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Wed, 18 Oct 2023 11:29:42 -0400 Subject: [PATCH 4/9] Fix: keep self.space_groups as a dict when updating it. --- gflownet/envs/crystals/spacegroup.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/gflownet/envs/crystals/spacegroup.py b/gflownet/envs/crystals/spacegroup.py index cd4839cff..94170f35c 100644 --- a/gflownet/envs/crystals/spacegroup.py +++ b/gflownet/envs/crystals/spacegroup.py @@ -692,7 +692,9 @@ def _restrict_space_groups(self, sg_subset: Optional[Iterable] = None): sg_subset = set(sg_subset) # Update self.space_groups - self.space_groups = [sg for sg in self.space_groups if sg in sg_subset] + self.space_groups = { + k: v for (k, v) in self.space_groups.items() if k in sg_subset + } # Update self.crystal_lattice_systems based on space groups self.crystal_lattice_systems = deepcopy(self.crystal_lattice_systems) From f886d58a48f9f7db1bd332ab1c94da8253f4a209 Mon Sep 17 00:00:00 2001 From: Pierre Luc Carrier Date: Wed, 18 Oct 2023 14:13:55 -0400 Subject: [PATCH 5/9] Add tests for space group restriction --- tests/gflownet/envs/test_spacegroup.py | 61 +++++++++++++++++++++----- 1 file changed, 49 insertions(+), 12 deletions(-) diff --git a/tests/gflownet/envs/test_spacegroup.py b/tests/gflownet/envs/test_spacegroup.py index 50e82d61c..049d081c5 100644 --- a/tests/gflownet/envs/test_spacegroup.py +++ b/tests/gflownet/envs/test_spacegroup.py @@ -33,18 +33,55 @@ def test__environment__initializes_properly(): def test__environment__space_groups_subset__initializes_properly(): - env_sg_subset = SpaceGroup(space_groups_subset=[1, 2]) - assert env_sg_subset.source == [0] * 3 - assert env_sg_subset.state == [0] * 3 - assert len(env_sg_subset.space_groups) == 2 - env_sg_subset = SpaceGroup(space_groups_subset=range(1, 15 + 1)) - assert env_sg_subset.source == [0] * 3 - assert env_sg_subset.state == [0] * 3 - assert len(env_sg_subset.space_groups) == 15 - env_sg_subset = SpaceGroup(space_groups_subset=SG_SUBSET) - assert env_sg_subset.source == [0] * 3 - assert env_sg_subset.state == [0] * 3 - assert len(env_sg_subset.space_groups) == len(SG_SUBSET) + def count_distinct(my_dict, sub_key): + all_elements = [] + for sub_dict in my_dict.values(): + all_elements.extend(sub_dict[sub_key]) + + distinct_elements = set(all_elements) + return len(distinct_elements) + + env = SpaceGroup(space_groups_subset=[1, 2]) + nb_spacegroups = 2 + nb_cls = 1 + nb_ps = 2 + assert env.source == [0] * 3 + assert env.state == [0] * 3 + assert len(env.space_groups) == nb_spacegroups + assert len(env.crystal_lattice_systems) == nb_cls + assert len(env.point_symmetries) == nb_ps + assert count_distinct(env.crystal_lattice_systems, "space_groups") == nb_spacegroups + assert count_distinct(env.crystal_lattice_systems, "point_symmetries") == nb_ps + assert count_distinct(env.point_symmetries, "space_groups") == nb_spacegroups + assert count_distinct(env.point_symmetries, "crystal_lattice_systems") == nb_cls + + env = SpaceGroup(space_groups_subset=range(1, 15 + 1)) + nb_spacegroups = 15 + nb_cls = 2 + nb_ps = 3 + assert env.source == [0] * 3 + assert env.state == [0] * 3 + assert len(env.space_groups) == nb_spacegroups + assert len(env.crystal_lattice_systems) == nb_cls + assert len(env.point_symmetries) == nb_ps + assert count_distinct(env.crystal_lattice_systems, "space_groups") == nb_spacegroups + assert count_distinct(env.crystal_lattice_systems, "point_symmetries") == nb_ps + assert count_distinct(env.point_symmetries, "space_groups") == nb_spacegroups + assert count_distinct(env.point_symmetries, "crystal_lattice_systems") == nb_cls + + env = SpaceGroup(space_groups_subset=SG_SUBSET) + nb_spacegroups = len(SG_SUBSET) + nb_cls = 4 + nb_ps = 4 + assert env.source == [0] * 3 + assert env.state == [0] * 3 + assert len(env.space_groups) == nb_spacegroups + assert len(env.crystal_lattice_systems) == nb_cls + assert len(env.point_symmetries) == nb_ps + assert count_distinct(env.crystal_lattice_systems, "space_groups") == nb_spacegroups + assert count_distinct(env.crystal_lattice_systems, "point_symmetries") == nb_ps + assert count_distinct(env.point_symmetries, "space_groups") == nb_spacegroups + assert count_distinct(env.point_symmetries, "crystal_lattice_systems") == nb_cls def test__environment__action_space_has_eos(): From c195a60a57b2b7ca2b514c1c4b462d555fb2b4d4 Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Fri, 20 Oct 2023 16:31:19 -0400 Subject: [PATCH 6/9] Remove useless pieces of code in update of space groups. --- gflownet/envs/crystals/spacegroup.py | 32 ---------------------------- 1 file changed, 32 deletions(-) diff --git a/gflownet/envs/crystals/spacegroup.py b/gflownet/envs/crystals/spacegroup.py index 94170f35c..7ddbb0a26 100644 --- a/gflownet/envs/crystals/spacegroup.py +++ b/gflownet/envs/crystals/spacegroup.py @@ -726,38 +726,6 @@ def _restrict_space_groups(self, sg_subset: Optional[Iterable] = None): for ps in ps_to_remove: del self.point_symmetries[ps] - # Update self.crystal_lattice_systems based on point symmetries - cls_to_remove = [] - point_symmetries = set(self.point_symmetries) - for cls in self.crystal_lattice_systems: - cls_point_symmetries = point_symmetries.intersection( - set(self.crystal_lattice_systems[cls]["point_symmetries"]) - ) - if len(cls_point_symmetries) == 0: - cls_to_remove.append(cls) - else: - self.crystal_lattice_systems[cls]["point_symmetries"] = list( - cls_point_symmetries - ) - for cls in cls_to_remove: - del self.crystal_lattice_systems[cls] - - # Update self.point_symmetries based on point symmetries - ps_to_remove = [] - crystal_lattice_systems = set(self.crystal_lattice_systems) - for ps in self.point_symmetries: - ps_crystal_lattice_systems = crystal_lattice_systems.intersection( - set(self.point_symmetries[ps]["crystal_lattice_systems"]) - ) - if len(ps_crystal_lattice_systems) == 0: - ps_to_remove.append(ps) - else: - self.point_symmetries[ps]["crystal_lattice_systems"] = list( - ps_crystal_lattice_systems - ) - for ps in ps_to_remove: - del self.point_symmetries[ps] - def get_all_terminating_states( self, apply_stoichiometry_constraints: Optional[bool] = True ) -> List[List]: From dd3635064e0161bb18cc53620326f96d8279db9e Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Fri, 20 Oct 2023 16:42:37 -0400 Subject: [PATCH 7/9] Revert "Remove useless pieces of code in update of space groups." This reverts commit c195a60a57b2b7ca2b514c1c4b462d555fb2b4d4. --- gflownet/envs/crystals/spacegroup.py | 32 ++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/gflownet/envs/crystals/spacegroup.py b/gflownet/envs/crystals/spacegroup.py index 7ddbb0a26..94170f35c 100644 --- a/gflownet/envs/crystals/spacegroup.py +++ b/gflownet/envs/crystals/spacegroup.py @@ -726,6 +726,38 @@ def _restrict_space_groups(self, sg_subset: Optional[Iterable] = None): for ps in ps_to_remove: del self.point_symmetries[ps] + # Update self.crystal_lattice_systems based on point symmetries + cls_to_remove = [] + point_symmetries = set(self.point_symmetries) + for cls in self.crystal_lattice_systems: + cls_point_symmetries = point_symmetries.intersection( + set(self.crystal_lattice_systems[cls]["point_symmetries"]) + ) + if len(cls_point_symmetries) == 0: + cls_to_remove.append(cls) + else: + self.crystal_lattice_systems[cls]["point_symmetries"] = list( + cls_point_symmetries + ) + for cls in cls_to_remove: + del self.crystal_lattice_systems[cls] + + # Update self.point_symmetries based on point symmetries + ps_to_remove = [] + crystal_lattice_systems = set(self.crystal_lattice_systems) + for ps in self.point_symmetries: + ps_crystal_lattice_systems = crystal_lattice_systems.intersection( + set(self.point_symmetries[ps]["crystal_lattice_systems"]) + ) + if len(ps_crystal_lattice_systems) == 0: + ps_to_remove.append(ps) + else: + self.point_symmetries[ps]["crystal_lattice_systems"] = list( + ps_crystal_lattice_systems + ) + for ps in ps_to_remove: + del self.point_symmetries[ps] + def get_all_terminating_states( self, apply_stoichiometry_constraints: Optional[bool] = True ) -> List[List]: From d7b27f678a986611ae3e9ac49955355f0fc839fd Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Fri, 20 Oct 2023 16:59:57 -0400 Subject: [PATCH 8/9] Remove useless pieces of code in update of space groups, but keep the things that should stay... --- gflownet/envs/crystals/spacegroup.py | 26 +++++++------------------- 1 file changed, 7 insertions(+), 19 deletions(-) diff --git a/gflownet/envs/crystals/spacegroup.py b/gflownet/envs/crystals/spacegroup.py index 94170f35c..8c43a5dc6 100644 --- a/gflownet/envs/crystals/spacegroup.py +++ b/gflownet/envs/crystals/spacegroup.py @@ -727,36 +727,24 @@ def _restrict_space_groups(self, sg_subset: Optional[Iterable] = None): del self.point_symmetries[ps] # Update self.crystal_lattice_systems based on point symmetries - cls_to_remove = [] point_symmetries = set(self.point_symmetries) for cls in self.crystal_lattice_systems: cls_point_symmetries = point_symmetries.intersection( set(self.crystal_lattice_systems[cls]["point_symmetries"]) ) - if len(cls_point_symmetries) == 0: - cls_to_remove.append(cls) - else: - self.crystal_lattice_systems[cls]["point_symmetries"] = list( - cls_point_symmetries - ) - for cls in cls_to_remove: - del self.crystal_lattice_systems[cls] + self.crystal_lattice_systems[cls]["point_symmetries"] = list( + cls_point_symmetries + ) - # Update self.point_symmetries based on point symmetries - ps_to_remove = [] + # Update self.point_symmetries based on crystal lattice systems crystal_lattice_systems = set(self.crystal_lattice_systems) for ps in self.point_symmetries: ps_crystal_lattice_systems = crystal_lattice_systems.intersection( set(self.point_symmetries[ps]["crystal_lattice_systems"]) ) - if len(ps_crystal_lattice_systems) == 0: - ps_to_remove.append(ps) - else: - self.point_symmetries[ps]["crystal_lattice_systems"] = list( - ps_crystal_lattice_systems - ) - for ps in ps_to_remove: - del self.point_symmetries[ps] + self.point_symmetries[ps]["crystal_lattice_systems"] = list( + ps_crystal_lattice_systems + ) def get_all_terminating_states( self, apply_stoichiometry_constraints: Optional[bool] = True From 8f68bdf0cc2ba871c6471e50efd33331d63c3385 Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Fri, 20 Oct 2023 17:21:49 -0400 Subject: [PATCH 9/9] Minor update of commments --- gflownet/envs/crystals/spacegroup.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/gflownet/envs/crystals/spacegroup.py b/gflownet/envs/crystals/spacegroup.py index 8c43a5dc6..66d7ed793 100644 --- a/gflownet/envs/crystals/spacegroup.py +++ b/gflownet/envs/crystals/spacegroup.py @@ -726,7 +726,7 @@ def _restrict_space_groups(self, sg_subset: Optional[Iterable] = None): for ps in ps_to_remove: del self.point_symmetries[ps] - # Update self.crystal_lattice_systems based on point symmetries + # Update point symmetries of remaining crystal lattice systems point_symmetries = set(self.point_symmetries) for cls in self.crystal_lattice_systems: cls_point_symmetries = point_symmetries.intersection( @@ -736,7 +736,7 @@ def _restrict_space_groups(self, sg_subset: Optional[Iterable] = None): cls_point_symmetries ) - # Update self.point_symmetries based on crystal lattice systems + # Update crystal lattice systems of remaining point symmetries crystal_lattice_systems = set(self.crystal_lattice_systems) for ps in self.point_symmetries: ps_crystal_lattice_systems = crystal_lattice_systems.intersection(