Skip to content

Commit

Permalink
Update tests for MillerIndices
Browse files Browse the repository at this point in the history
  • Loading branch information
carriepl-mila committed Nov 20, 2023
1 parent e2a5710 commit fa079bf
Showing 1 changed file with 35 additions and 37 deletions.
72 changes: 35 additions & 37 deletions tests/gflownet/envs/test_miller.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@


@pytest.fixture
def cubic():
return MillerIndices(is_cubic=True)
def hexa_rhombo():
return MillerIndices(is_hexagonal_rhombohedral=True)


@pytest.fixture
def nocubic():
return MillerIndices(is_cubic=False)
def no_hexa_rhombo():
return MillerIndices(is_hexagonal_rhombohedral=False)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -39,46 +39,44 @@ def nocubic():
),
],
)
def test__state2oracle__cubic__returns_expected(cubic, state, state2oracle):
env = cubic
assert state2oracle == env.state2oracle(state)
def test__state2oracle__returns_expected(
hexa_rhombo, no_hexa_rhombo, state, state2oracle
):
assert state2oracle == hexa_rhombo.state2oracle(state)
assert state2oracle == no_hexa_rhombo.state2oracle(state)


@pytest.mark.parametrize(
"state, state2oracle",
"env_input, state, action, is_action_valid",
[
(
[0, 0, 0, 0],
[-2.0, -2.0, -2.0, -2.0],
),
(
[4, 4, 4, 4],
[2.0, 2.0, 2.0, 2.0],
),
(
[2, 2, 2, 2],
[0.0, 0.0, 0.0, 0.0],
),
(
[0, 2, 0, 2],
[-2.0, 0.0, -2.0, 0.0],
),
(
[2, 1, 3, 0],
[0.0, -1.0, 1.0, -2.0],
),
("no_hexa_rhombo", [0, 0, 0], (0, 0, 0), True),
("no_hexa_rhombo", [2, 2, 2], (0, 0, 0), True),
("no_hexa_rhombo", [4, 4, 4], (0, 0, 0), True),
("no_hexa_rhombo", [3, 3, 2], (1, 0, 0), True),
("no_hexa_rhombo", [3, 3, 2], (0, 1, 0), True),
("no_hexa_rhombo", [3, 3, 2], (0, 0, 1), True),
("hexa_rhombo", [0, 0, 0], (0, 0, 0), False),
("hexa_rhombo", [2, 2, 2], (0, 0, 0), True),
("hexa_rhombo", [4, 4, 4], (0, 0, 0), False),
("hexa_rhombo", [3, 3, 2], (1, 0, 0), False),
("hexa_rhombo", [3, 3, 2], (0, 1, 0), False),
("hexa_rhombo", [3, 3, 2], (0, 0, 1), True),
],
)
def test__state2oracle__nocubic__returns_expected(nocubic, state, state2oracle):
env = nocubic
assert state2oracle == env.state2oracle(state)
def test_get_mask_invalid_actions_forward__masks_expected_actions(
env_input, state, action, is_action_valid, request
):
env = request.getfixturevalue(env_input)
env.set_state(state, done=False)
_, _, valid = env.step(action)
assert is_action_valid == valid


def test__all_env_common__cubic(cubic):
print("\n\nCommon tests for cubic Miller indices\n")
return common.test__all_env_common(cubic)
def test__all_env_common__hexagonal_rhombohedral(hexa_rhombo):
print("\n\nCommon tests for hexagonal or rhombohedral Miller indices\n")
return common.test__all_env_common(hexa_rhombo)


def test__all_env_common__nocubic(nocubic):
print("\n\nCommon tests for hexagonal or rhombohedral Miller indices\n")
return common.test__all_env_common(nocubic)
def test__all_env_common__no_hexagonal_rhombohedral(no_hexa_rhombo):
print("\n\nCommon tests for non-{hexagonal, rhombohedral} Miller indices\n")
return common.test__all_env_common(no_hexa_rhombo)

0 comments on commit fa079bf

Please sign in to comment.