From fa079bfa37ca6357a6a44c25a6fe34a04ac15584 Mon Sep 17 00:00:00 2001 From: Pierre Luc Carrier Date: Mon, 20 Nov 2023 14:58:32 -0500 Subject: [PATCH] Update tests for MillerIndices --- tests/gflownet/envs/test_miller.py | 72 +++++++++++++++--------------- 1 file changed, 35 insertions(+), 37 deletions(-) diff --git a/tests/gflownet/envs/test_miller.py b/tests/gflownet/envs/test_miller.py index 2d5aa1ea6..9d3bd6fbc 100644 --- a/tests/gflownet/envs/test_miller.py +++ b/tests/gflownet/envs/test_miller.py @@ -5,13 +5,13 @@ @pytest.fixture -def cubic(): - return MillerIndices(is_cubic=True) +def hexa_rhombo(): + return MillerIndices(is_hexagonal_rhombohedral=True) @pytest.fixture -def nocubic(): - return MillerIndices(is_cubic=False) +def no_hexa_rhombo(): + return MillerIndices(is_hexagonal_rhombohedral=False) @pytest.mark.parametrize( @@ -39,46 +39,44 @@ def nocubic(): ), ], ) -def test__state2oracle__cubic__returns_expected(cubic, state, state2oracle): - env = cubic - assert state2oracle == env.state2oracle(state) +def test__state2oracle__returns_expected( + hexa_rhombo, no_hexa_rhombo, state, state2oracle +): + assert state2oracle == hexa_rhombo.state2oracle(state) + assert state2oracle == no_hexa_rhombo.state2oracle(state) @pytest.mark.parametrize( - "state, state2oracle", + "env_input, state, action, is_action_valid", [ - ( - [0, 0, 0, 0], - [-2.0, -2.0, -2.0, -2.0], - ), - ( - [4, 4, 4, 4], - [2.0, 2.0, 2.0, 2.0], - ), - ( - [2, 2, 2, 2], - [0.0, 0.0, 0.0, 0.0], - ), - ( - [0, 2, 0, 2], - [-2.0, 0.0, -2.0, 0.0], - ), - ( - [2, 1, 3, 0], - [0.0, -1.0, 1.0, -2.0], - ), + ("no_hexa_rhombo", [0, 0, 0], (0, 0, 0), True), + ("no_hexa_rhombo", [2, 2, 2], (0, 0, 0), True), + ("no_hexa_rhombo", [4, 4, 4], (0, 0, 0), True), + ("no_hexa_rhombo", [3, 3, 2], (1, 0, 0), True), + ("no_hexa_rhombo", [3, 3, 2], (0, 1, 0), True), + ("no_hexa_rhombo", [3, 3, 2], (0, 0, 1), True), + ("hexa_rhombo", [0, 0, 0], (0, 0, 0), False), + ("hexa_rhombo", [2, 2, 2], (0, 0, 0), True), + ("hexa_rhombo", [4, 4, 4], (0, 0, 0), False), + ("hexa_rhombo", [3, 3, 2], (1, 0, 0), False), + ("hexa_rhombo", [3, 3, 2], (0, 1, 0), False), + ("hexa_rhombo", [3, 3, 2], (0, 0, 1), True), ], ) -def test__state2oracle__nocubic__returns_expected(nocubic, state, state2oracle): - env = nocubic - assert state2oracle == env.state2oracle(state) +def test_get_mask_invalid_actions_forward__masks_expected_actions( + env_input, state, action, is_action_valid, request +): + env = request.getfixturevalue(env_input) + env.set_state(state, done=False) + _, _, valid = env.step(action) + assert is_action_valid == valid -def test__all_env_common__cubic(cubic): - print("\n\nCommon tests for cubic Miller indices\n") - return common.test__all_env_common(cubic) +def test__all_env_common__hexagonal_rhombohedral(hexa_rhombo): + print("\n\nCommon tests for hexagonal or rhombohedral Miller indices\n") + return common.test__all_env_common(hexa_rhombo) -def test__all_env_common__nocubic(nocubic): - print("\n\nCommon tests for hexagonal or rhombohedral Miller indices\n") - return common.test__all_env_common(nocubic) +def test__all_env_common__no_hexagonal_rhombohedral(no_hexa_rhombo): + print("\n\nCommon tests for non-{hexagonal, rhombohedral} Miller indices\n") + return common.test__all_env_common(no_hexa_rhombo)