Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify state conversions (Crystal environments) #269

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 29 additions & 71 deletions gflownet/envs/crystals/ccrystal.py
Original file line number Diff line number Diff line change
Expand Up @@ -857,100 +857,58 @@ def get_logprobs(
)
return logprobs

def state2policy(self, state: Optional[List[int]] = None) -> Tensor:
"""
Prepares one state in "GFlowNet format" for the policy. Simply
a concatenation of all crystal components.
"""
state = self._get_state(state)
return self.statetorch2policy(
torch.unsqueeze(tfloat(state, device=self.device, float_type=self.float), 0)
)[0]

def statebatch2policy(
self, states: List[List]
def states2policy(
self, states: Union[List[List], TensorType["batch", "state_dim"]]
) -> TensorType["batch", "state_policy_dim"]:
"""
Prepares a batch of states in "GFlowNet format" for the policy. Simply
Prepares a batch of states in "environment format" for the policy model: simply
a concatenation of all crystal components.
"""
return self.statetorch2policy(
tfloat(states, device=self.device, float_type=self.float)
)

def statetorch2policy(
self, states: TensorType["batch", "state_dim"]
) -> TensorType["batch", "state_policy_dim"]:
"""
Prepares a tensor batch of states in "GFlowNet format" for the policy. Simply
a concatenation of all crystal components.
Args
----
states : list or tensor
A batch of states in environment format, either as a list of states or as a
single tensor.

Returns
-------
A tensor containing all the states in the batch.
"""
states = tfloat(states, device=self.device, float_type=self.float)
return torch.cat(
[
subenv.statetorch2policy(self._get_states_of_subenv(states, stage))
subenv.states2policy(self._get_states_of_subenv(states, stage))
for stage, subenv in self.subenvs.items()
],
dim=1,
)

def state2oracle(self, state: Optional[List[int]] = None) -> Tensor:
"""
Prepares one state in "GFlowNet format" for the oracle. Simply
a concatenation of all crystal components.
"""
state = self._get_state(state)
return self.statetorch2oracle(
torch.unsqueeze(tfloat(state, device=self.device, float_type=self.float), 0)
)

def statebatch2oracle(
self, states: List[List]
def states2proxy(
self, states: Union[List[List], TensorType["batch", "state_dim"]]
) -> TensorType["batch", "state_oracle_dim"]:
"""
Prepares a batch of states in "GFlowNet format" for the oracle. Simply
a concatenation of all crystal components.
"""
return self.statetorch2oracle(
tfloat(states, device=self.device, float_type=self.float)
)
Prepares a batch of states in "environment format" for a proxy: simply a
concatenation of all crystal components.

def statetorch2oracle(
self, states: TensorType["batch", "state_dim"]
) -> TensorType["batch", "state_oracle_dim"]:
"""
Prepares one state in "GFlowNet format" for the oracle. Simply
a concatenation of all crystal components.
Args
----
states : list or tensor
A batch of states in environment format, either as a list of states or as a
single tensor.

Returns
-------
A tensor containing all the states in the batch.
"""
states = tfloat(states, device=self.device, float_type=self.float)
return torch.cat(
[
subenv.statetorch2oracle(self._get_states_of_subenv(states, stage))
subenv.states2proxy(self._get_states_of_subenv(states, stage))
for stage, subenv in self.subenvs.items()
],
dim=1,
)

def state2proxy(self, state: Optional[List[int]] = None) -> Tensor:
"""
Returns state2oracle(state).
"""
return self.state2oracle(state)

def statebatch2proxy(
self, states: List[List]
) -> TensorType["batch", "state_oracle_dim"]:
"""
Returns statebatch2oracle(states).
"""
return self.statebatch2oracle(states)

def statetorch2proxy(
self, states: TensorType["batch", "state_dim"]
) -> TensorType["batch", "state_oracle_dim"]:
"""
Returns statetorch2oracle(states).
"""
return self.statetorch2oracle(states)

def set_state(self, state: List, done: Optional[bool] = False):
super().set_state(state, done)

Expand Down
15 changes: 12 additions & 3 deletions gflownet/envs/crystals/composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,8 +406,10 @@ def states2proxy(
self, states: Union[List[List], TensorType["batch", "state_dim"]]
) -> TensorType["batch", "state_proxy_dim"]:
"""
Prepares a batch of states in "environment format" for the proxy: simply
returns the states as are with dtype long.
Prepares a batch of states in "environment format" for the proxy: The output is
a tensor of dtype long with N_ELEMENTS_ORACLE + 1 columns, where the positions
of self.elements are filled with the number of atoms of each element in the
state.

Args
----
Expand All @@ -419,7 +421,14 @@ def states2proxy(
-------
A tensor containing all the states in the batch.
"""
return tlong(states, device=self.device)
states = tlong(states, device=self.device)
states_proxy = torch.zeros(
(states.shape[0], N_ELEMENTS_ORACLE + 1),
device=self.device,
dtype=torch.long,
)
states_proxy[:, tlong(self.elements, device=self.device)] = states
return states_proxy

def state2readable(self, state=None):
"""
Expand Down
18 changes: 9 additions & 9 deletions tests/gflownet/envs/test_ccrystal.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,22 +152,22 @@ def test__pad_depad_action(env):
],
],
)
def test__statetorch2policy__is_concatenation_of_subenv_states(env, states):
def test__states2policy__is_concatenation_of_subenv_states(env, states):
# Get policy states from the batch of states converted into each subenv
states_dict = {stage: [] for stage in env.subenvs}
for state in states:
for stage in env.subenvs:
states_dict[stage].append(env._get_state_of_subenv(state, stage))
states_policy_dict = {
stage: subenv.statebatch2policy(states_dict[stage])
stage: subenv.states2policy(states_dict[stage])
for stage, subenv in env.subenvs.items()
}
states_policy_expected = torch.cat(
[el for el in states_policy_dict.values()], dim=1
)
# Get policy states from env.statetorch2policy
# Get policy states from env.states2policy
states_torch = tfloat(states, float_type=env.float, device=env.device)
states_policy = env.statetorch2policy(states_torch)
states_policy = env.states2policy(states_torch)
assert torch.all(torch.eq(states_policy, states_policy_expected))


Expand All @@ -191,20 +191,20 @@ def test__statetorch2policy__is_concatenation_of_subenv_states(env, states):
],
],
)
def test__statetorch2proxy__is_concatenation_of_subenv_states(env, states):
def test__states2proxy__is_concatenation_of_subenv_states(env, states):
# Get proxy states from the batch of states converted into each subenv
states_dict = {stage: [] for stage in env.subenvs}
for state in states:
for stage in env.subenvs:
states_dict[stage].append(env._get_state_of_subenv(state, stage))
states_proxy_dict = {
stage: subenv.statebatch2proxy(states_dict[stage])
stage: subenv.states2proxy(states_dict[stage])
for stage, subenv in env.subenvs.items()
}
states_proxy_expected = torch.cat([el for el in states_proxy_dict.values()], dim=1)
# Get proxy states from env.statetorch2proxy
# Get proxy states from env.states2proxy
states_torch = tfloat(states, float_type=env.float, device=env.device)
states_proxy = env.statetorch2proxy(states_torch)
states_proxy = env.states2proxy(states_torch)
assert torch.all(torch.eq(states_proxy, states_proxy_expected))


Expand Down Expand Up @@ -243,7 +243,7 @@ def test__state2readable__is_concatenation_of_subenv_states(env, states):
f"SpaceGroup = {readables[1]}; "
f"LatticeParameters = {readables[2]}"
)
# Get policy states from env.statetorch2policy
# Get policy states from env.states2policy
states_readable = [env.state2readable(state) for state in states]
for readable, readable_expected in zip(states_readable, states_readable_expected):
assert readable == readable_expected
Expand Down