Skip to content

Commit

Permalink
Merge pull request #269 from alexhernandezgarcia/conversions-simplify…
Browse files Browse the repository at this point in the history
…-crystals

Simplify state conversions (Crystal environments)
  • Loading branch information
alexhernandezgarcia authored Dec 22, 2023
2 parents fe105d1 + c5763ea commit 768d059
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 83 deletions.
100 changes: 29 additions & 71 deletions gflownet/envs/crystals/ccrystal.py
Original file line number Diff line number Diff line change
Expand Up @@ -857,100 +857,58 @@ def get_logprobs(
)
return logprobs

def state2policy(self, state: Optional[List[int]] = None) -> Tensor:
"""
Prepares one state in "GFlowNet format" for the policy. Simply
a concatenation of all crystal components.
"""
state = self._get_state(state)
return self.statetorch2policy(
torch.unsqueeze(tfloat(state, device=self.device, float_type=self.float), 0)
)[0]

def statebatch2policy(
self, states: List[List]
def states2policy(
self, states: Union[List[List], TensorType["batch", "state_dim"]]
) -> TensorType["batch", "state_policy_dim"]:
"""
Prepares a batch of states in "GFlowNet format" for the policy. Simply
Prepares a batch of states in "environment format" for the policy model: simply
a concatenation of all crystal components.
"""
return self.statetorch2policy(
tfloat(states, device=self.device, float_type=self.float)
)
def statetorch2policy(
self, states: TensorType["batch", "state_dim"]
) -> TensorType["batch", "state_policy_dim"]:
"""
Prepares a tensor batch of states in "GFlowNet format" for the policy. Simply
a concatenation of all crystal components.
Args
----
states : list or tensor
A batch of states in environment format, either as a list of states or as a
single tensor.
Returns
-------
A tensor containing all the states in the batch.
"""
states = tfloat(states, device=self.device, float_type=self.float)
return torch.cat(
[
subenv.statetorch2policy(self._get_states_of_subenv(states, stage))
subenv.states2policy(self._get_states_of_subenv(states, stage))
for stage, subenv in self.subenvs.items()
],
dim=1,
)

def state2oracle(self, state: Optional[List[int]] = None) -> Tensor:
"""
Prepares one state in "GFlowNet format" for the oracle. Simply
a concatenation of all crystal components.
"""
state = self._get_state(state)
return self.statetorch2oracle(
torch.unsqueeze(tfloat(state, device=self.device, float_type=self.float), 0)
)

def statebatch2oracle(
self, states: List[List]
def states2proxy(
self, states: Union[List[List], TensorType["batch", "state_dim"]]
) -> TensorType["batch", "state_oracle_dim"]:
"""
Prepares a batch of states in "GFlowNet format" for the oracle. Simply
a concatenation of all crystal components.
"""
return self.statetorch2oracle(
tfloat(states, device=self.device, float_type=self.float)
)
Prepares a batch of states in "environment format" for a proxy: simply a
concatenation of all crystal components.
def statetorch2oracle(
self, states: TensorType["batch", "state_dim"]
) -> TensorType["batch", "state_oracle_dim"]:
"""
Prepares one state in "GFlowNet format" for the oracle. Simply
a concatenation of all crystal components.
Args
----
states : list or tensor
A batch of states in environment format, either as a list of states or as a
single tensor.
Returns
-------
A tensor containing all the states in the batch.
"""
states = tfloat(states, device=self.device, float_type=self.float)
return torch.cat(
[
subenv.statetorch2oracle(self._get_states_of_subenv(states, stage))
subenv.states2proxy(self._get_states_of_subenv(states, stage))
for stage, subenv in self.subenvs.items()
],
dim=1,
)

def state2proxy(self, state: Optional[List[int]] = None) -> Tensor:
"""
Returns state2oracle(state).
"""
return self.state2oracle(state)

def statebatch2proxy(
self, states: List[List]
) -> TensorType["batch", "state_oracle_dim"]:
"""
Returns statebatch2oracle(states).
"""
return self.statebatch2oracle(states)

def statetorch2proxy(
self, states: TensorType["batch", "state_dim"]
) -> TensorType["batch", "state_oracle_dim"]:
"""
Returns statetorch2oracle(states).
"""
return self.statetorch2oracle(states)

def set_state(self, state: List, done: Optional[bool] = False):
super().set_state(state, done)

Expand Down
15 changes: 12 additions & 3 deletions gflownet/envs/crystals/composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,8 +406,10 @@ def states2proxy(
self, states: Union[List[List], TensorType["batch", "state_dim"]]
) -> TensorType["batch", "state_proxy_dim"]:
"""
Prepares a batch of states in "environment format" for the proxy: simply
returns the states as are with dtype long.
Prepares a batch of states in "environment format" for the proxy: The output is
a tensor of dtype long with N_ELEMENTS_ORACLE + 1 columns, where the positions
of self.elements are filled with the number of atoms of each element in the
state.
Args
----
Expand All @@ -419,7 +421,14 @@ def states2proxy(
-------
A tensor containing all the states in the batch.
"""
return tlong(states, device=self.device)
states = tlong(states, device=self.device)
states_proxy = torch.zeros(
(states.shape[0], N_ELEMENTS_ORACLE + 1),
device=self.device,
dtype=torch.long,
)
states_proxy[:, tlong(self.elements, device=self.device)] = states
return states_proxy

def state2readable(self, state=None):
"""
Expand Down
18 changes: 9 additions & 9 deletions tests/gflownet/envs/test_ccrystal.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,22 +152,22 @@ def test__pad_depad_action(env):
],
],
)
def test__statetorch2policy__is_concatenation_of_subenv_states(env, states):
def test__states2policy__is_concatenation_of_subenv_states(env, states):
# Get policy states from the batch of states converted into each subenv
states_dict = {stage: [] for stage in env.subenvs}
for state in states:
for stage in env.subenvs:
states_dict[stage].append(env._get_state_of_subenv(state, stage))
states_policy_dict = {
stage: subenv.statebatch2policy(states_dict[stage])
stage: subenv.states2policy(states_dict[stage])
for stage, subenv in env.subenvs.items()
}
states_policy_expected = torch.cat(
[el for el in states_policy_dict.values()], dim=1
)
# Get policy states from env.statetorch2policy
# Get policy states from env.states2policy
states_torch = tfloat(states, float_type=env.float, device=env.device)
states_policy = env.statetorch2policy(states_torch)
states_policy = env.states2policy(states_torch)
assert torch.all(torch.eq(states_policy, states_policy_expected))


Expand All @@ -191,20 +191,20 @@ def test__statetorch2policy__is_concatenation_of_subenv_states(env, states):
],
],
)
def test__statetorch2proxy__is_concatenation_of_subenv_states(env, states):
def test__states2proxy__is_concatenation_of_subenv_states(env, states):
# Get proxy states from the batch of states converted into each subenv
states_dict = {stage: [] for stage in env.subenvs}
for state in states:
for stage in env.subenvs:
states_dict[stage].append(env._get_state_of_subenv(state, stage))
states_proxy_dict = {
stage: subenv.statebatch2proxy(states_dict[stage])
stage: subenv.states2proxy(states_dict[stage])
for stage, subenv in env.subenvs.items()
}
states_proxy_expected = torch.cat([el for el in states_proxy_dict.values()], dim=1)
# Get proxy states from env.statetorch2proxy
# Get proxy states from env.states2proxy
states_torch = tfloat(states, float_type=env.float, device=env.device)
states_proxy = env.statetorch2proxy(states_torch)
states_proxy = env.states2proxy(states_torch)
assert torch.all(torch.eq(states_proxy, states_proxy_expected))


Expand Down Expand Up @@ -243,7 +243,7 @@ def test__state2readable__is_concatenation_of_subenv_states(env, states):
f"SpaceGroup = {readables[1]}; "
f"LatticeParameters = {readables[2]}"
)
# Get policy states from env.statetorch2policy
# Get policy states from env.states2policy
states_readable = [env.state2readable(state) for state in states]
for readable, readable_expected in zip(states_readable, states_readable_expected):
assert readable == readable_expected
Expand Down

0 comments on commit 768d059

Please sign in to comment.