Skip to content

Commit

Permalink
simplify
Browse files Browse the repository at this point in the history
  • Loading branch information
jla-gardner committed Jun 4, 2024
1 parent 257961d commit 66009ed
Show file tree
Hide file tree
Showing 17 changed files with 353 additions and 374 deletions.
266 changes: 65 additions & 201 deletions src/graph_pes/core.py

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions src/graph_pes/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# TODO: split into batching, io etc.
# in readiness for moving to readable dataset approach
from __future__ import annotations

import warnings
Expand Down Expand Up @@ -81,6 +83,7 @@


class _AtomicGraph_Impl(dict):
# TODO: remove?
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

Expand Down
3 changes: 3 additions & 0 deletions src/graph_pes/data/keys.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# deliberately not using future imports here to appease torchscript?
# TODO: check if this is necessary

from typing import TYPE_CHECKING, Literal

# graph properties
Expand Down
4 changes: 3 additions & 1 deletion src/graph_pes/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@


def random_split(
sequence: Sequence[E], lengths: Sequence[int], seed: int | None = None
sequence: Sequence[E],
lengths: Sequence[int],
seed: int | None = None,
) -> list[list[E]]:
"""
Randomly split `sequence` into sub-sequences according to `lengths`.
Expand Down
192 changes: 131 additions & 61 deletions src/graph_pes/loss.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,21 @@
from __future__ import annotations

from typing import Callable
from typing import Callable, NamedTuple

import torch
from torch import Tensor, nn

from .data import LabelledBatch, keys
from .transform import Identity, Transform
from .transform import divide_per_atom


class Loss(nn.Module):
r"""
Measure the discrepancy between predictions and labels.
Measure the discrepancy between predictions and labels for a given property.
Often, it is convenient to apply some well known loss function,
Often, it is convenient to apply some well known loss metric,
e.g. `MSELoss`, to a transformed version of the predictions and labels,
e.g. normalisation, such that the loss value takes on "nice" values,
and that the resulting gradients and parameter updates are well-behaved.
Expand Down Expand Up @@ -54,13 +56,10 @@ def __init__(
self,
label: keys.LabelKey,
metric: Callable[[Tensor, Tensor], Tensor] | None = None,
transform: Transform | None = None,
):
super().__init__()
self.property_key: keys.LabelKey = label
self.metric = MAE() if metric is None else metric
self.transform = transform or Identity()
self.transform.trainable = False

