Skip to content

Commit

Permalink
add docs
Browse files Browse the repository at this point in the history
  • Loading branch information
jla-gardner committed Jul 17, 2024
1 parent 4304f42 commit 3310aaa
Show file tree
Hide file tree
Showing 3 changed files with 132 additions and 4 deletions.
1 change: 1 addition & 0 deletions docs/source/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ Available Models
models/pairwise
models/schnet
models/painn
models/mace


Helper Classes and Functions
Expand Down
5 changes: 5 additions & 0 deletions docs/source/models/mace.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
MACE
####

.. autoclass:: graph_pes.models.MACE
.. autoclass:: graph_pes.models.ZEmbeddingMACE
130 changes: 126 additions & 4 deletions src/graph_pes/models/e3nn/mace.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
from typing import Callable, Union

import e3nn.util.jit
import graph_pes.models.distances
import torch
from e3nn import o3
from graph_pes.graphs.graph_typing import AtomicGraph
from graph_pes.graphs.operations import neighbour_distances, neighbour_vectors
from graph_pes.models.distances import (
Bessel,
DistanceExpansion,
PolynomialEnvelope,
)
Expand Down Expand Up @@ -50,14 +50,25 @@ def __call__(self, x: torch.Tensor) -> torch.Tensor:
ReadOut = Union[LinearReadOut, NonLinearReadOut]


def _get_distance_expansion(name: str) -> type[DistanceExpansion]:
try:
return getattr(graph_pes.models.distances, name)
except AttributeError:
raise ValueError(f"Unknown distance expansion type: {name}") from None

Check warning on line 57 in src/graph_pes/models/e3nn/mace.py

View check run for this annotation

Codecov / codecov/patch

src/graph_pes/models/e3nn/mace.py#L56-L57

Added lines #L56 - L57 were not covered by tests


@e3nn.util.jit.compile_mode("script")
class _BaseMACE(AutoScaledPESModel):
"""
Base class for MACE models.
"""

def __init__(
self,
# radial things
cutoff: float,
n_radial: int,
radial_expansion_type: type[DistanceExpansion],
radial_expansion_type: type[DistanceExpansion] | str,
# node attributes
z_embed_dim: int,
z_embedding: Callable[[torch.Tensor], torch.Tensor],
Expand All @@ -71,6 +82,11 @@ def __init__(
):
super().__init__()

if isinstance(radial_expansion_type, str):
radial_expansion_type = _get_distance_expansion(
radial_expansion_type
)

self.radial_expansion = HaddamardProduct(
radial_expansion_type(
n_features=n_radial, cutoff=cutoff, trainable=True
Expand Down Expand Up @@ -127,13 +143,70 @@ def predict_unscaled_energies(self, graph: AtomicGraph) -> torch.Tensor:

@e3nn.util.jit.compile_mode("script")
class MACE(_BaseMACE):
r"""
Vanilla MACE model.
One-hot encodings of the atomic numbers are used to condition the
``TensorProduct`` update in the residual connection of the message passing
layers.
Parameters
----------
elements
list of elements that this MACE model will be able to handle.
cutoff
radial cutoff (in Å) for the radial expansion (and message passing)
n_radial
number of bases to expand the radial distances into
radial_expansion_type
type of radial expansion to use
layers
number of message passing layers
max_ell
:math:`l_\max` for the spherical harmonics
correlation
maximum correlation order of the messages
hidden_irreps
:class:`~e3nn.o3.Irreps` string for the node features at each
message passing layer
neighbour_scaling
scaling factor used to scale the neighbour message aggregation
use_self_connection
whether to use self-connections in the message passing layers
Examples
--------
Basic usage:
.. code-block:: python
>>> from graph_pes.models import MACE
>>> from graph_pes.models.distances import Bessel
>>> model = MACE(
... elements=["H", "C", "N", "O"],
... cutoff=5.0,
... radial_expansion_type=Bessel,
... )
Specification in a YAML file:
.. code-block:: yaml
model:
graph_pes.models.MACE:
elements: [H, C, N, O]
cutoff: 5.0
hidden_irreps: "128x0e + 128x1o"
radial_expansion_type: GaussianSmearing
"""

def __init__(
self,
elements: list[str],
# radial things
cutoff: float = 5.0,
n_radial: int = 8,
radial_expansion_type: type[DistanceExpansion] = Bessel,
radial_expansion_type: type[DistanceExpansion] | str = "Bessel",
# message passing
layers: int = 2,
max_ell: int = 3,
Expand Down Expand Up @@ -162,12 +235,61 @@ def __init__(

@e3nn.util.jit.compile_mode("script")
class ZEmbeddingMACE(_BaseMACE):
r"""
MACE model that uses a learnable embedding of atomic number to
condition the ``TensorProduct`` update in the residual connection of the
message passing layers.
Parameters
----------
cutoff
radial cutoff (in Å) for the radial expansion (and message passing)
n_radial
number of bases to expand the radial distances into
radial_expansion_type
type of radial expansion to use
z_embed_dim
dimension of the atomic number embedding
layers
number of message passing layers
max_ell
:math:`l_\max` for the spherical harmonics
correlation
maximum correlation order of the messages
hidden_irreps
:class:`~e3nn.o3.Irreps` string for the node features at each
message passing layer
neighbour_scaling
scaling factor used to scale the neighbour message aggregation
use_self_connection
whether to use self-connections in the message passing layers
Examples
--------
Basic usage:
.. code-block:: python
>>> model = ZEmbeddingMACE(
... cutoff=5.0,
... )
Specification in a YAML file:
.. code-block:: yaml
model:
graph_pes.models.ZEmbeddingMACE:
cutoff: 5.0
hidden_irreps: "128x0e + 128x1o"
"""

def __init__(
self,
# radial things
cutoff: float = 5.0,
n_radial: int = 8,
radial_expansion_type: type[DistanceExpansion] = Bessel,
radial_expansion_type: type[DistanceExpansion] | str = "Bessel",
# node attributes
z_embed_dim: int = 16,
# message passing
Expand Down

0 comments on commit 3310aaa

Please sign in to comment.