From 2b35d0bcb3ccb9e7fa4ab223359e4259ccb7ecd2 Mon Sep 17 00:00:00 2001 From: E-Rum Date: Tue, 9 Apr 2024 12:13:21 +0000 Subject: [PATCH] Refactor charge encoding in MeshPotential class --- src/meshlode/calculators/meshpotential.py | 26 ++++++++++++---- src/meshlode/metatensor/meshpotential.py | 6 ++-- tests/calculators/test_meshpotential.py | 36 +++++++++++++---------- 3 files changed, 42 insertions(+), 26 deletions(-) diff --git a/src/meshlode/calculators/meshpotential.py b/src/meshlode/calculators/meshpotential.py index de3b8bc8..214f129c 100644 --- a/src/meshlode/calculators/meshpotential.py +++ b/src/meshlode/calculators/meshpotential.py @@ -203,13 +203,13 @@ def compute( charges = [] for types_single, positions_single in zip(types, positions): # One-hot encoding of charge information - charges_single = torch.zeros( - (len(types_single), n_types), - dtype=positions_single.dtype, - device=positions_single.device, + charges_single = self._one_hot_charges( + types_single, + requested_types, + n_types, + positions_single.dtype, + positions_single.device, ) - for i_type, atomic_type in enumerate(requested_types): - charges_single[types_single == atomic_type, i_type] = 1.0 charges.append(charges_single) # If charges are provided, we need to make sure that they are consistent with @@ -352,3 +352,17 @@ def _compute_single_system( interpolated_potential -= charges * self_contrib return interpolated_potential + + def _one_hot_charges( + self, + types: torch.Tensor, + requested_types: List[int], + n_types: int, + dtype: torch.dtype, + device: torch.device, + ) -> torch.Tensor: + one_hot_charges = torch.zeros((len(types), n_types), dtype=dtype, device=device) + for i_type, atomic_type in enumerate(requested_types): + one_hot_charges[types == atomic_type, i_type] = 1.0 + + return one_hot_charges diff --git a/src/meshlode/metatensor/meshpotential.py b/src/meshlode/metatensor/meshpotential.py index e020343a..3954c0ee 100644 --- a/src/meshlode/metatensor/meshpotential.py +++ b/src/meshlode/metatensor/meshpotential.py @@ -162,11 +162,9 @@ def compute( charges = system.get_data("charges").values else: # One-hot encoding of charge information - charges = torch.zeros( - (len(system), n_types), dtype=dtype, device=device + charges = self._one_hot_charges( + system.types, requested_types, n_types, dtype, device ) - for i_specie, atomic_type in enumerate(requested_types): - charges[system.types == atomic_type, i_specie] = 1.0 # Compute the potentials potential = self._compute_single_system( diff --git a/tests/calculators/test_meshpotential.py b/tests/calculators/test_meshpotential.py index 811c2f45..58e73808 100644 --- a/tests/calculators/test_meshpotential.py +++ b/tests/calculators/test_meshpotential.py @@ -25,12 +25,9 @@ def cscl_system(): def cscl_system_with_charges(): - """CsCl crystal. Same as in the madelung test. Version with explicit charges""" - types = torch.tensor([55, 17]) - positions = torch.tensor([[0, 0, 0], [0.5, 0.5, 0.5]]) - cell = torch.eye(3) - - return types, positions, cell, torch.tensor([[0.0, 1.0], [1.0, 0]]) + """CsCl crystal with charges.""" + charges = torch.tensor([[0.0, 1.0], [1.0, 0]]) + return cscl_system() + (charges,) # Initialize the calculators. For now, only the MeshPotential is implemented. @@ -102,6 +99,7 @@ def test_single_frame(): rtol=1e-5, ) + # Test with explicit charges def test_single_frame_with_charges(): values = descriptor().compute(*cscl_system_with_charges()) @@ -155,6 +153,7 @@ def test_positions_error(): with pytest.raises(ValueError, match=match): descriptor().compute(types=types, positions=positions, cell=cell) + def test_charges_error_dimension_mismatch(): types = torch.tensor([1, 2]) positions = torch.zeros((2, 3)) @@ -162,23 +161,28 @@ def test_charges_error_dimension_mismatch(): charges = torch.zeros((1, 2)) # This should have the same first dimension as types match = ( - "The first dimension of `charges` must be the same as the length of `types`, got 1 and 2." + "The first dimension of `charges` must be the same as the length " + "of `types`, got 1 and 2." ) with pytest.raises(ValueError, match=match): - descriptor().compute(types=types, positions=positions, cell=cell, charges=charges) + descriptor().compute( + types=types, positions=positions, cell=cell, charges=charges + ) + def test_charges_error_length_mismatch(): types = [torch.tensor([1, 2]), torch.tensor([1, 2, 3])] positions = [torch.zeros((2, 3)), torch.zeros((3, 3))] cell = torch.eye(3) - charges = [torch.zeros(2,1)] # This should have the same length as types - match = ( - "The number of `types` and `charges` tensors must be the same, got 2 and 1." - ) + charges = [torch.zeros(2, 1)] # This should have the same length as types + match = "The number of `types` and `charges` tensors must be the same, got 2 and 1." with pytest.raises(ValueError, match=match): - descriptor().compute(types=types, positions=positions, cell=cell, charges=charges) + descriptor().compute( + types=types, positions=positions, cell=cell, charges=charges + ) + def test_cell_error(): types = torch.tensor([1, 2, 3]) @@ -253,6 +257,7 @@ def test_inconsistent_device(): with pytest.raises(ValueError, match=match): MP.compute(types=types, positions=positions, cell=cell) + def test_inconsistent_device_charges(): """Test if the cell and positions have inconsistent device and error is raised.""" types = torch.tensor([1], device="cpu") @@ -262,12 +267,11 @@ def test_inconsistent_device_charges(): MP = MeshPotential(atomic_smearing=0.2) - match = ( - "`charges` must be on the same device as `positions`, got meta and cpu." - ) + match = "`charges` must be on the same device as `positions`, got meta and cpu." with pytest.raises(ValueError, match=match): MP.compute(types=types, positions=positions, cell=cell, charges=charges) + def test_inconsistent_dtype_charges(): """Test if the cell and positions have inconsistent dtype and error is raised.""" types = torch.tensor([1], dtype=torch.float32)