Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

clebsch gordan submodule - implementation of TorchScript interface #269

Merged
merged 24 commits into from
Feb 21, 2024
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
3f8369f
initalize clebsch_gordan submodule in rascaline.torch
agoscinski Dec 29, 2023
e67d917
checkpoint all-deps and all-deps-torch tests passing
agoscinski Feb 13, 2024
5caaec8
all-deps all-deps-torch pass
agoscinski Feb 14, 2024
155f36c
change ClebschGordanReal to TensorMap
agoscinski Feb 14, 2024
3c3961b
adding for torch backend
agoscinski Feb 15, 2024
358d170
fixing TorchScript
agoscinski Feb 15, 2024
f94b8c4
fix dispatch and refactor tests
agoscinski Feb 15, 2024
2d621b6
remove _dispatch.max_axis not needed
agoscinski Feb 15, 2024
25f7ad3
add tests for properties of DensityCorrelations
agoscinski Feb 15, 2024
39aee61
simplify _parse_selected_keys, now it does not need to be scritable
agoscinski Feb 15, 2024
f6d88ef
Make CG cache contiguous, fix some
jwa7 Feb 15, 2024
3513558
Remove comment block
jwa7 Feb 15, 2024
1504c19
Update python/rascaline/rascaline/utils/clebsch_gordan/_clebsch_gorda…
agoscinski Feb 16, 2024
d024c7f
Update python/rascaline/rascaline/utils/clebsch_gordan/correlate_dens…
agoscinski Feb 16, 2024
5a69107
Test save/load for checking contiguous. Clean up. Docstring arg.
jwa7 Feb 16, 2024
3fb4ef7
Merge branch 'master' into cg-torchscript
jwa7 Feb 16, 2024
a45b73c
Partial resolution of review comments
jwa7 Feb 17, 2024
7c14c22
Get rid of __all__
jwa7 Feb 19, 2024
be331f1
Fix Python import
Luthaf Feb 19, 2024
993f005
Add CG to API docs
Luthaf Feb 19, 2024
f4337f4
linter noqa
jwa7 Feb 19, 2024
96cee2a
Make the CG cache a function not a class. Fix the mops CG cache and i…
jwa7 Feb 19, 2024
7c77adc
Review round 2
jwa7 Feb 21, 2024
44ded8c
Final review comment
jwa7 Feb 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion python/rascaline-torch/rascaline/torch/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os

from . import clebsch_gordan
from .power_spectrum import PowerSpectrum


