From 66009ed22983f3bce2584c200175954a61ea80c1 Mon Sep 17 00:00:00 2001 From: John Gardner Date: Tue, 4 Jun 2024 08:04:47 +0200 Subject: [PATCH] simplify --- src/graph_pes/core.py | 266 ++++++++---------------------- src/graph_pes/data/__init__.py | 3 + src/graph_pes/data/keys.py | 3 + src/graph_pes/data/utils.py | 4 +- src/graph_pes/loss.py | 192 ++++++++++++++------- src/graph_pes/models/distances.py | 4 +- src/graph_pes/models/offsets.py | 44 +++++ src/graph_pes/models/painn.py | 4 +- src/graph_pes/models/pairwise.py | 34 ++-- src/graph_pes/models/schnet.py | 4 +- src/graph_pes/models/tensornet.py | 4 +- src/graph_pes/models/zoo.py | 1 + src/graph_pes/training.py | 132 +++++++-------- src/graph_pes/transform.py | 14 +- src/graph_pes/util.py | 10 +- tests/test_loss.py | 4 +- tests/test_models.py | 4 +- 17 files changed, 353 insertions(+), 374 deletions(-) create mode 100644 src/graph_pes/models/offsets.py diff --git a/src/graph_pes/core.py b/src/graph_pes/core.py index 0b227156..ed84eed4 100644 --- a/src/graph_pes/core.py +++ b/src/graph_pes/core.py @@ -2,7 +2,7 @@ import warnings from abc import ABC, abstractmethod -from typing import Literal, Sequence, overload +from typing import Sequence, overload import torch from torch import Tensor, nn @@ -18,7 +18,6 @@ to_batch, ) from graph_pes.nn import PerElementParameter -from graph_pes.transform import PerAtomStandardScaler, Transform from graph_pes.util import differentiate, require_grad @@ -37,32 +36,12 @@ class GraphPESModel(nn.Module, ABC): and returns a per-atom prediction of the local energy. For a simple example, see the :class:`PairPotential ` `implementation <_modules/graph_pes/models/pairwise.html#PairPotential>`_. - - Under the hood, :class:`GraphPESModel`\ s pass the local energy predictions - through a :class:`graph_pes.transform.Transform` before summing them to - get the total energy. By default, this learns a per-species local-energy - scale and shift. This can be changed by directly altering passing a - different :class:`~graph_pes.transform.Transform` to this base class's - constructor. - - Parameters - ---------- - energy_transform - The transform to apply to the local energy predictions before summing - them to get the total energy. By default, this is a learnable - per-species scale and shift. """ - def __init__(self, energy_transform: Transform | None = None): + def __init__(self): super().__init__() - self.energy_transform: Transform = ( - PerAtomStandardScaler() - if energy_transform is None - else energy_transform - ) - - # save as a buffer so that this is saved and loaded + # save as a buffer so that this is de/serialized # with the model self._has_been_pre_fit: Tensor self.register_buffer("_has_been_pre_fit", torch.tensor(False)) @@ -81,16 +60,16 @@ def forward(self, graph: AtomicGraph) -> Tensor: Tensor The total energy of the structure/s. If the input is a batch of graphs, the result will be a tensor of shape :code:`(B,)`, - where :code:`B` is the batch size, else a scalar. + where :code:`B` is the batch size. Otherwise, a scalar tensor + will be returned. """ local_energies = self.predict_local_energies(graph).squeeze() - transformed = self.energy_transform(local_energies, graph) - return sum_per_structure(transformed, graph) + return sum_per_structure(local_energies, graph) @abstractmethod def predict_local_energies(self, graph: AtomicGraph) -> Tensor: """ - Predict the (non-transformed) local energy for each atom in the graph. + Predict the local energy for each atom in the graph. Parameters ---------- @@ -100,108 +79,49 @@ def predict_local_energies(self, graph: AtomicGraph) -> Tensor: Returns ------- Tensor - The per-atom local energy predictions with shape :code:`(N,)`. + The per-atom local energy predictions, with shape :code:`(N,)`. """ - def pre_fit( - self, - graphs: LabelledBatch | Sequence[LabelledGraph], - relative: bool = True, - ): + # TODO: move away from sequence approach to more general dataloader/dataset + def pre_fit(self, graphs: LabelledBatch | Sequence[LabelledGraph]): """ Pre-fit the model to the training data. - By default, this fits the :code:`energy_transform` to the energies - of the training data. To add additional pre-fitting steps, override - :meth:`_extra_pre_fit`. As an example of this, see the + This method detects the unique atomic numbers in the training data + and registers these with all of the model's per-element parameters + to ensure correct parameter counting. + + Additionally, this method performs any model-specific pre-fitting + steps, as implemented in :meth:`model_specific_pre_fit`. + + As an example of a model-specific pre-fitting process, see the :class:`~graph_pes.models.pairwise.LennardJones` `implementation <_modules/graph_pes/models/pairwise.html#LennardJones>`__. If the model has already been pre-fitted, subsequent calls to - :meth:`pre_fit` will be ignored. + :meth:`pre_fit` will be ignored (and a warning will be raised). Parameters ---------- graphs The training data. - relative - Whether to account for the current energy predictions when fitting - the energy transform. - - Example - ------- - Without any pre-fitting, models *tend* to predict energies that are - close to 0: - - >>> from graph_pes.models.zoo import LennardJones - >>> model = LennardJones() - >>> model - LennardJones( - (epsilon): 0.1 - (sigma): 1.0 - (energy_transform): PerAtomShift( - Cu : [0.], - ) - ) - >>> from graph_pes.analysis import parity_plot - >>> parity_plot(model, val, units="eV") - - .. image:: /_static/lj-parity-raw.svg - :align: center - - Pre-fitting a model's :code:`energy_transform` to the training data - (together with any other steps defined in :meth:`_extra_pre_fit`) - dramatically improves the predictions for free: - - >>> from graph_pes.data import to_batch - >>> model = LennardJones() - >>> model.pre_fit(to_batch(train_set), relative=False) - >>> model - LennardJones( - (epsilon): 0.1 - (sigma): 2.27 - (energy_transform): PerAtomShift( - Cu : [3.5229], - ) - ) - >>> parity_plot(model, val, units="eV") - - .. image:: /_static/lj-parity-prefit.svg - :align: center - - Accounting for the model's current predictions when fitting the - energy transforms (the default behaviour) leads to even better - pre-conditioned models: - - >>> model = LennardJones() - >>> model.pre_fit(to_batch(train_set), relative=True) - >>> model - LennardJones( - (epsilon): 0.1 - (sigma): 2.27 - (energy_transform): PerAtomShift( - Cu : [2.9238], - ) - ) - >>> parity_plot(model, val, units="eV") - - .. image:: /_static/lj-parity-relative.svg - :align: center """ + if self._has_been_pre_fit.item(): + model_name = self.__class__.__name__ warnings.warn( - "This model has already been pre-fitted. " - "Subsequent calls to pre_fit will be ignored.", + f"This model ({model_name}) has already been pre-fitted. " + "This, and any subsequent, call to pre_fit will be ignored.", stacklevel=2, ) return - self._has_been_pre_fit.fill_(True) if isinstance(graphs, Sequence): graphs = to_batch(graphs) - stop_here = self._extra_pre_fit(graphs) + self._has_been_pre_fit.fill_(True) + self.model_specific_pre_fit(graphs) # register all per-element parameters for param in self.parameters(): @@ -210,53 +130,40 @@ def pre_fit( torch.unique(graphs[keys.ATOMIC_NUMBERS]).tolist() ) - if stop_here: - return - - if "energy" not in graphs: - warnings.warn( - "The training data doesn't contain energies. " - "The energy transform will not be fitted.", - stacklevel=2, - ) - return - - target = graphs["energy"] - if relative: - with torch.no_grad(): - target = graphs["energy"] - self(graphs) - - self.energy_transform.fit_to_target(target, graphs) - - def _extra_pre_fit(self, graphs: LabelledBatch) -> bool | None: + def model_specific_pre_fit(self, graphs: LabelledBatch) -> None: """ Override this method to perform additional pre-fitting steps. - Return ``True`` to surpress the default pre-fitting of the energy - transform implemented on this base class. + + As an example, see the + :class:`~graph_pes.models.pairwise.LennardJones` + `implementation + <_modules/graph_pes/models/pairwise.html#LennardJones>`__. + + Parameters + ---------- + graphs + The training data. """ # add type hints to play nicely with mypy def __call__(self, graph: AtomicGraph) -> Tensor: return super().__call__(graph) - def __add__(self, other: GraphPESModel) -> Ensemble: - return Ensemble([self, other], aggregation="sum") + def __add__(self, other: GraphPESModel | AdditionModel) -> AdditionModel: + if isinstance(other, AdditionModel): + return AdditionModel([self, *other.models]) + return AdditionModel([self, other]) -class Ensemble(GraphPESModel): +class AdditionModel(GraphPESModel): """ - An ensemble of :class:`GraphPESModel` models. + A wrapper that makes predictions as the sum of the predictions + of its constituent models. Parameters ---------- models - the models to ensemble. - aggregation - the method of aggregating the predictions of the models. - weights - scalar weights for combining each model's prediction. - trainable_weights - whether the weights are trainable. + the models to sum. Examples -------- @@ -264,63 +171,31 @@ class Ensemble(GraphPESModel): .. code-block:: python - from graph_pes.models.pairwise import LennardJones - from graph_pes.models.schnet import SchNet - from graph_pes.core import Ensemble - - # create an ensemble of two models - # equivalent to Ensemble([LennardJones(), SchNet()], aggregation="sum") - ensemble = LennardJones() + SchNet() - - Use several models to get an average prediction: + from graph_pes.models.zoo import LennardJones, SchNet + from graph_pes.core import AdditionModel - .. code-block:: python - - models = ... # load/train your models - ensemble = Ensemble(models, aggregation="mean") - predictions = ensemble.predict(test_graphs) - ... + # create a model that sums two models + # equivalent to LennardJones() + SchNet() + model = AdditionModel([LennardJones(), SchNet()]) """ - def __init__( - self, - models: list[GraphPESModel], - aggregation: Literal["mean", "sum"] = "mean", - weights: list[float] | None = None, - trainable_weights: bool = False, - ): + def __init__(self, models: Sequence[GraphPESModel]): super().__init__() self.models: list[GraphPESModel] = nn.ModuleList(models) # type: ignore - self.aggregation = aggregation - self.weights = nn.Parameter( - torch.tensor( - weights or [1.0] * len(models), requires_grad=trainable_weights - ) - ) - # use the energy summation of each model separately - self.energy_summation = None + def predict_local_energies(self, graph: AtomicGraph) -> Tensor: + predictions = torch.stack( + [model.predict_local_energies(graph) for model in self.models] + ) # (atoms, models) + return torch.sum(predictions, dim=0) # (atoms,) sum over models - def predict_local_energies(self, graph: AtomicGraph): - raise NotImplementedError( - "Ensemble models don't have a single local energy prediction." - ) - - def forward(self, graph: AtomicGraph): - predictions: Tensor = torch.stack( - [w * model(graph) for w, model in zip(self.weights, self.models)] - ).sum(dim=0) - if self.aggregation == "mean": - return predictions / self.weights.sum() - else: - return predictions + def model_specific_pre_fit(self, graphs: LabelledBatch) -> None: + for model in self.models: + model.model_specific_pre_fit(graphs) def __repr__(self): - info = [str(self.models), f"aggregation={self.aggregation}"] - if self.weights.requires_grad: - info.append(f"weights={self.weights.tolist()}") - info = "\n ".join(info) - return f"Ensemble(\n {info}\n)" + model_info = "\n ".join(map(str, self.models)) + return f"{self.__class__.__name__}(\n {model_info}\n)" @overload @@ -329,10 +204,7 @@ def get_predictions( graph: AtomicGraph | AtomicGraphBatch | Sequence[AtomicGraph], *, training: bool = False, -) -> dict[keys.LabelKey, Tensor]: - """test""" - - +) -> dict[keys.LabelKey, Tensor]: ... @overload def get_predictions( model: GraphPESModel, @@ -340,13 +212,7 @@ def get_predictions( *, properties: Sequence[keys.LabelKey], training: bool = False, -) -> dict[keys.LabelKey, Tensor]: - """ - test - """ - ... - - +) -> dict[keys.LabelKey, Tensor]: ... @overload def get_predictions( model: GraphPESModel, @@ -354,10 +220,7 @@ def get_predictions( *, property: keys.LabelKey, training: bool = False, -) -> Tensor: - """test""" - - +) -> Tensor: ... def get_predictions( model: GraphPESModel, graph: AtomicGraph | AtomicGraphBatch | Sequence[AtomicGraph], @@ -434,7 +297,7 @@ def get_predictions( # use the autograd machinery to auto-magically # calculate forces and stress from the energy - with require_grad(graph[keys._POSITIONS]), require_grad(change_to_cell): + with require_grad(graph[keys._POSITIONS], change_to_cell): energy = model(graph) if keys.ENERGY in properties: @@ -444,6 +307,7 @@ def get_predictions( dE_dR = differentiate(energy, graph[keys._POSITIONS]) predictions[keys.FORCES] = -dE_dR + # TODO: check stress vs virial common definition if keys.STRESS in properties: stress = differentiate(energy, change_to_cell) predictions[keys.STRESS] = stress diff --git a/src/graph_pes/data/__init__.py b/src/graph_pes/data/__init__.py index efd88852..200fc350 100644 --- a/src/graph_pes/data/__init__.py +++ b/src/graph_pes/data/__init__.py @@ -1,3 +1,5 @@ +# TODO: split into batching, io etc. +# in readiness for moving to readable dataset approach from __future__ import annotations import warnings @@ -81,6 +83,7 @@ class _AtomicGraph_Impl(dict): + # TODO: remove? def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/src/graph_pes/data/keys.py b/src/graph_pes/data/keys.py index ab6c24e9..e4a7c2f0 100644 --- a/src/graph_pes/data/keys.py +++ b/src/graph_pes/data/keys.py @@ -1,3 +1,6 @@ +# deliberately not using future imports here to appease torchscript? +# TODO: check if this is necessary + from typing import TYPE_CHECKING, Literal # graph properties diff --git a/src/graph_pes/data/utils.py b/src/graph_pes/data/utils.py index 38708eb3..3f2fce5c 100644 --- a/src/graph_pes/data/utils.py +++ b/src/graph_pes/data/utils.py @@ -10,7 +10,9 @@ def random_split( - sequence: Sequence[E], lengths: Sequence[int], seed: int | None = None + sequence: Sequence[E], + lengths: Sequence[int], + seed: int | None = None, ) -> list[list[E]]: """ Randomly split `sequence` into sub-sequences according to `lengths`. diff --git a/src/graph_pes/loss.py b/src/graph_pes/loss.py index ffe0bd89..7a0977ab 100644 --- a/src/graph_pes/loss.py +++ b/src/graph_pes/loss.py @@ -1,19 +1,21 @@ from __future__ import annotations -from typing import Callable +from typing import Callable, NamedTuple import torch from torch import Tensor, nn from .data import LabelledBatch, keys -from .transform import Identity, Transform +from .transform import divide_per_atom class Loss(nn.Module): r""" - Measure the discrepancy between predictions and labels. + Measure the discrepancy between predictions and labels for a given property. - Often, it is convenient to apply some well known loss function, + + + Often, it is convenient to apply some well known loss metric, e.g. `MSELoss`, to a transformed version of the predictions and labels, e.g. normalisation, such that the loss value takes on "nice" values, and that the resulting gradients and parameter updates are well-behaved. @@ -54,13 +56,10 @@ def __init__( self, label: keys.LabelKey, metric: Callable[[Tensor, Tensor], Tensor] | None = None, - transform: Transform | None = None, ): super().__init__() self.property_key: keys.LabelKey = label self.metric = MAE() if metric is None else metric - self.transform = transform or Identity() - self.transform.trainable = False # add type hints to play nicely with mypy def __call__( @@ -87,56 +86,58 @@ def forward( """ return self.metric( - self.transform(predictions[self.property_key], graphs), - self.transform(graphs[self.property_key], graphs), + predictions[self.property_key], + graphs[self.property_key], ) - def fit_transform(self, graphs: LabelledBatch): - """ - Fit the transform to the target labels. + @property + def name(self) -> str: + """Get the name of this loss for logging purposes.""" + return f"{self.property_key}_{get_metric_name(self.metric)}" - Parameters - ---------- - graphs - The graphs containing the labels. - """ + ## Methods for creating weighted losses ## - self.transform.fit_to_target(graphs[self.property_key], graphs) + def __mul__(self, weight: float | int) -> TotalLoss: + if not isinstance(weight, (int, float)): + raise TypeError(f"Cannot multiply Loss and {type(weight)}") - @property - def name(self) -> str: - # if metric is a class, we want the class name otherwise we want - # the function name, all without the word "loss" in it - return ( - getattr( - self.metric, - "__name__", - self.metric.__class__.__name__, - ) - .lower() - .replace("loss", "") - ) + return TotalLoss([self], [weight]) + + def __rmul__(self, weight: float) -> TotalLoss: + if not isinstance(weight, (int, float)): + raise TypeError(f"Cannot multiply Loss and {type(weight)}") + + return TotalLoss([self], [weight]) - def __mul__(self, other: float) -> WeightedLoss: - return WeightedLoss([self], [other]) + def __truediv__(self, weight: float | int) -> TotalLoss: + if not isinstance(weight, (int, float)): + raise TypeError(f"Cannot divide Loss and {type(weight)}") - def __rmul__(self, other: float) -> WeightedLoss: - return WeightedLoss([self], [other]) + return TotalLoss([self], [1 / weight]) - def __add__(self, other: Loss | WeightedLoss) -> WeightedLoss: - if isinstance(other, Loss): - return WeightedLoss([self, other], [1, 1]) - elif isinstance(other, WeightedLoss): - return WeightedLoss([self] + other.losses, [1] + other.weights) + def __add__(self, loss: Loss | TotalLoss) -> TotalLoss: + if isinstance(loss, Loss): + return TotalLoss([self, loss], [1, 1]) + elif isinstance(loss, TotalLoss): + return TotalLoss([self] + loss.losses, [1] + loss.weights) else: - raise TypeError(f"Cannot add Loss and {type(other)}") + raise TypeError(f"Cannot add Loss and {type(loss)}") - def __radd__(self, other: Loss | WeightedLoss) -> WeightedLoss: + def __radd__(self, other: Loss | TotalLoss) -> TotalLoss: return self.__add__(other) -# TODO: callable weights -class WeightedLoss(torch.nn.Module): +class SubLossPair(NamedTuple): + loss_value: torch.Tensor + weighted_loss_value: torch.Tensor + + +class TotalLossResult(NamedTuple): + loss_value: torch.Tensor + components: dict[str, SubLossPair] + + +class TotalLoss(torch.nn.Module): r""" A lightweight wrapper around a collection of weighted losses. @@ -149,7 +150,7 @@ class WeightedLoss(torch.nn.Module): .. code-block:: python - WeightedLoss([Loss("energy"), Loss("forces")], [10, 1]) + WeightedLoss([Loss("energy"), Loss("forces")], weights=[10, 1]) # is equivalent to 10 * Loss("energy") + 1 * Loss("forces") @@ -164,39 +165,108 @@ class WeightedLoss(torch.nn.Module): def __init__( self, losses: list[Loss], - weights: list[float] | None = None, + weights: list[float | int] | None = None, ): super().__init__() self.losses: list[Loss] = nn.ModuleList(losses) # type: ignore self.weights = weights or [1.0] * len(losses) - def __add__(self, other: WeightedLoss) -> WeightedLoss: - return WeightedLoss( + def __add__(self, other: TotalLoss) -> TotalLoss: + return TotalLoss( self.losses + other.losses, self.weights + other.weights ) - def __mul__(self, other: float) -> WeightedLoss: - return WeightedLoss(self.losses, [w * other for w in self.weights]) + def __mul__(self, other: float | int) -> TotalLoss: + if not isinstance(other, (int, float)): + raise TypeError(f"Cannot multiply TotalLoss and {type(other)}") - def __rmul__(self, other: float) -> WeightedLoss: - return WeightedLoss(self.losses, [w * other for w in self.weights]) + return TotalLoss(self.losses, [w * other for w in self.weights]) - def __true_div__(self, other: float) -> WeightedLoss: - return WeightedLoss(self.losses, [w / other for w in self.weights]) + def __rmul__(self, other: float | int) -> TotalLoss: + if not isinstance(other, (int, float)): + raise TypeError(f"Cannot multiply TotalLoss and {type(other)}") - def fit_transform(self, graphs: LabelledBatch): - for loss in self.losses: - loss.fit_transform(graphs) + return TotalLoss(self.losses, [w * other for w in self.weights]) + + def __true_div__(self, other: float | int) -> TotalLoss: + if not isinstance(other, (int, float)): + raise TypeError(f"Cannot divide TotalLoss and {type(other)}") + + return TotalLoss(self.losses, [w / other for w in self.weights]) + + def forward( + self, + predictions: dict[keys.LabelKey, torch.Tensor], + graphs: LabelledBatch, + ) -> TotalLossResult: + """ + Computes the total loss value. + + Parameters + ---------- + predictions + The predictions from the model. + graphs + The graphs containing the labels. + """ + + total_loss = torch.scalar_tensor(0.0, device=self.device) + components: dict[str, SubLossPair] = {} + + for loss, weight in zip(self.losses, self.weights): + loss_value = loss(predictions, graphs) + weighted_loss_value = loss_value * weight + + total_loss += weighted_loss_value + components[loss.name] = SubLossPair(loss_value, weighted_loss_value) + + return TotalLossResult(total_loss, components) + + # add type hints to appease mypy + def __call__( + self, + predictions: dict[keys.LabelKey, torch.Tensor], + graphs: LabelledBatch, + ) -> TotalLossResult: + return super().__call__(predictions, graphs) + + +class PerAtomEnergyLoss(Loss): + def __init__( + self, + metric: Callable[[Tensor, Tensor], Tensor] | None = None, + ): + super().__init__(keys.ENERGY, metric) def forward( self, predictions: dict[keys.LabelKey, torch.Tensor], graphs: LabelledBatch, ) -> torch.Tensor: - return sum( - w * loss(predictions, graphs) - for w, loss in zip(self.weights, self.losses) - ) # type: ignore + return divide_per_atom(super().forward(predictions, graphs), graphs) + + @property + def name(self) -> str: + return f"per_atom_energy_{get_metric_name(self.metric)}" + + +def get_metric_name(metric: Callable[[Tensor, Tensor], Tensor]) -> str: + # if metric is a function, we want the function's name, otherwise + # we want the metric's class name, all lowercased + # and without the word "loss" in it + + return ( + getattr( + metric, + "__name__", + metric.__class__.__name__, + ) + .lower() + .replace("loss", "") + ) + + +## METRICS ## class RMSE(torch.nn.MSELoss): diff --git a/src/graph_pes/models/distances.py b/src/graph_pes/models/distances.py index 81dbe227..0e13eef9 100644 --- a/src/graph_pes/models/distances.py +++ b/src/graph_pes/models/distances.py @@ -33,6 +33,8 @@ class DistanceExpansion(nn.Module, ABC): def __init__(self, n_features: int, cutoff: float, trainable: bool = True): super().__init__() self.n_features = n_features + # TODO: check serialization - need to register as buffer to be included + # in state_dict? self.cutoff = cutoff self.trainable = trainable @@ -117,7 +119,7 @@ class Bessel(DistanceExpansion): """ def __init__(self, n_features: int, cutoff: float, trainable: bool = True): - super().__init__(n_features, cutoff) + super().__init__(n_features, cutoff, trainable) self.frequencies = nn.Parameter( torch.arange(1, n_features + 1) * math.pi / cutoff, requires_grad=trainable, diff --git a/src/graph_pes/models/offsets.py b/src/graph_pes/models/offsets.py new file mode 100644 index 00000000..ac99179d --- /dev/null +++ b/src/graph_pes/models/offsets.py @@ -0,0 +1,44 @@ +from __future__ import annotations + +from graph_pes.core import GraphPESModel +from graph_pes.data import AtomicGraph +from graph_pes.nn import PerElementParameter +from torch import Tensor + + +class EnergyOffset(GraphPESModel): + r""" + A model that predicts energy offsets: + + .. math:: + E(\mathcal{G}) = \sum_i \varepsilon_{Z_i} + + where :math:`\varepsilon_{Z_i}` is the energy offset for atomic species + :math:`Z_i`. + + Parameters + ---------- + fixed_values + A dictionary of fixed energy offsets for each atomic species. + trainable + Whether the energy offsets are trainable parameters. + """ + + def __init__( + self, + values: dict[str, float] | None = None, + trainable: bool = False, + ): + super().__init__() + + if values is None and trainable is False: + raise ValueError("Must provide values or set trainable to True") + + self.offsets = PerElementParameter.of_length( + 1, + default_value=0.0, + requires_grad=trainable, + ) + + def predict_local_energies(self, graph: AtomicGraph) -> Tensor: + return self.shift[graph["atomic_numbers"]] diff --git a/src/graph_pes/models/painn.py b/src/graph_pes/models/painn.py index f89de46b..4bf94009 100644 --- a/src/graph_pes/models/painn.py +++ b/src/graph_pes/models/painn.py @@ -9,7 +9,6 @@ number_of_atoms, ) from graph_pes.nn import MLP, HaddamardProduct, PerElementEmbedding -from graph_pes.transform import Transform from torch import Tensor, nn from .distances import Bessel, PolynomialEnvelope @@ -192,9 +191,8 @@ def __init__( radial_features: int = 20, layers: int = 3, cutoff: float = 5.0, - energy_transform: Transform | None = None, ): - super().__init__(energy_transform) + super().__init__() self.internal_dim = internal_dim self.layers = layers self.interactions: list[Interaction] = nn.ModuleList( diff --git a/src/graph_pes/models/pairwise.py b/src/graph_pes/models/pairwise.py index e01c966d..6726fba5 100644 --- a/src/graph_pes/models/pairwise.py +++ b/src/graph_pes/models/pairwise.py @@ -12,7 +12,6 @@ sum_over_neighbours, ) from graph_pes.nn import PerElementParameter -from graph_pes.transform import PerAtomShift, Transform from graph_pes.util import pytorch_repr, to_significant_figures from jaxtyping import Float from torch import Tensor @@ -123,12 +122,8 @@ def __init__( self, epsilon: float = 0.1, sigma: float = 1.0, - energy_transform: Transform | None = None, ): - # epsilon is a scaling term, so only need to learn a shift - if energy_transform is None: - energy_transform = PerAtomShift() - super().__init__(energy_transform) + super().__init__() self._log_epsilon = torch.nn.Parameter(torch.tensor(epsilon).log()) self._log_sigma = torch.nn.Parameter(torch.tensor(sigma).log()) @@ -156,7 +151,7 @@ def interaction(self, r: torch.Tensor, Z_i=None, Z_j=None): x = self.sigma / r return 4 * self.epsilon * (x**12 - x**6) - def _extra_pre_fit(self, graph: AtomicGraphBatch): + def model_specific_pre_fit(self, graph: AtomicGraphBatch): # set the distance at which the potential is zero to be # close to the minimum pair-wise distance d = torch.quantile(neighbour_distances(graph), 0.01) @@ -164,7 +159,7 @@ def _extra_pre_fit(self, graph: AtomicGraphBatch): def __repr__(self): return pytorch_repr( - "LennardJones", + self.__class__.__name__, _modules={ "epsilon": to_significant_figures(self.epsilon.item(), 3), "sigma": to_significant_figures(self.sigma.item(), 3), @@ -206,17 +201,8 @@ class Morse(PairPotential): :align: center """ - def __init__( - self, - D: float = 0.1, - a: float = 5.0, - r0: float = 1.5, - energy_transform: Transform | None = None, - ): - # D is a scaling term, so only need to learn a shift - if energy_transform is None: - energy_transform = PerAtomShift() - super().__init__(energy_transform) + def __init__(self, D: float = 0.1, a: float = 5.0, r0: float = 1.5): + super().__init__() self._log_D = torch.nn.Parameter(torch.tensor(D).log()) self._log_a = torch.nn.Parameter(torch.tensor(a).log()) @@ -249,7 +235,7 @@ def interaction(self, r: torch.Tensor, Z_i=None, Z_j=None): """ return self.D * (1 - torch.exp(-self.a * (r - self.r0))) ** 2 - def _extra_pre_fit(self, graph: AtomicGraphBatch): + def model_specific_pre_fit(self, graph: AtomicGraphBatch): # set the center of the well to be close to the minimum pair-wise # distance: the 10th percentile plus a small offset d = torch.quantile(neighbour_distances(graph), 0.1) + 0.1 @@ -267,6 +253,7 @@ def __repr__(self): ) +# TODO: improve this class class LennardJonesMixture(PairPotential): r""" An extension of the simple :class:`LennardJones` potential to @@ -290,10 +277,13 @@ class LennardJonesMixture(PairPotential): """ def __init__(self, modulate_distances: bool = True): - super().__init__(energy_transform=PerAtomShift()) + super().__init__() + + self.modulate_distances: Tensor self.register_buffer( "modulate_distances", torch.tensor(modulate_distances) ) + self.epsilon = PerElementParameter.of_length(1, default_value=0.1) self.sigma = PerElementParameter.covalent_radii(scaling_factor=0.9) self.nu = PerElementParameter.of_length( @@ -349,4 +339,4 @@ def __repr__(self): if self.modulate_distances: modules["nu"] = self.nu - return pytorch_repr("LennardJonesMixture", _modules=modules) + return pytorch_repr(self.__class__.__name__, _modules=modules) diff --git a/src/graph_pes/models/schnet.py b/src/graph_pes/models/schnet.py index 15fc9c89..4a8b0d2d 100644 --- a/src/graph_pes/models/schnet.py +++ b/src/graph_pes/models/schnet.py @@ -4,7 +4,6 @@ from graph_pes.core import GraphPESModel from graph_pes.data import AtomicGraph, neighbour_distances from graph_pes.nn import MLP, PerElementEmbedding, ShiftedSoftplus -from graph_pes.transform import Transform from torch import Tensor, nn from torch_geometric.nn import MessagePassing @@ -214,9 +213,8 @@ def __init__( cutoff: float = 5.0, layers: int = 3, expansion: type[DistanceExpansion] | None = None, - energy_transform: Transform | None = None, ): - super().__init__(energy_transform) + super().__init__() if expansion is None: expansion = GaussianSmearing diff --git a/src/graph_pes/models/tensornet.py b/src/graph_pes/models/tensornet.py index 446993b8..4b025ddb 100644 --- a/src/graph_pes/models/tensornet.py +++ b/src/graph_pes/models/tensornet.py @@ -1,7 +1,6 @@ from __future__ import annotations import torch -from graph_pes.transform import Transform from torch import Tensor, nn from ..core import GraphPESModel @@ -321,9 +320,8 @@ def __init__( embedding_size: int = 32, cutoff: float = 5.0, layers: int = 1, - energy_transform: Transform | None = None, ): - super().__init__(energy_transform) + super().__init__() self.embedding = Embedding(radial_features, embedding_size, cutoff) self.interactions: list[Interaction] = nn.ModuleList( [ diff --git a/src/graph_pes/models/zoo.py b/src/graph_pes/models/zoo.py index f79e7c2d..7309f80a 100644 --- a/src/graph_pes/models/zoo.py +++ b/src/graph_pes/models/zoo.py @@ -16,4 +16,5 @@ "LennardJonesMixture", ] +# TODO: nicer way to do this? ALL_MODELS: list[type[GraphPESModel]] = [globals()[model] for model in __all__] diff --git a/src/graph_pes/training.py b/src/graph_pes/training.py index 910676b1..690995fe 100644 --- a/src/graph_pes/training.py +++ b/src/graph_pes/training.py @@ -2,7 +2,7 @@ import logging from pathlib import Path -from typing import Callable, TypeVar +from typing import Callable, Literal, TypeVar import pytorch_lightning as pl import torch @@ -26,8 +26,7 @@ number_of_structures, to_batch, ) -from .loss import RMSE, Loss, WeightedLoss -from .transform import DividePerAtom, PerAtomScale, PerAtomStandardScaler, Scale +from .loss import RMSE, Loss, PerAtomEnergyLoss, TotalLoss T = TypeVar("T", bound=GraphPESModel) @@ -38,7 +37,7 @@ def train_model( val_data: list[LabelledGraph] | None = None, optimizer: Callable[[T], torch.optim.Optimizer | OptimizerLRSchedulerConfig] | None = None, - loss: WeightedLoss | Loss | None = None, + loss: TotalLoss | Loss | None = None, *, batch_size: int = 32, pre_fit_model: bool = True, @@ -56,10 +55,11 @@ def train_model( train_batch = to_batch(train_data) # process and validate the loss - total_loss = process_loss(loss, train_data[0]) + available_properties = get_existing_properties(train_batch) + total_loss = process_loss(loss, available_properties) training_on = [component.property_key for component in total_loss.losses] for prop in training_on: - if prop not in get_existing_properties(train_data[0]): + if prop not in available_properties: raise ValueError( f"Can't train on {prop} without the corresponding data" ) @@ -88,12 +88,11 @@ def train_model( else None ) - # deal with fitting transforms - if pre_fit_model and keys.ENERGY in training_on: + # pre-fit first, since users might naively change the model's parameters + if pre_fit_model: model.pre_fit(train_batch) - total_loss.fit_transform(train_batch) - # deal with the optimizer + # then create the optimizer if optimizer is None: opt = torch.optim.Adam(model.parameters(), lr=3e-4) else: @@ -112,7 +111,7 @@ def train_model( device = trainer.accelerator.__class__.__name__.replace("Accelerator", "") print(f"Training on : {training_on}") print(f"# of params : {params}") - print(f"Device : {device}") + print(f" Device : {device}") print() # train @@ -133,8 +132,8 @@ def __init__( self, model: GraphPESModel, optimizer: torch.optim.Optimizer | OptimizerLRSchedulerConfig, - total_loss: WeightedLoss, - validation_metrics: dict[str, Loss] | None = None, + total_loss: TotalLoss, + validation_metrics: list[Loss] | None = None, ): super().__init__() self.model = model @@ -145,34 +144,36 @@ def __init__( ] if validation_metrics is None: - validation_metrics = {} + validation_metrics = [] + + existing_loss_names = [l.name for l in total_loss.losses] + if keys.ENERGY in self.properties: - validation_metrics["per_energy_rmse"] = Loss( - "energy", RMSE(), DividePerAtom() - ) - if keys.FORCES in self.properties: - validation_metrics["force_rmse"] = Loss( - "forces", RMSE(), PerAtomScale() - ) + pael = PerAtomEnergyLoss() + if pael.name not in existing_loss_names: + validation_metrics.append(PerAtomEnergyLoss()) - # don't double log the force RMSE if its already in the total loss - for loss in total_loss.losses: - if loss.property_key == "forces" and isinstance( - loss.metric, RMSE - ): - validation_metrics.pop("force_rmse") + if keys.FORCES in self.properties: + fr = Loss("forces", RMSE()) + if fr.name not in existing_loss_names: + validation_metrics.append(fr) self.validation_metrics = validation_metrics def forward(self, graphs: AtomicGraphBatch) -> torch.Tensor: return self.model(graphs) - def _step(self, graph: LabelledBatch, prefix: str): + def _step(self, graph: LabelledBatch, prefix: Literal["train", "val"]): """ Get (and log) the losses for a training/validation step. """ - def log(name, value): + def log(name: str, value: torch.Tensor | float): + # TODO: revamp logging options + + if isinstance(value, torch.Tensor): + value = value.item() + return self.log( f"{prefix}_{name}", value, @@ -187,27 +188,24 @@ def log(name, value): self.model, graph, properties=self.properties, training=True ) - # compute the losses - total_loss = torch.scalar_tensor(0.0, device=self.device) + # compute the loss and its sub-components + total_loss_result = self.total_loss(predictions, graph) - for loss, weight in zip( - self.total_loss.losses, self.total_loss.weights - ): - value = loss(predictions, graph) - # log the unweighted components of the loss - log(f"{loss.property_key}_{loss.name}", value) - # but weight them when computing the total loss - total_loss = total_loss + weight * value + # log + log("total_loss", total_loss_result.loss_value) + + for name, loss_pair in total_loss_result.components.items(): + log(name, loss_pair.loss_value) + log(f"{name}_weighted", loss_pair.weighted_loss_value) # log additional values during validation if prefix == "val": with torch.no_grad(): - for name, loss in self.validation_metrics.items(): - value = loss(predictions, graph) - log(name, value) + for val_loss in self.validation_metrics: + value = val_loss(predictions, graph) + log(val_loss.name, value) - log("total_loss", total_loss) - return total_loss + return total_loss_result.loss_value def training_step(self, structure: LabelledBatch, _): return self._step(structure, "train") @@ -215,6 +213,7 @@ def training_step(self, structure: LabelledBatch, _): def validation_step(self, structure: LabelledBatch, _): return self._step(structure, "val") + # TODO move this to be a factory def configure_optimizers(self): return self.optimizer @@ -249,36 +248,24 @@ def load_best_weights( def process_loss( - loss: WeightedLoss | Loss | None, graph: LabelledGraph -) -> WeightedLoss: - if isinstance(loss, WeightedLoss): + loss: TotalLoss | Loss | None, + available_properties: list[keys.LabelKey], +) -> TotalLoss: + if isinstance(loss, TotalLoss): return loss - elif isinstance(loss, Loss): - return WeightedLoss([loss], [1.0]) - default_transforms = { - keys.ENERGY: PerAtomStandardScaler(), - keys.FORCES: PerAtomScale(), - keys.STRESS: Scale(), - } - default_weights = { - keys.ENERGY: 1.0, - keys.FORCES: 1.0, - keys.STRESS: 1.0, - } + elif isinstance(loss, Loss): + return TotalLoss([loss], [1.0]) - available_properties = get_existing_properties(graph) + if loss is not None: + raise ValueError( + "Invalid loss: must be a TotalLoss, a Loss, or None. " + f"Got {type(loss)}" + ) - return WeightedLoss( - [ - Loss( - key, - metric=RMSE(), - transform=default_transforms[key], - ) - for key in available_properties - ], - [default_weights[key] for key in available_properties], + return TotalLoss( + [Loss(key, metric=RMSE()) for key in available_properties], + [1.0 for key in available_properties], ) @@ -296,6 +283,7 @@ def default_trainer_kwargs() -> dict: save_top_k=1, save_weights_only=True, ), + # TODO: this is gimicky: only use in notebooks? RichProgressBar(), ], } @@ -314,9 +302,13 @@ def device_info_filter(record): ) +## Optimizers ## + + def Adam( lr: float = 3e-4, weight_decay: float = 0.0, + # TODO: deal with this energy_transform_overrides: dict | None = None, ) -> Callable[[GraphPESModel], torch.optim.Optimizer]: if energy_transform_overrides is None: diff --git a/src/graph_pes/transform.py b/src/graph_pes/transform.py index ac477370..27b5e6d2 100644 --- a/src/graph_pes/transform.py +++ b/src/graph_pes/transform.py @@ -1,3 +1,5 @@ +# TODO: be more concrete throughout with this + from __future__ import annotations from abc import ABC, abstractmethod @@ -514,7 +516,15 @@ def __init__(self): super().__init__(trainable=False) def forward(self, x: Tensor, graph: AtomicGraph) -> Tensor: - return left_aligned_div(x, structure_sizes(graph)) + return divide_per_atom(x, graph) def inverse(self, y: Tensor, graph: AtomicGraph) -> Tensor: - return left_aligned_mul(y, structure_sizes(graph)) + return times_by_n_atoms(y, graph) + + +def divide_per_atom(x: Tensor, graph: AtomicGraph) -> Tensor: + return left_aligned_div(x, structure_sizes(graph)) + + +def times_by_n_atoms(x: Tensor, graph: AtomicGraph) -> Tensor: + return left_aligned_mul(x, structure_sizes(graph)) diff --git a/src/graph_pes/util.py b/src/graph_pes/util.py index 5a7f538d..4aa83cec 100644 --- a/src/graph_pes/util.py +++ b/src/graph_pes/util.py @@ -82,7 +82,7 @@ def differentiate(y: torch.Tensor, x: torch.Tensor): @contextmanager -def require_grad(tensor: torch.Tensor): +def require_grad(*tensors: torch.Tensor): # check if in a torch.no_grad() context: if so, # raise an error if not torch.is_grad_enabled(): @@ -92,10 +92,12 @@ def require_grad(tensor: torch.Tensor): "a torch.enable_grad() context." ) - req_grad = tensor.requires_grad - tensor.requires_grad_(True) + original = [tensor.requires_grad for tensor in tensors] + for tensor in tensors: + tensor.requires_grad_(True) yield - tensor.requires_grad_(req_grad) + for tensor, req_grad in zip(tensors, original): + tensor.requires_grad_(req_grad) def to_significant_figures(x: float | int, sf: int = 3) -> float: diff --git a/tests/test_loss.py b/tests/test_loss.py index a7423e7a..50573234 100644 --- a/tests/test_loss.py +++ b/tests/test_loss.py @@ -1,7 +1,7 @@ from __future__ import annotations import torch -from graph_pes.loss import MAE, RMSE, Loss, WeightedLoss +from graph_pes.loss import MAE, RMSE, Loss, TotalLoss def test_metrics(): @@ -21,6 +21,6 @@ def test_loss_ops(): l2 = Loss("forces") weighted = 10 * l1 + l2 * 1 - assert isinstance(weighted, WeightedLoss) + assert isinstance(weighted, TotalLoss) assert set(weighted.losses) == {l1, l2} assert set(weighted.weights) == {10, 1} diff --git a/tests/test_models.py b/tests/test_models.py index cd32e306..dc55b2e4 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -93,7 +93,9 @@ def predict_local_energies( ) -> torch.Tensor: return torch.ones(number_of_atoms(graph)) - def _extra_pre_fit(self, graphs: AtomicGraphBatch) -> bool | None: + def model_specific_pre_fit( + self, graphs: AtomicGraphBatch + ) -> bool | None: return ret_value # noqa: B023 model = DummyModel()