Skip to content

Commit

Permalink
Fix crystal tests
Browse files Browse the repository at this point in the history
  • Loading branch information
carriepl-mila committed Oct 24, 2023
1 parent 80a32cf commit 229bed5
Showing 1 changed file with 110 additions and 6 deletions.
116 changes: 110 additions & 6 deletions tests/gflownet/envs/test_crystal.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,47 @@ def test__pad_depad_action(env):
[
[
(2, 1, 1, 1, 1, 1, 2, 3, 1, 2, 3, 4, 5, 6),
Tensor([1.0, 1.0, 1.0, 1.0, 3.0, 1.4, 1.8, 2.2, 78.0, 90.0, 102.0]),
Tensor(
[
# fmt: off
# Composition state
0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0,
# Spacegroup state
3.0,
# Lattice parameter state
1.4, 1.8, 2.2, 78.0, 90.0, 102.0,
# fmt: on
]
),
],
[
(2, 4, 9, 0, 3, 0, 0, 105, 5, 3, 1, 0, 0, 9),
Tensor([4.0, 9.0, 0.0, 3.0, 105.0, 3.0, 2.2, 1.4, 30.0, 30.0, 138.0]),
Tensor(
[
# fmt: off
# Composition state
0.0, 4.0, 9.0, 0.0, 3.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0,
# Spacegroup state
105.0,
# Lattice parameter state
3.0, 2.2, 1.4, 30.0, 30.0, 138.0,
# fmt: on
]
),
],
],
)
Expand All @@ -83,11 +119,47 @@ def test__state2oracle__returns_expected_value(env, state, expected):
[
[
(2, 1, 1, 1, 1, 1, 2, 3, 1, 2, 3, 4, 5, 6),
Tensor([1.0, 1.0, 1.0, 1.0, 3.0, 1.4, 1.8, 2.2, 78.0, 90.0, 102.0]),
Tensor(
[
# fmt: off
# Composition state
0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0,
# Spacegroup state
3.0,
# Lattice parameter state
1.4, 1.8, 2.2, 78.0, 90.0, 102.0,
# fmt: on
]
),
],
[
(2, 4, 9, 0, 3, 0, 0, 105, 5, 3, 1, 0, 0, 9),
Tensor([4.0, 9.0, 0.0, 3.0, 105.0, 3.0, 2.2, 1.4, 30.0, 30.0, 138.0]),
Tensor(
[
# fmt: off
# Composition state
0.0, 4.0, 9.0, 0.0, 3.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0,
# Spacegroup state
105.0,
# Lattice parameter state
3.0, 2.2, 1.4, 30.0, 30.0, 138.0,
# fmt: on
]
),
],
],
)
Expand All @@ -105,8 +177,40 @@ def test__state2proxy__returns_expected_value(env, state, expected):
],
Tensor(
[
[1.0, 1.0, 1.0, 1.0, 3.0, 1.4, 1.8, 2.2, 78.0, 90.0, 102.0],
[4.0, 9.0, 0.0, 3.0, 105.0, 3.0, 2.2, 1.4, 30.0, 30.0, 138.0],
[
# fmt: off
# Composition state
0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0,
# Spacegroup state
3.0,
# Lattice parameter state
1.4, 1.8, 2.2, 78.0, 90.0, 102.0,
# fmt: on
],
[
# fmt: off
# Composition state
0.0, 4.0, 9.0, 0.0, 3.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0,
# Spacegroup state
105.0,
# Lattice parameter state
3.0, 2.2, 1.4, 30.0, 30.0, 138.0,
# fmt: on
],
]
),
],
Expand Down

0 comments on commit 229bed5

Please sign in to comment.