Skip to content

Commit

Permalink
finish tests
Browse files Browse the repository at this point in the history
  • Loading branch information
PicoCentauri committed Jul 10, 2024
1 parent ef5be85 commit 85bad7d
Show file tree
Hide file tree
Showing 10 changed files with 307 additions and 22 deletions.
5 changes: 2 additions & 3 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,8 @@ You can install *MeshLode* using pip with
You can then ``import meshlode`` and use it in your projects!

We also provide bindings to `metatensor
<https://lab-cosmo.github.io/metatensor/latest/>`_ which can optionally be installed
together and used as ``meshlode.metatensor`` via
We also provide bindings to `metatensor <https://docs.metatensor.org/latest/>`_ which can
optionally be installed together and used as ``meshlode.metatensor`` via

.. code-block:: bash
Expand Down
2 changes: 1 addition & 1 deletion docs/src/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
"python": ("https://docs.python.org/3", None),
"numpy": ("https://numpy.org/doc/stable/", None),
"torch": ("https://pytorch.org/docs/stable/", None),
"metatensor": ("https://lab-cosmo.github.io/metatensor/latest/", None),
"metatensor": ("https://docs.metatensor.org/latest/", None),
}

# -- Options for HTML output -------------------------------------------------
Expand Down
4 changes: 2 additions & 2 deletions examples/neighborlist_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,9 @@

data = TensorBlock(
values=charges,
samples=Labels.range("atom", len(system)),
samples=Labels.range("atom", charges.shape[0]),
components=[],
properties=Labels("charge", torch.tensor([[0]])),
properties=Labels.range("charge", charges.shape[1]),
)
system.add_data(name="charges", data=data)

Expand Down
3 changes: 1 addition & 2 deletions src/meshlode/lib/potentials.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import Union

import math
from typing import Union

import torch
from torch.special import gammainc, gammaincc, gammaln
Expand Down
6 changes: 3 additions & 3 deletions src/meshlode/metatensor/directpotential.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@


class DirectPotential(CalculatorBaseMetatensor, _DirectPotentialImpl):
"""Specie-wise long-range potential using a direct summation over all atoms.
r"""Specie-wise long-range potential using a direct summation over all atoms.
Refer to :class:`meshlode.DirectPotential` for parameter documentation.
Expand All @@ -28,9 +28,9 @@ class DirectPotential(CalculatorBaseMetatensor, _DirectPotentialImpl):
>>> charges = torch.tensor([1.0, -1.0]).reshape(-1, 1)
>>> data = TensorBlock(
... values=charges,
... samples=Labels.range("atom", len(system)),
... samples=Labels.range("atom", charges.shape[0]),
... components=[],
... properties=Labels("charge", torch.tensor([[0]])),
... properties=Labels.range("charge", charges.shape[1]),
... )
>>> system.add_data(name="charges", data=data)
Expand Down
6 changes: 3 additions & 3 deletions src/meshlode/metatensor/ewaldpotential.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@


class EwaldPotential(CalculatorBaseMetatensor, _EwaldPotentialImpl):
"""Specie-wise long-range potential computed using the Ewald sum.
r"""Specie-wise long-range potential computed using the Ewald sum.
Refer to :class:`meshlode.EwaldPotential` for parameter documentation.
Expand All @@ -33,9 +33,9 @@ class EwaldPotential(CalculatorBaseMetatensor, _EwaldPotentialImpl):
>>> charges = torch.tensor([1.0, -1.0]).reshape(-1, 1)
>>> data = TensorBlock(
... values=charges,
... samples=Labels.range("atom", len(system)),
... samples=Labels.range("atom", charges.shape[0]),
... components=[],
... properties=Labels("charge", torch.tensor([[0]])),
... properties=Labels.range("charge", charges.shape[1]),
... )
>>> system.add_data(name="charges", data=data)
Expand Down
6 changes: 3 additions & 3 deletions src/meshlode/metatensor/pmepotential.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@


class PMEPotential(CalculatorBaseMetatensor, _PMEPotentialImpl):
"""Specie-wise long-range potential using a particle mesh-based Ewald (PME).
r"""Specie-wise long-range potential using a particle mesh-based Ewald (PME).
Refer to :class:`meshlode.PMEPotential` for parameter documentation.
Expand All @@ -33,9 +33,9 @@ class PMEPotential(CalculatorBaseMetatensor, _PMEPotentialImpl):
>>> charges = torch.tensor([1.0, -1.0]).reshape(-1, 1)
>>> data = TensorBlock(
... values=charges,
... samples=Labels.range("atom", len(system)),
... samples=Labels.range("atom", charges.shape[0]),
... components=[],
... properties=Labels("charge", torch.tensor([[0]])),
... properties=Labels.range("charge", charges.shape[1]),
... )
>>> system.add_data(name="charges", data=data)
Expand Down
5 changes: 0 additions & 5 deletions tests/__init__.py

This file was deleted.

193 changes: 193 additions & 0 deletions tests/metatensor/test_base_metatensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
import pytest
import torch
from metatensor.torch import Labels, TensorBlock
from metatensor.torch.atomistic import System
from packaging import version

from meshlode.metatensor.base import CalculatorBaseMetatensor


