Skip to content

Commit

Permalink
fix nequip tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jla-gardner committed Jul 10, 2024
1 parent 954ec54 commit a296fd7
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 7 deletions.
3 changes: 2 additions & 1 deletion src/graph_pes/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,8 @@ def register_Zs(self, Zs: list[int]):
unique_Zs = sorted(set(Zs))
if len(unique_Zs) != self.n_elements:
raise ValueError(
f"Expected {self.n_elements} elements, got {unique_Zs}"
f"Expected {self.n_elements} elements, got "
f"{len(unique_Zs)}: {unique_Zs}"
)

for i, Z in enumerate(unique_Zs):
Expand Down
13 changes: 7 additions & 6 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
number_of_atoms,
number_of_edges,
)
from graph_pes.models import ALL_MODELS, LennardJones, Morse
from graph_pes.models import ALL_MODELS, LennardJones, Morse, NequIP

structures: list[Atoms] = read("tests/test.xyz", ":") # type: ignore
graphs = to_atomic_graphs(structures, cutoff=3)
Expand Down Expand Up @@ -65,17 +65,18 @@ def test_pre_fit():
ids=[m.__name__ for m in ALL_MODELS],
)
def test_model_serialisation(model: type[GraphPESModel], tmp_path):
m = model()
m.pre_fit(graphs)
kwargs = {} if model is not NequIP else {"n_elements": 1}
m1 = model(**kwargs)
m1.pre_fit(graphs)

torch.save(m.state_dict(), tmp_path / "model.pt")
torch.save(m1.state_dict(), tmp_path / "model.pt")

m2 = model()
m2 = model(**kwargs)
# check no errors occur
m2.load_state_dict(torch.load(tmp_path / "model.pt"))

# check predictions are the same
assert torch.allclose(m(graphs[0]), m2(graphs[0]))
assert torch.allclose(m1(graphs[0]), m2(graphs[0]))


def test_addition():
Expand Down

0 comments on commit a296fd7

Please sign in to comment.