Skip to content

Commit

Permalink
PaiNN docs
Browse files Browse the repository at this point in the history
  • Loading branch information
jla-gardner committed Feb 6, 2024
1 parent cdd0cf0 commit dc64835
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 7 deletions.
1 change: 1 addition & 0 deletions docs/source/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ Available Models

models/pairwise
models/schnet
models/painn


Helper Classes and Functions
Expand Down
7 changes: 7 additions & 0 deletions docs/source/models/painn.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
PaiNN
#####

.. autoclass:: graph_pes.models.painn.PaiNN
:show-inheritance:
.. autoclass:: graph_pes.models.painn.Interaction
.. autoclass:: graph_pes.models.painn.Update
80 changes: 74 additions & 6 deletions src/graph_pes/models/painn.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,32 @@
from .distances import Bessel, PolynomialEnvelope


class InteractionBlock(nn.Module):
class Interaction(nn.Module):
r"""
The interaction block of the :class:`PaiNN` model.
Continuous filters generated from neighbour distances are convolved with
existing scalar embeddings to create messages :math:`x_{j \rightarrow i}`
for each neighbour :math:`j` of atom :math:`i`.
Scalar total messages, :math:`\Delta s_i`, are created by summing over
neighbours, while vector total messages, :math:`\Delta v_i`,
incorporate directional information from neighbour unit vectors and
existing vector embeddings.
The code aims to follow **Figure 2b** of the `PaiNN paper
<https://arxiv.org/abs/2102.03150>`_ as closely as possible.
Parameters
----------
radial_features
The number of radial features to expand bond distances into.
internal_dim
The dimension of the internal representations.
cutoff
The cutoff distance for the radial features.
"""

def __init__(
self,
radial_features: int,
Expand Down Expand Up @@ -80,7 +105,20 @@ def forward(
return self._linear(x.transpose(-1, -2)).transpose(-1, -2)


class UpdateBlock(nn.Module):
class Update(nn.Module):
r"""
The update block of the :class:`PaiNN` model.
Projections of vector embeddings are used to update the scalar embeddings,
and vice versa. The code aims to follow **Figure 2c** of the `PaiNN paper
<https://arxiv.org/abs/2102.03150>`_ as closely as possible.
Parameters
----------
internal_dim
The dimension of the internal representations.
"""

def __init__(self, internal_dim: int):
super().__init__()
self.internal_dim = internal_dim
Expand Down Expand Up @@ -120,6 +158,36 @@ def forward(


class PaiNN(GraphPESModel):
r"""
The `Polarizable Atom Interaction Neural Network (PaiNN)
<https://arxiv.org/abs/2102.03150>`_ model.
Alternating :class:`Interaction` and :class:`Update` blocks
are used to residually update both vector and scalar per-atom embeddings.
Citation:
.. code-block:: bibtex
@misc{Schutt-21-06,
title = {Equivariant Message Passing for the Prediction of Tensorial Properties and Molecular Spectra},
author = {Sch{\"u}tt, Kristof T. and Unke, Oliver T. and Gastegger, Michael},
year = {2021},
doi = {10.48550/arXiv.2102.03150},
}
Parameters
----------
internal_dim
The dimension of the internal representations.
radial_features
The number of radial features to expand bond distances into.
layers
The number of (interaction + update) layers to use.
cutoff
The cutoff distance for the radial features.
""" # noqa: E501

def __init__(
self,
internal_dim: int = 32,
Expand All @@ -130,14 +198,14 @@ def __init__(
super().__init__()
self.internal_dim = internal_dim
self.layers = layers
self.interactions: list[InteractionBlock] = nn.ModuleList(
self.interactions: list[Interaction] = nn.ModuleList(
[
InteractionBlock(radial_features, internal_dim, cutoff)
Interaction(radial_features, internal_dim, cutoff)
for _ in range(layers)
]
) # type: ignore
self.updates: list[UpdateBlock] = nn.ModuleList(
[UpdateBlock(internal_dim) for _ in range(layers)]
self.updates: list[Update] = nn.ModuleList(
[Update(internal_dim) for _ in range(layers)]
) # type: ignore
self.z_embedding = PerSpeciesEmbedding(internal_dim)
self.read_out = MLP(
Expand Down
2 changes: 1 addition & 1 deletion src/graph_pes/models/schnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ class SchNet(GraphPESModel):
Citation:
.. code::
.. code:: bibtex
@article{Schutt-18-03,
title = {{{SchNet}} {\textendash} {{A}} Deep Learning Architecture for Molecules and Materials},
Expand Down

0 comments on commit dc64835

Please sign in to comment.