Skip to content

Commit

Permalink
Add tests for space group restriction
Browse files Browse the repository at this point in the history
  • Loading branch information
carriepl-mila committed Oct 18, 2023
1 parent 0a10e03 commit f886d58
Showing 1 changed file with 49 additions and 12 deletions.
61 changes: 49 additions & 12 deletions tests/gflownet/envs/test_spacegroup.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,18 +33,55 @@ def test__environment__initializes_properly():


def test__environment__space_groups_subset__initializes_properly():
env_sg_subset = SpaceGroup(space_groups_subset=[1, 2])
assert env_sg_subset.source == [0] * 3
assert env_sg_subset.state == [0] * 3
assert len(env_sg_subset.space_groups) == 2
env_sg_subset = SpaceGroup(space_groups_subset=range(1, 15 + 1))
assert env_sg_subset.source == [0] * 3
assert env_sg_subset.state == [0] * 3
assert len(env_sg_subset.space_groups) == 15
env_sg_subset = SpaceGroup(space_groups_subset=SG_SUBSET)
assert env_sg_subset.source == [0] * 3
assert env_sg_subset.state == [0] * 3
assert len(env_sg_subset.space_groups) == len(SG_SUBSET)
def count_distinct(my_dict, sub_key):
all_elements = []
for sub_dict in my_dict.values():
all_elements.extend(sub_dict[sub_key])

distinct_elements = set(all_elements)
return len(distinct_elements)

env = SpaceGroup(space_groups_subset=[1, 2])
nb_spacegroups = 2
nb_cls = 1
nb_ps = 2
assert env.source == [0] * 3
assert env.state == [0] * 3
assert len(env.space_groups) == nb_spacegroups
assert len(env.crystal_lattice_systems) == nb_cls
assert len(env.point_symmetries) == nb_ps
assert count_distinct(env.crystal_lattice_systems, "space_groups") == nb_spacegroups
assert count_distinct(env.crystal_lattice_systems, "point_symmetries") == nb_ps
assert count_distinct(env.point_symmetries, "space_groups") == nb_spacegroups
assert count_distinct(env.point_symmetries, "crystal_lattice_systems") == nb_cls

env = SpaceGroup(space_groups_subset=range(1, 15 + 1))
nb_spacegroups = 15
nb_cls = 2
nb_ps = 3
assert env.source == [0] * 3
assert env.state == [0] * 3
assert len(env.space_groups) == nb_spacegroups
assert len(env.crystal_lattice_systems) == nb_cls
assert len(env.point_symmetries) == nb_ps
assert count_distinct(env.crystal_lattice_systems, "space_groups") == nb_spacegroups
assert count_distinct(env.crystal_lattice_systems, "point_symmetries") == nb_ps
assert count_distinct(env.point_symmetries, "space_groups") == nb_spacegroups
assert count_distinct(env.point_symmetries, "crystal_lattice_systems") == nb_cls

env = SpaceGroup(space_groups_subset=SG_SUBSET)
nb_spacegroups = len(SG_SUBSET)
nb_cls = 4
nb_ps = 4
assert env.source == [0] * 3
assert env.state == [0] * 3
assert len(env.space_groups) == nb_spacegroups
assert len(env.crystal_lattice_systems) == nb_cls
assert len(env.point_symmetries) == nb_ps
assert count_distinct(env.crystal_lattice_systems, "space_groups") == nb_spacegroups
assert count_distinct(env.crystal_lattice_systems, "point_symmetries") == nb_ps
assert count_distinct(env.point_symmetries, "space_groups") == nb_spacegroups
assert count_distinct(env.point_symmetries, "crystal_lattice_systems") == nb_cls


def test__environment__action_space_has_eos():
Expand Down

0 comments on commit f886d58

Please sign in to comment.