# add type hints to play nicely with mypy
def __call__(
Expand All @@ -87,56 +86,58 @@ def forward(
"""

return self.metric(
self.transform(predictions[self.property_key], graphs),
self.transform(graphs[self.property_key], graphs),
predictions[self.property_key],
graphs[self.property_key],
)

def fit_transform(self, graphs: LabelledBatch):
"""
Fit the transform to the target labels.
@property
def name(self) -> str:
"""Get the name of this loss for logging purposes."""
return f"{self.property_key}_{get_metric_name(self.metric)}"

Parameters
----------
graphs
The graphs containing the labels.
"""
## Methods for creating weighted losses ##

self.transform.fit_to_target(graphs[self.property_key], graphs)
def __mul__(self, weight: float | int) -> TotalLoss:
if not isinstance(weight, (int, float)):
raise TypeError(f"Cannot multiply Loss and {type(weight)}")

@property
def name(self) -> str:
# if metric is a class, we want the class name otherwise we want
# the function name, all without the word "loss" in it
return (
getattr(
self.metric,
"__name__",
self.metric.__class__.__name__,
)
.lower()
.replace("loss", "")
)
return TotalLoss([self], [weight])

def __rmul__(self, weight: float) -> TotalLoss:
if not isinstance(weight, (int, float)):
raise TypeError(f"Cannot multiply Loss and {type(weight)}")

return TotalLoss([self], [weight])

def __mul__(self, other: float) -> WeightedLoss:
return WeightedLoss([self], [other])
def __truediv__(self, weight: float | int) -> TotalLoss:
if not isinstance(weight, (int, float)):
raise TypeError(f"Cannot divide Loss and {type(weight)}")

def __rmul__(self, other: float) -> WeightedLoss:
return WeightedLoss([self], [other])
return TotalLoss([self], [1 / weight])

def __add__(self, other: Loss | WeightedLoss) -> WeightedLoss:
if isinstance(other, Loss):
return WeightedLoss([self, other], [1, 1])
elif isinstance(other, WeightedLoss):
return WeightedLoss([self] + other.losses, [1] + other.weights)
def __add__(self, loss: Loss | TotalLoss) -> TotalLoss:
if isinstance(loss, Loss):
return TotalLoss([self, loss], [1, 1])
elif isinstance(loss, TotalLoss):
return TotalLoss([self] + loss.losses, [1] + loss.weights)
else:
raise TypeError(f"Cannot add Loss and {type(other)}")
raise TypeError(f"Cannot add Loss and {type(loss)}")

def __radd__(self, other: Loss | WeightedLoss) -> WeightedLoss:
def __radd__(self, other: Loss | TotalLoss) -> TotalLoss:
return self.__add__(other)


# TODO: callable weights
class WeightedLoss(torch.nn.Module):
class SubLossPair(NamedTuple):
loss_value: torch.Tensor
weighted_loss_value: torch.Tensor


class TotalLossResult(NamedTuple):
loss_value: torch.Tensor
components: dict[str, SubLossPair]


class TotalLoss(torch.nn.Module):
r"""
A lightweight wrapper around a collection of weighted losses.
Expand All @@ -149,7 +150,7 @@ class WeightedLoss(torch.nn.Module):
.. code-block:: python
WeightedLoss([Loss("energy"), Loss("forces")], [10, 1])
WeightedLoss([Loss("energy"), Loss("forces")], weights=[10, 1])
# is equivalent to
10 * Loss("energy") + 1 * Loss("forces")
Expand All @@ -164,39 +165,108 @@ class WeightedLoss(torch.nn.Module):
def __init__(
self,
losses: list[Loss],
weights: list[float] | None = None,
weights: list[float | int] | None = None,
):
super().__init__()
self.losses: list[Loss] = nn.ModuleList(losses) # type: ignore
self.weights = weights or [1.0] * len(losses)

def __add__(self, other: WeightedLoss) -> WeightedLoss:
return WeightedLoss(
def __add__(self, other: TotalLoss) -> TotalLoss:
return TotalLoss(
self.losses + other.losses, self.weights + other.weights
)

def __mul__(self, other: float) -> WeightedLoss:
return WeightedLoss(self.losses, [w * other for w in self.weights])
def __mul__(self, other: float | int) -> TotalLoss:
if not isinstance(other, (int, float)):
raise TypeError(f"Cannot multiply TotalLoss and {type(other)}")

def __rmul__(self, other: float) -> WeightedLoss:
return WeightedLoss(self.losses, [w * other for w in self.weights])
return TotalLoss(self.losses, [w * other for w in self.weights])

def __true_div__(self, other: float) -> WeightedLoss:
return WeightedLoss(self.losses, [w / other for w in self.weights])
def __rmul__(self, other: float | int) -> TotalLoss:
if not isinstance(other, (int, float)):
raise TypeError(f"Cannot multiply TotalLoss and {type(other)}")

def fit_transform(self, graphs: LabelledBatch):
for loss in self.losses:
loss.fit_transform(graphs)
return TotalLoss(self.losses, [w * other for w in self.weights])

def __true_div__(self, other: float | int) -> TotalLoss:
if not isinstance(other, (int, float)):
raise TypeError(f"Cannot divide TotalLoss and {type(other)}")

return TotalLoss(self.losses, [w / other for w in self.weights])

def forward(
self,
predictions: dict[keys.LabelKey, torch.Tensor],
graphs: LabelledBatch,
) -> TotalLossResult:
"""
Computes the total loss value.
Parameters
----------
predictions
The predictions from the model.
graphs
The graphs containing the labels.
"""

total_loss = torch.scalar_tensor(0.0, device=self.device)
components: dict[str, SubLossPair] = {}

for loss, weight in zip(self.losses, self.weights):
loss_value = loss(predictions, graphs)
weighted_loss_value = loss_value * weight

total_loss += weighted_loss_value
components[loss.name] = SubLossPair(loss_value, weighted_loss_value)

return TotalLossResult(total_loss, components)

# add type hints to appease mypy
def __call__(
self,
predictions: dict[keys.LabelKey, torch.Tensor],
graphs: LabelledBatch,
) -> TotalLossResult:
return super().__call__(predictions, graphs)


class PerAtomEnergyLoss(Loss):
def __init__(
self,
metric: Callable[[Tensor, Tensor], Tensor] | None = None,
):
super().__init__(keys.ENERGY, metric)

def forward(
self,
predictions: dict[keys.LabelKey, torch.Tensor],
graphs: LabelledBatch,
) -> torch.Tensor:
return sum(
w * loss(predictions, graphs)
for w, loss in zip(self.weights, self.losses)
) # type: ignore
return divide_per_atom(super().forward(predictions, graphs), graphs)

@property
def name(self) -> str:
return f"per_atom_energy_{get_metric_name(self.metric)}"


def get_metric_name(metric: Callable[[Tensor, Tensor], Tensor]) -> str:
# if metric is a function, we want the function's name, otherwise
# we want the metric's class name, all lowercased
# and without the word "loss" in it

return (
getattr(
metric,
"__name__",
metric.__class__.__name__,
)
.lower()
.replace("loss", "")
)


## METRICS ##


class RMSE(torch.nn.MSELoss):
Expand Down
4 changes: 3 additions & 1 deletion src/graph_pes/models/distances.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ class DistanceExpansion(nn.Module, ABC):
def __init__(self, n_features: int, cutoff: float, trainable: bool = True):
super().__init__()
self.n_features = n_features
# TODO: check serialization - need to register as buffer to be included
# in state_dict?
self.cutoff = cutoff
self.trainable = trainable

Expand Down Expand Up @@ -117,7 +119,7 @@ class Bessel(DistanceExpansion):
"""

def __init__(self, n_features: int, cutoff: float, trainable: bool = True):
super().__init__(n_features, cutoff)
super().__init__(n_features, cutoff, trainable)
self.frequencies = nn.Parameter(
torch.arange(1, n_features + 1) * math.pi / cutoff,
requires_grad=trainable,
Expand Down
44 changes: 44 additions & 0 deletions src/graph_pes/models/offsets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from __future__ import annotations

from graph_pes.core import GraphPESModel
from graph_pes.data import AtomicGraph
from graph_pes.nn import PerElementParameter
from torch import Tensor


class EnergyOffset(GraphPESModel):
r"""
A model that predicts energy offsets:
.. math::
E(\mathcal{G}) = \sum_i \varepsilon_{Z_i}
where :math:`\varepsilon_{Z_i}` is the energy offset for atomic species
:math:`Z_i`.
Parameters
----------
fixed_values
A dictionary of fixed energy offsets for each atomic species.
trainable
Whether the energy offsets are trainable parameters.
"""

def __init__(
self,
values: dict[str, float] | None = None,
trainable: bool = False,
):
super().__init__()

if values is None and trainable is False:
raise ValueError("Must provide values or set trainable to True")

self.offsets = PerElementParameter.of_length(
1,
default_value=0.0,
requires_grad=trainable,
)

def predict_local_energies(self, graph: AtomicGraph) -> Tensor:
return self.shift[graph["atomic_numbers"]]
4 changes: 1 addition & 3 deletions src/graph_pes/models/painn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
number_of_atoms,
)
from graph_pes.nn import MLP, HaddamardProduct, PerElementEmbedding
from graph_pes.transform import Transform
from torch import Tensor, nn

from .distances import Bessel, PolynomialEnvelope
Expand Down Expand Up @@ -192,9 +191,8 @@ def __init__(
radial_features: int = 20,
layers: int = 3,
cutoff: float = 5.0,
energy_transform: Transform | None = None,
):
super().__init__(energy_transform)
super().__init__()
self.internal_dim = internal_dim
self.layers = layers
self.interactions: list[Interaction] = nn.ModuleList(
Expand Down
Loading

0 comments on commit 66009ed

Please sign in to comment.