Skip to content

Commit

Permalink
overload model prediction
Browse files Browse the repository at this point in the history
  • Loading branch information
jla-gardner committed Feb 9, 2024
1 parent d09e6ae commit 7723cb3
Show file tree
Hide file tree
Showing 8 changed files with 159 additions and 58 deletions.
4 changes: 1 addition & 3 deletions src/graph_pes/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,9 +159,7 @@ def parity_plot(

ground_truth = transform(graphs[property_label], graphs).detach()
predictions = transform(
# TODO: use overload
model.predict(graphs, [property])[property],
graphs,
model.predict(graphs, property=property), graphs
).detach()

# plot
Expand Down
172 changes: 122 additions & 50 deletions src/graph_pes/core.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,14 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import Literal, Sequence
from typing import Literal, Sequence, overload

import torch
from graph_pes.data import AtomicGraph
from graph_pes.data.batching import AtomicGraphBatch, sum_per_structure
from graph_pes.transform import (
Chain,
Identity,
PerAtomScale,
PerAtomShift,
Transform,
)
from graph_pes.transform import Identity, PerAtomStandardScaler, Transform
from graph_pes.util import Property, PropertyKey, differentiate, require_grad
from jaxtyping import Float
from jaxtyping import Float # TODO: use this throughout
from torch import Tensor, nn


Expand All @@ -30,23 +24,16 @@ class GraphPESModel(nn.Module, ABC):
To create such a model, implement :meth:`predict_local_energies`,
which takes an :class:`AtomicGraph`, or an :class:`AtomicGraphBatch`,
and returns a per-atom prediction of the local energy. For a simple example,
see :class:`LennardJones <graph_pes.models.pairwise.LennardJones>`.
see the :class:`PairPotential <graph_pes.models.pairwise.PairPotential>`
`implementation <_modules/graph_pes/models/pairwise.html#PairPotential>`_.
Under the hood, :class:`GraphPESModel` contains an
:class:`EnergySummation` module, which is responsible for
summing over local energies to obtain the total energy/ies,
with optional transformations of the local and total energies.
By default, this learns a per-species, local energy offset and scale.
.. note::
All :class:`GraphPESModel` instances are also instances of
:class:`torch.nn.Module`. This allows for easy optimisation
of parameters, and automated save/load functionality.
"""

# TODO: fix this for the case of an isolated atom, either by itself
# or within a batch: perhaps that should go in sum_per_structure?
# or maybe default to a local scale followed by a global peratomshift?
@abstractmethod
def predict_local_energies(
self, graph: AtomicGraph | AtomicGraphBatch
Expand Down Expand Up @@ -81,14 +68,6 @@ def __init__(self):
self.energy_summation = EnergySummation()

def __add__(self, other: GraphPESModel) -> Ensemble:
"""
A convenient way to create a summation of two models.
Examples
--------
>>> TwoBody() + ThreeBody()
Ensemble([TwoBody(), ThreeBody()], aggregation=sum)
"""
return Ensemble([self, other], aggregation="sum")

def pre_fit(self, graphs: AtomicGraphBatch):
Expand All @@ -101,23 +80,59 @@ def pre_fit(self, graphs: AtomicGraphBatch):
output by the underlying model will result in energy predictions
that are distributed according to the training data.
For an example customisation of this method, see the
:class:`LennardJones <graph_pes.models.pairwise.LennardJones>`
`implementation
<_modules/graph_pes/models/pairwise.html#LennardJones>`_.
Parameters
----------
graphs
The training data.
"""
self.energy_summation.fit_to_graphs(graphs)

# TODO: overload to get single property if passed
@overload
def predict(
self,
graph: AtomicGraph | AtomicGraphBatch | list[AtomicGraph],
*,
training: bool = False,
) -> dict[PropertyKey, Tensor]:
...

@overload
def predict(
self,
graph: AtomicGraph | AtomicGraphBatch | list[AtomicGraph],
*,
properties: Sequence[PropertyKey],
training: bool = False,
) -> dict[PropertyKey, Tensor]:
...

@overload
def predict(
self,
graph: AtomicGraph | AtomicGraphBatch | list[AtomicGraph],
*,
property: PropertyKey,
training: bool = False,
) -> Tensor:
...

# TODO: implement max batch size
def predict(
self,
graph: AtomicGraph | AtomicGraphBatch | list[AtomicGraph],
properties: Sequence[PropertyKey] | None = None, # type: ignore
*,
properties: Sequence[PropertyKey] | None = None,
property: PropertyKey | None = None,
training: bool = False,
) -> dict[PropertyKey, torch.Tensor]:
) -> dict[PropertyKey, Tensor] | Tensor:
"""
Evaluate the model on the given structure to get the labels requested.
Evaluate the model on the given structure to get
the properties requested.
Parameters
----------
Expand All @@ -128,8 +143,12 @@ def predict(
:code:`[Property.ENERGY, Property.FORCES]` if the structure
has no cell, and :code:`[Property.ENERGY, Property.FORCES,
Property.STRESS]` if it does.
property
The property to predict. Can't be used when :code:`properties`
is also provided.
training
Whether the model is currently being trained.
Whether the model is currently being trained. If :code:`False`,
the gradients of the predictions will be detached.
Returns
-------
Expand All @@ -138,24 +157,31 @@ def predict(
Examples
--------
>>> # TODO
>>> model.predict(graph_pbc)
{'energy': tensor(-12.3), 'forces': tensor(...), 'stress': tensor(...)}
>>> model.predict(graph_no_pbc)
{'energy': tensor(-12.3), 'forces': tensor(...)}
>>> model.predict(graph_pbc, property="energy")
tensor(-12.3)
"""

