Skip to content

Commit

Permalink
Repair states2proxy in Composition
Browse files Browse the repository at this point in the history
  • Loading branch information
alexhernandezgarcia committed Dec 9, 2023
1 parent b07e258 commit c5763ea
Showing 1 changed file with 12 additions and 3 deletions.
15 changes: 12 additions & 3 deletions gflownet/envs/crystals/composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,8 +406,10 @@ def states2proxy(
self, states: Union[List[List], TensorType["batch", "state_dim"]]
) -> TensorType["batch", "state_proxy_dim"]:
"""
Prepares a batch of states in "environment format" for the proxy: simply
returns the states as are with dtype long.
Prepares a batch of states in "environment format" for the proxy: The output is
a tensor of dtype long with N_ELEMENTS_ORACLE + 1 columns, where the positions
of self.elements are filled with the number of atoms of each element in the
state.
Args
----
Expand All @@ -419,7 +421,14 @@ def states2proxy(
-------
A tensor containing all the states in the batch.
"""
return tlong(states, device=self.device)
states = tlong(states, device=self.device)
states_proxy = torch.zeros(
(states.shape[0], N_ELEMENTS_ORACLE + 1),
device=self.device,
dtype=torch.long,
)
states_proxy[:, tlong(self.elements, device=self.device)] = states
return states_proxy

def state2readable(self, state=None):
"""
Expand Down

0 comments on commit c5763ea

Please sign in to comment.