class CalculatorTest(CalculatorBaseMetatensor):
def _compute_single_system(
self, positions, charges, cell, neighbor_indices, neighbor_shifts
):
return charges


@pytest.mark.parametrize("method_name", ["compute", "forward"])
def test_compute_output_shapes_single(method_name):
system = System(
types=torch.tensor([1, 1]),
positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 2.0]]),
cell=torch.zeros([3, 3]),
)

charges = torch.tensor([1.0, -1.0]).reshape(-1, 1)
data = TensorBlock(
values=charges,
samples=Labels.range("atom", charges.shape[0]),
components=[],
properties=Labels.range("charge", charges.shape[1]),
)

system.add_data(name="charges", data=data)

calculator = CalculatorTest(exponent=1.0)
method = getattr(calculator, method_name)
result = method(system)

assert isinstance(result, torch.ScriptObject)
if version.parse(torch.__version__) >= version.parse("2.1"):
assert result._type().name() == "TensorMap"

assert len(result) == 1
assert result[0].samples.names == ["system", "atom"]
assert result[0].components == []
assert result[0].properties.names == ["charges_channel"]

assert tuple(result[0].values.shape) == (len(system), 1)


def test_compute_output_shapes_multiple():

system = System(
types=torch.tensor([1, 1]),
positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 2.0]]),
cell=torch.zeros([3, 3]),
)

charges = torch.tensor([1.0, -1.0]).reshape(-1, 1)
data = TensorBlock(
values=charges,
samples=Labels.range("atom", charges.shape[0]),
components=[],
properties=Labels.range("charge", charges.shape[1]),
)

system.add_data(name="charges", data=data)

calculator = CalculatorTest(exponent=1.0)
result = calculator.compute([system, system])

assert isinstance(result, torch.ScriptObject)
if version.parse(torch.__version__) >= version.parse("2.1"):
assert result._type().name() == "TensorMap"

assert len(result) == 1
assert result[0].samples.names == ["system", "atom"]
assert result[0].components == []
assert result[0].properties.names == ["charges_channel"]

assert tuple(result[0].values.shape) == (2 * len(system), 1)


def test_wrong_system_dtype():
system1 = System(
types=torch.tensor([1, 1]),
positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 2.0]]),
cell=torch.zeros([3, 3]),
)

system2 = System(
types=torch.tensor([1, 1]),
positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 2.0]], dtype=torch.float64),
cell=torch.zeros([3, 3], dtype=torch.float64),
)

calculator = CalculatorTest(exponent=1.0)

match = r"`dtype` of all systems must be the same, got 7 and 6"
with pytest.raises(ValueError, match=match):
calculator.compute([system1, system2])


def test_wrong_system_device():
system1 = System(
types=torch.tensor([1, 1]),
positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 2.0]]),
cell=torch.zeros([3, 3]),
)

system2 = System(
types=torch.tensor([1, 1], device="meta"),
positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 2.0]], device="meta"),
cell=torch.zeros([3, 3], device="meta"),
)

calculator = CalculatorTest(exponent=1.0)

match = r"`device` of all systems must be the same, got meta and cpu"
with pytest.raises(ValueError, match=match):
calculator.compute([system1, system2])


def test_wrong_system_not_all_charges():
system1 = System(
types=torch.tensor([1, 1]),
positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 2.0]]),
cell=torch.zeros([3, 3]),
)

charges = torch.tensor([1.0, -1.0]).reshape(-1, 1)
data = TensorBlock(
values=charges,
samples=Labels.range("atom", charges.shape[0]),
components=[],
properties=Labels.range("charge", charges.shape[1]),
)

system1.add_data(name="charges", data=data)

system2 = System(
types=torch.tensor([1, 1],),
positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 2.0]]),
cell=torch.zeros([3, 3]),
)

calculator = CalculatorTest(exponent=1.0)

match = r"`systems` do not consistently contain `charges` data"
with pytest.raises(ValueError, match=match):
calculator.compute([system1, system2])


def test_different_number_charge_channles():
system1 = System(
types=torch.tensor([1, 1]),
positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 2.0]]),
cell=torch.zeros([3, 3]),
)

charges1 = torch.tensor([1.0, -1.0]).reshape(-1, 1)
data1 = TensorBlock(
values=charges1,
samples=Labels.range("atom", charges1.shape[0]),
components=[],
properties=Labels.range("charge", charges1.shape[1]),
)

system1.add_data(name="charges", data=data1)

system2 = System(
types=torch.tensor([1, 1]),
positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 2.0]]),
cell=torch.zeros([3, 3]),
)

charges2 = torch.tensor([[1.0, 2.0], [-1.0, -2.0]])
data2 = TensorBlock(
values=charges2,
samples=Labels.range("atom", charges2.shape[0]),
components=[],
properties=Labels.range("charge", charges2.shape[1]),
)
system2.add_data(name="charges", data=data2)

calculator = CalculatorTest(exponent=1.0)

match = (
r"number of charges-channels in system index 1 \(2\) is inconsistent with "
r"first system \(1\)"
)
with pytest.raises(ValueError, match=match):
calculator.compute([system1, system2])
Loading

0 comments on commit 85bad7d

Please sign in to comment.