Skip to content

Commit

Permalink
adopt for SchNet
Browse files Browse the repository at this point in the history
  • Loading branch information
jla-gardner committed Jul 11, 2024
1 parent b0c59b9 commit 2819ac8
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 9 deletions.
15 changes: 9 additions & 6 deletions src/graph_pes/models/scaling.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import warnings
from abc import ABC, abstractmethod
from typing import Sequence
Expand All @@ -12,7 +14,7 @@
from graph_pes.nn import PerElementParameter


class ScaledPESModel(GraphPESModel, ABC):
class UnScaledPESModel(GraphPESModel, ABC):
"""
An abstract base class for all PES models implementations that are best
suited to making raw predictions that with ~unit variance. By inheriting
Expand Down Expand Up @@ -66,11 +68,12 @@ def pre_fit(self, graphs: LabelledGraphDataset | Sequence[LabelledGraph]):
# use Ridge regression to calculate standard deviations in the
# per-element contributions to the total energy
if "energy" in graph_batch:
_, variances = guess_per_element_mean_and_var(
graph_batch["energy"], graph_batch
)
for Z, var in variances.items():
self._per_element_scaling[Z] = var**0.5
with torch.no_grad():
_, variances = guess_per_element_mean_and_var(
graph_batch["energy"], graph_batch
)
for Z, var in variances.items():
self._per_element_scaling[Z] = var**0.5

else:
model_name = self.__class__.__name__
Expand Down
6 changes: 3 additions & 3 deletions src/graph_pes/models/schnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
from torch import Tensor, nn
from torch_geometric.nn import MessagePassing

from graph_pes.core import GraphPESModel
from graph_pes.graphs import AtomicGraph
from graph_pes.graphs.operations import neighbour_distances
from graph_pes.models.scaling import UnScaledPESModel
from graph_pes.nn import MLP, PerElementEmbedding, ShiftedSoftplus

from .distances import DistanceExpansion, GaussianSmearing
Expand Down Expand Up @@ -167,7 +167,7 @@ def forward(
return self.mlp(h)


class SchNet(GraphPESModel):
class SchNet(UnScaledPESModel):
r"""
The `SchNet <https://arxiv.org/abs/1706.08566>`_ model: a pairwise, scalar,
message passing GNN.
Expand Down Expand Up @@ -237,7 +237,7 @@ def __init__(
activation=ShiftedSoftplus(),
)

def predict_local_energies(self, graph: AtomicGraph) -> torch.Tensor:
def predict_unscaled_energies(self, graph: AtomicGraph) -> torch.Tensor:
h = self.chemical_embedding(graph["atomic_numbers"])

for interaction in self.interactions:
Expand Down

0 comments on commit 2819ac8

Please sign in to comment.