Skip to content

Commit

Permalink
Refactor distance expansion models and update documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
jla-gardner committed Jan 15, 2024
1 parent 73b17f2 commit 6765adf
Show file tree
Hide file tree
Showing 6 changed files with 101 additions and 65 deletions.
11 changes: 3 additions & 8 deletions docs/source/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,7 @@ All models implemented in ``graph_pes`` are subclasses of
:members: predict_local_energies, __add__
Pair Potentials
===============
.. toctree::

Pair potential models can be recast as local-energy models acting on graphs.

.. autoclass :: graph_pes.models.pairwise.PairPotential
:members: interaction
.. autoclass :: graph_pes.models.pairwise.LennardJones
models/pairwise
models/distances
19 changes: 19 additions & 0 deletions docs/source/models/distances.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
Distance Expansions
===================

Available Expansions
--------------------

`graph-pes` exposes the :class:`DistanceExpansion <graph_pes.models.distances.DistanceExpansion>`
base class, which can be used to implement new distance expansions.
We also provide a few common expansions:

.. autoclass :: graph_pes.models.distances.Bessel
.. autoclass :: graph_pes.models.distances.GaussianSmearing
Implementing a new Expansion
----------------------------

.. autoclass :: graph_pes.models.distances.DistanceExpansion
:members:
7 changes: 7 additions & 0 deletions docs/source/models/pairwise.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Pair Potentials
===============

.. autoclass :: graph_pes.models.pairwise.PairPotential
:members: interaction
.. autoclass :: graph_pes.models.pairwise.LennardJones
8 changes: 4 additions & 4 deletions docs/source/training.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ An Overview

Training a :class:`GraphPESModel <graph_pes.core.GraphPESModel>` involves three preparatory steps:

1. Define and initialize the model (see :ref:`models`)
2. Loading the training data (see :ref:`loading atomic graphs`)
3. Defining the loss function (see :ref:`loss functions`)
1. Define and initialize the model (see models)
2. Loading the training data (see loading atomic graphs)
3. Defining the loss function (see loss functions)

Within the graph-pes framework, this can be as simple as:

Expand Down Expand Up @@ -47,7 +47,7 @@ Roughly, the following steps are taken:

.. literalinclude:: ./training_setup.py
:language: python
:lines: 24-35
:lines: 24-30

---

Expand Down
118 changes: 66 additions & 52 deletions src/graph_pes/models/distances.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,41 +2,42 @@

from abc import ABC, abstractmethod
from math import pi as π
from typing import Any, Callable
from typing import Callable

import torch
from torch import nn
from jaxtyping import Float
from torch import Tensor, nn


class DistanceExpansion(nn.Module, ABC): # TODO- make protocol?
class DistanceExpansion(nn.Module, ABC):
r"""
Base class for an expansion function, :math:`\phi(r)` such that:
Abstract base class for an expansion function, :math:`\phi(r) :
[0, r_{\text{cutoff}}] \rightarrow \mathbb{R}^{n_\text{features}}`.
.. math::
r \in \mathbb{R}^1 \quad \rightarrow \quad \phi(r) \in
\mathbb{R}^{n_\text{features}}
or, for a batch of distances:
.. math::
r \in \mathbb{R}^{n_\text{batch} \times 1} \quad \rightarrow \quad
\phi(r) \in \mathbb{R}^{n_\text{batch} \times n_\text{features}}
Subclasses should implement :meth:`expand`, which must also work over
batches: :math:`\phi(r) : [0, r_{\text{cutoff}}]^{n_\text{batch} \times 1}
\rightarrow \mathbb{R}^{n_\text{batch} \times n_\text{features}}`.
Parameters
----------
n_features : int
n_features
The number of features to expand into.
cutoff : float
cutoff
The cutoff radius.
trainable
Whether the expansion parameters are trainable.
"""

def __init__(self, n_features: int, cutoff: float):
def __init__(self, n_features: int, cutoff: float, trainable: bool = True):
super().__init__()
self.n_features = n_features
self.cutoff = cutoff
self.trainable = trainable

@abstractmethod
def expand(self, r: torch.Tensor) -> torch.Tensor:
def expand(
self, r: Float[Tensor, "... 1"]
) -> Float[Tensor, "... n_features"]:
r"""
Perform the expansion.
Expand All @@ -45,26 +46,18 @@ def expand(self, r: torch.Tensor) -> torch.Tensor:
r : torch.Tensor
The distances to expand. Guaranteed to have shape :math:`(..., 1)`.
"""
pass

def forward(self, r: torch.Tensor) -> torch.Tensor:
r"""
Ensure that the input has the correct shape, :math:`(..., 1)`,
and then perform the expansion.
"""
def forward(
self, r: Float[Tensor, "..."]
) -> Float[Tensor, "... n_features"]:
if r.shape[-1] != 1:
r = r.unsqueeze(-1)
return self.expand(r)

