From 257961d6ecee7987c325b74d41bb030e91a0d275 Mon Sep 17 00:00:00 2001 From: John Gardner Date: Mon, 1 Apr 2024 07:59:29 +0200 Subject: [PATCH] changes --- .github/workflows/tests.yaml | 14 +++++++++++ src/graph_pes/core.py | 25 ++++++++++--------- src/graph_pes/models/painn.py | 6 ++++- src/graph_pes/models/pairwise.py | 40 ++++++++++++++++++++----------- src/graph_pes/models/schnet.py | 4 +++- src/graph_pes/models/tensornet.py | 4 +++- src/graph_pes/nn.py | 16 ------------- src/graph_pes/training.py | 2 +- src/graph_pes/transform.py | 31 +++++++++++++++++++++++- 9 files changed, 96 insertions(+), 46 deletions(-) diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 8582d32e..85fa219b 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -3,6 +3,20 @@ on: [push] permissions: contents: read jobs: + formatting: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - uses: actions/setup-python@v3 + with: + python-version: 3.9 + - name: update pip + run: pip install --upgrade pip + - name: Install ruff + run: pip install ruff + - name: Run ruff + run: ruff check + tests: runs-on: ubuntu-latest steps: diff --git a/src/graph_pes/core.py b/src/graph_pes/core.py index 67eb3b7c..0b227156 100644 --- a/src/graph_pes/core.py +++ b/src/graph_pes/core.py @@ -53,6 +53,20 @@ class GraphPESModel(nn.Module, ABC): per-species scale and shift. """ + def __init__(self, energy_transform: Transform | None = None): + super().__init__() + + self.energy_transform: Transform = ( + PerAtomStandardScaler() + if energy_transform is None + else energy_transform + ) + + # save as a buffer so that this is saved and loaded + # with the model + self._has_been_pre_fit: Tensor + self.register_buffer("_has_been_pre_fit", torch.tensor(False)) + def forward(self, graph: AtomicGraph) -> Tensor: """ Calculate the total energy of the structure. @@ -89,17 +103,6 @@ def predict_local_energies(self, graph: AtomicGraph) -> Tensor: The per-atom local energy predictions with shape :code:`(N,)`. """ - def __init__(self, energy_transform: Transform | None = None): - super().__init__() - self.energy_transform: Transform = ( - PerAtomStandardScaler() - if energy_transform is None - else energy_transform - ) - - self._has_been_pre_fit: Tensor - self.register_buffer("_has_been_pre_fit", torch.tensor(False)) - def pre_fit( self, graphs: LabelledBatch | Sequence[LabelledGraph], diff --git a/src/graph_pes/models/painn.py b/src/graph_pes/models/painn.py index 9764af73..f89de46b 100644 --- a/src/graph_pes/models/painn.py +++ b/src/graph_pes/models/painn.py @@ -9,6 +9,7 @@ number_of_atoms, ) from graph_pes.nn import MLP, HaddamardProduct, PerElementEmbedding +from graph_pes.transform import Transform from torch import Tensor, nn from .distances import Bessel, PolynomialEnvelope @@ -181,6 +182,8 @@ class PaiNN(GraphPESModel): The number of (interaction + update) layers to use. cutoff The cutoff distance for the radial features. + energy_transform + The energy transform to use (defaults to PerAtomStandardScaler) """ # noqa: E501 def __init__( @@ -189,8 +192,9 @@ def __init__( radial_features: int = 20, layers: int = 3, cutoff: float = 5.0, + energy_transform: Transform | None = None, ): - super().__init__() + super().__init__(energy_transform) self.internal_dim = internal_dim self.layers = layers self.interactions: list[Interaction] = nn.ModuleList( diff --git a/src/graph_pes/models/pairwise.py b/src/graph_pes/models/pairwise.py index 2cbd4de5..e01c966d 100644 --- a/src/graph_pes/models/pairwise.py +++ b/src/graph_pes/models/pairwise.py @@ -12,7 +12,7 @@ sum_over_neighbours, ) from graph_pes.nn import PerElementParameter -from graph_pes.transform import PerAtomShift +from graph_pes.transform import PerAtomShift, Transform from graph_pes.util import pytorch_repr, to_significant_figures from jaxtyping import Float from torch import Tensor @@ -119,9 +119,16 @@ class LennardJones(PairPotential): :align: center """ - def __init__(self, epsilon: float = 0.1, sigma: float = 1.0): + def __init__( + self, + epsilon: float = 0.1, + sigma: float = 1.0, + energy_transform: Transform | None = None, + ): # epsilon is a scaling term, so only need to learn a shift - super().__init__(energy_transform=PerAtomShift()) + if energy_transform is None: + energy_transform = PerAtomShift() + super().__init__(energy_transform) self._log_epsilon = torch.nn.Parameter(torch.tensor(epsilon).log()) self._log_sigma = torch.nn.Parameter(torch.tensor(sigma).log()) @@ -199,10 +206,17 @@ class Morse(PairPotential): :align: center """ - def __init__(self, D: float = 0.1, a: float = 5.0, r0: float = 1.5): + def __init__( + self, + D: float = 0.1, + a: float = 5.0, + r0: float = 1.5, + energy_transform: Transform | None = None, + ): # D is a scaling term, so only need to learn a shift - # parameter (rather than a shift and scale) - super().__init__(energy_transform=PerAtomShift()) + if energy_transform is None: + energy_transform = PerAtomShift() + super().__init__(energy_transform) self._log_D = torch.nn.Parameter(torch.tensor(D).log()) self._log_a = torch.nn.Parameter(torch.tensor(a).log()) @@ -220,9 +234,7 @@ def a(self): def r0(self): return self._log_r0.exp() - def interaction( - self, r: torch.Tensor, Z_i: torch.Tensor, Z_j: torch.Tensor - ): + def interaction(self, r: torch.Tensor, Z_i=None, Z_j=None): """ Evaluate the pair potential. @@ -304,16 +316,16 @@ def interaction(self, r: Tensor, Z_i: Tensor, Z_j: Tensor) -> Tensor: Z_j : torch.Tensor The atomic numbers of the neighbours. """ - cross_interaction = Z_i != Z_j + cross_interaction = Z_i != Z_j # (E) - sigma_j = self.sigma[Z_j].squeeze() - sigma_i = self.sigma[Z_i].squeeze() + sigma_j = self.sigma[Z_j].squeeze() # (E) + sigma_i = self.sigma[Z_i].squeeze() # (E) nu = self.nu[Z_i, Z_j].squeeze() if self.modulate_distances else 1 sigma = torch.where( cross_interaction, nu * (sigma_i + sigma_j) / 2, sigma_i, - ).clamp(min=0.2) + ).clamp(min=0.2) # (E) epsilon_i = self.epsilon[Z_i].squeeze() epsilon_j = self.epsilon[Z_j].squeeze() @@ -322,7 +334,7 @@ def interaction(self, r: Tensor, Z_i: Tensor, Z_j: Tensor) -> Tensor: cross_interaction, zeta * (epsilon_i * epsilon_j).sqrt(), epsilon_i, - ).clamp(min=0.00) + ).clamp(min=0.00) # (E) x = sigma / r return 4 * epsilon * (x**12 - x**6) diff --git a/src/graph_pes/models/schnet.py b/src/graph_pes/models/schnet.py index 4a8b0d2d..15fc9c89 100644 --- a/src/graph_pes/models/schnet.py +++ b/src/graph_pes/models/schnet.py @@ -4,6 +4,7 @@ from graph_pes.core import GraphPESModel from graph_pes.data import AtomicGraph, neighbour_distances from graph_pes.nn import MLP, PerElementEmbedding, ShiftedSoftplus +from graph_pes.transform import Transform from torch import Tensor, nn from torch_geometric.nn import MessagePassing @@ -213,8 +214,9 @@ def __init__( cutoff: float = 5.0, layers: int = 3, expansion: type[DistanceExpansion] | None = None, + energy_transform: Transform | None = None, ): - super().__init__() + super().__init__(energy_transform) if expansion is None: expansion = GaussianSmearing diff --git a/src/graph_pes/models/tensornet.py b/src/graph_pes/models/tensornet.py index 4b025ddb..446993b8 100644 --- a/src/graph_pes/models/tensornet.py +++ b/src/graph_pes/models/tensornet.py @@ -1,6 +1,7 @@ from __future__ import annotations import torch +from graph_pes.transform import Transform from torch import Tensor, nn from ..core import GraphPESModel @@ -320,8 +321,9 @@ def __init__( embedding_size: int = 32, cutoff: float = 5.0, layers: int = 1, + energy_transform: Transform | None = None, ): - super().__init__() + super().__init__(energy_transform) self.embedding = Embedding(radial_features, embedding_size, cutoff) self.interactions: list[Interaction] = nn.ModuleList( [ diff --git a/src/graph_pes/nn.py b/src/graph_pes/nn.py index a6d52dd6..979ea71d 100644 --- a/src/graph_pes/nn.py +++ b/src/graph_pes/nn.py @@ -131,16 +131,6 @@ def prod(iterable): return reduce(lambda x, y: x * y, iterable, 1) -# # Metaclass to combine _TensorMeta and the instance check override -# # for PerElementParameter (see equivalent in torch.nn.Parameter) -# class _PerElementParameterMeta(torch._C._TensorMeta): -# def __instancecheck__(self, instance): -# return super().__instancecheck__(instance) or ( -# isinstance(instance, torch.Tensor) -# and getattr(instance, "_is_per_element_param", False) -# ) - - class PerElementParameter(torch.nn.Parameter): def __new__( cls, data: Tensor, requires_grad: bool = True @@ -224,12 +214,6 @@ def __instancecheck__(self, instance) -> bool: @torch.no_grad() def __repr__(self) -> str: - # TODO implement custom repr for different shapes - # 1 index dimension: - # table with header column for Z - # 2 index dimensions with singleton further shape: - # 2D table with Z as both header row and header column - # print numel otherwise if len(self._accessed_Zs) == 0: return ( f"PerElementParameter(index_dims={self._index_dims}, " diff --git a/src/graph_pes/training.py b/src/graph_pes/training.py index d4177aa5..910676b1 100644 --- a/src/graph_pes/training.py +++ b/src/graph_pes/training.py @@ -320,7 +320,7 @@ def Adam( energy_transform_overrides: dict | None = None, ) -> Callable[[GraphPESModel], torch.optim.Optimizer]: if energy_transform_overrides is None: - energy_transform_overrides = {"weight": 1.0} + energy_transform_overrides = {"weight_decay": 0.0} def adam(model: GraphPESModel) -> torch.optim.Optimizer: model_params = [ diff --git a/src/graph_pes/transform.py b/src/graph_pes/transform.py index 64367cee..ac477370 100644 --- a/src/graph_pes/transform.py +++ b/src/graph_pes/transform.py @@ -3,6 +3,7 @@ from abc import ABC, abstractmethod import torch +from ase.data import atomic_numbers as symbol_to_Z from torch import Tensor, nn from graph_pes.data import ( @@ -318,7 +319,7 @@ def inverse(self, y: Tensor, graph: AtomicGraph) -> Tensor: return left_aligned_add(y, shifts) def __repr__(self): - return f"PerAtomShift({self.shift})" + return f"{self.__class__.__name__}({self.shift})" class PerAtomScale(Transform): @@ -458,6 +459,34 @@ def PerAtomStandardScaler(trainable: bool = True) -> Transform: return Chain([PerAtomScale(trainable), PerAtomShift(trainable)]) +class FixedEnergyOffsets(PerAtomShift): + r""" + A convenience function for a :class:`PerAtomShift` transform with fixed + species-dependent shifts. + + Parameters + ---------- + offsets + The fixed shifts to apply to each species. + + Examples + -------- + >>> FixedEnergyOffsets(H=-0.1, Pt=-13.14) + """ + + def __init__(self, **offsets: float): + super().__init__(trainable=False) + for symbol, offset in offsets.items(): + assert symbol in symbol_to_Z, f"Unknown element: {symbol}" + self.shift[symbol_to_Z[symbol]] = offset + + def fit_to_source(self, x: Tensor, graphs: AtomicGraphBatch): + pass + + def fit_to_target(self, y: Tensor, graphs: AtomicGraphBatch): + pass + + class Scale(Transform): def __init__(self, trainable: bool = True, scale: float | int = 1.0): super().__init__(trainable=trainable)