Skip to content

Commit

Permalink
Merge pull request #223 from alexhernandezgarcia/space-group-subset
Browse files Browse the repository at this point in the history
Space group env accepts an iterable of valid space groups
  • Loading branch information
alexhernandezgarcia authored Oct 20, 2023
2 parents a106516 + 8f68bdf commit 96f1eaf
Show file tree
Hide file tree
Showing 3 changed files with 234 additions and 34 deletions.
3 changes: 3 additions & 0 deletions config/env/crystals/spacegroup.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ defaults:
_target_: gflownet.envs.crystals.spacegroup.SpaceGroup

id: spacegroup

# Subset of space groups
space_groups_subset: null
# Stoichiometry
n_atoms: null

Expand Down
148 changes: 119 additions & 29 deletions gflownet/envs/crystals/spacegroup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
Classes to represent crystal environments
"""
import itertools
from copy import deepcopy
from enum import Enum
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union
from typing import Dict, Iterable, List, Optional, Tuple, Union

import numpy as np
import torch
Expand Down Expand Up @@ -43,6 +45,19 @@ def _get_space_groups():
return SPACE_GROUPS


class Prop(Enum):
"""
Enumeration of the 3 properties of the SpaceGroup Environment:
- Crystal lattice system
- Point symmetry
- Space group
"""

CLS = 0
PS = 1
SG = 2


class SpaceGroup(GFlowNetEnv):
"""
SpaceGroup environment for ionic conductivity.
Expand Down Expand Up @@ -73,10 +88,20 @@ class SpaceGroup(GFlowNetEnv):
the order of selection of properties.
"""

def __init__(self, n_atoms: Optional[List[int]] = None, **kwargs):
def __init__(
self,
space_groups_subset: Optional[Iterable] = None,
n_atoms: Optional[List[int]] = None,
**kwargs,
):
"""
Args
----
space_groups_subset : iterable
A subset of space group (international) numbers to which to restrict the
state space. If None (default), the entire set of 230 space groups is
considered.
n_atoms : list of int (optional)
A list with the number of atoms per element, used to compute constraints on
the space group. 0's are removed from the list. If None, composition/space
Expand All @@ -86,14 +111,18 @@ def __init__(self, n_atoms: Optional[List[int]] = None, **kwargs):
self.crystal_lattice_systems = _get_crystal_lattice_systems()
self.point_symmetries = _get_point_symmetries()
self.space_groups = _get_space_groups()
self.n_crystal_lattice_systems = len(self.crystal_lattice_systems)
self.n_point_symmetries = len(self.point_symmetries)
self.n_space_groups = len(self.space_groups)
self._restrict_space_groups(space_groups_subset)
# Set dictionary of compatibility with number of atoms
self.set_n_atoms_compatibility_dict(n_atoms)
# Indices in the state representation: crystal-lattice system (cls), point
# symmetry (ps) and space group (sg)
self.cls_idx, self.ps_idx, self.sg_idx = 0, 1, 2
# Dictionary of all properties
self.properties = {
Prop.CLS: self.crystal_lattice_systems,
Prop.PS: self.point_symmetries,
Prop.SG: self.space_groups,
}
# Indices of state types (see self.get_state_type)
self.state_type_indices = [0, 1, 2, 3]
# End-of-sequence action
Expand All @@ -117,20 +146,13 @@ def get_action_space(self):
state (see self.state_type_indices).
"""
actions = []
for prop, n_idx in zip(
[self.cls_idx, self.ps_idx, self.sg_idx],
[
self.n_crystal_lattice_systems,
self.n_point_symmetries,
self.n_space_groups,
],
):
for prop, indices in self.properties.items():
for s_from_type in self.state_type_indices:
if prop == self.cls_idx and s_from_type in [1, 3]:
if prop == Prop.CLS and s_from_type in [1, 3]:
continue
if prop == self.ps_idx and s_from_type in [2, 3]:
if prop == Prop.PS and s_from_type in [2, 3]:
continue
actions_prop = [(prop, idx + 1, s_from_type) for idx in range(n_idx)]
actions_prop = [(prop.value, idx, s_from_type) for idx in indices]
actions += actions_prop
actions += [self.eos]
return actions
Expand Down Expand Up @@ -162,14 +184,14 @@ def get_mask_invalid_actions_forward(
# composition-compatibility constraints
if cls_idx == 0 and ps_idx == 0:
crystal_lattice_systems = [
(self.cls_idx, idx + 1, state_type)
for idx in range(self.n_crystal_lattice_systems)
if self._is_compatible(cls_idx=idx + 1)
(self.cls_idx, idx, state_type)
for idx in self.crystal_lattice_systems
if self._is_compatible(cls_idx=idx)
]
point_symmetries = [
(self.ps_idx, idx + 1, state_type)
for idx in range(self.n_point_symmetries)
if self._is_compatible(ps_idx=idx + 1)
(self.ps_idx, idx, state_type)
for idx in self.point_symmetries
if self._is_compatible(ps_idx=idx)
]
# Constraints after having selected crystal-lattice system
if cls_idx != 0:
Expand All @@ -188,9 +210,9 @@ def get_mask_invalid_actions_forward(
]
else:
space_groups_cls = [
(self.sg_idx, idx + 1, state_type)
for idx in range(self.n_space_groups)
if self.n_atoms_compatibility_dict[idx + 1]
(self.sg_idx, idx, state_type)
for idx in self.space_groups
if self.n_atoms_compatibility_dict[idx]
]
# Constraints after having selected point symmetry
if ps_idx != 0:
Expand All @@ -209,9 +231,9 @@ def get_mask_invalid_actions_forward(
]
else:
space_groups_ps = [
(self.sg_idx, idx + 1, state_type)
for idx in range(self.n_space_groups)
if self.n_atoms_compatibility_dict[idx + 1]
(self.sg_idx, idx, state_type)
for idx in self.space_groups
if self.n_atoms_compatibility_dict[idx]
]
# Merge space_groups constraints and determine valid space group actions
space_groups = list(set(space_groups_cls).intersection(set(space_groups_ps)))
Expand Down Expand Up @@ -658,11 +680,79 @@ def build_n_atoms_compatibility_dict(n_atoms: List[int], space_groups: List[int]
assert all([sg > 0 and sg <= 230 for sg in space_groups])
return {sg: space_group_check_compatible(sg, n_atoms) for sg in space_groups}

def _restrict_space_groups(self, sg_subset: Optional[Iterable] = None):
"""
Updates the dictionaries:
- self.space_groups
- self.crystal_lattice_systems
- self.point_symmetries
by eliminating the space groups that are not in the subset sg_subset passed as
an argument.
"""
if sg_subset is None:
return
sg_subset = set(sg_subset)

# Update self.space_groups
self.space_groups = {
k: v for (k, v) in self.space_groups.items() if k in sg_subset
}

# Update self.crystal_lattice_systems based on space groups
self.crystal_lattice_systems = deepcopy(self.crystal_lattice_systems)
cls_to_remove = []
for cls in self.crystal_lattice_systems:
cls_space_groups = sg_subset.intersection(
set(self.crystal_lattice_systems[cls]["space_groups"])
)
if len(cls_space_groups) == 0:
cls_to_remove.append(cls)
else:
self.crystal_lattice_systems[cls]["space_groups"] = list(
cls_space_groups
)
for cls in cls_to_remove:
del self.crystal_lattice_systems[cls]

# Update self.point_symmetries based on space groups
self.point_symmetries = deepcopy(self.point_symmetries)
ps_to_remove = []
for ps in self.point_symmetries:
ps_space_groups = sg_subset.intersection(
set(self.point_symmetries[ps]["space_groups"])
)
if len(ps_space_groups) == 0:
ps_to_remove.append(ps)
else:
self.point_symmetries[ps]["space_groups"] = list(ps_space_groups)
for ps in ps_to_remove:
del self.point_symmetries[ps]

# Update point symmetries of remaining crystal lattice systems
point_symmetries = set(self.point_symmetries)
for cls in self.crystal_lattice_systems:
cls_point_symmetries = point_symmetries.intersection(
set(self.crystal_lattice_systems[cls]["point_symmetries"])
)
self.crystal_lattice_systems[cls]["point_symmetries"] = list(
cls_point_symmetries
)

# Update crystal lattice systems of remaining point symmetries
crystal_lattice_systems = set(self.crystal_lattice_systems)
for ps in self.point_symmetries:
ps_crystal_lattice_systems = crystal_lattice_systems.intersection(
set(self.point_symmetries[ps]["crystal_lattice_systems"])
)
self.point_symmetries[ps]["crystal_lattice_systems"] = list(
ps_crystal_lattice_systems
)

def get_all_terminating_states(
self, apply_stoichiometry_constraints: Optional[bool] = True
) -> List[List]:
all_x = []
for sg in range(1, self.n_space_groups + 1):
for sg in self.space_groups:
if (
apply_stoichiometry_constraints
and self.n_atoms_compatibility_dict[sg] is False
Expand Down
Loading

0 comments on commit 96f1eaf

Please sign in to comment.