Skip to content

Commit

Permalink
make graph sum operations more general
Browse files Browse the repository at this point in the history
  • Loading branch information
jla-gardner committed Mar 6, 2024
1 parent 9961b16 commit 18e9f98
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 21 deletions.
76 changes: 68 additions & 8 deletions src/graph_pes/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@
from ase.neighborlist import neighbor_list
from torch import Tensor
from torch.utils.data import DataLoader as TorchDataLoader
from torch_geometric.utils import scatter
from typing_extensions import TypeAlias

from graph_pes.nn import left_aligned_mul

from . import keys
from .graph_typing import AtomicGraph as AtomicGraphType
from .graph_typing import AtomicGraphBatch as AtomicGraphBatchType
Expand Down Expand Up @@ -235,6 +236,53 @@ def convert_to_atomic_graphs(
]


def sum_over_neighbours(p: Tensor, graph: AtomicGraph) -> Tensor:
r"""
Sum a per-edge property, :math:`p^e_{ij}` over neighbours to get a
per-atom property, :math:`p_i`:
..math::
p_i = \sum_{j \in \mathcal{N}_i} p^e_{ij}
where :math:`p_i \in \mathbb{R}^{a \times b \times \ldots}`, i.e.
supports broadcasting over arbitrary tensor shapes. In all cases,
if :math:`|\mathcal{N}_i| = 0`, then
:math:`p_i = 0^{a \times b \times \dots}`.
Parameters
----------
p
The per-edge property to sum.
graph
The graph to sum the property for.
"""
N = number_of_atoms(graph)
central_atoms = graph[keys.NEIGHBOUR_INDEX][0] # shape: (E,)

# optimised implementations for common cases
if p.dim() == 1:
zeros = torch.zeros(N, dtype=p.dtype, device=p.device)
return zeros.scatter_add(0, central_atoms, p)

elif p.dim() == 2:
C = p.shape[1]
zeros = torch.zeros(N, C, dtype=p.dtype, device=p.device)
return zeros.scatter_add(0, central_atoms.unsqueeze(1).expand(-1, C), p)

shape = (N,) + p.shape[1:]
zeros = torch.zeros(shape, dtype=p.dtype, device=p.device)

if p.shape[0] == 0:
# return all zeros if there are no atoms
return zeros

# create `index`, where index[e].shape = p.shape[1:]
# and (index[e] == central_atoms[e]).all()
ones = torch.ones_like(zeros)
index = left_aligned_mul(ones, central_atoms).long()
return zeros.scatter_add(0, index, p)


#### BATCHING ####


Expand Down Expand Up @@ -291,8 +339,21 @@ def batch_graphs(graphs: list[AtomicGraph]) -> AtomicGraphBatch:


