Skip to content

Commit

Permalink
Update energy and force transformations
Browse files Browse the repository at this point in the history
  • Loading branch information
jla-gardner committed Jan 15, 2024
1 parent 8c22190 commit c171952
Show file tree
Hide file tree
Showing 6 changed files with 54 additions and 29 deletions.
3 changes: 1 addition & 2 deletions src/graph_pes/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
Chain,
Identity,
PerAtomScale,
PerAtomShift,
Transform,
)
from graph_pes.util import Keys, differentiate, require_grad
Expand Down Expand Up @@ -148,7 +147,7 @@ def __init__(
# if both None, default to a per-species, local energy offset
if local_transform is None and total_transform is None:
local_transform = Chain(
[PerAtomShift(), PerAtomScale()], trainable=True
[PerAtomScale(), PerAtomScale()], trainable=True
)
self.local_transform: Transform = local_transform or Identity()
self.total_transform: Transform = total_transform or Identity()
Expand Down
6 changes: 3 additions & 3 deletions src/graph_pes/models/pairwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from graph_pes.data import AtomicGraph
from graph_pes.data.batching import AtomicGraphBatch
from graph_pes.nn import MLP, PositiveParameter
from graph_pes.transform import PerAtomShift
from graph_pes.transform import PerAtomScale
from jaxtyping import Float
from torch import Tensor, nn
from torch_geometric.utils import scatter
Expand Down Expand Up @@ -104,7 +104,7 @@ def __init__(self):

# epsilon is a scaling term, so only need to learn a shift
# parameter (rather than a shift and scale)
self._energy_summation = EnergySummation(local_transform=PerAtomShift())
self._energy_summation = EnergySummation(local_transform=PerAtomScale())

def interaction(
self, r: torch.Tensor, Z_i: torch.Tensor, Z_j: torch.Tensor
Expand Down Expand Up @@ -166,7 +166,7 @@ def __init__(self):

# D is a scaling term, so only need to learn a shift
# parameter (rather than a shift and scale)
self._energy_summation = EnergySummation(local_transform=PerAtomShift())
self._energy_summation = EnergySummation(local_transform=PerAtomScale())

def interaction(
self, r: torch.Tensor, Z_i: torch.Tensor, Z_j: torch.Tensor
Expand Down
4 changes: 2 additions & 2 deletions src/graph_pes/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from .data import AtomicGraph
from .data.batching import AtomicDataLoader, AtomicGraphBatch
from .loss import RMSE, Loss, WeightedLoss
from .transform import Chain, PerAtomScale, PerAtomShift, Scale
from .transform import Chain, PerAtomScale, Scale
from .util import Keys


Expand Down Expand Up @@ -207,7 +207,7 @@ def get_loss(
) -> WeightedLoss:
if loss is None:
default_transforms = {
Keys.ENERGY: Chain([PerAtomScale(), PerAtomShift()]),
Keys.ENERGY: Chain([PerAtomScale(), PerAtomScale()]),
Keys.FORCES: PerAtomScale(),
Keys.STRESS: Scale(),
}
Expand Down
46 changes: 27 additions & 19 deletions src/graph_pes/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,10 +261,11 @@ def forward(
The input data, shifted by the learned shift.
"""
shifts = self.shift[graph.Z].squeeze()
if graph.is_local_property(x):
return x - shifts
else:
return x - sum_per_structure(shifts, graph)
if not graph.is_local_property(x):
shifts = sum_per_structure(shifts, graph)

ndims = len(x.shape)
return x - shifts.view(-1, *([1] * (ndims - 1)))

def inverse(
self,
Expand All @@ -291,10 +292,11 @@ def inverse(
The batch of atomic graphs.
"""
shifts = self.shift[graph.Z].squeeze()
if graph.is_local_property(x):
return x + shifts
else:
return x + sum_per_structure(shifts, graph)
if not graph.is_local_property(x):
shifts = sum_per_structure(shifts, graph)

ndims = len(x.shape)
return x + shifts.view(-1, *([1] * (ndims - 1)))

def __repr__(self):
return self.shift.__repr__().replace(
Expand Down Expand Up @@ -325,12 +327,13 @@ class PerAtomScale(Transform):
scale is fixed.
"""

def __init__(self, trainable: bool = True):
def __init__(self, trainable: bool = True, act_on_norms: bool = True):
super().__init__(trainable=trainable)
self.scales = PerSpeciesParameter.of_dim(
dim=1, requires_grad=trainable, generator=1
)
"""The fitted, per-species scales (variances)."""
self.act_on_norms = act_on_norms

@torch.no_grad()
def fit(self, x: LocalProperty | GlobalProperty, graphs: AtomicGraphBatch):
Expand Down Expand Up @@ -358,6 +361,9 @@ def fit(self, x: LocalProperty | GlobalProperty, graphs: AtomicGraphBatch):
)
zs = torch.unique(graphs.Z)

if self.act_on_norms:
x = x.norm(dim=-1)

if graphs.is_local_property(x):
# we have one data point per atom in the batch
# we therefore fit the scale to be the variance of x
Expand Down Expand Up @@ -399,11 +405,12 @@ def forward(
Shaped[Tensor, "shape ..."]
The input data, scaled by the learned scale.
"""
if graph.is_local_property(x):
return x / self.scales[graph.Z].squeeze() ** 0.5
else:
var = sum_per_structure(self.scales[graph.Z].squeeze(), graph)
return x / var**0.5
scales = self.scales[graph.Z]
if not graph.is_local_property(x):
scales = sum_per_structure(scales, graph)

ndims = len(x.shape)
return x / scales.view(-1, *([1] * (ndims - 1))) ** 0.5

def inverse(
self,
Expand Down Expand Up @@ -433,11 +440,12 @@ def inverse(
Shaped[Tensor, "shape ..."]
The input data, scaled by the inverse of the learned scale.
"""
if graph.is_local_property(x):
return x * self.scales[graph.Z].squeeze() ** 0.5
else:
var = sum_per_structure(self.scales[graph.Z].squeeze(), graph)
return x * var**0.5
scales = self.scales[graph.Z]
if not graph.is_local_property(x):
scales = sum_per_structure(scales, graph)

ndims = len(x.shape)
return x * scales.view(-1, *([1] * (ndims - 1))) ** 0.5

def __repr__(self):
return self.scales.__repr__().replace(
Expand Down
5 changes: 4 additions & 1 deletion tests/data/test_batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,18 +38,21 @@ def test_label_batching():

# per-atom labels:
structures[0].arrays["atom_label"] = [0, 1]
structures[0].arrays["force"] = np.zeros((2, 3))
structures[1].arrays["atom_label"] = [2, 3, 4]
structures[1].arrays["force"] = np.zeros((3, 3))

graphs = convert_to_atomic_graphs(structures, cutoff=1.5)
batch = AtomicGraphBatch.from_graphs(graphs)

# per-structure, array-type labels are concatenated along a new batch axis
assert batch.structure_labels["stress"].shape == (2, 3, 3)
# per-structure, scalar-type labels are concatenated
# energy is a scalar, so the "new batch axis" is just concatenation:
assert batch.structure_labels["label"].tolist() == [0, 1]

# per-atom labels are concatenated along the first axis
assert batch.atom_labels["atom_label"].tolist() == [0, 1, 2, 3, 4]
assert batch.atom_labels["force"].shape == (5, 3)


def test_pbcs():
Expand Down
19 changes: 17 additions & 2 deletions tests/test_transform.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import numpy as np
import torch
from ase import Atoms
from graph_pes.data import convert_to_atomic_graph, convert_to_atomic_graphs
from graph_pes.data.batching import AtomicGraphBatch
from graph_pes.transform import Identity, PerAtomShift
from graph_pes.transform import Identity, PerAtomScale, PerAtomShift

structure = Atoms("H2", positions=[(0, 0, 0), (0, 0, 1)])
structure.info["energy"] = -1.0
Expand Down Expand Up @@ -40,6 +41,7 @@ def test_per_atom_transforms():
atoms.info["energy"] = n_H * H_energy + n_C * C_energy
local_prop = [H_energy] * n_H + [C_energy] * n_C
atoms.arrays["local_prop"] = local_prop
atoms.arrays["forces"] = np.zeros((n_H + n_C, 3))
structures.append(atoms)

graphs = convert_to_atomic_graphs(structures, cutoff=1.5)
Expand All @@ -49,7 +51,11 @@ def test_per_atom_transforms():
shift = PerAtomShift(trainable=False)
total_energies = batch["energy"]
shift.fit(total_energies, batch)
shifted_total_energies = shift(total_energies, batch)

# shape preservation
assert shifted_total_energies.shape == total_energies.shape
# learn the correct shifts
assert torch.allclose(
shift.shift[torch.tensor([1, 6])].detach().squeeze(),
torch.tensor([H_energy, C_energy]),
Expand All @@ -68,7 +74,8 @@ def test_per_atom_transforms():
local_energies = batch["local_prop"]

shift.fit(local_energies, batch)

shifted_local_energies = shift(local_energies, batch)
assert shifted_local_energies.shape == local_energies.shape
assert torch.allclose(
shift.shift[torch.tensor([1, 6])].detach().squeeze(),
torch.tensor([H_energy, C_energy]),
Expand All @@ -81,3 +88,11 @@ def test_per_atom_transforms():
atol=1e-5,
)
assert not centered_local_energy.requires_grad

# test scaling forces
scale = PerAtomScale(trainable=False)
forces = batch["forces"]
scale.fit(forces, batch)

scaled_forces = scale(forces, batch)
assert scaled_forces.shape == forces.shape

0 comments on commit c171952

Please sign in to comment.