Skip to content

Commit

Permalink
Refactor charge encoding in MeshPotential class
Browse files Browse the repository at this point in the history
  • Loading branch information
E-Rum committed Apr 9, 2024
1 parent 159c249 commit 2b35d0b
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 26 deletions.
26 changes: 20 additions & 6 deletions src/meshlode/calculators/meshpotential.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,13 +203,13 @@ def compute(
charges = []
for types_single, positions_single in zip(types, positions):
# One-hot encoding of charge information
charges_single = torch.zeros(
(len(types_single), n_types),
dtype=positions_single.dtype,
device=positions_single.device,
charges_single = self._one_hot_charges(
types_single,
requested_types,
n_types,
positions_single.dtype,
positions_single.device,
)
for i_type, atomic_type in enumerate(requested_types):
charges_single[types_single == atomic_type, i_type] = 1.0
charges.append(charges_single)

# If charges are provided, we need to make sure that they are consistent with
Expand Down Expand Up @@ -352,3 +352,17 @@ def _compute_single_system(
interpolated_potential -= charges * self_contrib

return interpolated_potential

def _one_hot_charges(
self,
types: torch.Tensor,
requested_types: List[int],
n_types: int,
dtype: torch.dtype,
device: torch.device,
) -> torch.Tensor:
one_hot_charges = torch.zeros((len(types), n_types), dtype=dtype, device=device)
for i_type, atomic_type in enumerate(requested_types):
one_hot_charges[types == atomic_type, i_type] = 1.0

return one_hot_charges
6 changes: 2 additions & 4 deletions src/meshlode/metatensor/meshpotential.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,11 +162,9 @@ def compute(
charges = system.get_data("charges").values
else:
# One-hot encoding of charge information
charges = torch.zeros(
(len(system), n_types), dtype=dtype, device=device
charges = self._one_hot_charges(
system.types, requested_types, n_types, dtype, device
)
for i_specie, atomic_type in enumerate(requested_types):
charges[system.types == atomic_type, i_specie] = 1.0

# Compute the potentials
potential = self._compute_single_system(
Expand Down
36 changes: 20 additions & 16 deletions tests/calculators/test_meshpotential.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,9 @@ def cscl_system():


def cscl_system_with_charges():
"""CsCl crystal. Same as in the madelung test. Version with explicit charges"""
types = torch.tensor([55, 17])
positions = torch.tensor([[0, 0, 0], [0.5, 0.5, 0.5]])
cell = torch.eye(3)

return types, positions, cell, torch.tensor([[0.0, 1.0], [1.0, 0]])
"""CsCl crystal with charges."""
charges = torch.tensor([[0.0, 1.0], [1.0, 0]])
return cscl_system() + (charges,)


# Initialize the calculators. For now, only the MeshPotential is implemented.
Expand Down Expand Up @@ -102,6 +99,7 @@ def test_single_frame():
rtol=1e-5,
)


# Test with explicit charges
def test_single_frame_with_charges():
values = descriptor().compute(*cscl_system_with_charges())
Expand Down Expand Up @@ -155,30 +153,36 @@ def test_positions_error():
with pytest.raises(ValueError, match=match):
descriptor().compute(types=types, positions=positions, cell=cell)


def test_charges_error_dimension_mismatch():
types = torch.tensor([1, 2])
positions = torch.zeros((2, 3))
cell = torch.eye(3)
charges = torch.zeros((1, 2)) # This should have the same first dimension as types

match = (
"The first dimension of `charges` must be the same as the length of `types`, got 1 and 2."
"The first dimension of `charges` must be the same as the length "
"of `types`, got 1 and 2."
)

with pytest.raises(ValueError, match=match):
descriptor().compute(types=types, positions=positions, cell=cell, charges=charges)
descriptor().compute(
types=types, positions=positions, cell=cell, charges=charges
)


def test_charges_error_length_mismatch():
types = [torch.tensor([1, 2]), torch.tensor([1, 2, 3])]
positions = [torch.zeros((2, 3)), torch.zeros((3, 3))]
cell = torch.eye(3)
charges = [torch.zeros(2,1)] # This should have the same length as types
match = (
"The number of `types` and `charges` tensors must be the same, got 2 and 1."
)
charges = [torch.zeros(2, 1)] # This should have the same length as types
match = "The number of `types` and `charges` tensors must be the same, got 2 and 1."

with pytest.raises(ValueError, match=match):
descriptor().compute(types=types, positions=positions, cell=cell, charges=charges)
descriptor().compute(
types=types, positions=positions, cell=cell, charges=charges
)


def test_cell_error():
types = torch.tensor([1, 2, 3])
Expand Down Expand Up @@ -253,6 +257,7 @@ def test_inconsistent_device():
with pytest.raises(ValueError, match=match):
MP.compute(types=types, positions=positions, cell=cell)


def test_inconsistent_device_charges():
"""Test if the cell and positions have inconsistent device and error is raised."""
types = torch.tensor([1], device="cpu")
Expand All @@ -262,12 +267,11 @@ def test_inconsistent_device_charges():

MP = MeshPotential(atomic_smearing=0.2)

match = (
"`charges` must be on the same device as `positions`, got meta and cpu."
)
match = "`charges` must be on the same device as `positions`, got meta and cpu."
with pytest.raises(ValueError, match=match):
MP.compute(types=types, positions=positions, cell=cell, charges=charges)


def test_inconsistent_dtype_charges():
"""Test if the cell and positions have inconsistent dtype and error is raised."""
types = torch.tensor([1], dtype=torch.float32)
Expand Down

0 comments on commit 2b35d0b

Please sign in to comment.