Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jla-gardner committed Mar 8, 2024
1 parent bf70ed6 commit 70cd3b5
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 6 deletions.
9 changes: 7 additions & 2 deletions src/graph_pes/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,14 @@ def __init__(self, energy_transform: Transform | None = None):
else energy_transform
)

#
self._has_been_pre_fit: Tensor
self.register_buffer("_has_been_pre_fit", torch.tensor(False))

def pre_fit(self, graphs: AtomicGraphBatch, relative: bool = True):
def pre_fit(
self,
graphs: AtomicGraphBatch | Sequence[AtomicGraph],
relative: bool = True,
):
"""
Pre-fit the model to the training data.
Expand Down Expand Up @@ -189,6 +192,8 @@ def pre_fit(self, graphs: AtomicGraphBatch, relative: bool = True):
return

self._has_been_pre_fit.fill_(True)
if isinstance(graphs, Sequence):
graphs = to_batch(graphs)

if self._extra_pre_fit(graphs):
return
Expand Down
4 changes: 2 additions & 2 deletions src/graph_pes/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,10 +469,10 @@ def __init__(self, trainable: bool = True, scale: float | int = 1.0):
)

def forward(self, x: Tensor, graph: AtomicGraph) -> Tensor:
return x * self.scale
return x / self.scale

def inverse(self, y: Tensor, graph: AtomicGraph) -> Tensor:
return y / self.scale
return y * self.scale

@torch.no_grad()
def fit_to_source(self, x: Tensor, graphs: AtomicGraphBatch):
Expand Down
49 changes: 47 additions & 2 deletions tests/test_models.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from __future__ import annotations

import pytest
import torch
from ase import Atoms
from ase.io import read
from graph_pes.core import Ensemble, get_predictions
from graph_pes.core import Ensemble, GraphPESModel, get_predictions
from graph_pes.data import (
AtomicGraph,
AtomicGraphBatch,
has_cell,
number_of_atoms,
number_of_edges,
Expand All @@ -13,14 +16,15 @@
to_batch,
)
from graph_pes.models.zoo import LennardJones, Morse
from graph_pes.transform import PerAtomShift

structures: list[Atoms] = read("tests/test.xyz", ":") # type: ignore
graphs = to_atomic_graphs(structures, cutoff=3)


def test_model():
model = LennardJones()
model.pre_fit(to_batch(graphs[:2])) # type: ignore
model.pre_fit(graphs[:2])

assert sum(p.numel() for p in model.parameters()) == 3

Expand Down Expand Up @@ -52,3 +56,44 @@ def test_ensembling():
mean_model(graphs[0]),
(1.2 * lj(graphs[0]) + 5.7 * morse(graphs[0])) / (1.2 + 5.7),
)


def test_pre_fit():
model = LennardJones()
model.pre_fit(graphs)

with pytest.warns(
UserWarning,
match="This model has already been pre-fitted",
):
model.pre_fit(graphs)

batch = to_batch(graphs)
batch.pop("energy") # type: ignore
with pytest.warns(
UserWarning,
match="The training data doesn't contain energies.",
):
LennardJones().pre_fit(batch)

for ret_value in True, False:
# make sure energy transform is not called if return from _extra_pre_fit
class DummyModel(GraphPESModel):
def __init__(self):
super().__init__(energy_transform=PerAtomShift())

def predict_local_energies(
self, graph: AtomicGraph
) -> torch.Tensor:
return torch.ones(number_of_atoms(graph))

def _extra_pre_fit(self, graphs: AtomicGraphBatch) -> bool | None:
return ret_value # noqa: B023

model = DummyModel()
assert model.energy_transform.shift[29] == 0
model.pre_fit(graphs)
if ret_value:
assert model.energy_transform.shift[29] == 0
else:
assert model.energy_transform.shift[29] != 0
18 changes: 18 additions & 0 deletions tests/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,3 +111,21 @@ def test_inverse(transform: Transform):

inverse = transform.inverse(y, graph)
assert x.equal(inverse)


def test_scale():
x = torch.tensor([1.3, 2.1]) # per atom property
std = x.std()

transform = Scale()

# when fitting to source, output of x should have unit std
transform.fit_to_source(x, graph) # type: ignore
y = transform(x, graph)
assert torch.allclose(y, x / std)

# when fitting to target, output of ones should have std of x
transform.fit_to_target(x, graph) # type: ignore
ones = torch.ones_like(x)
y = transform(ones, graph)
assert torch.allclose(y, ones * std)

0 comments on commit 70cd3b5

Please sign in to comment.