Skip to content

Commit

Permalink
changes
Browse files Browse the repository at this point in the history
  • Loading branch information
jla-gardner committed Apr 1, 2024
1 parent 20c6cb3 commit 257961d
Show file tree
Hide file tree
Showing 9 changed files with 96 additions and 46 deletions.
14 changes: 14 additions & 0 deletions .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,20 @@ on: [push]
permissions:
contents: read
jobs:
formatting:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v3
with:
python-version: 3.9
- name: update pip
run: pip install --upgrade pip
- name: Install ruff
run: pip install ruff
- name: Run ruff
run: ruff check

tests:
runs-on: ubuntu-latest
steps:
Expand Down
25 changes: 14 additions & 11 deletions src/graph_pes/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,20 @@ class GraphPESModel(nn.Module, ABC):
per-species scale and shift.
"""

def __init__(self, energy_transform: Transform | None = None):
super().__init__()

self.energy_transform: Transform = (
PerAtomStandardScaler()
if energy_transform is None
else energy_transform
)

# save as a buffer so that this is saved and loaded
# with the model
self._has_been_pre_fit: Tensor
self.register_buffer("_has_been_pre_fit", torch.tensor(False))

def forward(self, graph: AtomicGraph) -> Tensor:
"""
Calculate the total energy of the structure.
Expand Down Expand Up @@ -89,17 +103,6 @@ def predict_local_energies(self, graph: AtomicGraph) -> Tensor:
The per-atom local energy predictions with shape :code:`(N,)`.
"""

def __init__(self, energy_transform: Transform | None = None):
super().__init__()
self.energy_transform: Transform = (
PerAtomStandardScaler()
if energy_transform is None
else energy_transform
)

self._has_been_pre_fit: Tensor
self.register_buffer("_has_been_pre_fit", torch.tensor(False))

def pre_fit(
self,
graphs: LabelledBatch | Sequence[LabelledGraph],
Expand Down
6 changes: 5 additions & 1 deletion src/graph_pes/models/painn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
number_of_atoms,
)
from graph_pes.nn import MLP, HaddamardProduct, PerElementEmbedding
from graph_pes.transform import Transform
from torch import Tensor, nn

from .distances import Bessel, PolynomialEnvelope
Expand Down Expand Up @@ -181,6 +182,8 @@ class PaiNN(GraphPESModel):
The number of (interaction + update) layers to use.
cutoff
The cutoff distance for the radial features.
energy_transform
The energy transform to use (defaults to PerAtomStandardScaler)
""" # noqa: E501

def __init__(
Expand All @@ -189,8 +192,9 @@ def __init__(
radial_features: int = 20,
layers: int = 3,
cutoff: float = 5.0,
energy_transform: Transform | None = None,
):
super().__init__()
super().__init__(energy_transform)
self.internal_dim = internal_dim
self.layers = layers
self.interactions: list[Interaction] = nn.ModuleList(
Expand Down
40 changes: 26 additions & 14 deletions src/graph_pes/models/pairwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
sum_over_neighbours,
)
from graph_pes.nn import PerElementParameter
from graph_pes.transform import PerAtomShift
from graph_pes.transform import PerAtomShift, Transform
from graph_pes.util import pytorch_repr, to_significant_figures
from jaxtyping import Float
from torch import Tensor
Expand Down Expand Up @@ -119,9 +119,16 @@ class LennardJones(PairPotential):
:align: center
"""

def __init__(self, epsilon: float = 0.1, sigma: float = 1.0):
def __init__(
self,
epsilon: float = 0.1,
sigma: float = 1.0,
energy_transform: Transform | None = None,
):
# epsilon is a scaling term, so only need to learn a shift
super().__init__(energy_transform=PerAtomShift())
if energy_transform is None:
energy_transform = PerAtomShift()
super().__init__(energy_transform)

self._log_epsilon = torch.nn.Parameter(torch.tensor(epsilon).log())
self._log_sigma = torch.nn.Parameter(torch.tensor(sigma).log())
Expand Down Expand Up @@ -199,10 +206,17 @@ class Morse(PairPotential):
:align: center
"""

def __init__(self, D: float = 0.1, a: float = 5.0, r0: float = 1.5):
def __init__(
self,
D: float = 0.1,
a: float = 5.0,
r0: float = 1.5,
energy_transform: Transform | None = None,
):
# D is a scaling term, so only need to learn a shift
# parameter (rather than a shift and scale)
super().__init__(energy_transform=PerAtomShift())
if energy_transform is None:
energy_transform = PerAtomShift()
super().__init__(energy_transform)

self._log_D = torch.nn.Parameter(torch.tensor(D).log())
self._log_a = torch.nn.Parameter(torch.tensor(a).log())
Expand All @@ -220,9 +234,7 @@ def a(self):
def r0(self):
return self._log_r0.exp()

def interaction(
self, r: torch.Tensor, Z_i: torch.Tensor, Z_j: torch.Tensor
):
def interaction(self, r: torch.Tensor, Z_i=None, Z_j=None):
"""
Evaluate the pair potential.
Expand Down Expand Up @@ -304,16 +316,16 @@ def interaction(self, r: Tensor, Z_i: Tensor, Z_j: Tensor) -> Tensor:
Z_j : torch.Tensor
The atomic numbers of the neighbours.
"""
cross_interaction = Z_i != Z_j
cross_interaction = Z_i != Z_j # (E)

sigma_j = self.sigma[Z_j].squeeze()
sigma_i = self.sigma[Z_i].squeeze()
sigma_j = self.sigma[Z_j].squeeze() # (E)
sigma_i = self.sigma[Z_i].squeeze() # (E)
nu = self.nu[Z_i, Z_j].squeeze() if self.modulate_distances else 1
sigma = torch.where(
cross_interaction,
nu * (sigma_i + sigma_j) / 2,
sigma_i,
).clamp(min=0.2)
).clamp(min=0.2) # (E)

