diff --git a/gflownet/envs/crystals/spacegroup.py b/gflownet/envs/crystals/spacegroup.py index 94170f35c..8c43a5dc6 100644 --- a/gflownet/envs/crystals/spacegroup.py +++ b/gflownet/envs/crystals/spacegroup.py @@ -727,36 +727,24 @@ def _restrict_space_groups(self, sg_subset: Optional[Iterable] = None): del self.point_symmetries[ps] # Update self.crystal_lattice_systems based on point symmetries - cls_to_remove = [] point_symmetries = set(self.point_symmetries) for cls in self.crystal_lattice_systems: cls_point_symmetries = point_symmetries.intersection( set(self.crystal_lattice_systems[cls]["point_symmetries"]) ) - if len(cls_point_symmetries) == 0: - cls_to_remove.append(cls) - else: - self.crystal_lattice_systems[cls]["point_symmetries"] = list( - cls_point_symmetries - ) - for cls in cls_to_remove: - del self.crystal_lattice_systems[cls] + self.crystal_lattice_systems[cls]["point_symmetries"] = list( + cls_point_symmetries + ) - # Update self.point_symmetries based on point symmetries - ps_to_remove = [] + # Update self.point_symmetries based on crystal lattice systems crystal_lattice_systems = set(self.crystal_lattice_systems) for ps in self.point_symmetries: ps_crystal_lattice_systems = crystal_lattice_systems.intersection( set(self.point_symmetries[ps]["crystal_lattice_systems"]) ) - if len(ps_crystal_lattice_systems) == 0: - ps_to_remove.append(ps) - else: - self.point_symmetries[ps]["crystal_lattice_systems"] = list( - ps_crystal_lattice_systems - ) - for ps in ps_to_remove: - del self.point_symmetries[ps] + self.point_symmetries[ps]["crystal_lattice_systems"] = list( + ps_crystal_lattice_systems + ) def get_all_terminating_states( self, apply_stoichiometry_constraints: Optional[bool] = True