Skip to content

Commit

Permalink
Merge pull request #246 from alexhernandezgarcia/ccrystal_fix_tests
Browse files Browse the repository at this point in the history
Ccrystal fix tests
  • Loading branch information
alexhernandezgarcia authored Oct 27, 2023
2 parents f4464ae + 229bed5 commit 01a1c08
Show file tree
Hide file tree
Showing 4 changed files with 147 additions and 15 deletions.
8 changes: 5 additions & 3 deletions gflownet/envs/crystals/composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,8 +425,8 @@ def state2oracle(self, state: List = None) -> Tensor:
if state is None:
state = self.state
return self.statetorch2oracle(
torch.unsqueeze(tfloat(states, device=self.device), 0)
)
torch.unsqueeze(tfloat(state, device=self.device, float_type=self.float), 0)
)[0]

def statetorch2oracle(
self, states: TensorType["batch", "state_dim"]
Expand All @@ -445,12 +445,14 @@ def statetorch2oracle(
----
oracle_states : Tensor
"""
states_float = states.to(self.float)

states_oracle = torch.zeros(
(states.shape[0], N_ELEMENTS_ORACLE + 1),
device=self.device,
dtype=self.float,
)
states_oracle[:, tlong(self.elements, device=self.device)] = states
states_oracle[:, tlong(self.elements, device=self.device)] = states_float
return states_oracle

def statebatch2oracle(
Expand Down
8 changes: 5 additions & 3 deletions tests/gflownet/envs/test_ccrystal.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
def env():
return CCrystal(
composition_kwargs={"elements": 4},
do_composition_to_sg_constraints=False,
space_group_kwargs={"space_groups_subset": list(range(1, 15 + 1)) + [105]},
)

Expand All @@ -46,7 +47,7 @@ def env():
def env_with_stoichiometry_sg_check():
return CCrystal(
composition_kwargs={"elements": 4},
do_stoichiometry_sg_check=True,
do_composition_to_sg_constraints=True,
space_group_kwargs={"space_groups_subset": SG_SUBSET_ALL_CLS_PS},
)

Expand Down Expand Up @@ -341,7 +342,7 @@ def test__state_of_subenv__returns_expected(
),
(
"env_with_stoichiometry_sg_check",
[2, 1, 0, 4, 0, 4, 3, 105, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
[2, 4, 0, 4, 0, 4, 3, 105, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
[True, True, False],
True,
True,
Expand Down Expand Up @@ -381,10 +382,11 @@ def test__set_state__sets_state_subenvs_dones_and_constraints(

# Check composition constraints
if has_composition_constraints:
n_atoms = [n for n in env.subenvs[Stage.COMPOSITION].state if n > 0]
n_atoms_compatibility_dict = env.subenvs[
Stage.SPACE_GROUP
].build_n_atoms_compatibility_dict(
env.subenvs[Stage.COMPOSITION].state,
n_atoms,
env.subenvs[Stage.SPACE_GROUP].space_groups.keys(),
)
assert (
Expand Down
30 changes: 27 additions & 3 deletions tests/gflownet/envs/test_composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,39 @@ def test__environment__initializes_properly(elements):
[
(
[0, 0, 2, 0],
[0, 0, 2, 0],
[
# fmt: off
0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
# fmt: on
],
),
(
[3, 0, 0, 0],
[3, 0, 0, 0],
[
# fmt: off
0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
# fmt: on
],
),
(
[0, 1, 0, 1],
[0, 1, 0, 1],
[
# fmt: off
0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
# fmt: on
],
),
],
)
Expand Down
116 changes: 110 additions & 6 deletions tests/gflownet/envs/test_crystal.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,47 @@ def test__pad_depad_action(env):
[
[
(2, 1, 1, 1, 1, 1, 2, 3, 1, 2, 3, 4, 5, 6),
Tensor([1.0, 1.0, 1.0, 1.0, 3.0, 1.4, 1.8, 2.2, 78.0, 90.0, 102.0]),
Tensor(
[
# fmt: off
# Composition state
0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0,
# Spacegroup state
3.0,
# Lattice parameter state
1.4, 1.8, 2.2, 78.0, 90.0, 102.0,
# fmt: on
]
),
],
[
(2, 4, 9, 0, 3, 0, 0, 105, 5, 3, 1, 0, 0, 9),
Tensor([4.0, 9.0, 0.0, 3.0, 105.0, 3.0, 2.2, 1.4, 30.0, 30.0, 138.0]),
Tensor(
[
# fmt: off
# Composition state
0.0, 4.0, 9.0, 0.0, 3.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0,
# Spacegroup state
105.0,
# Lattice parameter state
3.0, 2.2, 1.4, 30.0, 30.0, 138.0,
# fmt: on
]
),
],
],
)
Expand All @@ -83,11 +119,47 @@ def test__state2oracle__returns_expected_value(env, state, expected):
[
[
(2, 1, 1, 1, 1, 1, 2, 3, 1, 2, 3, 4, 5, 6),
Tensor([1.0, 1.0, 1.0, 1.0, 3.0, 1.4, 1.8, 2.2, 78.0, 90.0, 102.0]),
Tensor(
[
# fmt: off
# Composition state
0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0,
# Spacegroup state
3.0,
# Lattice parameter state
1.4, 1.8, 2.2, 78.0, 90.0, 102.0,
# fmt: on
]
),
],
[
(2, 4, 9, 0, 3, 0, 0, 105, 5, 3, 1, 0, 0, 9),
Tensor([4.0, 9.0, 0.0, 3.0, 105.0, 3.0, 2.2, 1.4, 30.0, 30.0, 138.0]),
Tensor(
[
# fmt: off
# Composition state
0.0, 4.0, 9.0, 0.0, 3.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0,
# Spacegroup state
105.0,
# Lattice parameter state
3.0, 2.2, 1.4, 30.0, 30.0, 138.0,
# fmt: on
]
),
],
],
)
Expand All @@ -105,8 +177,40 @@ def test__state2proxy__returns_expected_value(env, state, expected):
],
Tensor(
[
[1.0, 1.0, 1.0, 1.0, 3.0, 1.4, 1.8, 2.2, 78.0, 90.0, 102.0],
[4.0, 9.0, 0.0, 3.0, 105.0, 3.0, 2.2, 1.4, 30.0, 30.0, 138.0],
[
# fmt: off
# Composition state
0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0,
# Spacegroup state
3.0,
# Lattice parameter state
1.4, 1.8, 2.2, 78.0, 90.0, 102.0,
# fmt: on
],
[
# fmt: off
# Composition state
0.0, 4.0, 9.0, 0.0, 3.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0,
# Spacegroup state
105.0,
# Lattice parameter state
3.0, 2.2, 1.4, 30.0, 30.0, 138.0,
# fmt: on
],
]
),
],
Expand Down

0 comments on commit 01a1c08

Please sign in to comment.