Skip to content

Commit

Permalink
add tests for deploy (#7)
Browse files Browse the repository at this point in the history
* add test

* fix pairwise models

* use `e3nn` jit

* don't use non ascii chars

* tests

* fix nequip tests

* improve test

* test coverage
  • Loading branch information
jla-gardner authored Jul 11, 2024
1 parent d85cf6c commit da7f627
Show file tree
Hide file tree
Showing 10 changed files with 200 additions and 55 deletions.
16 changes: 14 additions & 2 deletions src/graph_pes/deploy.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from contextlib import contextmanager
from __future__ import annotations

from pathlib import Path

import e3nn
import torch

from graph_pes.core import GraphPESModel
Expand Down Expand Up @@ -60,7 +63,7 @@ def forward(self, graph: AtomicGraph) -> dict[str, torch.Tensor]:
graph[keys._POSITIONS].requires_grad_(True)
change_to_cell.requires_grad_(True)

local_energies = self.model.predict_local_energies(graph)
local_energies = self.model.predict_local_energies(graph).squeeze()
props["local_energies"] = local_energies
total_energy = torch.sum(local_energies)
props["total_energy"] = total_energy
Expand All @@ -76,3 +79,12 @@ def forward(self, graph: AtomicGraph) -> dict[str, torch.Tensor]:
for key in props:
props[key] = props[key].double()
return props

def __call__(self, graph: AtomicGraph) -> dict[str, torch.Tensor]:
return super().__call__(graph)


def deploy_model(model: GraphPESModel, cutoff: float, path: str | Path):
lammps_model = LAMMPSModel(model, cutoff)
scripted_model = e3nn.util.jit.script(lammps_model)
torch.jit.save(scripted_model, path)
2 changes: 2 additions & 0 deletions src/graph_pes/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from graph_pes.core import GraphPESModel

from .e3nn.nequip import NequIP
from .offsets import FixedOffset, LearnableOffset
from .painn import PaiNN
from .pairwise import LennardJones, LennardJonesMixture, Morse
Expand All @@ -19,6 +20,7 @@
"TensorNet",
"Morse",
"LennardJonesMixture",
"NequIP",
"FixedOffset",
"LearnableOffset",
"load_model",
Expand Down
4 changes: 2 additions & 2 deletions src/graph_pes/models/distances.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ def __init__(self, n_features: int, cutoff: float, trainable: bool = True):
super().__init__(n_features, cutoff, trainable)

c = torch.exp(-torch.tensor(cutoff))
self.β = nn.Parameter(
self.beta = nn.Parameter(
torch.ones(n_features) / (2 * (1 - c) / n_features) ** 2,
requires_grad=trainable,
)
Expand All @@ -329,7 +329,7 @@ def __init__(self, n_features: int, cutoff: float, trainable: bool = True):

def expand(self, r: torch.Tensor) -> torch.Tensor:
offsets = torch.exp(-r) - self.centers
return torch.exp(-self.β * offsets**2)
return torch.exp(-self.beta * offsets**2)


class Envelope(nn.Module):
Expand Down
38 changes: 21 additions & 17 deletions src/graph_pes/models/painn.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def __init__(
PolynomialEnvelope(cutoff),
)

self.φ = MLP(
self.Phi = MLP(
[internal_dim, internal_dim, internal_dim * 3],
activation=nn.SiLU(),
)
Expand All @@ -74,23 +74,25 @@ def forward(
unit_vectors = neighbour_vectors(graph) / d # (E, 3)

# continous filter message creation
x_ij = self.filter_generator(d) * self.φ(scalar_embeddings)[neighbours]
x_ij = (
self.filter_generator(d) * self.Phi(scalar_embeddings)[neighbours]
)
a, b, c = torch.split(x_ij, self.internal_dim, dim=-1) # (E, D)

# simple sum over neighbours to get scalar messages
Δs = torch.zeros_like(scalar_embeddings) # (N, D)
Δs.scatter_add_(0, neighbours.view(-1, 1).expand_as(a), a)
delta_s = torch.zeros_like(scalar_embeddings) # (N, D)
delta_s.scatter_add_(0, neighbours.view(-1, 1).expand_as(a), a)

# create vector messages
v_ij = b.unsqueeze(-1) * unit_vectors.unsqueeze(1) # (E, D, 3)
v_ij = v_ij + c.unsqueeze(-1) * vector_embeddings[neighbours]

Δv = torch.zeros_like(vector_embeddings) # (N, D, 3)
Δv.scatter_add_(
delta_v = torch.zeros_like(vector_embeddings) # (N, D, 3)
delta_v.scatter_add_(
0, neighbours.unsqueeze(-1).unsqueeze(-1).expand_as(v_ij), v_ij
)

return Δv, Δs # (N, D, 3), (N, D)
return delta_v, delta_s # (N, D, 3), (N, D)


class VectorLinear(nn.Module):
Expand Down Expand Up @@ -144,13 +146,13 @@ def forward(
a, b, c = torch.split(m, self.internal_dim, dim=-1) # (N, D)

# vector update:
Δv = u * a.unsqueeze(-1) # (N, D, 3)
delta_v = u * a.unsqueeze(-1) # (N, D, 3)

# scalar update:
dot = torch.sum(u * v, dim=-1) # (N, D)
Δs = b + c * dot # (N, D)
delta_s = b + c * dot # (N, D)

return Δv, Δs
return delta_v, delta_s


class PaiNN(GraphPESModel):
Expand Down Expand Up @@ -217,12 +219,14 @@ def predict_local_energies(self, graph: AtomicGraph) -> Tensor:
scalar_embeddings = self.z_embedding(graph["atomic_numbers"])

for interaction, update in zip(self.interactions, self.updates):
Δv, Δs = interaction(vector_embeddings, scalar_embeddings, graph)
vector_embeddings = vector_embeddings + Δv
scalar_embeddings = scalar_embeddings + Δs

Δv, Δs = update(vector_embeddings, scalar_embeddings)
vector_embeddings = vector_embeddings + Δv
scalar_embeddings = scalar_embeddings + Δs
delta_v, delta_s = interaction(
vector_embeddings, scalar_embeddings, graph
)
vector_embeddings = vector_embeddings + delta_v
scalar_embeddings = scalar_embeddings + delta_s

delta_v, delta_s = update(vector_embeddings, scalar_embeddings)
vector_embeddings = vector_embeddings + delta_v
scalar_embeddings = scalar_embeddings + delta_s

return self.read_out(scalar_embeddings)
17 changes: 13 additions & 4 deletions src/graph_pes/models/pairwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,8 @@ def sigma(self):
def interaction(
self,
r: torch.Tensor,
Z_i: Optional[torch.Tensor] = None,
Z_j: Optional[torch.Tensor] = None,
Z_i: Optional[torch.Tensor] = None, # noqa: UP007
Z_j: Optional[torch.Tensor] = None, # noqa: UP007
):
"""
Evaluate the pair potential.
Expand Down Expand Up @@ -226,7 +226,12 @@ def a(self):
def r0(self):
return self._log_r0.exp()

def interaction(self, r: torch.Tensor, Z_i=None, Z_j=None):
def interaction(
self,
r: torch.Tensor,
Z_i: Optional[torch.Tensor] = None, # noqa: UP007
Z_j: Optional[torch.Tensor] = None, # noqa: UP007
):
"""
Evaluate the pair potential.
Expand Down Expand Up @@ -313,7 +318,11 @@ def interaction(self, r: Tensor, Z_i: Tensor, Z_j: Tensor) -> Tensor:

sigma_j = self.sigma[Z_j].squeeze() # (E)
sigma_i = self.sigma[Z_i].squeeze() # (E)
nu = self.nu[Z_i, Z_j].squeeze() if self.modulate_distances else 1
nu = (
self.nu[Z_i, Z_j].squeeze()
if self.modulate_distances
else torch.tensor(1)
)
sigma = torch.where(
cross_interaction,
nu * (sigma_i + sigma_j) / 2,
Expand Down
3 changes: 2 additions & 1 deletion src/graph_pes/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,8 @@ def register_Zs(self, Zs: list[int]):
unique_Zs = sorted(set(Zs))
if len(unique_Zs) != self.n_elements:
raise ValueError(
f"Expected {self.n_elements} elements, got {unique_Zs}"
f"Expected {self.n_elements} elements, got "
f"{len(unique_Zs)}: {unique_Zs}"
)

for i, Z in enumerate(unique_Zs):
Expand Down
62 changes: 62 additions & 0 deletions tests/test_deploy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from __future__ import annotations

from pathlib import Path

import pytest
import torch
from ase.build import molecule
from graph_pes.core import GraphPESModel
from graph_pes.data.io import to_atomic_graph
from graph_pes.deploy import deploy_model
from graph_pes.graphs.operations import number_of_atoms
from graph_pes.models import ALL_MODELS, NequIP

CUTOFF = 1.5
graph = to_atomic_graph(molecule("CH3CH2OH"), cutoff=CUTOFF)


@pytest.mark.parametrize(
"model_klass",
ALL_MODELS,
ids=[model.__name__ for model in ALL_MODELS],
)
def test_deploy(model_klass: type[GraphPESModel], tmp_path: Path):
# 1. instantiate the model
kwargs = {"n_elements": 3} if model_klass is NequIP else {}
model = model_klass(**kwargs)
model.pre_fit([graph]) # required by some models before making predictions

# 2. deploy the model
save_path = tmp_path / "model.pt"
deploy_model(model, cutoff=CUTOFF, path=save_path)

# 3. load the model back in
loaded_model = torch.jit.load(save_path)
assert isinstance(loaded_model, torch.jit.ScriptModule)
assert loaded_model.get_cutoff() == CUTOFF

# 4. test outputs
outputs = loaded_model(
# mock the graph that would be passed through from LAMMPS
{
**graph,
"compute_virial": torch.tensor(True),
"debug": torch.tensor(False),
}
)
assert isinstance(outputs, dict)
assert set(outputs.keys()) == {
"total_energy",
"local_energies",
"forces",
"virial",
}
assert outputs["total_energy"].shape == torch.Size([])
assert outputs["local_energies"].shape == (number_of_atoms(graph),)
assert outputs["forces"].shape == graph["_positions"].shape
assert outputs["virial"].shape == (3, 3)

# 5. test that the deployment process hasn't changed the model's predictions
with torch.no_grad():
original_energy = model(graph).double()
assert torch.allclose(original_energy, outputs["total_energy"])
49 changes: 49 additions & 0 deletions tests/test_lammps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import pytest
import torch
from ase.build import molecule
from graph_pes.core import get_predictions
from graph_pes.data.io import to_atomic_graph
from graph_pes.deploy import LAMMPSModel
from graph_pes.graphs import keys
from graph_pes.models import LennardJones


@pytest.mark.parametrize(
"compute_virial",
[True, False],
)
def test_lammps_model(compute_virial: bool):
# generate a structure
structure = molecule("CH3CH2OH")
if compute_virial:
# ensure the structure has a cell
structure.center(vacuum=5.0)
graph = to_atomic_graph(structure, cutoff=1.5)

# create a normal model, and get normal predictions
model = LennardJones()
props: list[keys.LabelKey] = ["energy", "forces"]
if compute_virial:
props.append("stress")
outputs = get_predictions(model, graph, properties=props, training=False)

# create a LAMMPS model, and get LAMMPS predictions
lammps_model = LAMMPSModel(model)
lammps_graph: dict[str, torch.Tensor] = {
**graph,
"compute_virial": torch.tensor(compute_virial),
"debug": torch.tensor(False),
} # type: ignore
lammps_outputs = lammps_model(lammps_graph)

# check outputs
if compute_virial:
assert "virial" in lammps_outputs
assert (
outputs["stress"].shape == lammps_outputs["virial"].shape == (3, 3)
)

assert torch.allclose(
outputs["energy"].float(),
lammps_outputs["total_energy"].float(),
)
13 changes: 7 additions & 6 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
number_of_atoms,
number_of_edges,
)
from graph_pes.models import ALL_MODELS, LennardJones, Morse
from graph_pes.models import ALL_MODELS, LennardJones, Morse, NequIP

structures: list[Atoms] = read("tests/test.xyz", ":") # type: ignore
graphs = to_atomic_graphs(structures, cutoff=3)
Expand Down Expand Up @@ -65,17 +65,18 @@ def test_pre_fit():
ids=[m.__name__ for m in ALL_MODELS],
)
def test_model_serialisation(model: type[GraphPESModel], tmp_path):
m = model()
m.pre_fit(graphs)
kwargs = {} if model is not NequIP else {"n_elements": 1}
m1 = model(**kwargs)
m1.pre_fit(graphs)

torch.save(m.state_dict(), tmp_path / "model.pt")
torch.save(m1.state_dict(), tmp_path / "model.pt")

m2 = model()
m2 = model(**kwargs)
# check no errors occur
m2.load_state_dict(torch.load(tmp_path / "model.pt"))

# check predictions are the same
assert torch.allclose(m(graphs[0]), m2(graphs[0]))
assert torch.allclose(m1(graphs[0]), m2(graphs[0]))


def test_addition():
Expand Down
Loading

0 comments on commit da7f627

Please sign in to comment.