epsilon_i = self.epsilon[Z_i].squeeze()
epsilon_j = self.epsilon[Z_j].squeeze()
Expand All @@ -322,7 +334,7 @@ def interaction(self, r: Tensor, Z_i: Tensor, Z_j: Tensor) -> Tensor:
cross_interaction,
zeta * (epsilon_i * epsilon_j).sqrt(),
epsilon_i,
).clamp(min=0.00)
).clamp(min=0.00) # (E)

x = sigma / r
return 4 * epsilon * (x**12 - x**6)
Expand Down
4 changes: 3 additions & 1 deletion src/graph_pes/models/schnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from graph_pes.core import GraphPESModel
from graph_pes.data import AtomicGraph, neighbour_distances
from graph_pes.nn import MLP, PerElementEmbedding, ShiftedSoftplus
from graph_pes.transform import Transform
from torch import Tensor, nn
from torch_geometric.nn import MessagePassing

Expand Down Expand Up @@ -213,8 +214,9 @@ def __init__(
cutoff: float = 5.0,
layers: int = 3,
expansion: type[DistanceExpansion] | None = None,
energy_transform: Transform | None = None,
):
super().__init__()
super().__init__(energy_transform)

if expansion is None:
expansion = GaussianSmearing
Expand Down
4 changes: 3 additions & 1 deletion src/graph_pes/models/tensornet.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import torch
from graph_pes.transform import Transform
from torch import Tensor, nn

from ..core import GraphPESModel
Expand Down Expand Up @@ -320,8 +321,9 @@ def __init__(
embedding_size: int = 32,
cutoff: float = 5.0,
layers: int = 1,
energy_transform: Transform | None = None,
):
super().__init__()
super().__init__(energy_transform)
self.embedding = Embedding(radial_features, embedding_size, cutoff)
self.interactions: list[Interaction] = nn.ModuleList(
[
Expand Down
16 changes: 0 additions & 16 deletions src/graph_pes/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,16 +131,6 @@ def prod(iterable):
return reduce(lambda x, y: x * y, iterable, 1)


# # Metaclass to combine _TensorMeta and the instance check override
# # for PerElementParameter (see equivalent in torch.nn.Parameter)
# class _PerElementParameterMeta(torch._C._TensorMeta):
# def __instancecheck__(self, instance):
# return super().__instancecheck__(instance) or (
# isinstance(instance, torch.Tensor)
# and getattr(instance, "_is_per_element_param", False)
# )


class PerElementParameter(torch.nn.Parameter):
def __new__(
cls, data: Tensor, requires_grad: bool = True
Expand Down Expand Up @@ -224,12 +214,6 @@ def __instancecheck__(self, instance) -> bool:

@torch.no_grad()
def __repr__(self) -> str:
# TODO implement custom repr for different shapes
# 1 index dimension:
# table with header column for Z
# 2 index dimensions with singleton further shape:
# 2D table with Z as both header row and header column
# print numel otherwise
if len(self._accessed_Zs) == 0:
return (
f"PerElementParameter(index_dims={self._index_dims}, "
Expand Down
2 changes: 1 addition & 1 deletion src/graph_pes/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ def Adam(
energy_transform_overrides: dict | None = None,
) -> Callable[[GraphPESModel], torch.optim.Optimizer]:
if energy_transform_overrides is None:
energy_transform_overrides = {"weight": 1.0}
energy_transform_overrides = {"weight_decay": 0.0}

def adam(model: GraphPESModel) -> torch.optim.Optimizer:
model_params = [
Expand Down
31 changes: 30 additions & 1 deletion src/graph_pes/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from abc import ABC, abstractmethod

import torch
from ase.data import atomic_numbers as symbol_to_Z
from torch import Tensor, nn

from graph_pes.data import (
Expand Down Expand Up @@ -318,7 +319,7 @@ def inverse(self, y: Tensor, graph: AtomicGraph) -> Tensor:
return left_aligned_add(y, shifts)

def __repr__(self):
return f"PerAtomShift({self.shift})"
return f"{self.__class__.__name__}({self.shift})"


class PerAtomScale(Transform):
Expand Down Expand Up @@ -458,6 +459,34 @@ def PerAtomStandardScaler(trainable: bool = True) -> Transform:
return Chain([PerAtomScale(trainable), PerAtomShift(trainable)])


class FixedEnergyOffsets(PerAtomShift):
r"""
A convenience function for a :class:`PerAtomShift` transform with fixed
species-dependent shifts.
Parameters
----------
offsets
The fixed shifts to apply to each species.
Examples
--------
>>> FixedEnergyOffsets(H=-0.1, Pt=-13.14)
"""

def __init__(self, **offsets: float):
super().__init__(trainable=False)
for symbol, offset in offsets.items():
assert symbol in symbol_to_Z, f"Unknown element: {symbol}"
self.shift[symbol_to_Z[symbol]] = offset

def fit_to_source(self, x: Tensor, graphs: AtomicGraphBatch):
pass

def fit_to_target(self, y: Tensor, graphs: AtomicGraphBatch):
pass


class Scale(Transform):
def __init__(self, trainable: bool = True, scale: float | int = 1.0):
super().__init__(trainable=trainable)
Expand Down

0 comments on commit 257961d

Please sign in to comment.