Skip to content

Commit

Permalink
add test
Browse files Browse the repository at this point in the history
  • Loading branch information
jla-gardner committed Jul 11, 2024
1 parent 2819ac8 commit 9865eb8
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 5 deletions.
5 changes: 4 additions & 1 deletion src/graph_pes/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,10 @@ def predict_local_energies(self, graph: AtomicGraph) -> Tensor:
The per-atom local energy predictions, with shape :code:`(N,)`.
"""

def pre_fit(self, graphs: LabelledGraphDataset | Sequence[LabelledGraph]):
def pre_fit(
self,
graphs: LabelledGraphDataset | Sequence[LabelledGraph] | LabelledBatch,
):
"""
Pre-fit the model to the training data.
Expand Down
18 changes: 14 additions & 4 deletions src/graph_pes/models/scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,12 @@

from graph_pes.core import GraphPESModel
from graph_pes.data.dataset import LabelledGraphDataset
from graph_pes.graphs.graph_typing import AtomicGraph, LabelledGraph
from graph_pes.graphs.operations import to_batch
from graph_pes.graphs.graph_typing import (
AtomicGraph,
LabelledBatch,
LabelledGraph,
)
from graph_pes.graphs.operations import is_batch, to_batch
from graph_pes.models.pre_fit import guess_per_element_mean_and_var
from graph_pes.nn import PerElementParameter

Expand Down Expand Up @@ -53,7 +57,10 @@ def predict_unscaled_energies(self, graph: AtomicGraph) -> torch.Tensor:
The unscaled, local energies with shape ``(n_atoms,)``.
"""

def pre_fit(self, graphs: LabelledGraphDataset | Sequence[LabelledGraph]):
def pre_fit(
self,
graphs: LabelledGraphDataset | Sequence[LabelledGraph] | LabelledBatch,
):
_was_already_prefit = self._has_been_pre_fit.item()
super().pre_fit(graphs)

Expand All @@ -63,7 +70,10 @@ def pre_fit(self, graphs: LabelledGraphDataset | Sequence[LabelledGraph]):
if isinstance(graphs, LabelledGraphDataset):
graphs = list(graphs)

graph_batch = to_batch(graphs)
if isinstance(graphs, dict) and is_batch(graphs):
graph_batch = graphs
else:
graph_batch = to_batch(graphs) # type: ignore

# use Ridge regression to calculate standard deviations in the
# per-element contributions to the total energy
Expand Down
31 changes: 31 additions & 0 deletions tests/test_scaling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import torch
from ase.build import molecule
from graph_pes.data.io import to_atomic_graph
from graph_pes.graphs.graph_typing import AtomicGraph
from graph_pes.models.scaling import UnScaledPESModel


class StupidModel(UnScaledPESModel):
def predict_unscaled_energies(self, graph: AtomicGraph) -> torch.Tensor:
return torch.ones_like(graph["atomic_numbers"]).float()


def test_scaling():
model = StupidModel()

# set the scaling terms for H and C
with torch.no_grad():
model._per_element_scaling[1] = 0.5
model._per_element_scaling[6] = 2.0

graph = to_atomic_graph(molecule("CH4"), cutoff=3)
assert torch.equal(
graph["atomic_numbers"],
torch.tensor([6, 1, 1, 1, 1]),
)

unscaled = model.predict_unscaled_energies(graph)
assert torch.equal(unscaled, torch.ones(5))

local = model.predict_local_energies(graph)
assert torch.equal(local, torch.tensor([2.0, 0.5, 0.5, 0.5, 0.5]))

0 comments on commit 9865eb8

Please sign in to comment.