From 7723cb3b8f5a8c8588014761df55cdb50e18de9f Mon Sep 17 00:00:00 2001 From: John Gardner Date: Fri, 9 Feb 2024 13:51:15 +0000 Subject: [PATCH] overload model prediction --- src/graph_pes/analysis.py | 4 +- src/graph_pes/core.py | 172 ++++++++++++++++++++--------- src/graph_pes/data/atomic_graph.py | 3 + src/graph_pes/loss.py | 1 + src/graph_pes/training.py | 5 +- tests/test_integration.py | 4 +- tests/test_models.py | 26 +++++ tests/test_predictions.py | 2 +- 8 files changed, 159 insertions(+), 58 deletions(-) create mode 100644 tests/test_models.py diff --git a/src/graph_pes/analysis.py b/src/graph_pes/analysis.py index 599037c5..65c500b2 100644 --- a/src/graph_pes/analysis.py +++ b/src/graph_pes/analysis.py @@ -159,9 +159,7 @@ def parity_plot( ground_truth = transform(graphs[property_label], graphs).detach() predictions = transform( - # TODO: use overload - model.predict(graphs, [property])[property], - graphs, + model.predict(graphs, property=property), graphs ).detach() # plot diff --git a/src/graph_pes/core.py b/src/graph_pes/core.py index 060f3113..36f270cb 100644 --- a/src/graph_pes/core.py +++ b/src/graph_pes/core.py @@ -1,20 +1,14 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import Literal, Sequence +from typing import Literal, Sequence, overload import torch from graph_pes.data import AtomicGraph from graph_pes.data.batching import AtomicGraphBatch, sum_per_structure -from graph_pes.transform import ( - Chain, - Identity, - PerAtomScale, - PerAtomShift, - Transform, -) +from graph_pes.transform import Identity, PerAtomStandardScaler, Transform from graph_pes.util import Property, PropertyKey, differentiate, require_grad -from jaxtyping import Float +from jaxtyping import Float # TODO: use this throughout from torch import Tensor, nn @@ -30,23 +24,16 @@ class GraphPESModel(nn.Module, ABC): To create such a model, implement :meth:`predict_local_energies`, which takes an :class:`AtomicGraph`, or an :class:`AtomicGraphBatch`, and returns a per-atom prediction of the local energy. For a simple example, - see :class:`LennardJones `. + see the :class:`PairPotential ` + `implementation <_modules/graph_pes/models/pairwise.html#PairPotential>`_. Under the hood, :class:`GraphPESModel` contains an :class:`EnergySummation` module, which is responsible for summing over local energies to obtain the total energy/ies, with optional transformations of the local and total energies. By default, this learns a per-species, local energy offset and scale. - - .. note:: - All :class:`GraphPESModel` instances are also instances of - :class:`torch.nn.Module`. This allows for easy optimisation - of parameters, and automated save/load functionality. """ - # TODO: fix this for the case of an isolated atom, either by itself - # or within a batch: perhaps that should go in sum_per_structure? - # or maybe default to a local scale followed by a global peratomshift? @abstractmethod def predict_local_energies( self, graph: AtomicGraph | AtomicGraphBatch @@ -81,14 +68,6 @@ def __init__(self): self.energy_summation = EnergySummation() def __add__(self, other: GraphPESModel) -> Ensemble: - """ - A convenient way to create a summation of two models. - - Examples - -------- - >>> TwoBody() + ThreeBody() - Ensemble([TwoBody(), ThreeBody()], aggregation=sum) - """ return Ensemble([self, other], aggregation="sum") def pre_fit(self, graphs: AtomicGraphBatch): @@ -101,6 +80,11 @@ def pre_fit(self, graphs: AtomicGraphBatch): output by the underlying model will result in energy predictions that are distributed according to the training data. + For an example customisation of this method, see the + :class:`LennardJones ` + `implementation + <_modules/graph_pes/models/pairwise.html#LennardJones>`_. + Parameters ---------- graphs @@ -108,16 +92,47 @@ def pre_fit(self, graphs: AtomicGraphBatch): """ self.energy_summation.fit_to_graphs(graphs) - # TODO: overload to get single property if passed + @overload + def predict( + self, + graph: AtomicGraph | AtomicGraphBatch | list[AtomicGraph], + *, + training: bool = False, + ) -> dict[PropertyKey, Tensor]: + ... + + @overload + def predict( + self, + graph: AtomicGraph | AtomicGraphBatch | list[AtomicGraph], + *, + properties: Sequence[PropertyKey], + training: bool = False, + ) -> dict[PropertyKey, Tensor]: + ... + + @overload + def predict( + self, + graph: AtomicGraph | AtomicGraphBatch | list[AtomicGraph], + *, + property: PropertyKey, + training: bool = False, + ) -> Tensor: + ... + # TODO: implement max batch size def predict( self, graph: AtomicGraph | AtomicGraphBatch | list[AtomicGraph], - properties: Sequence[PropertyKey] | None = None, # type: ignore + *, + properties: Sequence[PropertyKey] | None = None, + property: PropertyKey | None = None, training: bool = False, - ) -> dict[PropertyKey, torch.Tensor]: + ) -> dict[PropertyKey, Tensor] | Tensor: """ - Evaluate the model on the given structure to get the labels requested. + Evaluate the model on the given structure to get + the properties requested. Parameters ---------- @@ -128,8 +143,12 @@ def predict( :code:`[Property.ENERGY, Property.FORCES]` if the structure has no cell, and :code:`[Property.ENERGY, Property.FORCES, Property.STRESS]` if it does. + property + The property to predict. Can't be used when :code:`properties` + is also provided. training - Whether the model is currently being trained. + Whether the model is currently being trained. If :code:`False`, + the gradients of the predictions will be detached. Returns ------- @@ -138,24 +157,31 @@ def predict( Examples -------- - >>> # TODO - + >>> model.predict(graph_pbc) + {'energy': tensor(-12.3), 'forces': tensor(...), 'stress': tensor(...)} + >>> model.predict(graph_no_pbc) + {'energy': tensor(-12.3), 'forces': tensor(...)} + >>> model.predict(graph_pbc, property="energy") + tensor(-12.3) """ + # check correctly called + if property is not None and properties is not None: + raise ValueError("Can't specify both `property` and `properties`") + if isinstance(graph, list): graph = AtomicGraphBatch.from_graphs(graph) if properties is None: - properties: list[PropertyKey] = [Property.ENERGY, Property.FORCES] if graph.has_cell: - properties.append(Property.STRESS) - # elif isinstance(properties, str): - # properties = [properties] + properties = [Property.ENERGY, Property.FORCES, Property.STRESS] + else: + properties = [Property.ENERGY, Property.FORCES] if Property.STRESS in properties and not graph.has_cell: raise ValueError("Can't predict stress without cell information.") - predictions = {} + predictions: dict[PropertyKey, Tensor] = {} # setup for calculating stress: if Property.STRESS in properties: @@ -193,10 +219,31 @@ def predict( if not training: for key, value in predictions.items(): predictions[key] = value.detach() + + if property is not None: + return predictions[property] + return predictions class EnergySummation(nn.Module): + """ + A module for summing local energies to obtain the total energy. + + Before summation, :code:`local_transform` is applied to the local energies. + After summation, :code:`total_transform` is applied to the total energy. + + By default, :code:`EnergySummation()` learns a per-species, local energy + offset and scale. + + Parameters + ---------- + local_transform + A transformation of the local energies. + total_transform + A transformation of the total energy. + """ + def __init__( self, local_transform: Transform | None = None, @@ -206,19 +253,35 @@ def __init__( # if both None, default to a per-species, local energy offset if local_transform is None and total_transform is None: - local_transform = Chain( - [PerAtomShift(), PerAtomScale()], trainable=True - ) + local_transform = PerAtomStandardScaler() self.local_transform: Transform = local_transform or Identity() self.total_transform: Transform = total_transform or Identity() def forward(self, local_energies: torch.Tensor, graph: AtomicGraphBatch): + """ + Sum the local energies to obtain the total energy. + + Parameters + ---------- + local_energies + The local energies. + graph + The graph representation of the structure/s. + """ local_energies = self.local_transform.inverse(local_energies, graph) total_E = sum_per_structure(local_energies, graph) total_E = self.total_transform.inverse(total_E, graph) return total_E def fit_to_graphs(self, graphs: AtomicGraphBatch | list[AtomicGraph]): + """ + Fit the transforms to the training data. + + Parameters + ---------- + graphs + The training data. + """ if not isinstance(graphs, AtomicGraphBatch): graphs = AtomicGraphBatch.from_graphs(graphs) @@ -256,17 +319,26 @@ class Ensemble(GraphPESModel): Examples -------- + Create a model with explicit two-body and multi-body terms: - >>> from graph_pes.models.pairwise import LennardJones - >>> from graph_pes.models.schnet import SchNet - >>> from graph_pes.core import Ensemble - >>> # create an ensemble of two models - >>> # equivalent to Ensemble([LennardJones(), SchNet()], aggregation="sum") - >>> ensemble = LennardJones() + SchNet() + .. code-block:: python - See Also - -------- - :meth:`GraphPESModel.__add__` + from graph_pes.models.pairwise import LennardJones + from graph_pes.models.schnet import SchNet + from graph_pes.core import Ensemble + + # create an ensemble of two models + # equivalent to Ensemble([LennardJones(), SchNet()], aggregation="sum") + ensemble = LennardJones() + SchNet() + + Use several models to get an average prediction: + + .. code-block:: python + + models = ... # load/train your models + ensemble = Ensemble(models, aggregation="mean") + predictions = ensemble.predict(test_graphs) + ... """ def __init__( diff --git a/src/graph_pes/data/atomic_graph.py b/src/graph_pes/data/atomic_graph.py index 47acbb62..6516d2eb 100644 --- a/src/graph_pes/data/atomic_graph.py +++ b/src/graph_pes/data/atomic_graph.py @@ -375,6 +375,9 @@ def extract_tensors( return tensor_dict +# TODO: move to being class method? +# TODO: generalised edge creation? i.e. not just cutoff, but arbitrary method +# and can then have radius_cutoff class, k nearest neighbours etc. def convert_to_atomic_graphs( structures: Iterable[ase.Atoms] | ase.Atoms, cutoff: float, diff --git a/src/graph_pes/loss.py b/src/graph_pes/loss.py index 2e1b3b9e..7061e763 100644 --- a/src/graph_pes/loss.py +++ b/src/graph_pes/loss.py @@ -149,6 +149,7 @@ def __radd__(self, other: Loss | WeightedLoss) -> WeightedLoss: return self.__add__(other) +# TODO: callable weights class WeightedLoss(torch.nn.Module): r""" A lightweight wrapper around a collection of weighted losses. diff --git a/src/graph_pes/training.py b/src/graph_pes/training.py index 777c114e..0866b21a 100644 --- a/src/graph_pes/training.py +++ b/src/graph_pes/training.py @@ -75,7 +75,6 @@ def train_model( ) # deal with fitting transforms - # TODO: what if not training on energy? if pre_fit_model and Property.ENERGY in training_on: model.pre_fit(train_batch) total_loss.fit_transform(train_batch) @@ -147,7 +146,9 @@ def log(name, value, verbose=True): ) # generate prediction: - predictions = self.model.predict(graph, self.properties, training=True) + predictions = self.model.predict( + graph, properties=self.properties, training=True + ) # compute the losses total_loss = torch.scalar_tensor(0.0, device=self.device) diff --git a/tests/test_integration.py b/tests/test_integration.py index b0cf6520..1038566d 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -13,7 +13,7 @@ def test_integration(): model = LennardJones() loss = Loss("energy") - before = loss(model.predict(batch, ["energy"]), batch) + before = loss(model(batch), batch) train_model( model, @@ -25,6 +25,6 @@ def test_integration(): callbacks=[], ) - after = loss(model.predict(batch, ["energy"]), batch) + after = loss(model(batch), batch) assert after < before, "training did not improve the loss" diff --git a/tests/test_models.py b/tests/test_models.py new file mode 100644 index 00000000..991583a2 --- /dev/null +++ b/tests/test_models.py @@ -0,0 +1,26 @@ +from ase import Atoms +from ase.io import read +from graph_pes.data import convert_to_atomic_graph, convert_to_atomic_graphs +from graph_pes.models.pairwise import LennardJones + +structures = read("tests/test.xyz", ":") +graphs = convert_to_atomic_graphs(structures, cutoff=3) + + +def test_model(): + model = LennardJones() + predictions = model.predict(graphs) + assert "energy" in predictions + assert "forces" in predictions + assert "stress" in predictions and graphs[0].has_cell + assert predictions["energy"].shape == (len(graphs),) + assert predictions["stress"].shape == (len(graphs), 3, 3) + + +def test_isolated_atom(): + atom = Atoms("He", positions=[[0, 0, 0]]) + graph = convert_to_atomic_graph(atom, cutoff=3) + assert graph.n_atoms == 1 and graph.n_edges == 0 + + model = LennardJones() + assert model(graph) == 0 diff --git a/tests/test_predictions.py b/tests/test_predictions.py index 30f40b8d..a836fd06 100644 --- a/tests/test_predictions.py +++ b/tests/test_predictions.py @@ -34,7 +34,7 @@ def test_predictions(): # if we ask for stress, we get an error: with pytest.raises(ValueError): - model.predict(no_pbc, ["stress"]) + model.predict(no_pbc, property="stress") # with pbc structures, we should get all three predictions predictions = model.predict(pbc)