# check correctly called
if property is not None and properties is not None:
raise ValueError("Can't specify both `property` and `properties`")

if isinstance(graph, list):
graph = AtomicGraphBatch.from_graphs(graph)

if properties is None:
properties: list[PropertyKey] = [Property.ENERGY, Property.FORCES]
if graph.has_cell:
properties.append(Property.STRESS)
# elif isinstance(properties, str):
# properties = [properties]
properties = [Property.ENERGY, Property.FORCES, Property.STRESS]
else:
properties = [Property.ENERGY, Property.FORCES]

if Property.STRESS in properties and not graph.has_cell:
raise ValueError("Can't predict stress without cell information.")

predictions = {}
predictions: dict[PropertyKey, Tensor] = {}

# setup for calculating stress:
if Property.STRESS in properties:
Expand Down Expand Up @@ -193,10 +219,31 @@ def predict(
if not training:
for key, value in predictions.items():
predictions[key] = value.detach()

if property is not None:
return predictions[property]

return predictions


class EnergySummation(nn.Module):
"""
A module for summing local energies to obtain the total energy.
Before summation, :code:`local_transform` is applied to the local energies.
After summation, :code:`total_transform` is applied to the total energy.
By default, :code:`EnergySummation()` learns a per-species, local energy
offset and scale.
Parameters
----------
local_transform
A transformation of the local energies.
total_transform
A transformation of the total energy.
"""

def __init__(
self,
local_transform: Transform | None = None,
Expand All @@ -206,19 +253,35 @@ def __init__(

# if both None, default to a per-species, local energy offset
if local_transform is None and total_transform is None:
local_transform = Chain(
[PerAtomShift(), PerAtomScale()], trainable=True
)
local_transform = PerAtomStandardScaler()
self.local_transform: Transform = local_transform or Identity()
self.total_transform: Transform = total_transform or Identity()

def forward(self, local_energies: torch.Tensor, graph: AtomicGraphBatch):
"""
Sum the local energies to obtain the total energy.
Parameters
----------
local_energies
The local energies.
graph
The graph representation of the structure/s.
"""
local_energies = self.local_transform.inverse(local_energies, graph)
total_E = sum_per_structure(local_energies, graph)
total_E = self.total_transform.inverse(total_E, graph)
return total_E

def fit_to_graphs(self, graphs: AtomicGraphBatch | list[AtomicGraph]):
"""
Fit the transforms to the training data.
Parameters
----------
graphs
The training data.
"""
if not isinstance(graphs, AtomicGraphBatch):
graphs = AtomicGraphBatch.from_graphs(graphs)

Expand Down Expand Up @@ -256,17 +319,26 @@ class Ensemble(GraphPESModel):
Examples
--------
Create a model with explicit two-body and multi-body terms:
>>> from graph_pes.models.pairwise import LennardJones
>>> from graph_pes.models.schnet import SchNet
>>> from graph_pes.core import Ensemble
>>> # create an ensemble of two models
>>> # equivalent to Ensemble([LennardJones(), SchNet()], aggregation="sum")
>>> ensemble = LennardJones() + SchNet()
.. code-block:: python
See Also
--------
:meth:`GraphPESModel.__add__`
from graph_pes.models.pairwise import LennardJones
from graph_pes.models.schnet import SchNet
from graph_pes.core import Ensemble
# create an ensemble of two models
# equivalent to Ensemble([LennardJones(), SchNet()], aggregation="sum")
ensemble = LennardJones() + SchNet()
Use several models to get an average prediction:
.. code-block:: python
models = ... # load/train your models
ensemble = Ensemble(models, aggregation="mean")
predictions = ensemble.predict(test_graphs)
...
"""

def __init__(
Expand Down
3 changes: 3 additions & 0 deletions src/graph_pes/data/atomic_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,9 @@ def extract_tensors(
return tensor_dict


# TODO: move to being class method?
# TODO: generalised edge creation? i.e. not just cutoff, but arbitrary method
# and can then have radius_cutoff class, k nearest neighbours etc.
def convert_to_atomic_graphs(
structures: Iterable[ase.Atoms] | ase.Atoms,
cutoff: float,
Expand Down
1 change: 1 addition & 0 deletions src/graph_pes/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ def __radd__(self, other: Loss | WeightedLoss) -> WeightedLoss:
return self.__add__(other)


# TODO: callable weights
class WeightedLoss(torch.nn.Module):
r"""
A lightweight wrapper around a collection of weighted losses.
Expand Down
5 changes: 3 additions & 2 deletions src/graph_pes/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ def train_model(
)

# deal with fitting transforms
# TODO: what if not training on energy?
if pre_fit_model and Property.ENERGY in training_on:
model.pre_fit(train_batch)
total_loss.fit_transform(train_batch)
Expand Down Expand Up @@ -147,7 +146,9 @@ def log(name, value, verbose=True):
)

# generate prediction:
predictions = self.model.predict(graph, self.properties, training=True)
predictions = self.model.predict(
graph, properties=self.properties, training=True
)

# compute the losses
total_loss = torch.scalar_tensor(0.0, device=self.device)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def test_integration():
model = LennardJones()

loss = Loss("energy")
before = loss(model.predict(batch, ["energy"]), batch)
before = loss(model(batch), batch)

train_model(
model,
Expand All @@ -25,6 +25,6 @@ def test_integration():
callbacks=[],
)

after = loss(model.predict(batch, ["energy"]), batch)
after = loss(model(batch), batch)

assert after < before, "training did not improve the loss"
26 changes: 26 additions & 0 deletions tests/test_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from ase import Atoms
from ase.io import read
from graph_pes.data import convert_to_atomic_graph, convert_to_atomic_graphs
from graph_pes.models.pairwise import LennardJones

structures = read("tests/test.xyz", ":")
graphs = convert_to_atomic_graphs(structures, cutoff=3)


def test_model():
model = LennardJones()
predictions = model.predict(graphs)
assert "energy" in predictions
assert "forces" in predictions
assert "stress" in predictions and graphs[0].has_cell
assert predictions["energy"].shape == (len(graphs),)
assert predictions["stress"].shape == (len(graphs), 3, 3)


def test_isolated_atom():
atom = Atoms("He", positions=[[0, 0, 0]])
graph = convert_to_atomic_graph(atom, cutoff=3)
assert graph.n_atoms == 1 and graph.n_edges == 0

model = LennardJones()
assert model(graph) == 0
2 changes: 1 addition & 1 deletion tests/test_predictions.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def test_predictions():

# if we ask for stress, we get an error:
with pytest.raises(ValueError):
model.predict(no_pbc, ["stress"])
model.predict(no_pbc, property="stress")

# with pbc structures, we should get all three predictions
predictions = model.predict(pbc)
Expand Down

0 comments on commit 7723cb3

Please sign in to comment.