From 85bad7d7d6a5fce57167e59cacff23f5ad90d1b2 Mon Sep 17 00:00:00 2001 From: Philip Loche Date: Wed, 10 Jul 2024 13:39:22 +0200 Subject: [PATCH] finish tests --- README.rst | 5 +- docs/src/conf.py | 2 +- examples/neighborlist_example.py | 4 +- src/meshlode/lib/potentials.py | 3 +- src/meshlode/metatensor/directpotential.py | 6 +- src/meshlode/metatensor/ewaldpotential.py | 6 +- src/meshlode/metatensor/pmepotential.py | 6 +- tests/__init__.py | 5 - tests/metatensor/test_base_metatensor.py | 193 +++++++++++++++++++ tests/metatensor/test_workflow_metatensor.py | 99 ++++++++++ 10 files changed, 307 insertions(+), 22 deletions(-) delete mode 100644 tests/__init__.py create mode 100644 tests/metatensor/test_base_metatensor.py create mode 100644 tests/metatensor/test_workflow_metatensor.py diff --git a/README.rst b/README.rst index 3ad71d1e..035876dc 100644 --- a/README.rst +++ b/README.rst @@ -22,9 +22,8 @@ You can install *MeshLode* using pip with You can then ``import meshlode`` and use it in your projects! -We also provide bindings to `metatensor -`_ which can optionally be installed -together and used as ``meshlode.metatensor`` via +We also provide bindings to `metatensor `_ which can +optionally be installed together and used as ``meshlode.metatensor`` via .. code-block:: bash diff --git a/docs/src/conf.py b/docs/src/conf.py index 2e7071cf..1ebbdbcc 100644 --- a/docs/src/conf.py +++ b/docs/src/conf.py @@ -59,7 +59,7 @@ "python": ("https://docs.python.org/3", None), "numpy": ("https://numpy.org/doc/stable/", None), "torch": ("https://pytorch.org/docs/stable/", None), - "metatensor": ("https://lab-cosmo.github.io/metatensor/latest/", None), + "metatensor": ("https://docs.metatensor.org/latest/", None), } # -- Options for HTML output ------------------------------------------------- diff --git a/examples/neighborlist_example.py b/examples/neighborlist_example.py index e4992f01..5ca0647a 100644 --- a/examples/neighborlist_example.py +++ b/examples/neighborlist_example.py @@ -86,9 +86,9 @@ data = TensorBlock( values=charges, - samples=Labels.range("atom", len(system)), + samples=Labels.range("atom", charges.shape[0]), components=[], - properties=Labels("charge", torch.tensor([[0]])), + properties=Labels.range("charge", charges.shape[1]), ) system.add_data(name="charges", data=data) diff --git a/src/meshlode/lib/potentials.py b/src/meshlode/lib/potentials.py index 63be3e05..c9bbcb15 100644 --- a/src/meshlode/lib/potentials.py +++ b/src/meshlode/lib/potentials.py @@ -1,6 +1,5 @@ -from typing import Union - import math +from typing import Union import torch from torch.special import gammainc, gammaincc, gammaln diff --git a/src/meshlode/metatensor/directpotential.py b/src/meshlode/metatensor/directpotential.py index 7d989062..232e4c13 100644 --- a/src/meshlode/metatensor/directpotential.py +++ b/src/meshlode/metatensor/directpotential.py @@ -3,7 +3,7 @@ class DirectPotential(CalculatorBaseMetatensor, _DirectPotentialImpl): - """Specie-wise long-range potential using a direct summation over all atoms. + r"""Specie-wise long-range potential using a direct summation over all atoms. Refer to :class:`meshlode.DirectPotential` for parameter documentation. @@ -28,9 +28,9 @@ class DirectPotential(CalculatorBaseMetatensor, _DirectPotentialImpl): >>> charges = torch.tensor([1.0, -1.0]).reshape(-1, 1) >>> data = TensorBlock( ... values=charges, - ... samples=Labels.range("atom", len(system)), + ... samples=Labels.range("atom", charges.shape[0]), ... components=[], - ... properties=Labels("charge", torch.tensor([[0]])), + ... properties=Labels.range("charge", charges.shape[1]), ... ) >>> system.add_data(name="charges", data=data) diff --git a/src/meshlode/metatensor/ewaldpotential.py b/src/meshlode/metatensor/ewaldpotential.py index a831c954..0eac3b6b 100644 --- a/src/meshlode/metatensor/ewaldpotential.py +++ b/src/meshlode/metatensor/ewaldpotential.py @@ -7,7 +7,7 @@ class EwaldPotential(CalculatorBaseMetatensor, _EwaldPotentialImpl): - """Specie-wise long-range potential computed using the Ewald sum. + r"""Specie-wise long-range potential computed using the Ewald sum. Refer to :class:`meshlode.EwaldPotential` for parameter documentation. @@ -33,9 +33,9 @@ class EwaldPotential(CalculatorBaseMetatensor, _EwaldPotentialImpl): >>> charges = torch.tensor([1.0, -1.0]).reshape(-1, 1) >>> data = TensorBlock( ... values=charges, - ... samples=Labels.range("atom", len(system)), + ... samples=Labels.range("atom", charges.shape[0]), ... components=[], - ... properties=Labels("charge", torch.tensor([[0]])), + ... properties=Labels.range("charge", charges.shape[1]), ... ) >>> system.add_data(name="charges", data=data) diff --git a/src/meshlode/metatensor/pmepotential.py b/src/meshlode/metatensor/pmepotential.py index f4a25b4d..d2e14983 100644 --- a/src/meshlode/metatensor/pmepotential.py +++ b/src/meshlode/metatensor/pmepotential.py @@ -7,7 +7,7 @@ class PMEPotential(CalculatorBaseMetatensor, _PMEPotentialImpl): - """Specie-wise long-range potential using a particle mesh-based Ewald (PME). + r"""Specie-wise long-range potential using a particle mesh-based Ewald (PME). Refer to :class:`meshlode.PMEPotential` for parameter documentation. @@ -33,9 +33,9 @@ class PMEPotential(CalculatorBaseMetatensor, _PMEPotentialImpl): >>> charges = torch.tensor([1.0, -1.0]).reshape(-1, 1) >>> data = TensorBlock( ... values=charges, - ... samples=Labels.range("atom", len(system)), + ... samples=Labels.range("atom", charges.shape[0]), ... components=[], - ... properties=Labels("charge", torch.tensor([[0]])), + ... properties=Labels.range("charge", charges.shape[1]), ... ) >>> system.add_data(name="charges", data=data) diff --git a/tests/__init__.py b/tests/__init__.py deleted file mode 100644 index 1c2cd789..00000000 --- a/tests/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -import meshlode - - -def test_version_exist(): - meshlode.__version__ diff --git a/tests/metatensor/test_base_metatensor.py b/tests/metatensor/test_base_metatensor.py new file mode 100644 index 00000000..768a8588 --- /dev/null +++ b/tests/metatensor/test_base_metatensor.py @@ -0,0 +1,193 @@ +import pytest +import torch +from metatensor.torch import Labels, TensorBlock +from metatensor.torch.atomistic import System +from packaging import version + +from meshlode.metatensor.base import CalculatorBaseMetatensor + + +class CalculatorTest(CalculatorBaseMetatensor): + def _compute_single_system( + self, positions, charges, cell, neighbor_indices, neighbor_shifts + ): + return charges + + +@pytest.mark.parametrize("method_name", ["compute", "forward"]) +def test_compute_output_shapes_single(method_name): + system = System( + types=torch.tensor([1, 1]), + positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 2.0]]), + cell=torch.zeros([3, 3]), + ) + + charges = torch.tensor([1.0, -1.0]).reshape(-1, 1) + data = TensorBlock( + values=charges, + samples=Labels.range("atom", charges.shape[0]), + components=[], + properties=Labels.range("charge", charges.shape[1]), + ) + + system.add_data(name="charges", data=data) + + calculator = CalculatorTest(exponent=1.0) + method = getattr(calculator, method_name) + result = method(system) + + assert isinstance(result, torch.ScriptObject) + if version.parse(torch.__version__) >= version.parse("2.1"): + assert result._type().name() == "TensorMap" + + assert len(result) == 1 + assert result[0].samples.names == ["system", "atom"] + assert result[0].components == [] + assert result[0].properties.names == ["charges_channel"] + + assert tuple(result[0].values.shape) == (len(system), 1) + + +def test_compute_output_shapes_multiple(): + + system = System( + types=torch.tensor([1, 1]), + positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 2.0]]), + cell=torch.zeros([3, 3]), + ) + + charges = torch.tensor([1.0, -1.0]).reshape(-1, 1) + data = TensorBlock( + values=charges, + samples=Labels.range("atom", charges.shape[0]), + components=[], + properties=Labels.range("charge", charges.shape[1]), + ) + + system.add_data(name="charges", data=data) + + calculator = CalculatorTest(exponent=1.0) + result = calculator.compute([system, system]) + + assert isinstance(result, torch.ScriptObject) + if version.parse(torch.__version__) >= version.parse("2.1"): + assert result._type().name() == "TensorMap" + + assert len(result) == 1 + assert result[0].samples.names == ["system", "atom"] + assert result[0].components == [] + assert result[0].properties.names == ["charges_channel"] + + assert tuple(result[0].values.shape) == (2 * len(system), 1) + + +def test_wrong_system_dtype(): + system1 = System( + types=torch.tensor([1, 1]), + positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 2.0]]), + cell=torch.zeros([3, 3]), + ) + + system2 = System( + types=torch.tensor([1, 1]), + positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 2.0]], dtype=torch.float64), + cell=torch.zeros([3, 3], dtype=torch.float64), + ) + + calculator = CalculatorTest(exponent=1.0) + + match = r"`dtype` of all systems must be the same, got 7 and 6" + with pytest.raises(ValueError, match=match): + calculator.compute([system1, system2]) + + +def test_wrong_system_device(): + system1 = System( + types=torch.tensor([1, 1]), + positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 2.0]]), + cell=torch.zeros([3, 3]), + ) + + system2 = System( + types=torch.tensor([1, 1], device="meta"), + positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 2.0]], device="meta"), + cell=torch.zeros([3, 3], device="meta"), + ) + + calculator = CalculatorTest(exponent=1.0) + + match = r"`device` of all systems must be the same, got meta and cpu" + with pytest.raises(ValueError, match=match): + calculator.compute([system1, system2]) + + +def test_wrong_system_not_all_charges(): + system1 = System( + types=torch.tensor([1, 1]), + positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 2.0]]), + cell=torch.zeros([3, 3]), + ) + + charges = torch.tensor([1.0, -1.0]).reshape(-1, 1) + data = TensorBlock( + values=charges, + samples=Labels.range("atom", charges.shape[0]), + components=[], + properties=Labels.range("charge", charges.shape[1]), + ) + + system1.add_data(name="charges", data=data) + + system2 = System( + types=torch.tensor([1, 1],), + positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 2.0]]), + cell=torch.zeros([3, 3]), + ) + + calculator = CalculatorTest(exponent=1.0) + + match = r"`systems` do not consistently contain `charges` data" + with pytest.raises(ValueError, match=match): + calculator.compute([system1, system2]) + + +def test_different_number_charge_channles(): + system1 = System( + types=torch.tensor([1, 1]), + positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 2.0]]), + cell=torch.zeros([3, 3]), + ) + + charges1 = torch.tensor([1.0, -1.0]).reshape(-1, 1) + data1 = TensorBlock( + values=charges1, + samples=Labels.range("atom", charges1.shape[0]), + components=[], + properties=Labels.range("charge", charges1.shape[1]), + ) + + system1.add_data(name="charges", data=data1) + + system2 = System( + types=torch.tensor([1, 1]), + positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 2.0]]), + cell=torch.zeros([3, 3]), + ) + + charges2 = torch.tensor([[1.0, 2.0], [-1.0, -2.0]]) + data2 = TensorBlock( + values=charges2, + samples=Labels.range("atom", charges2.shape[0]), + components=[], + properties=Labels.range("charge", charges2.shape[1]), + ) + system2.add_data(name="charges", data=data2) + + calculator = CalculatorTest(exponent=1.0) + + match = ( + r"number of charges-channels in system index 1 \(2\) is inconsistent with " + r"first system \(1\)" + ) + with pytest.raises(ValueError, match=match): + calculator.compute([system1, system2]) diff --git a/tests/metatensor/test_workflow_metatensor.py b/tests/metatensor/test_workflow_metatensor.py new file mode 100644 index 00000000..09d95fa2 --- /dev/null +++ b/tests/metatensor/test_workflow_metatensor.py @@ -0,0 +1,99 @@ +""" +Madelung tests +""" + +import pytest +import torch +from packaging import version + + +meshlode_metatensor = pytest.importorskip("meshlode.metatensor") +mts_torch = pytest.importorskip("metatensor.torch") +mts_atomistic = pytest.importorskip("metatensor.torch.atomistic") + + +ATOMIC_SMEARING = 0.1 +LR_WAVELENGTH = ATOMIC_SMEARING / 4 +MESH_SPACING = ATOMIC_SMEARING / 4 +INTERPOLATION_ORDER = 2 +SUBTRACT_SELF = True + + +@pytest.mark.parametrize( + "CalculatorClass, params", + [ + (meshlode_metatensor.DirectPotential, {}), + ( + meshlode_metatensor.EwaldPotential, + { + "atomic_smearing": ATOMIC_SMEARING, + "lr_wavelength": LR_WAVELENGTH, + "subtract_self": SUBTRACT_SELF, + }, + ), + ( + meshlode_metatensor.PMEPotential, + { + "atomic_smearing": ATOMIC_SMEARING, + "mesh_spacing": MESH_SPACING, + "interpolation_order": INTERPOLATION_ORDER, + "subtract_self": SUBTRACT_SELF, + }, + ), + ], +) +class TestWorkflow: + def cscl_system(self): + """CsCl crystal. Same as in the madelung test""" + + system = mts_atomistic.System( + types=torch.tensor([17, 55]), + positions=torch.tensor([[0, 0, 0], [0.5, 0.5, 0.5]]), + cell=torch.eye(3), + ) + + data = mts_torch.TensorBlock( + values=torch.tensor([-1.0, 1.0]).reshape(-1, 1), + samples=mts_torch.Labels.range("atom", len(system)), + components=[], + properties=mts_torch.Labels("charge", torch.tensor([[0]])), + ) + system.add_data(name="charges", data=data) + + return system + + def calculator(self, CalculatorClass, params): + return CalculatorClass(**params) + + def test_forward(self, CalculatorClass, params): + calculator = self.calculator(CalculatorClass, params) + descriptor_compute = calculator.compute(self.cscl_system()) + descriptor_forward = calculator.forward(self.cscl_system()) + + assert isinstance(descriptor_compute, torch.ScriptObject) + assert isinstance(descriptor_forward, torch.ScriptObject) + if version.parse(torch.__version__) >= version.parse("2.1"): + assert descriptor_compute._type().name() == "TensorMap" + assert descriptor_forward._type().name() == "TensorMap" + + assert mts_torch.equal(descriptor_forward, descriptor_compute) + + # Make sure that the calculators are computing the features without raising errors, + # and returns the correct output format (TensorMap) + def check_operation(self, CalculatorClass, params): + calculator = self.calculator(CalculatorClass, params) + descriptor = calculator.compute(self.cscl_system()) + + assert isinstance(descriptor, torch.ScriptObject) + if version.parse(torch.__version__) >= version.parse("2.1"): + assert descriptor._type().name() == "TensorMap" + + # Run the above test as a normal python script + def test_operation_as_python(self, CalculatorClass, params): + self.check_operation(CalculatorClass, params) + + # Similar to the above, but also testing that the code can be compiled as a torch + # script + # def test_operation_as_torch_script(self, CalculatorClass, params): + # scripted = torch.jit.script(CalculatorClass, params) + # self.check_operation(scripted)