Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

migrate to vesin for neighbour lists #120

Merged
merged 3 commits into from
Feb 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

Use [`vesin`](https://luthaf.fr/vesin/latest/index.html#) for accelerated neighbour list construction.

Add `ase_calculator` method to `GraphPESModel` for easy access to an ASE calculator wrapping the model.

Update the `mace` interfaces to use the default torch dtype if none is specified.

Add `ruff` check to CI.

## [0.0.22] - 2025-02-05
Expand Down
4 changes: 2 additions & 2 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ Alternatively, you can use [`uv`](https://docs.astral.sh/uv/):
```bash
git clone https://github.com/<your-username-here>/graph-pes.git
cd graph-pes
uv sync --extra test
uv sync --all-extras
```

---
Expand All @@ -34,7 +34,7 @@ Next verify the tests all pass:

```bash
pip install pytest
pytest src/ # or uv run pytest src/
pytest tests/ # or uv run pytest tests/
```

Then push your changes back to your fork of the repository:
Expand Down
1 change: 0 additions & 1 deletion docs/source/models/root.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ Models
:show-inheritance:



Loading Models
==============

Expand Down
5 changes: 2 additions & 3 deletions docs/source/quickstart/implement-a-model.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -507,10 +507,9 @@
"from ase import units\n",
"from ase.build import molecule\n",
"from ase.md.langevin import Langevin\n",
"from graph_pes.utils.calculator import GraphPESCalculator\n",
"\n",
"# set up calculator and structure\n",
"calculator = GraphPESCalculator(model)\n",
"calculator = model.ase_calculator()\n",
"structure = molecule(\"CH4\")\n",
"structure.center(vacuum=3.0) # place in a large unit cell\n",
"structure.pbc = True\n",
Expand Down Expand Up @@ -830,7 +829,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.18"
"version": "3.9.21"
}
},
"nbformat": 4,
Expand Down
4,958 changes: 2,560 additions & 2,398 deletions docs/source/tools/ase.ipynb

Large diffs are not rendered by default.

5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,13 @@ dependencies = [
"wandb",
"data2objects>=0.1.0",
"pyright>=1.1.394",
"vesin>=0.3.2",
]
requires-python = ">=3.9"


