Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jla-gardner committed Mar 4, 2024
1 parent 058360b commit 6ad2093
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 8 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ jobs:
- uses: actions/checkout@v3
- uses: actions/setup-python@v3
with:
cache: pip
cache-dependency-path: pyproject.toml
python-version: 3.9
- name: update pip
run: pip install --upgrade pip
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ dependencies = [
"ase",
"numpy",
"jaxtyping",
"rich",
]
requires-python = ">=3.8"

Expand Down
11 changes: 11 additions & 0 deletions src/graph_pes/models/zoo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from .painn import PaiNN
from .pairwise import LennardJones
from .schnet import SchNet
from .tensornet import TensorNet

__all__ = [
"PaiNN",
"LennardJones",
"SchNet",
"TensorNet",
]
15 changes: 7 additions & 8 deletions tests/test_torchscript.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,11 @@
import pytest
import torch
from ase.build import molecule
from graph_pes.data import convert_to_atomic_graph
from graph_pes.models.painn import PaiNN
from graph_pes.models.pairwise import LennardJones
from graph_pes.models.schnet import SchNet
from graph_pes.models.tensornet import TensorNet
from graph_pes.data import batch_graphs, convert_to_atomic_graph
from graph_pes.models.zoo import LennardJones, PaiNN, SchNet, TensorNet

graph = convert_to_atomic_graph(molecule("CH3CH2OH"), cutoff=1.5)
batch = batch_graphs([graph, graph])
models = [
LennardJones(),
PaiNN(),
Expand All @@ -24,9 +22,10 @@
ids=[model.__class__.__name__ for model in models],
)
def test_model(model):
actual_energy = model(graph)
actual_energies = model(batch)
assert actual_energies.shape == (2,)

scripted_model: torch.jit.ScriptModule = torch.jit.script(model) # type: ignore
scripted_energy = scripted_model(graph)
scripted_energy = scripted_model(batch)

assert torch.allclose(actual_energy, scripted_energy)
assert torch.allclose(actual_energies, scripted_energy)

0 comments on commit 6ad2093

Please sign in to comment.