diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 8582d32e..4d22a032 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -9,6 +9,8 @@ jobs: - uses: actions/checkout@v3 - uses: actions/setup-python@v3 with: + cache: pip + cache-dependency-path: pyproject.toml python-version: 3.9 - name: update pip run: pip install --upgrade pip diff --git a/pyproject.toml b/pyproject.toml index 0bccdb58..42d4b4c9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,6 +22,7 @@ dependencies = [ "ase", "numpy", "jaxtyping", + "rich", ] requires-python = ">=3.8" diff --git a/src/graph_pes/models/zoo.py b/src/graph_pes/models/zoo.py new file mode 100644 index 00000000..c40a1133 --- /dev/null +++ b/src/graph_pes/models/zoo.py @@ -0,0 +1,11 @@ +from .painn import PaiNN +from .pairwise import LennardJones +from .schnet import SchNet +from .tensornet import TensorNet + +__all__ = [ + "PaiNN", + "LennardJones", + "SchNet", + "TensorNet", +] diff --git a/tests/test_torchscript.py b/tests/test_torchscript.py index da5e8306..e2e47102 100644 --- a/tests/test_torchscript.py +++ b/tests/test_torchscript.py @@ -3,13 +3,11 @@ import pytest import torch from ase.build import molecule -from graph_pes.data import convert_to_atomic_graph -from graph_pes.models.painn import PaiNN -from graph_pes.models.pairwise import LennardJones -from graph_pes.models.schnet import SchNet -from graph_pes.models.tensornet import TensorNet +from graph_pes.data import batch_graphs, convert_to_atomic_graph +from graph_pes.models.zoo import LennardJones, PaiNN, SchNet, TensorNet graph = convert_to_atomic_graph(molecule("CH3CH2OH"), cutoff=1.5) +batch = batch_graphs([graph, graph]) models = [ LennardJones(), PaiNN(), @@ -24,9 +22,10 @@ ids=[model.__class__.__name__ for model in models], ) def test_model(model): - actual_energy = model(graph) + actual_energies = model(batch) + assert actual_energies.shape == (2,) scripted_model: torch.jit.ScriptModule = torch.jit.script(model) # type: ignore - scripted_energy = scripted_model(graph) + scripted_energy = scripted_model(batch) - assert torch.allclose(actual_energy, scripted_energy) + assert torch.allclose(actual_energies, scripted_energy)