[project.optional-dependencies]
test = ["pytest", "pytest-cov"]
dev = ["ruff", "sphinx-autobuild"]
docs = [
"sphinx",
"furo",
Expand Down Expand Up @@ -129,3 +129,6 @@ filterwarnings = [
"ignore:.*The TorchScript type system doesn't support instance-level annotations on empty non-base types.*",
]
norecursedirs = "tests/helpers"

[dependency-groups]
dev = ["notebook>=7.3.2", "ruff", "sphinx-autobuild"]
6 changes: 4 additions & 2 deletions src/graph_pes/atomic_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import torch
import torch.multiprocessing
import torch.utils.data
from ase.neighborlist import neighbor_list
import vesin
from ase.stress import voigt_6_to_full_3x3_stress
from load_atoms.utils import remove_calculator
from typing_extensions import TypeAlias
Expand Down Expand Up @@ -376,7 +376,9 @@ def from_ase(
cell = torch.tensor(structure.cell.array, dtype=_float)

# neighbour list
i, j, offsets = neighbor_list("ijS", structure, cutoff)
i, j, offsets = vesin.ase_neighbor_list("ijS", structure, float(cutoff))
i = i.astype(np.int64)
j = j.astype(np.int64)
neighbour_list = torch.tensor(np.vstack([i, j]), dtype=torch.long)
neighbour_cell_offsets = torch.tensor(offsets, dtype=_float)

Expand Down
29 changes: 28 additions & 1 deletion src/graph_pes/graph_pes_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import warnings
from abc import ABC, abstractmethod
from typing import Any, Final, Sequence, final
from typing import TYPE_CHECKING, Any, Final, Sequence, final

import torch
from ase.data import chemical_symbols
Expand All @@ -23,6 +23,9 @@
from .utils.misc import differentiate, differentiate_all
from .utils.nn import PerElementParameter

if TYPE_CHECKING:
from graph_pes.utils.calculator import GraphPESCalculator


class GraphPESModel(nn.Module, ABC):
r"""
Expand Down Expand Up @@ -569,3 +572,27 @@ def extra_state(self, state: Any) -> None:
:meth:`~graph_pes.GraphPESModel.extra_state` property.
"""
pass

@torch.jit.unused
def ase_calculator(
self, device: torch.device | str | None = None, skin: float = 1.0
) -> "GraphPESCalculator":
"""
Return an ASE calculator wrapping this model. See
:class:`~graph_pes.utils.calculator.GraphPESCalculator` for more
information.

Parameters
----------
device
The device to use for the calculator. If ``None``, the device of the
model will be used.
skin
The skin to use for the neighbour list. If all atoms have moved less
than half of this distance between calls to `calculate`, the
neighbour list will be reused, saving (in some cases) significant
computation time.
"""
from graph_pes.utils.calculator import GraphPESCalculator

return GraphPESCalculator(self, device=device, skin=skin)
43 changes: 28 additions & 15 deletions src/graph_pes/interfaces/_mace.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,7 @@ def predict(
return {k: v for k, v in predictions.items() if k in properties}


def _fix_precision(model: torch.nn.Module, precision: str) -> None:
dtype = {"float32": torch.float32, "float64": torch.float64}[precision]
def _fix_dtype(model: torch.nn.Module, dtype: torch.dtype) -> None:
for tensor in chain(
model.parameters(),
model.buffers(),
Expand All @@ -142,9 +141,17 @@ def _fix_precision(model: torch.nn.Module, precision: str) -> None:
tensor.data = tensor.data.to(dtype)


def _get_dtype(
precision: Literal["float32", "float64"] | None,
) -> torch.dtype:
if precision is None:
return torch.get_default_dtype()
return {"float32": torch.float32, "float64": torch.float64}[precision]


def mace_mp(
model: Literal["small", "medium", "large"],
precision: Literal["float32", "float64"] = "float64",
precision: Literal["float32", "float64"] | None = None,
) -> MACEWrapper:
"""
Donwload a MACE-MP model and convert it for use with ``graph-pes``.
Expand All @@ -167,24 +174,29 @@ def mace_mp(
model
The size of the MACE-MP model to download.
precision
The precision of the model.
The precision of the model. If ``None``, the default precision
of torch will be used (you can set this when using ``graph-pes-train``
via ``general/torch/dtype``)
""" # noqa: E501
from mace.calculators.foundations_models import mace_mp

dtype = _get_dtype(precision)
precision_str = {torch.float32: "float32", torch.float64: "float64"}[dtype]

mace_torch_model = mace_mp(
model,
device="cpu",
default_dtype=precision,
default_dtype=precision_str,
return_raw_model=True,
)
assert isinstance(mace_torch_model, torch.nn.Module)
_fix_precision(mace_torch_model, precision)
_fix_dtype(mace_torch_model, dtype)
return MACEWrapper(mace_torch_model)


def mace_off(
model: Literal["small", "medium", "large"],
precision: Literal["float32", "float64"] = "float64",
precision: Literal["float32", "float64"] | None = None,
) -> MACEWrapper:
"""
Download a MACE-OFF model and convert it for use with ``graph-pes``.
Expand All @@ -200,19 +212,22 @@ def mace_off(
""" # noqa: E501
from mace.calculators.foundations_models import mace_off

dtype = _get_dtype(precision)
precision_str = {torch.float32: "float32", torch.float64: "float64"}[dtype]

mace_torch_model = mace_off(
model,
device="cpu",
default_dtype=precision,
default_dtype=precision_str,
return_raw_model=True,
)
assert isinstance(mace_torch_model, torch.nn.Module)
_fix_precision(mace_torch_model, precision)
_fix_dtype(mace_torch_model, dtype)
return MACEWrapper(mace_torch_model)


def go_mace_23(
precision: Literal["float32", "float64"] = "float32",
precision: Literal["float32", "float64"] | None = None,
) -> MACEWrapper:
"""
Download the `GO-MACE-23 model <https://doi.org/10.1002/anie.202410088>`__
Expand Down Expand Up @@ -241,6 +256,8 @@ def go_mace_23(

""" # noqa: E501

dtype = _get_dtype(precision)

url = "https://github.com/zakmachachi/GO-MACE-23/raw/refs/heads/main/models/fitting/potential/iter-12-final-model/go-mace-23.pt"
save_path = Path.home() / ".graph-pes" / "go-mace-23.pt"
save_path.parent.mkdir(parents=True, exist_ok=True)
Expand All @@ -258,11 +275,7 @@ def go_mace_23(
save_path, weights_only=False, map_location=torch.device("cpu")
)
for p in mace_torch_model.parameters():
p.data = p.data.to(
dtype={"float32": torch.float32, "float64": torch.float64}[
precision
]
)
p.data = p.data.to(dtype)
model = MACEWrapper(mace_torch_model)

return model
28 changes: 26 additions & 2 deletions src/graph_pes/utils/calculator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import time
import warnings
from typing import Iterable, TypeVar, overload

Expand All @@ -11,7 +12,7 @@

from graph_pes.atomic_graph import AtomicGraph, PropertyKey, has_cell, to_batch
from graph_pes.graph_pes_model import GraphPESModel
from graph_pes.utils.misc import groups_of, pairs
from graph_pes.utils.misc import groups_of, pairs, uniform_repr


class GraphPESCalculator(Calculator):
Expand Down Expand Up @@ -67,9 +68,10 @@
self._cached_cell: numpy.ndarray | None = None
self.skin = skin

# cache stats
# stats
self.cache_hits = 0
self.total_calls = 0
self.nl_timings = []

def calculate(
self,
Expand Down Expand Up @@ -125,9 +127,12 @@

# cache miss
if graph is None:
tick = time.perf_counter()
graph = AtomicGraph.from_ase(
self.atoms, self.model.cutoff.item() + self.skin
).to(self.model.device)
tock = time.perf_counter()
self.nl_timings.append(tock - tick)
self._cached_graph = graph
self._cached_R = graph.R.detach().cpu().numpy()
self._cached_cell = graph.cell.detach().cpu().numpy()
Expand Down Expand Up @@ -161,10 +166,21 @@
return 0.0
return self.cache_hits / self.total_calls

@property
def average_nl_timing(self) -> float:
"""The average time taken to calculate the neighbour list in seconds."""
return numpy.mean(self.nl_timings).item()

Check warning on line 172 in src/graph_pes/utils/calculator.py

View check run for this annotation

Codecov / codecov/patch

src/graph_pes/utils/calculator.py#L172

Added line #L172 was not covered by tests

@property
def total_nl_timing(self) -> float:
"""The total time taken to calculate the neighbour list in seconds."""
return sum(self.nl_timings)

Check warning on line 177 in src/graph_pes/utils/calculator.py

View check run for this annotation

Codecov / codecov/patch

src/graph_pes/utils/calculator.py#L177

Added line #L177 was not covered by tests

def reset_cache_stats(self):
"""Reset the :attr:`cache_hit_rate` statistic."""
self.cache_hits = 0
self.total_calls = 0
self.nl_timings = []

Check warning on line 183 in src/graph_pes/utils/calculator.py

View check run for this annotation

Codecov / codecov/patch

src/graph_pes/utils/calculator.py#L183

Added line #L183 was not covered by tests

def calculate_all(
self,
Expand Down Expand Up @@ -224,6 +240,14 @@

return results

def __repr__(self):
return uniform_repr(
self.__class__.__name__,
model=self.model,
device=self.model.device,
skin=self.skin,
)


## utils ##

Expand Down
2 changes: 1 addition & 1 deletion tests/utils/test_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def test_calc():


def test_calc_all():
calc = GraphPESCalculator(LennardJones())
calc = LennardJones().ase_calculator()
molecules = [molecule(s) for s in "CH4 H2O CH3CH2OH C2H6".split()]

# add cell info so we can test stresses
Expand Down
Loading
Loading