diff --git a/config/env/crystals/spacegroup.yaml b/config/env/crystals/spacegroup.yaml index 0766ddee7..76da40e9f 100644 --- a/config/env/crystals/spacegroup.yaml +++ b/config/env/crystals/spacegroup.yaml @@ -4,6 +4,9 @@ defaults: _target_: gflownet.envs.crystals.spacegroup.SpaceGroup id: spacegroup + +# Subset of space groups +space_groups_subset: null # Stoichiometry n_atoms: null diff --git a/gflownet/envs/crystals/spacegroup.py b/gflownet/envs/crystals/spacegroup.py index 4dbb1622c..38ea7fb4e 100644 --- a/gflownet/envs/crystals/spacegroup.py +++ b/gflownet/envs/crystals/spacegroup.py @@ -2,8 +2,10 @@ Classes to represent crystal environments """ import itertools +from copy import deepcopy +from enum import Enum from pathlib import Path -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, Iterable, List, Optional, Tuple, Union import numpy as np import torch @@ -43,6 +45,19 @@ def _get_space_groups(): return SPACE_GROUPS +class Prop(Enum): + """ + Enumeration of the 3 properties of the SpaceGroup Environment: + - Crystal lattice system + - Point symmetry + - Space group + """ + + CLS = 0 + PS = 1 + SG = 2 + + class SpaceGroup(GFlowNetEnv): """ SpaceGroup environment for ionic conductivity. @@ -73,10 +88,20 @@ class SpaceGroup(GFlowNetEnv): the order of selection of properties. """ - def __init__(self, n_atoms: Optional[List[int]] = None, **kwargs): + def __init__( + self, + space_groups_subset: Optional[Iterable] = None, + n_atoms: Optional[List[int]] = None, + **kwargs, + ): """ Args ---- + space_groups_subset : iterable + A subset of space group (international) numbers to which to restrict the + state space. If None (default), the entire set of 230 space groups is + considered. + n_atoms : list of int (optional) A list with the number of atoms per element, used to compute constraints on the space group. 0's are removed from the list. If None, composition/space @@ -86,14 +111,18 @@ def __init__(self, n_atoms: Optional[List[int]] = None, **kwargs): self.crystal_lattice_systems = _get_crystal_lattice_systems() self.point_symmetries = _get_point_symmetries() self.space_groups = _get_space_groups() - self.n_crystal_lattice_systems = len(self.crystal_lattice_systems) - self.n_point_symmetries = len(self.point_symmetries) - self.n_space_groups = len(self.space_groups) + self._restrict_space_groups(space_groups_subset) # Set dictionary of compatibility with number of atoms self.set_n_atoms_compatibility_dict(n_atoms) # Indices in the state representation: crystal-lattice system (cls), point # symmetry (ps) and space group (sg) self.cls_idx, self.ps_idx, self.sg_idx = 0, 1, 2 + # Dictionary of all properties + self.properties = { + Prop.CLS: self.crystal_lattice_systems, + Prop.PS: self.point_symmetries, + Prop.SG: self.space_groups, + } # Indices of state types (see self.get_state_type) self.state_type_indices = [0, 1, 2, 3] # End-of-sequence action @@ -117,20 +146,13 @@ def get_action_space(self): state (see self.state_type_indices). """ actions = [] - for prop, n_idx in zip( - [self.cls_idx, self.ps_idx, self.sg_idx], - [ - self.n_crystal_lattice_systems, - self.n_point_symmetries, - self.n_space_groups, - ], - ): + for prop, indices in self.properties.items(): for s_from_type in self.state_type_indices: - if prop == self.cls_idx and s_from_type in [1, 3]: + if prop == Prop.CLS and s_from_type in [1, 3]: continue - if prop == self.ps_idx and s_from_type in [2, 3]: + if prop == Prop.PS and s_from_type in [2, 3]: continue - actions_prop = [(prop, idx + 1, s_from_type) for idx in range(n_idx)] + actions_prop = [(prop.value, idx, s_from_type) for idx in indices] actions += actions_prop actions += [self.eos] return actions @@ -162,14 +184,14 @@ def get_mask_invalid_actions_forward( # composition-compatibility constraints if cls_idx == 0 and ps_idx == 0: crystal_lattice_systems = [ - (self.cls_idx, idx + 1, state_type) - for idx in range(self.n_crystal_lattice_systems) - if self._is_compatible(cls_idx=idx + 1) + (self.cls_idx, idx, state_type) + for idx in self.crystal_lattice_systems + if self._is_compatible(cls_idx=idx) ] point_symmetries = [ - (self.ps_idx, idx + 1, state_type) - for idx in range(self.n_point_symmetries) - if self._is_compatible(ps_idx=idx + 1) + (self.ps_idx, idx, state_type) + for idx in self.point_symmetries + if self._is_compatible(ps_idx=idx) ] # Constraints after having selected crystal-lattice system if cls_idx != 0: @@ -188,9 +210,9 @@ def get_mask_invalid_actions_forward( ] else: space_groups_cls = [ - (self.sg_idx, idx + 1, state_type) - for idx in range(self.n_space_groups) - if self.n_atoms_compatibility_dict[idx + 1] + (self.sg_idx, idx, state_type) + for idx in self.space_groups + if self.n_atoms_compatibility_dict[idx] ] # Constraints after having selected point symmetry if ps_idx != 0: @@ -209,9 +231,9 @@ def get_mask_invalid_actions_forward( ] else: space_groups_ps = [ - (self.sg_idx, idx + 1, state_type) - for idx in range(self.n_space_groups) - if self.n_atoms_compatibility_dict[idx + 1] + (self.sg_idx, idx, state_type) + for idx in self.space_groups + if self.n_atoms_compatibility_dict[idx] ] # Merge space_groups constraints and determine valid space group actions space_groups = list(set(space_groups_cls).intersection(set(space_groups_ps))) @@ -658,11 +680,79 @@ def build_n_atoms_compatibility_dict(n_atoms: List[int], space_groups: List[int] assert all([sg > 0 and sg <= 230 for sg in space_groups]) return {sg: space_group_check_compatible(sg, n_atoms) for sg in space_groups} + def _restrict_space_groups(self, sg_subset: Optional[Iterable] = None): + """ + Updates the dictionaries: + - self.space_groups + - self.crystal_lattice_systems + - self.point_symmetries + by eliminating the space groups that are not in the subset sg_subset passed as + an argument. + """ + if sg_subset is None: + return + sg_subset = set(sg_subset) + + # Update self.space_groups + self.space_groups = { + k: v for (k, v) in self.space_groups.items() if k in sg_subset + } + + # Update self.crystal_lattice_systems based on space groups + self.crystal_lattice_systems = deepcopy(self.crystal_lattice_systems) + cls_to_remove = [] + for cls in self.crystal_lattice_systems: + cls_space_groups = sg_subset.intersection( + set(self.crystal_lattice_systems[cls]["space_groups"]) + ) + if len(cls_space_groups) == 0: + cls_to_remove.append(cls) + else: + self.crystal_lattice_systems[cls]["space_groups"] = list( + cls_space_groups + ) + for cls in cls_to_remove: + del self.crystal_lattice_systems[cls] + + # Update self.point_symmetries based on space groups + self.point_symmetries = deepcopy(self.point_symmetries) + ps_to_remove = [] + for ps in self.point_symmetries: + ps_space_groups = sg_subset.intersection( + set(self.point_symmetries[ps]["space_groups"]) + ) + if len(ps_space_groups) == 0: + ps_to_remove.append(ps) + else: + self.point_symmetries[ps]["space_groups"] = list(ps_space_groups) + for ps in ps_to_remove: + del self.point_symmetries[ps] + + # Update point symmetries of remaining crystal lattice systems + point_symmetries = set(self.point_symmetries) + for cls in self.crystal_lattice_systems: + cls_point_symmetries = point_symmetries.intersection( + set(self.crystal_lattice_systems[cls]["point_symmetries"]) + ) + self.crystal_lattice_systems[cls]["point_symmetries"] = list( + cls_point_symmetries + ) + + # Update crystal lattice systems of remaining point symmetries + crystal_lattice_systems = set(self.crystal_lattice_systems) + for ps in self.point_symmetries: + ps_crystal_lattice_systems = crystal_lattice_systems.intersection( + set(self.point_symmetries[ps]["crystal_lattice_systems"]) + ) + self.point_symmetries[ps]["crystal_lattice_systems"] = list( + ps_crystal_lattice_systems + ) + def get_all_terminating_states( self, apply_stoichiometry_constraints: Optional[bool] = True ) -> List[List]: all_x = [] - for sg in range(1, self.n_space_groups + 1): + for sg in self.space_groups: if ( apply_stoichiometry_constraints and self.n_atoms_compatibility_dict[sg] is False diff --git a/tests/gflownet/envs/test_spacegroup.py b/tests/gflownet/envs/test_spacegroup.py index a7de2e237..049d081c5 100644 --- a/tests/gflownet/envs/test_spacegroup.py +++ b/tests/gflownet/envs/test_spacegroup.py @@ -8,6 +8,7 @@ from gflownet.envs.crystals.spacegroup import SpaceGroup N_ATOMS = [3, 7, 9] +SG_SUBSET = [1, 17, 39, 123, 230] @pytest.fixture @@ -20,12 +21,69 @@ def env_with_composition(): return SpaceGroup(n_atoms=N_ATOMS) +@pytest.fixture +def env_with_restricted_spacegroups(): + return SpaceGroup(space_groups_subset=SG_SUBSET) + + def test__environment__initializes_properly(): env = SpaceGroup() assert env.source == [0] * 3 assert env.state == [0] * 3 +def test__environment__space_groups_subset__initializes_properly(): + def count_distinct(my_dict, sub_key): + all_elements = [] + for sub_dict in my_dict.values(): + all_elements.extend(sub_dict[sub_key]) + + distinct_elements = set(all_elements) + return len(distinct_elements) + + env = SpaceGroup(space_groups_subset=[1, 2]) + nb_spacegroups = 2 + nb_cls = 1 + nb_ps = 2 + assert env.source == [0] * 3 + assert env.state == [0] * 3 + assert len(env.space_groups) == nb_spacegroups + assert len(env.crystal_lattice_systems) == nb_cls + assert len(env.point_symmetries) == nb_ps + assert count_distinct(env.crystal_lattice_systems, "space_groups") == nb_spacegroups + assert count_distinct(env.crystal_lattice_systems, "point_symmetries") == nb_ps + assert count_distinct(env.point_symmetries, "space_groups") == nb_spacegroups + assert count_distinct(env.point_symmetries, "crystal_lattice_systems") == nb_cls + + env = SpaceGroup(space_groups_subset=range(1, 15 + 1)) + nb_spacegroups = 15 + nb_cls = 2 + nb_ps = 3 + assert env.source == [0] * 3 + assert env.state == [0] * 3 + assert len(env.space_groups) == nb_spacegroups + assert len(env.crystal_lattice_systems) == nb_cls + assert len(env.point_symmetries) == nb_ps + assert count_distinct(env.crystal_lattice_systems, "space_groups") == nb_spacegroups + assert count_distinct(env.crystal_lattice_systems, "point_symmetries") == nb_ps + assert count_distinct(env.point_symmetries, "space_groups") == nb_spacegroups + assert count_distinct(env.point_symmetries, "crystal_lattice_systems") == nb_cls + + env = SpaceGroup(space_groups_subset=SG_SUBSET) + nb_spacegroups = len(SG_SUBSET) + nb_cls = 4 + nb_ps = 4 + assert env.source == [0] * 3 + assert env.state == [0] * 3 + assert len(env.space_groups) == nb_spacegroups + assert len(env.crystal_lattice_systems) == nb_cls + assert len(env.point_symmetries) == nb_ps + assert count_distinct(env.crystal_lattice_systems, "space_groups") == nb_spacegroups + assert count_distinct(env.crystal_lattice_systems, "point_symmetries") == nb_ps + assert count_distinct(env.point_symmetries, "space_groups") == nb_spacegroups + assert count_distinct(env.point_symmetries, "crystal_lattice_systems") == nb_cls + + def test__environment__action_space_has_eos(): env = SpaceGroup() assert env.eos in env.action_space @@ -88,6 +146,42 @@ def test__action_space__contains_expected(env, action, expected): assert (action in env.action_space) == expected +@pytest.mark.parametrize( + "action, expected", + [ + ( + (2, 1, 0), + True, + ), + ( + (2, 17, 0), + True, + ), + ( + (2, 39, 0), + True, + ), + ( + (2, 123, 0), + True, + ), + ( + (2, 230, 0), + True, + ), + ( + (2, 2, 0), + False, + ), + ], +) +def test__action_space__env_with_restricted_spacegroups__contains_expected( + env_with_restricted_spacegroups, action, expected +): + env = env_with_restricted_spacegroups + assert (action in env.action_space) == expected + + @pytest.mark.parametrize( "state, action, expected", [ @@ -231,7 +325,7 @@ def test__get_mask_invalid_actions_forward__incompatible_sg_are_invalid( env_with_composition.set_state(state=state, done=False) mask_f = env_with_composition.get_mask_invalid_actions_forward() state_type = env_with_composition.get_state_type(state) - for sg in range(1, env_with_composition.n_space_groups + 1): + for sg in env_with_composition.space_groups: sg_pyxtal = Group(sg) is_compatible = sg_pyxtal.check_compatible(N_ATOMS)[0] action = (env_with_composition.sg_idx, sg, state_type) @@ -240,10 +334,10 @@ def test__get_mask_invalid_actions_forward__incompatible_sg_are_invalid( def test__states_are_compatible_with_pymatgen(env): - for idx in range(env.n_space_groups): + for idx in env.space_groups: env = env.reset() - env.step((2, idx + 1, 0)) - sg_int = pmgg.sg_symbol_from_int_number(idx + 1) + env.step((2, idx, 0)) + sg_int = pmgg.sg_symbol_from_int_number(idx) sg = pmgg.SpaceGroup(sg_int) assert sg.int_number == env.state[env.sg_idx] assert sg.crystal_system == env.crystal_system @@ -280,5 +374,18 @@ def test__special_cases_composition_compatibility(n_atoms, cls_idx, ps_idx): assert valid is False -def test__all_env_common(env): +def test__all_common__env(env): + print("\n\nCommon tests for SpaceGroup without composition restrictions\n") return common.test__all_env_common(env) + + +def test__all_common__env_with_composition(env_with_composition): + print( + f"\n\nCommon tests for SpaceGroup with restrictions from composition {N_ATOMS}\n" + ) + return common.test__all_env_common(env_with_composition) + + +def test__all_common__env_with_restricted_spacegroups(env_with_restricted_spacegroups): + print(f"\n\nCommon tests for SpaceGroup with restricted space groups {SG_SUBSET}") + return common.test__all_env_common(env_with_restricted_spacegroups)