Expand All @@ -10,4 +11,4 @@
Path containing the CMake configuration files for the underlying C library
"""

__all__ = ["PowerSpectrum"]
__all__ = ["PowerSpectrum", "clebsch_gordan"]
jwa7 marked this conversation as resolved.
Show resolved Hide resolved
73 changes: 73 additions & 0 deletions python/rascaline-torch/rascaline/torch/utils/clebsch_gordan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import importlib
import os
import sys
from typing import Any

import torch
from metatensor.torch import Labels, LabelsEntry, TensorBlock, TensorMap

import rascaline.utils.clebsch_gordan


# For details what is happening here take a look an `rascaline.torch.calculators`.

# Step 1: create te `_classes` module as an empty module
jwa7 marked this conversation as resolved.
Show resolved Hide resolved
spec = importlib.util.spec_from_loader(
"rascaline.torch.utils.clebsch_gordan._classes",
jwa7 marked this conversation as resolved.
Show resolved Hide resolved
loader=None,
)
module = importlib.util.module_from_spec(spec)
# This module only exposes a handful of things, defined here. Any changes here MUST also
# be made to the `metatensor/operations/_classes.py` file, which is used in non
# TorchScript mode.
module.__dict__["Labels"] = Labels
module.__dict__["TensorBlock"] = TensorBlock
module.__dict__["TensorMap"] = TensorMap
module.__dict__["LabelsEntry"] = LabelsEntry
module.__dict__["torch_jit_is_scripting"] = torch.jit.is_scripting
module.__dict__["torch_jit_annotate"] = torch.jit.annotate
module.__dict__["torch_jit_export"] = torch.jit.export
module.__dict__["TorchTensor"] = torch.Tensor
module.__dict__["TorchModule"] = torch.nn.Module
module.__dict__["TorchScriptClass"] = torch.ScriptClass
module.__dict__["Array"] = torch.Tensor


def is_labels(obj: Any):
return isinstance(obj, Labels)


if os.environ.get("RASCALINE_IMPORT_FOR_SPHINX") is None:
is_labels = torch.jit.script(is_labels)

module.__dict__["is_labels"] = is_labels


def check_isinstance(obj, ty):
if isinstance(ty, torch.ScriptClass):
# This branch is taken when `ty` is a custom class (TensorMap, …). since `ty` is
# an instance of `torch.ScriptClass` and not a class itself, there is no way to
# check if obj is an "instance" of this class, so we always return True and hope
# for the best. Most errors should be caught by the TorchScript compiler anyway.
return True
else:
assert isinstance(ty, type)
return isinstance(obj, ty)


module.__dict__["check_isinstance"] = check_isinstance

# register the module in sys.modules, so future import find it directly
sys.modules[spec.name] = module


# Step 2: create a module named `rascaline.torch.utils.clebsch_gordan` using code from
# `rascaline.utils.clebsch_gordan`
spec = importlib.util.spec_from_file_location(
"rascaline.torch.utils.clebsch_gordan",
rascaline.utils.clebsch_gordan.__file__,
)

module = importlib.util.module_from_spec(spec)
sys.modules[spec.name] = module
spec.loader.exec_module(module)
109 changes: 109 additions & 0 deletions python/rascaline-torch/tests/utils/correlate_density.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# -*- coding: utf-8 -*-
import io
import os
from typing import Any, List

import ase.io
import metatensor.torch
import pytest
import torch
from metatensor.torch import Labels, TensorBlock, TensorMap # noqa

import rascaline.torch
from rascaline.torch.utils.clebsch_gordan.correlate_density import DensityCorrelations


DATA_ROOT = os.path.join(os.path.dirname(__file__), "data")


def is_tensor_map(obj: Any):
return isinstance(obj, TensorMap)


is_tensor_map = torch.jit.script(is_tensor_map)
jwa7 marked this conversation as resolved.
Show resolved Hide resolved

SPHEX_HYPERS = {
"cutoff": 2.5, # Angstrom
"max_radial": 3, # Exclusive
"max_angular": 3, # Inclusive
jwa7 marked this conversation as resolved.
Show resolved Hide resolved
"atomic_gaussian_width": 0.2,
"radial_basis": {"Gto": {}},
"cutoff_function": {"ShiftedCosine": {"width": 0.5}},
"center_atom_weight": 1.0,
}


def h2o_isolated():
return ase.io.read(os.path.join(DATA_ROOT, "h2o_isolated.xyz"), ":")


def spherical_expansion(frames: List[ase.Atoms]):
"""Returns a rascaline SphericalExpansion"""
calculator = rascaline.torch.SphericalExpansion(**SPHEX_HYPERS)
return calculator.compute(rascaline.torch.systems_to_torch(frames))


# copy of def test_correlate_density_angular_selection(
@pytest.mark.parametrize(
"selected_keys",
[
None,
Labels(
jwa7 marked this conversation as resolved.
Show resolved Hide resolved
names=["spherical_harmonics_l"], values=torch.tensor([1, 3]).reshape(-1, 1)
),
],
)
@pytest.mark.parametrize("skip_redundant", [True, False])
def test_torch_script_correlate_density_angular_selection(
selected_keys: Labels,
skip_redundant: bool,
):
"""
Tests that the correct angular channels are output based on the specified
``selected_keys``.
"""
frames = h2o_isolated()
nu_1 = spherical_expansion(frames)
correlation_order = 2
corr_calculator = DensityCorrelations(
max_angular=SPHEX_HYPERS["max_angular"] * correlation_order,
correlation_order=correlation_order,
angular_cutoff=None,
selected_keys=selected_keys,
skip_redundant=skip_redundant,
)

ref_nu_2 = corr_calculator.compute(nu_1)
scripted_corr_calculator = torch.jit.script(corr_calculator)

# Test compute
scripted_nu_2 = scripted_corr_calculator.compute(nu_1)
assert metatensor.torch.equal_metadata(scripted_nu_2, ref_nu_2)
assert metatensor.torch.allclose(scripted_nu_2, ref_nu_2)

# Test compute_metadata
scripted_nu_2 = scripted_corr_calculator.compute_metadata(nu_1)
assert metatensor.torch.equal_metadata(scripted_nu_2, ref_nu_2)

# Test if properties are accesible
assert isinstance(corr_calculator.correlation_order, int)
assert isinstance(corr_calculator.selected_keys, list)
assert isinstance(corr_calculator.skip_redundant, list)
assert isinstance(corr_calculator.output_selection, list)
assert isinstance(corr_calculator.arrays_backend, str)
assert isinstance(corr_calculator.cg_backend, str)
assert is_tensor_map(corr_calculator.cg_coeffs)


def test_save_load():
corr_calculator = DensityCorrelations(
max_angular=2,
correlation_order=2,
angular_cutoff=1,
)
scripted_correlate_density = torch.jit.script(corr_calculator)
buffer = io.BytesIO()
torch.jit.save(scripted_correlate_density, buffer)
buffer.seek(0)
torch.jit.load(buffer)
buffer.close()
jwa7 marked this conversation as resolved.
Show resolved Hide resolved
5 changes: 5 additions & 0 deletions python/rascaline-torch/tests/utils/data/h2o_isolated.xyz
jwa7 marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
3
jwa7 marked this conversation as resolved.
Show resolved Hide resolved
pbc="F F F"
O 2.56633400 2.50000000 2.50370100
H 1.97361700 1.73067300 2.47063400
H 1.97361700 3.26932700 2.47063400
2 changes: 1 addition & 1 deletion python/rascaline/rascaline/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os

from .clebsch_gordan import correlate_density, correlate_density_metadata # noqa
from .clebsch_gordan import DensityCorrelations # noqa
from .power_spectrum import PowerSpectrum # noqa
from .splines import ( # noqa
AtomicDensityBase,
Expand Down
5 changes: 2 additions & 3 deletions python/rascaline/rascaline/utils/clebsch_gordan/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from .correlate_density import correlate_density, correlate_density_metadata # noqa
from .correlate_density import DensityCorrelations # noqa

jwa7 marked this conversation as resolved.
Show resolved Hide resolved

__all__ = [
"correlate_density",
"correlate_density_metadata",
"DensityCorrelations",
]
Loading
Loading