Skip to content

Commit

Permalink
StateType not Enum anymore
Browse files Browse the repository at this point in the history
  • Loading branch information
alexhernandezgarcia committed Oct 18, 2023
1 parent f26b700 commit bf43239
Showing 1 changed file with 14 additions and 11 deletions.
25 changes: 14 additions & 11 deletions gflownet/envs/crystals/spacegroup.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,12 @@ class Prop:
CLS = 0
PS = 1
SG = 2
ALL = (CLS, PS, SG)


class StateType(Enum):
class StateType:
"""
Enumeration of the 4 types of state from which transitions can originate:
Encodes the 4 types of state from which transitions can originate:
0: Source - both crystal-lattice system and point symmetry are unset (== 0)
1: CLS - crystal-lattice system is set (!= 0); point symmetry is unset
2: PS - crystal-lattice system is unset; point symmetry is set
Expand All @@ -71,6 +72,13 @@ class StateType(Enum):
CLS = 1
PS = 2
CLS_PS = 3
ALL = (SOURCE, CLS, PS, CLS_PS)

def get_state_type(state: List[int]) -> int:
"""
Returns the value of the type of the state passed as an argument.
"""
return sum([int(s > 0) * f for s, f in zip(state, (1, 2))])


class SpaceGroup(GFlowNetEnv):
Expand Down Expand Up @@ -153,12 +161,12 @@ def get_action_space(self):
Prop.SG: self.space_groups,
}
for prop, indices in properties.items():
for state_type in StateType:
for state_type in StateType.ALL:
if prop == Prop.CLS and state_type in [StateType.CLS, StateType.CLS_PS]:
continue
if prop == Prop.PS and state_type in [StateType.PS, StateType.CLS_PS]:
continue
actions_prop = [(prop, idx, state_type.value) for idx in indices]
actions_prop = [(prop, idx, state_type) for idx in indices]
actions += actions_prop
actions += [self.eos]
return actions
Expand Down Expand Up @@ -595,16 +603,11 @@ def point_group(self) -> str:

def get_state_type(self, state: List[int] = None) -> int:
"""
Returns the index of the type of the state passed as an argument. The state
type is one of the following (StateType):
0: both crystal-lattice system and point symmetry are unset (== 0)
1: crystal-lattice system is set (!= 0); point symmetry is unset
2: crystal-lattice system is unset; point symmetry is set
3: both crystal-lattice system and point symmetry are set
Returns the value of the type of the state passed as an argument.
"""
if state is None:
state = self.state
return sum([int(s > 0) * f for s, f in zip(state, (1, 2))])
return StateType.get_state_type(state)

def set_n_atoms_compatibility_dict(self, n_atoms: List):
"""
Expand Down

0 comments on commit bf43239

Please sign in to comment.