From f886d58a48f9f7db1bd332ab1c94da8253f4a209 Mon Sep 17 00:00:00 2001 From: Pierre Luc Carrier Date: Wed, 18 Oct 2023 14:13:55 -0400 Subject: [PATCH] Add tests for space group restriction --- tests/gflownet/envs/test_spacegroup.py | 61 +++++++++++++++++++++----- 1 file changed, 49 insertions(+), 12 deletions(-) diff --git a/tests/gflownet/envs/test_spacegroup.py b/tests/gflownet/envs/test_spacegroup.py index 50e82d61c..049d081c5 100644 --- a/tests/gflownet/envs/test_spacegroup.py +++ b/tests/gflownet/envs/test_spacegroup.py @@ -33,18 +33,55 @@ def test__environment__initializes_properly(): def test__environment__space_groups_subset__initializes_properly(): - env_sg_subset = SpaceGroup(space_groups_subset=[1, 2]) - assert env_sg_subset.source == [0] * 3 - assert env_sg_subset.state == [0] * 3 - assert len(env_sg_subset.space_groups) == 2 - env_sg_subset = SpaceGroup(space_groups_subset=range(1, 15 + 1)) - assert env_sg_subset.source == [0] * 3 - assert env_sg_subset.state == [0] * 3 - assert len(env_sg_subset.space_groups) == 15 - env_sg_subset = SpaceGroup(space_groups_subset=SG_SUBSET) - assert env_sg_subset.source == [0] * 3 - assert env_sg_subset.state == [0] * 3 - assert len(env_sg_subset.space_groups) == len(SG_SUBSET) + def count_distinct(my_dict, sub_key): + all_elements = [] + for sub_dict in my_dict.values(): + all_elements.extend(sub_dict[sub_key]) + + distinct_elements = set(all_elements) + return len(distinct_elements) + + env = SpaceGroup(space_groups_subset=[1, 2]) + nb_spacegroups = 2 + nb_cls = 1 + nb_ps = 2 + assert env.source == [0] * 3 + assert env.state == [0] * 3 + assert len(env.space_groups) == nb_spacegroups + assert len(env.crystal_lattice_systems) == nb_cls + assert len(env.point_symmetries) == nb_ps + assert count_distinct(env.crystal_lattice_systems, "space_groups") == nb_spacegroups + assert count_distinct(env.crystal_lattice_systems, "point_symmetries") == nb_ps + assert count_distinct(env.point_symmetries, "space_groups") == nb_spacegroups + assert count_distinct(env.point_symmetries, "crystal_lattice_systems") == nb_cls + + env = SpaceGroup(space_groups_subset=range(1, 15 + 1)) + nb_spacegroups = 15 + nb_cls = 2 + nb_ps = 3 + assert env.source == [0] * 3 + assert env.state == [0] * 3 + assert len(env.space_groups) == nb_spacegroups + assert len(env.crystal_lattice_systems) == nb_cls + assert len(env.point_symmetries) == nb_ps + assert count_distinct(env.crystal_lattice_systems, "space_groups") == nb_spacegroups + assert count_distinct(env.crystal_lattice_systems, "point_symmetries") == nb_ps + assert count_distinct(env.point_symmetries, "space_groups") == nb_spacegroups + assert count_distinct(env.point_symmetries, "crystal_lattice_systems") == nb_cls + + env = SpaceGroup(space_groups_subset=SG_SUBSET) + nb_spacegroups = len(SG_SUBSET) + nb_cls = 4 + nb_ps = 4 + assert env.source == [0] * 3 + assert env.state == [0] * 3 + assert len(env.space_groups) == nb_spacegroups + assert len(env.crystal_lattice_systems) == nb_cls + assert len(env.point_symmetries) == nb_ps + assert count_distinct(env.crystal_lattice_systems, "space_groups") == nb_spacegroups + assert count_distinct(env.crystal_lattice_systems, "point_symmetries") == nb_ps + assert count_distinct(env.point_symmetries, "space_groups") == nb_spacegroups + assert count_distinct(env.point_symmetries, "crystal_lattice_systems") == nb_cls def test__environment__action_space_has_eos():