def properties(self) -> dict[str, Any]:
return {}

def __repr__(self) -> str:
_kwargs_rep = ", ".join(
f"{k}={v}" for k, v in self.properties().items()
)
return (
f"{self.__class__.__name__}(1 → {self.n_features}, {_kwargs_rep})"
f"{self.__class__.__name__}(n_features={self.n_features}, "
f"cutoff={self.cutoff}, trainable={self.trainable})"
)


Expand All @@ -73,73 +66,94 @@ class Bessel(DistanceExpansion):
The Bessel expansion:
.. math::
\phi_{n}(r) = \frac{\sin(\pi n r / r_\text{cut})}{r} \quad n
\phi_{n}(r) = \sqrt{\frac{2}{r_{\text{cut}}}}
\frac{\sin(n \pi \frac{r}{r_\text{cut}})}{r} \quad n
\in [1, n_\text{features}]
where :math:`r_\text{cut}` is the cutoff radius and :math:`n` is the order
of the Bessel function.
of the Bessel function, as introduced in `Directional Message Passing for
Molecular Graphs <http://arxiv.org/abs/2003.03123>`_.
Parameters
----------
n_features : int
n_features
The number of features to expand into.
cutoff : float
cutoff
The cutoff radius.
trainable
Whether the expansion parameters are trainable.
Attributes
----------
frequencies
:math:`n`, the frequencies of the Bessel functions.
"""

def __init__(self, n_features: int, cutoff: float):
def __init__(self, n_features: int, cutoff: float, trainable: bool = True):
super().__init__(n_features, cutoff)

self.register_buffer(
"frequencies", torch.arange(1, n_features + 1) * π / cutoff
self.frequencies = nn.Parameter(
torch.arange(1, n_features + 1) * π / cutoff,
requires_grad=trainable,
)
self.pre_factor = torch.sqrt(torch.tensor(2 / cutoff))

def expand(self, r: torch.Tensor) -> torch.Tensor:
numerator = torch.sin(r * self.frequencies)
numerator = self.pre_factor * torch.sin(r * self.frequencies)
# we avoid dividing by zero by replacing any zero elements with 1
denominator = torch.where(r != 0, torch.tensor(1.0), r)
return numerator / denominator

def properties(self) -> dict[str, Any]:
return {"cutoff": self.cutoff}


class GaussianSmearing(DistanceExpansion):
r"""
A Gaussian smearing expansion:
.. math::
\phi_{n}(r) = \exp\left(-\frac{(r - \mu_n)^2}{2\sigma^2}\right)
\quad n \in [1, n_\text{features}]
where :math:`\mu_n` is the center of the :math:`n`th Gaussian
and :math:`\sigma` is the width of the Gaussians.
where :math:`\mu_n` is the center of the :math:`n`'th Gaussian
and :math:`\sigma` is a width shared across all the Gaussians.
Parameters
----------
n_features : int
n_features
The number of features to expand into.
cutoff : float
cutoff
The cutoff radius.
trainable
Whether the expansion parameters are trainable.
Attributes
----------
centers
:math:`\mu_n`, the centers of the Gaussians.
coef
:math:`\frac{1}{2\sigma^2}`, the coefficient of the exponent.
"""

def __init__(
self,
n_features: int,
cutoff: float,
trainable: bool = True,
):
super().__init__(n_features, cutoff)
super().__init__(n_features, cutoff, trainable)

sigma = cutoff / n_features
self.coef = -1 / (2 * sigma**2)
self.register_buffer("centers", torch.linspace(0, cutoff, n_features))
self.coef = nn.Parameter(
torch.tensor(-1 / (2 * sigma**2)),
requires_grad=trainable,
)
self.centers = nn.Parameter(
torch.linspace(0, cutoff, n_features),
requires_grad=trainable,
)

def expand(self, r: torch.Tensor) -> torch.Tensor:
offsets = r - self.centers
return torch.exp(self.coef * offsets**2)

def properties(self) -> dict[str, Any]:
return {"cutoff": self.cutoff}


Envelope = Callable[[torch.Tensor], torch.Tensor]

Expand Down
3 changes: 2 additions & 1 deletion src/graph_pes/models/pairwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,13 @@ class PairPotential(GraphPESModel, ABC):
where :math:`r_{ij}` is the distance between atoms :math:`i` and :math:`j`,
and :math:`Z_i` and :math:`Z_j` are their atomic numbers.
This can be recast as a sum over local energy contributions,
:math:`E = \sum_i \varepsilon_i`, according to:
.. math::
\varepsilon_i = \frac{1}{2} \sum_j V(r_{ij}, Z_i, Z_j)
Subclasses should implement :meth:`interaction`.
"""

@abstractmethod
Expand Down

0 comments on commit 6765adf

Please sign in to comment.