Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add tests for deploy #7

Merged
merged 8 commits into from
Jul 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions src/graph_pes/deploy.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from contextlib import contextmanager
from __future__ import annotations

from pathlib import Path

import e3nn
import torch

from graph_pes.core import GraphPESModel
Expand Down Expand Up @@ -60,7 +63,7 @@ def forward(self, graph: AtomicGraph) -> dict[str, torch.Tensor]:
graph[keys._POSITIONS].requires_grad_(True)
change_to_cell.requires_grad_(True)

local_energies = self.model.predict_local_energies(graph)
local_energies = self.model.predict_local_energies(graph).squeeze()
props["local_energies"] = local_energies
total_energy = torch.sum(local_energies)
props["total_energy"] = total_energy
Expand All @@ -76,3 +79,12 @@ def forward(self, graph: AtomicGraph) -> dict[str, torch.Tensor]:
for key in props:
props[key] = props[key].double()
return props

def __call__(self, graph: AtomicGraph) -> dict[str, torch.Tensor]:
return super().__call__(graph)


def deploy_model(model: GraphPESModel, cutoff: float, path: str | Path):
lammps_model = LAMMPSModel(model, cutoff)
scripted_model = e3nn.util.jit.script(lammps_model)
torch.jit.save(scripted_model, path)
2 changes: 2 additions & 0 deletions src/graph_pes/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from graph_pes.core import GraphPESModel

from .e3nn.nequip import NequIP
from .offsets import FixedOffset, LearnableOffset
from .painn import PaiNN
from .pairwise import LennardJones, LennardJonesMixture, Morse
Expand All @@ -19,6 +20,7 @@
"TensorNet",
"Morse",
"LennardJonesMixture",
"NequIP",
"FixedOffset",
"LearnableOffset",
"load_model",
Expand Down
4 changes: 2 additions & 2 deletions src/graph_pes/models/distances.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ def __init__(self, n_features: int, cutoff: float, trainable: bool = True):
super().__init__(n_features, cutoff, trainable)

