From 6765adf92d13903e7ef80d8ae4063bc74d5ebd61 Mon Sep 17 00:00:00 2001 From: John Gardner Date: Mon, 15 Jan 2024 15:37:03 +0000 Subject: [PATCH] Refactor distance expansion models and update documentation --- docs/source/models.rst | 11 +-- docs/source/models/distances.rst | 19 +++++ docs/source/models/pairwise.rst | 7 ++ docs/source/training.rst | 8 +- src/graph_pes/models/distances.py | 118 +++++++++++++++++------------- src/graph_pes/models/pairwise.py | 3 +- 6 files changed, 101 insertions(+), 65 deletions(-) create mode 100644 docs/source/models/distances.rst create mode 100644 docs/source/models/pairwise.rst diff --git a/docs/source/models.rst b/docs/source/models.rst index 8bdc75e8..cdcddf6d 100644 --- a/docs/source/models.rst +++ b/docs/source/models.rst @@ -12,12 +12,7 @@ All models implemented in ``graph_pes`` are subclasses of :members: predict_local_energies, __add__ -Pair Potentials -=============== +.. toctree:: -Pair potential models can be recast as local-energy models acting on graphs. - -.. autoclass :: graph_pes.models.pairwise.PairPotential - :members: interaction - -.. autoclass :: graph_pes.models.pairwise.LennardJones + models/pairwise + models/distances diff --git a/docs/source/models/distances.rst b/docs/source/models/distances.rst new file mode 100644 index 00000000..63cce6b4 --- /dev/null +++ b/docs/source/models/distances.rst @@ -0,0 +1,19 @@ +Distance Expansions +=================== + +Available Expansions +-------------------- + +`graph-pes` exposes the :class:`DistanceExpansion ` +base class, which can be used to implement new distance expansions. +We also provide a few common expansions: + +.. autoclass :: graph_pes.models.distances.Bessel +.. autoclass :: graph_pes.models.distances.GaussianSmearing + + +Implementing a new Expansion +---------------------------- + +.. autoclass :: graph_pes.models.distances.DistanceExpansion + :members: \ No newline at end of file diff --git a/docs/source/models/pairwise.rst b/docs/source/models/pairwise.rst new file mode 100644 index 00000000..f8893eab --- /dev/null +++ b/docs/source/models/pairwise.rst @@ -0,0 +1,7 @@ +Pair Potentials +=============== + +.. autoclass :: graph_pes.models.pairwise.PairPotential + :members: interaction + +.. autoclass :: graph_pes.models.pairwise.LennardJones \ No newline at end of file diff --git a/docs/source/training.rst b/docs/source/training.rst index fb49424e..7e68b13d 100644 --- a/docs/source/training.rst +++ b/docs/source/training.rst @@ -11,9 +11,9 @@ An Overview Training a :class:`GraphPESModel ` involves three preparatory steps: -1. Define and initialize the model (see :ref:`models`) -2. Loading the training data (see :ref:`loading atomic graphs`) -3. Defining the loss function (see :ref:`loss functions`) +1. Define and initialize the model (see models) +2. Loading the training data (see loading atomic graphs) +3. Defining the loss function (see loss functions) Within the graph-pes framework, this can be as simple as: @@ -47,7 +47,7 @@ Roughly, the following steps are taken: .. literalinclude:: ./training_setup.py :language: python - :lines: 24-35 + :lines: 24-30 --- diff --git a/src/graph_pes/models/distances.py b/src/graph_pes/models/distances.py index 187aad3b..e45ece88 100644 --- a/src/graph_pes/models/distances.py +++ b/src/graph_pes/models/distances.py @@ -2,41 +2,42 @@ from abc import ABC, abstractmethod from math import pi as π -from typing import Any, Callable +from typing import Callable import torch -from torch import nn +from jaxtyping import Float +from torch import Tensor, nn -class DistanceExpansion(nn.Module, ABC): # TODO- make protocol? +class DistanceExpansion(nn.Module, ABC): r""" - Base class for an expansion function, :math:`\phi(r)` such that: + Abstract base class for an expansion function, :math:`\phi(r) : + [0, r_{\text{cutoff}}] \rightarrow \mathbb{R}^{n_\text{features}}`. - .. math:: - r \in \mathbb{R}^1 \quad \rightarrow \quad \phi(r) \in - \mathbb{R}^{n_\text{features}} - - or, for a batch of distances: - - .. math:: - r \in \mathbb{R}^{n_\text{batch} \times 1} \quad \rightarrow \quad - \phi(r) \in \mathbb{R}^{n_\text{batch} \times n_\text{features}} + Subclasses should implement :meth:`expand`, which must also work over + batches: :math:`\phi(r) : [0, r_{\text{cutoff}}]^{n_\text{batch} \times 1} + \rightarrow \mathbb{R}^{n_\text{batch} \times n_\text{features}}`. Parameters ---------- - n_features : int + n_features The number of features to expand into. - cutoff : float + cutoff The cutoff radius. + trainable + Whether the expansion parameters are trainable. """ - def __init__(self, n_features: int, cutoff: float): + def __init__(self, n_features: int, cutoff: float, trainable: bool = True): super().__init__() self.n_features = n_features self.cutoff = cutoff + self.trainable = trainable @abstractmethod - def expand(self, r: torch.Tensor) -> torch.Tensor: + def expand( + self, r: Float[Tensor, "... 1"] + ) -> Float[Tensor, "... n_features"]: r""" Perform the expansion. @@ -45,26 +46,18 @@ def expand(self, r: torch.Tensor) -> torch.Tensor: r : torch.Tensor The distances to expand. Guaranteed to have shape :math:`(..., 1)`. """ - pass - def forward(self, r: torch.Tensor) -> torch.Tensor: - r""" - Ensure that the input has the correct shape, :math:`(..., 1)`, - and then perform the expansion. - """ + def forward( + self, r: Float[Tensor, "..."] + ) -> Float[Tensor, "... n_features"]: if r.shape[-1] != 1: r = r.unsqueeze(-1) return self.expand(r) - def properties(self) -> dict[str, Any]: - return {} - def __repr__(self) -> str: - _kwargs_rep = ", ".join( - f"{k}={v}" for k, v in self.properties().items() - ) return ( - f"{self.__class__.__name__}(1 → {self.n_features}, {_kwargs_rep})" + f"{self.__class__.__name__}(n_features={self.n_features}, " + f"cutoff={self.cutoff}, trainable={self.trainable})" ) @@ -73,36 +66,43 @@ class Bessel(DistanceExpansion): The Bessel expansion: .. math:: - \phi_{n}(r) = \frac{\sin(\pi n r / r_\text{cut})}{r} \quad n + \phi_{n}(r) = \sqrt{\frac{2}{r_{\text{cut}}}} + \frac{\sin(n \pi \frac{r}{r_\text{cut}})}{r} \quad n \in [1, n_\text{features}] where :math:`r_\text{cut}` is the cutoff radius and :math:`n` is the order - of the Bessel function. + of the Bessel function, as introduced in `Directional Message Passing for + Molecular Graphs `_. Parameters ---------- - n_features : int + n_features The number of features to expand into. - cutoff : float + cutoff The cutoff radius. + trainable + Whether the expansion parameters are trainable. + + Attributes + ---------- + frequencies + :math:`n`, the frequencies of the Bessel functions. """ - def __init__(self, n_features: int, cutoff: float): + def __init__(self, n_features: int, cutoff: float, trainable: bool = True): super().__init__(n_features, cutoff) - - self.register_buffer( - "frequencies", torch.arange(1, n_features + 1) * π / cutoff + self.frequencies = nn.Parameter( + torch.arange(1, n_features + 1) * π / cutoff, + requires_grad=trainable, ) + self.pre_factor = torch.sqrt(torch.tensor(2 / cutoff)) def expand(self, r: torch.Tensor) -> torch.Tensor: - numerator = torch.sin(r * self.frequencies) + numerator = self.pre_factor * torch.sin(r * self.frequencies) # we avoid dividing by zero by replacing any zero elements with 1 denominator = torch.where(r != 0, torch.tensor(1.0), r) return numerator / denominator - def properties(self) -> dict[str, Any]: - return {"cutoff": self.cutoff} - class GaussianSmearing(DistanceExpansion): r""" @@ -110,36 +110,50 @@ class GaussianSmearing(DistanceExpansion): .. math:: \phi_{n}(r) = \exp\left(-\frac{(r - \mu_n)^2}{2\sigma^2}\right) + \quad n \in [1, n_\text{features}] - where :math:`\mu_n` is the center of the :math:`n`th Gaussian - and :math:`\sigma` is the width of the Gaussians. + where :math:`\mu_n` is the center of the :math:`n`'th Gaussian + and :math:`\sigma` is a width shared across all the Gaussians. Parameters ---------- - n_features : int + n_features The number of features to expand into. - cutoff : float + cutoff The cutoff radius. + trainable + Whether the expansion parameters are trainable. + + Attributes + ---------- + centers + :math:`\mu_n`, the centers of the Gaussians. + coef + :math:`\frac{1}{2\sigma^2}`, the coefficient of the exponent. """ def __init__( self, n_features: int, cutoff: float, + trainable: bool = True, ): - super().__init__(n_features, cutoff) + super().__init__(n_features, cutoff, trainable) sigma = cutoff / n_features - self.coef = -1 / (2 * sigma**2) - self.register_buffer("centers", torch.linspace(0, cutoff, n_features)) + self.coef = nn.Parameter( + torch.tensor(-1 / (2 * sigma**2)), + requires_grad=trainable, + ) + self.centers = nn.Parameter( + torch.linspace(0, cutoff, n_features), + requires_grad=trainable, + ) def expand(self, r: torch.Tensor) -> torch.Tensor: offsets = r - self.centers return torch.exp(self.coef * offsets**2) - def properties(self) -> dict[str, Any]: - return {"cutoff": self.cutoff} - Envelope = Callable[[torch.Tensor], torch.Tensor] diff --git a/src/graph_pes/models/pairwise.py b/src/graph_pes/models/pairwise.py index 51b7cbf8..a39ed0f1 100644 --- a/src/graph_pes/models/pairwise.py +++ b/src/graph_pes/models/pairwise.py @@ -26,12 +26,13 @@ class PairPotential(GraphPESModel, ABC): where :math:`r_{ij}` is the distance between atoms :math:`i` and :math:`j`, and :math:`Z_i` and :math:`Z_j` are their atomic numbers. - This can be recast as a sum over local energy contributions, :math:`E = \sum_i \varepsilon_i`, according to: .. math:: \varepsilon_i = \frac{1}{2} \sum_j V(r_{ij}, Z_i, Z_j) + + Subclasses should implement :meth:`interaction`. """ @abstractmethod