def sum_per_structure(x: Tensor, graph: AtomicGraph) -> Tensor:
"""
Sum a per-atom property to get a per-structure property.
r"""
Sum a per-atom property, :math:`p` to get a per-structure property,
:math:`P`:
If a single structure is present, then:
..math::
P = \sum_i p_i
If a batch of structures is present, then:
..math::
P_s = \sum_{i \in S} p_i
where :math:`S` is the collection of all atoms in structure :math:`s`.
Parameters
----------
Expand All @@ -303,13 +364,12 @@ def sum_per_structure(x: Tensor, graph: AtomicGraph) -> Tensor:
"""

if is_batch(graph):
# we have more than one structure: sum over local energies to
# get a total energy for each structure
batch = graph[keys.BATCH] # type: ignore
return scatter(x, batch, dim=0, reduce="sum")
shape = (number_of_structures(graph),) + x.shape[1:]
zeros = torch.zeros(shape, dtype=x.dtype, device=x.device)
return zeros.scatter_add(0, batch, x)
else:
# we only have one structure: sum over all the atoms
return x.sum()
return x.sum(dim=0)


def number_of_structures(batch: AtomicGraph) -> int:
Expand Down
8 changes: 3 additions & 5 deletions src/graph_pes/models/pairwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
AtomicGraphBatch,
keys,
neighbour_distances,
sum_over_neighbours,
)
from graph_pes.transform import PerAtomShift
from jaxtyping import Float
Expand Down Expand Up @@ -76,13 +77,10 @@ def predict_local_energies(self, graph: AtomicGraph) -> Tensor:

V = self.interaction(
distances.view(-1, 1), Z_i.view(-1, 1), Z_j.view(-1, 1)
)
) # (E, 1)

# sum over the neighbours
energies = torch.zeros_like(
graph[keys.ATOMIC_NUMBERS], dtype=torch.float
)
energies.scatter_add_(0, central_atoms, V.squeeze())
energies = sum_over_neighbours(V.squeeze(), graph)

# divide by 2 to avoid double counting
return energies / 2
Expand Down
29 changes: 21 additions & 8 deletions tests/data/test_atomic_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import numpy as np
import pytest
import torch
from ase import Atoms
from graph_pes.data import (
AtomicGraph,
Expand All @@ -11,17 +12,19 @@
neighbour_vectors,
number_of_atoms,
number_of_edges,
sum_over_neighbours,
)

ISOLATED_ATOM = Atoms("H", positions=[(0, 0, 0)], pbc=False)
PERIODIC_ATOM = Atoms("H", positions=[(0, 0, 0)], pbc=True, cell=(1, 1, 1))
DIMER = Atoms("H2", positions=[(0, 0, 0), (0, 0, 1)], pbc=False)
RANDOM_STRUCTURE = Atoms(
"H8",
positions=np.random.RandomState(42).rand(8, 3),
pbc=True,
cell=np.eye(3),
)
STRUCTURES = [ISOLATED_ATOM, PERIODIC_ATOM, RANDOM_STRUCTURE]
STRUCTURES = [ISOLATED_ATOM, PERIODIC_ATOM, DIMER, RANDOM_STRUCTURE]
GRAPHS = convert_to_atomic_graphs(STRUCTURES, cutoff=1.0)


Expand Down Expand Up @@ -57,13 +60,6 @@ def test_random_structure(cutoff: int):
assert neighbour_distances(graph).max() <= cutoff


# def test_warning_on_position():
# # check that a warning is raised if the user tries to access the positions
# # directly for a structure with a unit cell
# with pytest.warns(UserWarning):
# _ = GRAPHS[1].positions


def test_get_labels():
atoms = Atoms("H2", positions=[(0, 0, 0), (0, 0, 1)], pbc=False)
atoms.info["energy"] = -1.0
Expand All @@ -78,3 +74,20 @@ def test_get_labels():

with pytest.raises(KeyError):
graph["missing"]


# in each of these structures, each atom has the same number of neighbours,
# making it easy to test that the sum over neighbours is correct
@pytest.mark.parametrize("structure", [ISOLATED_ATOM, DIMER, PERIODIC_ATOM])
def test_sum_over_neighbours(structure):
graph = convert_to_atomic_graph(structure, cutoff=1.1)
N = number_of_atoms(graph)
E = number_of_edges(graph)

n_neighbours = (graph["neighbour_index"][0] == 0).sum()

for shape in [(E,), (E, 2), (E, 2, 3), (E, 2, 2, 2)]:
edge_property = torch.ones(shape)
result = sum_over_neighbours(edge_property, graph)
assert result.shape == (N, *shape[1:])
assert (result == n_neighbours).all()
6 changes: 6 additions & 0 deletions tests/data/test_batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,12 @@ def test_sum_per_structure():
x = torch.tensor([1, 2, 3, 4, 5])
assert sum_per_structure(x, batch).tolist() == [3, 12]

# and also for general sizes

x = torch.ones(2, 3, 4)
result = sum_per_structure(x, graphs[0])
assert result.shape == (3, 4)


def test_data_loader():
loader = AtomicDataLoader(GRAPHS, batch_size=2)
Expand Down

0 comments on commit 18e9f98

Please sign in to comment.