c = torch.exp(-torch.tensor(cutoff))
self.β = nn.Parameter(
self.beta = nn.Parameter(
torch.ones(n_features) / (2 * (1 - c) / n_features) ** 2,
requires_grad=trainable,
)
Expand All @@ -329,7 +329,7 @@ def __init__(self, n_features: int, cutoff: float, trainable: bool = True):

def expand(self, r: torch.Tensor) -> torch.Tensor:
offsets = torch.exp(-r) - self.centers
return torch.exp(-self.β * offsets**2)
return torch.exp(-self.beta * offsets**2)


class Envelope(nn.Module):
Expand Down
38 changes: 21 additions & 17 deletions src/graph_pes/models/painn.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def __init__(
PolynomialEnvelope(cutoff),
)

self.φ = MLP(
self.Phi = MLP(
[internal_dim, internal_dim, internal_dim * 3],
activation=nn.SiLU(),
)
Expand All @@ -74,23 +74,25 @@ def forward(
unit_vectors = neighbour_vectors(graph) / d # (E, 3)

# continous filter message creation
x_ij = self.filter_generator(d) * self.φ(scalar_embeddings)[neighbours]
x_ij = (
self.filter_generator(d) * self.Phi(scalar_embeddings)[neighbours]
)
a, b, c = torch.split(x_ij, self.internal_dim, dim=-1) # (E, D)

# simple sum over neighbours to get scalar messages
Δs = torch.zeros_like(scalar_embeddings) # (N, D)
Δs.scatter_add_(0, neighbours.view(-1, 1).expand_as(a), a)
delta_s = torch.zeros_like(scalar_embeddings) # (N, D)
delta_s.scatter_add_(0, neighbours.view(-1, 1).expand_as(a), a)

# create vector messages
v_ij = b.unsqueeze(-1) * unit_vectors.unsqueeze(1) # (E, D, 3)
v_ij = v_ij + c.unsqueeze(-1) * vector_embeddings[neighbours]

Δv = torch.zeros_like(vector_embeddings) # (N, D, 3)
Δv.scatter_add_(
delta_v = torch.zeros_like(vector_embeddings) # (N, D, 3)
delta_v.scatter_add_(
0, neighbours.unsqueeze(-1).unsqueeze(-1).expand_as(v_ij), v_ij
)

return Δv, Δs # (N, D, 3), (N, D)
return delta_v, delta_s # (N, D, 3), (N, D)


class VectorLinear(nn.Module):
Expand Down Expand Up @@ -144,13 +146,13 @@ def forward(
a, b, c = torch.split(m, self.internal_dim, dim=-1) # (N, D)

# vector update:
Δv = u * a.unsqueeze(-1) # (N, D, 3)
delta_v = u * a.unsqueeze(-1) # (N, D, 3)

# scalar update:
dot = torch.sum(u * v, dim=-1) # (N, D)
Δs = b + c * dot # (N, D)
delta_s = b + c * dot # (N, D)

return Δv, Δs
return delta_v, delta_s


class PaiNN(GraphPESModel):
Expand Down Expand Up @@ -217,12 +219,14 @@ def predict_local_energies(self, graph: AtomicGraph) -> Tensor:
scalar_embeddings = self.z_embedding(graph["atomic_numbers"])

for interaction, update in zip(self.interactions, self.updates):
Δv, Δs = interaction(vector_embeddings, scalar_embeddings, graph)
vector_embeddings = vector_embeddings + Δv
scalar_embeddings = scalar_embeddings + Δs

Δv, Δs = update(vector_embeddings, scalar_embeddings)
vector_embeddings = vector_embeddings + Δv
scalar_embeddings = scalar_embeddings + Δs
delta_v, delta_s = interaction(
vector_embeddings, scalar_embeddings, graph
)
vector_embeddings = vector_embeddings + delta_v
scalar_embeddings = scalar_embeddings + delta_s

delta_v, delta_s = update(vector_embeddings, scalar_embeddings)
vector_embeddings = vector_embeddings + delta_v
scalar_embeddings = scalar_embeddings + delta_s

return self.read_out(scalar_embeddings)
17 changes: 13 additions & 4 deletions src/graph_pes/models/pairwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,8 @@ def sigma(self):
def interaction(
self,
r: torch.Tensor,
Z_i: Optional[torch.Tensor] = None,
Z_j: Optional[torch.Tensor] = None,
Z_i: Optional[torch.Tensor] = None, # noqa: UP007
Z_j: Optional[torch.Tensor] = None, # noqa: UP007
):
"""
Evaluate the pair potential.
Expand Down Expand Up @@ -226,7 +226,12 @@ def a(self):
def r0(self):
return self._log_r0.exp()

def interaction(self, r: torch.Tensor, Z_i=None, Z_j=None):
def interaction(
self,
r: torch.Tensor,
Z_i: Optional[torch.Tensor] = None, # noqa: UP007
Z_j: Optional[torch.Tensor] = None, # noqa: UP007
):
"""
Evaluate the pair potential.

Expand Down Expand Up @@ -313,7 +318,11 @@ def interaction(self, r: Tensor, Z_i: Tensor, Z_j: Tensor) -> Tensor:

sigma_j = self.sigma[Z_j].squeeze() # (E)
sigma_i = self.sigma[Z_i].squeeze() # (E)
nu = self.nu[Z_i, Z_j].squeeze() if self.modulate_distances else 1
nu = (
self.nu[Z_i, Z_j].squeeze()
if self.modulate_distances
else torch.tensor(1)
)
sigma = torch.where(
cross_interaction,
nu * (sigma_i + sigma_j) / 2,
Expand Down
3 changes: 2 additions & 1 deletion src/graph_pes/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,8 @@ def register_Zs(self, Zs: list[int]):
unique_Zs = sorted(set(Zs))
if len(unique_Zs) != self.n_elements:
raise ValueError(
f"Expected {self.n_elements} elements, got {unique_Zs}"
f"Expected {self.n_elements} elements, got "
f"{len(unique_Zs)}: {unique_Zs}"
)

for i, Z in enumerate(unique_Zs):
Expand Down
62 changes: 62 additions & 0 deletions tests/test_deploy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from __future__ import annotations

from pathlib import Path

import pytest
import torch
from ase.build import molecule
from graph_pes.core import GraphPESModel
from graph_pes.data.io import to_atomic_graph
from graph_pes.deploy import deploy_model
from graph_pes.graphs.operations import number_of_atoms
from graph_pes.models import ALL_MODELS, NequIP

CUTOFF = 1.5
graph = to_atomic_graph(molecule("CH3CH2OH"), cutoff=CUTOFF)


@pytest.mark.parametrize(
"model_klass",
ALL_MODELS,
ids=[model.__name__ for model in ALL_MODELS],
)
def test_deploy(model_klass: type[GraphPESModel], tmp_path: Path):
# 1. instantiate the model
kwargs = {"n_elements": 3} if model_klass is NequIP else {}
model = model_klass(**kwargs)
model.pre_fit([graph]) # required by some models before making predictions

# 2. deploy the model
save_path = tmp_path / "model.pt"
deploy_model(model, cutoff=CUTOFF, path=save_path)

# 3. load the model back in
loaded_model = torch.jit.load(save_path)
assert isinstance(loaded_model, torch.jit.ScriptModule)
assert loaded_model.get_cutoff() == CUTOFF

# 4. test outputs
outputs = loaded_model(
# mock the graph that would be passed through from LAMMPS
{
**graph,
"compute_virial": torch.tensor(True),
"debug": torch.tensor(False),
}
)
assert isinstance(outputs, dict)
assert set(outputs.keys()) == {
"total_energy",
"local_energies",
"forces",
"virial",
}
assert outputs["total_energy"].shape == torch.Size([])
assert outputs["local_energies"].shape == (number_of_atoms(graph),)
assert outputs["forces"].shape == graph["_positions"].shape
assert outputs["virial"].shape == (3, 3)

# 5. test that the deployment process hasn't changed the model's predictions
with torch.no_grad():
original_energy = model(graph).double()
assert torch.allclose(original_energy, outputs["total_energy"])
49 changes: 49 additions & 0 deletions tests/test_lammps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import pytest
import torch
from ase.build import molecule
from graph_pes.core import get_predictions
from graph_pes.data.io import to_atomic_graph
from graph_pes.deploy import LAMMPSModel
from graph_pes.graphs import keys
from graph_pes.models import LennardJones


@pytest.mark.parametrize(
"compute_virial",
[True, False],
)
def test_lammps_model(compute_virial: bool):
# generate a structure
structure = molecule("CH3CH2OH")
if compute_virial:
# ensure the structure has a cell
structure.center(vacuum=5.0)
graph = to_atomic_graph(structure, cutoff=1.5)

# create a normal model, and get normal predictions
model = LennardJones()
props: list[keys.LabelKey] = ["energy", "forces"]
if compute_virial:
props.append("stress")
outputs = get_predictions(model, graph, properties=props, training=False)

# create a LAMMPS model, and get LAMMPS predictions
lammps_model = LAMMPSModel(model)
lammps_graph: dict[str, torch.Tensor] = {
**graph,
"compute_virial": torch.tensor(compute_virial),
"debug": torch.tensor(False),
} # type: ignore
lammps_outputs = lammps_model(lammps_graph)

# check outputs
if compute_virial:
assert "virial" in lammps_outputs
assert (
outputs["stress"].shape == lammps_outputs["virial"].shape == (3, 3)
)

assert torch.allclose(
outputs["energy"].float(),
lammps_outputs["total_energy"].float(),
)
13 changes: 7 additions & 6 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
number_of_atoms,
number_of_edges,
)
from graph_pes.models import ALL_MODELS, LennardJones, Morse
from graph_pes.models import ALL_MODELS, LennardJones, Morse, NequIP

structures: list[Atoms] = read("tests/test.xyz", ":") # type: ignore
graphs = to_atomic_graphs(structures, cutoff=3)
Expand Down Expand Up @@ -65,17 +65,18 @@ def test_pre_fit():
ids=[m.__name__ for m in ALL_MODELS],
)
def test_model_serialisation(model: type[GraphPESModel], tmp_path):
m = model()
m.pre_fit(graphs)
kwargs = {} if model is not NequIP else {"n_elements": 1}
m1 = model(**kwargs)
m1.pre_fit(graphs)

torch.save(m.state_dict(), tmp_path / "model.pt")
torch.save(m1.state_dict(), tmp_path / "model.pt")

m2 = model()
m2 = model(**kwargs)
# check no errors occur
m2.load_state_dict(torch.load(tmp_path / "model.pt"))

# check predictions are the same
assert torch.allclose(m(graphs[0]), m2(graphs[0]))
assert torch.allclose(m1(graphs[0]), m2(graphs[0]))


def test_addition():
Expand Down
Loading
Loading