From 261b8fb5e851aedba8fa85c05c7d3f5ea95c7a69 Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Tue, 28 Nov 2023 13:13:02 -0500 Subject: [PATCH 1/3] Simplify state conversions in CCrystal. --- gflownet/envs/crystals/ccrystal.py | 100 ++++++++------------------- tests/gflownet/envs/test_ccrystal.py | 18 ++--- 2 files changed, 38 insertions(+), 80 deletions(-) diff --git a/gflownet/envs/crystals/ccrystal.py b/gflownet/envs/crystals/ccrystal.py index 0336d8a43..c6eadb364 100644 --- a/gflownet/envs/crystals/ccrystal.py +++ b/gflownet/envs/crystals/ccrystal.py @@ -857,100 +857,58 @@ def get_logprobs( ) return logprobs - def state2policy(self, state: Optional[List[int]] = None) -> Tensor: - """ - Prepares one state in "GFlowNet format" for the policy. Simply - a concatenation of all crystal components. - """ - state = self._get_state(state) - return self.statetorch2policy( - torch.unsqueeze(tfloat(state, device=self.device, float_type=self.float), 0) - )[0] - - def statebatch2policy( - self, states: List[List] + def states2policy( + self, states: Union[List[List], TensorType["batch", "state_dim"]] ) -> TensorType["batch", "state_policy_dim"]: """ - Prepares a batch of states in "GFlowNet format" for the policy. Simply + Prepares a batch of states in "environment format" for the policy model: simply a concatenation of all crystal components. - """ - return self.statetorch2policy( - tfloat(states, device=self.device, float_type=self.float) - ) - def statetorch2policy( - self, states: TensorType["batch", "state_dim"] - ) -> TensorType["batch", "state_policy_dim"]: - """ - Prepares a tensor batch of states in "GFlowNet format" for the policy. Simply - a concatenation of all crystal components. + Args + ---- + states : list or tensor + A batch of states in environment format, either as a list of states or as a + single tensor. + + Returns + ------- + A tensor containing all the states in the batch. """ + states = tfloat(states, device=self.device, float_type=self.float) return torch.cat( [ - subenv.statetorch2policy(self._get_states_of_subenv(states, stage)) + subenv.states2policy(self._get_states_of_subenv(states, stage)) for stage, subenv in self.subenvs.items() ], dim=1, ) - def state2oracle(self, state: Optional[List[int]] = None) -> Tensor: - """ - Prepares one state in "GFlowNet format" for the oracle. Simply - a concatenation of all crystal components. - """ - state = self._get_state(state) - return self.statetorch2oracle( - torch.unsqueeze(tfloat(state, device=self.device, float_type=self.float), 0) - ) - - def statebatch2oracle( - self, states: List[List] + def states2proxy( + self, states: Union[List[List], TensorType["batch", "state_dim"]] ) -> TensorType["batch", "state_oracle_dim"]: """ - Prepares a batch of states in "GFlowNet format" for the oracle. Simply - a concatenation of all crystal components. - """ - return self.statetorch2oracle( - tfloat(states, device=self.device, float_type=self.float) - ) + Prepares a batch of states in "environment format" for a proxy: simply a + concatenation of all crystal components. - def statetorch2oracle( - self, states: TensorType["batch", "state_dim"] - ) -> TensorType["batch", "state_oracle_dim"]: - """ - Prepares one state in "GFlowNet format" for the oracle. Simply - a concatenation of all crystal components. + Args + ---- + states : list or tensor + A batch of states in environment format, either as a list of states or as a + single tensor. + + Returns + ------- + A tensor containing all the states in the batch. """ + states = tfloat(states, device=self.device, float_type=self.float) return torch.cat( [ - subenv.statetorch2oracle(self._get_states_of_subenv(states, stage)) + subenv.states2oracle(self._get_states_of_subenv(states, stage)) for stage, subenv in self.subenvs.items() ], dim=1, ) - def state2proxy(self, state: Optional[List[int]] = None) -> Tensor: - """ - Returns state2oracle(state). - """ - return self.state2oracle(state) - - def statebatch2proxy( - self, states: List[List] - ) -> TensorType["batch", "state_oracle_dim"]: - """ - Returns statebatch2oracle(states). - """ - return self.statebatch2oracle(states) - - def statetorch2proxy( - self, states: TensorType["batch", "state_dim"] - ) -> TensorType["batch", "state_oracle_dim"]: - """ - Returns statetorch2oracle(states). - """ - return self.statetorch2oracle(states) - def set_state(self, state: List, done: Optional[bool] = False): super().set_state(state, done) diff --git a/tests/gflownet/envs/test_ccrystal.py b/tests/gflownet/envs/test_ccrystal.py index e88a348ae..4d0cdf00b 100644 --- a/tests/gflownet/envs/test_ccrystal.py +++ b/tests/gflownet/envs/test_ccrystal.py @@ -152,22 +152,22 @@ def test__pad_depad_action(env): ], ], ) -def test__statetorch2policy__is_concatenation_of_subenv_states(env, states): +def test__states2policy__is_concatenation_of_subenv_states(env, states): # Get policy states from the batch of states converted into each subenv states_dict = {stage: [] for stage in env.subenvs} for state in states: for stage in env.subenvs: states_dict[stage].append(env._get_state_of_subenv(state, stage)) states_policy_dict = { - stage: subenv.statebatch2policy(states_dict[stage]) + stage: subenv.states2policy(states_dict[stage]) for stage, subenv in env.subenvs.items() } states_policy_expected = torch.cat( [el for el in states_policy_dict.values()], dim=1 ) - # Get policy states from env.statetorch2policy + # Get policy states from env.states2policy states_torch = tfloat(states, float_type=env.float, device=env.device) - states_policy = env.statetorch2policy(states_torch) + states_policy = env.states2policy(states_torch) assert torch.all(torch.eq(states_policy, states_policy_expected)) @@ -191,20 +191,20 @@ def test__statetorch2policy__is_concatenation_of_subenv_states(env, states): ], ], ) -def test__statetorch2proxy__is_concatenation_of_subenv_states(env, states): +def test__states2proxy__is_concatenation_of_subenv_states(env, states): # Get proxy states from the batch of states converted into each subenv states_dict = {stage: [] for stage in env.subenvs} for state in states: for stage in env.subenvs: states_dict[stage].append(env._get_state_of_subenv(state, stage)) states_proxy_dict = { - stage: subenv.statebatch2proxy(states_dict[stage]) + stage: subenv.states2proxy(states_dict[stage]) for stage, subenv in env.subenvs.items() } states_proxy_expected = torch.cat([el for el in states_proxy_dict.values()], dim=1) - # Get proxy states from env.statetorch2proxy + # Get proxy states from env.states2proxy states_torch = tfloat(states, float_type=env.float, device=env.device) - states_proxy = env.statetorch2proxy(states_torch) + states_proxy = env.states2proxy(states_torch) assert torch.all(torch.eq(states_proxy, states_proxy_expected)) @@ -243,7 +243,7 @@ def test__state2readable__is_concatenation_of_subenv_states(env, states): f"SpaceGroup = {readables[1]}; " f"LatticeParameters = {readables[2]}" ) - # Get policy states from env.statetorch2policy + # Get policy states from env.states2policy states_readable = [env.state2readable(state) for state in states] for readable, readable_expected in zip(states_readable, states_readable_expected): assert readable == readable_expected From b07e25825cd9fc5c215e28447d05ed729acaaca3 Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Tue, 28 Nov 2023 13:28:21 -0500 Subject: [PATCH 2/3] Simplify state conversions in Composition and missing bit in CCrystal. --- gflownet/envs/crystals/ccrystal.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gflownet/envs/crystals/ccrystal.py b/gflownet/envs/crystals/ccrystal.py index c6eadb364..49cfdda8f 100644 --- a/gflownet/envs/crystals/ccrystal.py +++ b/gflownet/envs/crystals/ccrystal.py @@ -903,7 +903,7 @@ def states2proxy( states = tfloat(states, device=self.device, float_type=self.float) return torch.cat( [ - subenv.states2oracle(self._get_states_of_subenv(states, stage)) + subenv.states2proxy(self._get_states_of_subenv(states, stage)) for stage, subenv in self.subenvs.items() ], dim=1, From c5763ea98c54a1050374a12d8b82a71dae8e90ac Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Sat, 9 Dec 2023 15:01:56 -0500 Subject: [PATCH 3/3] Repair states2proxy in Composition --- gflownet/envs/crystals/composition.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/gflownet/envs/crystals/composition.py b/gflownet/envs/crystals/composition.py index 371c002b0..f109faadb 100644 --- a/gflownet/envs/crystals/composition.py +++ b/gflownet/envs/crystals/composition.py @@ -406,8 +406,10 @@ def states2proxy( self, states: Union[List[List], TensorType["batch", "state_dim"]] ) -> TensorType["batch", "state_proxy_dim"]: """ - Prepares a batch of states in "environment format" for the proxy: simply - returns the states as are with dtype long. + Prepares a batch of states in "environment format" for the proxy: The output is + a tensor of dtype long with N_ELEMENTS_ORACLE + 1 columns, where the positions + of self.elements are filled with the number of atoms of each element in the + state. Args ---- @@ -419,7 +421,14 @@ def states2proxy( ------- A tensor containing all the states in the batch. """ - return tlong(states, device=self.device) + states = tlong(states, device=self.device) + states_proxy = torch.zeros( + (states.shape[0], N_ELEMENTS_ORACLE + 1), + device=self.device, + dtype=torch.long, + ) + states_proxy[:, tlong(self.elements, device=self.device)] = states + return states_proxy def state2readable(self, state=None): """