From d3f83d31a875c3645b722c493b1c98fe9df7c863 Mon Sep 17 00:00:00 2001 From: Kevin Kazuki Huguenin-Dumittan Date: Tue, 11 Jun 2024 13:28:29 +0200 Subject: [PATCH 01/35] Add class to handle the bare potentials --- src/meshlode/lib/potentials.py | 149 +++++++++++++++++++++++++++++++++ 1 file changed, 149 insertions(+) create mode 100644 src/meshlode/lib/potentials.py diff --git a/src/meshlode/lib/potentials.py b/src/meshlode/lib/potentials.py new file mode 100644 index 00000000..41e989af --- /dev/null +++ b/src/meshlode/lib/potentials.py @@ -0,0 +1,149 @@ +import torch +from torch.special import gammainc, gammaincc, gammaln +import math + +# since pytorch has implemented the incomplete Gamma functions, but not the much more +# commonly used (complete) Gamma function, we define it in a custom way to make autograd +# work as in https://discuss.pytorch.org/t/is-there-a-gamma-function-in-pytorch/17122 +def gamma(x): + return torch.exp(gammaln(x)) + +class InversePowerLawPotential: + """ + Class to handle inverse power-law potentials of the form 1/r^p, where r is a + distance parameter and p an exponent. + + It can be used to compute: + 1. the full 1/r^p potential + 2. its short-range (SR) and long-range (LR) parts, the split being determined by a + length-scale parameter (called "smearing" in the code) + 3. the Fourier transform of the LR part + + :param exponent: torch.tensor corresponding to the exponent "p" in 1/r^p potentials + """ + def __init__(self, exponent: torch.Tensor): + self.exponent = exponent + + def potential_from_dist(self, + dist: torch.Tensor + ) -> torch.Tensor: + """ + Full 1/r^p potential as a function of r + + :param dist: torch.tensor containing the distances at which the potential is to + be evaluated. + """ + return torch.pow(dist, -self.exponent) + + def potential_from_dist_sq(self, + dist_sq: torch.Tensor + ) -> torch.Tensor: + """ + Full 1/r^p potential as a function of r^2, which is more useful in some + implementations + + :param dist_sq: torch.tensor containing the squared distances at which the + potential is to be evaluated. + """ + return torch.pow(dist_sq, -self.exponent / 2.) + + def potential_sr_from_dist(self, + dist: torch.Tensor, + smearing: torch.Tensor + ) -> torch.Tensor: + """ + Short-range (SR) part of the range-separated 1/r^p potential as a function of r. + More explicitly: it corresponds to V_SR(r) in 1/r^p = V_SR(r) + V_LR(r), + where the location of the split is determined by the smearing parameter. + + For the Coulomb potential, this would return + potential = erfc(dist / sqrt(2) / smearing) / dist + + :param dist: torch.tensor containing the distances at which the potential is to + be evaluated. + :param smearing: torch.tensor containing the parameter often called "sigma" in + publications, which determines the length-scale at which the short-range and + long-range parts of the naive 1/r^p potential are separated. For the Coulomb + potential (p=1), this potential can be interpreted as the effective + potential generated by a Gaussian charge density, in which case this + smearing parameter corresponds to the "width" of the Gaussian. + """ + x = 0.5 * dist**2 / smearing**2 + peff = self.exponent / 2 + prefac = 1./(2*smearing**2)**peff + potential = prefac * gammainc(peff, x) / x**peff + return potential + + def potential_lr_from_dist(self, + dist: torch.Tensor, + smearing: torch.Tensor + ) -> torch.Tensor: + """ + Long-range (LR) part of the range-separated 1/r^p potential as a function of r. + Used to subtract out the interior contributions after computing the LR part + in reciprocal (Fourier) space. + + For the Coulomb potential, this would return (note that the only change between + the SR and LR parts is the fact that erfc changes to erf) + potential = erf(dist / sqrt(2) / smearing) / dist + + :param dist: torch.tensor containing the distances at which the potential is to + be evaluated. + :param smearing: torch.tensor containing the parameter often called "sigma" in + publications, which determines the length-scale at which the short-range and + long-range parts of the naive 1/r^p potential are separated. For the Coulomb + potential (p=1), this potential can be interpreted as the effective + potential generated by a Gaussian charge density, in which case this + smearing parameter corresponds to the "width" of the Gaussian. + """ + x = 0.5 * dist**2 / smearing**2 + peff = self.exponent / 2 + prefac = 1./(2*smearing**2)**peff + potential = prefac * gammainc(peff, x) / x**peff + return potential + + def potential_fourier_from_k_sq(self, + k_sq: torch.Tensor, + smearing: torch.Tensor + ) -> torch.Tensor: + """ + Fourier transform of the long-range (LR) part potential parametrized in terms of + k^2. + If only the Coulomb potential is needed, the last line can be replaced by + fourier = 4 * torch.pi * torch.exp(-0.5 * smearing**2 * k_sq) / k_sq + + :param k_sq: torch.tensor containing the squared lengths (2-norms) of the wave + vectors k at which the Fourier-transformed potential is to be evaluated + :param smearing: torch.tensor containing the parameter often called "sigma" in + publications, which determines the length-scale at which the short-range and + long-range parts of the naive 1/r^p potential are separated. For the Coulomb + potential (p=1), this potential can be interpreted as the effective + potential generated by a Gaussian charge density, in which case this + smearing parameter corresponds to the "width" of the Gaussian. + """ + peff = (3-self.exponent) / 2 + prefac = (math.pi)**1.5 / gamma(self.exponent/2) * (2*smearing**2)**peff + x = 0.5*smearing**2*k_sq + fourier = prefac * gammaincc(peff, x) / x**peff * gamma(peff) + + return fourier + + def potential_fourier_at_zero(self, smearing: torch.Tensor) -> torch.Tensor: + """ + The value of the Fourier-transformed potential (LR part implemented above) as + k --> 0 often needs to be set separately since for exponents p <= 3 = dimension, + there is a divergence to +infinity. + Setting this value manually to zero physically corresponds to the addition of a + uniform backgruond charge to make the system charge-neutral. + For p > 3, on the other hand, the Fourier-transformed LR potential does not + diverge as k --> 0, and one could instead assign the correct limit. + This is not implemented for now for consistency reasons. + + :param smearing: torch.tensor containing the parameter often called "sigma" in + publications, which determines the length-scale at which the short-range and + long-range parts of the naive 1/r^p potential are separated. For the Coulomb + potential (p=1), this potential can be interpreted as the effective + potential generated by a Gaussian charge density, in which case this + smearing parameter corresponds to the "width" of the Gaussian. + """ + return torch.tensor(0.) \ No newline at end of file From 8f1a2873cfa82a9560f2de52143b35317f65e147 Mon Sep 17 00:00:00 2001 From: Kevin Kazuki Huguenin-Dumittan Date: Tue, 11 Jun 2024 13:29:40 +0200 Subject: [PATCH 02/35] Add base classes for calculators --- src/meshlode/calculators/calculator_base.py | 247 ++++++++++++++++++ .../calculators/calculator_base_periodic.py | 171 ++++++++++++ 2 files changed, 418 insertions(+) create mode 100644 src/meshlode/calculators/calculator_base.py create mode 100644 src/meshlode/calculators/calculator_base_periodic.py diff --git a/src/meshlode/calculators/calculator_base.py b/src/meshlode/calculators/calculator_base.py new file mode 100644 index 00000000..e6abbf39 --- /dev/null +++ b/src/meshlode/calculators/calculator_base.py @@ -0,0 +1,247 @@ +from meshlode.lib import InversePowerLawPotential +from typing import List, Optional, Union + +import torch + + +@torch.jit.script +def _1d_tolist(x: torch.Tensor) -> List[int]: + """Auxilary function to convert 1d torch tensor to list of integers.""" + result: List[int] = [] + for i in x: + result.append(i.item()) + return result + + +@torch.jit.script +def _is_subset(subset_candidate: List[int], superset: List[int]) -> bool: + """Checks whether all elements of `subset_candidate` are part of `superset`.""" + for element in subset_candidate: + if element not in superset: + return False + return True + + +class CalculatorBase(torch.nn.Module): + """ + Base calculator + + :param all_types: Optional global list of all atomic types that should be considered + for the computation. This option might be useful when running the calculation on + subset of a whole dataset and it required to keep the shape of the output + consistent. If this is not set the possible atomic types will be determined when + calling the :meth:`compute()`. + """ + + name = "CalculatorBase" + + def __init__( + self, + all_types: Optional[List[int]] = None, + exponent: Optional[torch.Tensor] = torch.tensor(1., dtype=torch.float64), + ): + super().__init__() + + if all_types is None: + self.all_types = None + else: + self.all_types = _1d_tolist(torch.unique(torch.tensor(all_types))) + + self.exponent = exponent + self.potential = InversePowerLawPotential(exponent = exponent) + + # This function is kept to keep this library compatible with the broader pytorch + # infrastructure, which require a "forward" function. We name this function + # "compute" instead, for compatibility with other COSMO software. + def forward( + self, + types: Union[List[torch.Tensor], torch.Tensor], + positions: Union[List[torch.Tensor], torch.Tensor], + charges: Optional[Union[List[torch.Tensor], torch.Tensor]] = None, + ) -> Union[torch.Tensor, List[torch.Tensor]]: + """forward just calls :py:meth:`CalculatorModule.compute`""" + return self.compute( + types=types, positions=positions, charges=charges + ) + + def compute( + self, + types: Union[List[torch.Tensor], torch.Tensor], + positions: Union[List[torch.Tensor], torch.Tensor], + charges: Optional[Union[List[torch.Tensor], torch.Tensor]] = None, + ) -> Union[torch.Tensor, List[torch.Tensor]]: + """Compute potential for all provided "systems" stacked inside list. + + The computation is performed on the same ``device`` as ``systems`` is stored on. + The ``dtype`` of the output tensors will be the same as the input. + + :param types: single or list of 1D tensor of integer representing the + particles identity. For atoms, this is typically their atomic numbers. + :param positions: single or 2D tensor of shape (len(types), 3) containing the + Cartesian positions of all particles in the system. + :param charges: Optional single or list of 2D tensor of shape (len(types), n), + + :return: List of torch Tensors containing the potentials for all frames and all + atoms. Each tensor in the list is of shape (n_atoms, n_types), where + n_types is the number of types in all systems combined. If the input was + a single system only a single torch tensor with the potentials is returned. + + IMPORTANT: If multiple types are present, the different "types-channels" + are ordered according to atomic number. For example, if a structure contains + a water molecule with atoms 0, 1, 2 being of types O, H, H, then for this + system, the feature tensor will be of shape (3, 2) = (``n_atoms``, + ``n_types``), where ``features[0, 0]`` is the potential at the position of + the Oxygen atom (atom 0, first index) generated by the HYDROGEN(!) atoms, + while ``features[0,1]`` is the potential at the position of the Oxygen atom + generated by the Oxygen atom(s). + """ + # make sure compute function works if only a single tensor are provided as input + if not isinstance(types, list): + types = [types] + if not isinstance(positions, list): + positions = [positions] + + # Check that all inputs are consistent + # We don't require and test that all dtypes and devices are consistent for a + # list of inputs. Each "frame" is processed independently. + for types_single, positions_single in zip(types, positions): + if len(types_single.shape) != 1: + raise ValueError( + "each `types` must be a 1 dimensional tensor, got at least " + f"one tensor with {len(types_single.shape)} dimensions" + ) + + if positions_single.shape != (len(types_single), 3): + raise ValueError( + "each `positions` must be a (n_types x 3) tensor, got at least " + f"one tensor with shape {list(positions_single.shape)}" + ) + + if positions_single.device != types_single.device: + raise ValueError( + "`types` and `positions` must be on the same device, got " + f"{types_single.device}, {positions_single.device}" + ) + + requested_types = self._get_requested_types(types) + + # If charges are not provided, we assume that all types are treated separately + if charges is None: + charges = [] + for types_single, positions_single in zip(types, positions): + # One-hot encoding of charge information + charges_single = self._one_hot_charges( + types=types_single, + requested_types=requested_types, + dtype=positions_single.dtype, + device=positions_single.device, + ) + charges.append(charges_single) + + # If charges are provided, we need to make sure that they are consistent with + # the provided types + else: + if not isinstance(charges, list): + charges = [charges] + if len(charges) != len(types): + raise ValueError( + "The number of `types` and `charges` tensors must be the same, " + f"got {len(types)} and {len(charges)}." + ) + for charges_single, types_single in zip(charges, types): + if charges_single.shape[0] != len(types_single): + raise ValueError( + "The first dimension of `charges` must be the same as the " + f"length of `types`, got {charges_single.shape[0]} and " + f"{len(types_single)}." + ) + if charges[0].dtype != positions[0].dtype: + raise ValueError( + "`charges` must be have the same dtype as `positions`, got " + f"{charges[0].dtype} and {positions[0].dtype}." + ) + if charges[0].device != positions[0].device: + raise ValueError( + "`charges` must be on the same device as `positions`, got " + f"{charges[0].device} and {positions[0].device}." + ) + + potentials = [] + for positions_single, charges_single in zip(positions, charges): + # Compute the potentials + potentials.append( + self._compute_single_system( + positions=positions_single, charges=charges_single + ) + ) + + if len(types) == 1: + return potentials[0] + else: + return potentials + + def _get_requested_types(self, types: List[torch.Tensor]) -> List[int]: + """Extract a list of all unique and present types from the list of types.""" + all_types = torch.hstack(types) + types_requested = _1d_tolist(torch.unique(all_types)) + + if self.all_types is not None: + if not _is_subset(types_requested, self.all_types): + raise ValueError( + f"Global list of types {self.all_types} does not contain all " + f"types for the provided systems {types_requested}." + ) + return self.all_types + else: + return types_requested + + def _one_hot_charges( + self, + types: torch.Tensor, + requested_types: List[int], + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + ) -> torch.Tensor: + n_types = len(requested_types) + one_hot_charges = torch.zeros((len(types), n_types), dtype=dtype, device=device) + + for i_type, atomic_type in enumerate(requested_types): + one_hot_charges[types == atomic_type, i_type] = 1.0 + + return one_hot_charges + + def _compute_single_system( + self, + positions: torch.Tensor, + charges: torch.Tensor, + cell: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Core of the calculator that actually implements the computation of the potential + using various algorithms. + + :param positions: torch.tensor of shape (n_atoms, 3). Contains the Cartesian + coordinates of the atoms. The implementation also works if the positions + are not contained within the unit cell. + :param charges: torch.tensor of shape `(n_atoms, n_channels)`. In the simplest + case, this would be a tensor of shape (n_atoms, 1) where charges[i,0] is the + charge of atom i. More generally, the potential for the same atom positions + is computed for n_channels independent meshes, and one can specify the + "charge" of each atom on each of the meshes independently. For standard LODE + that treats all (atomic) types separately, one example could be: If n_atoms + = 4 and the types are [Na, Cl, Cl, Na], one could set n_channels=2 and use + the one-hot encoding charges = torch.tensor([[1,0],[0,1],[0,1],[1,0]]) for + the charges. This would then separately compute the "Na" potential and "Cl" + potential. Subtracting these from each other, one could recover the more + standard electrostatic potential in which Na and Cl have charges of +1 and + -1, respectively. + :param cell: torch.tensor of shape `(3, 3)`. Describes the unit cell of the + structure, where cell[i] is the i-th basis vector. While redundant in this + particular implementation, the parameter is kept to keep the same inputs as + the other calculators. + + :returns: torch.tensor of shape `(n_atoms, n_channels)` containing the potential + at the position of each atom for the `n_channels` independent meshes separately. + """ + + return torch.zeros_like(charges) diff --git a/src/meshlode/calculators/calculator_base_periodic.py b/src/meshlode/calculators/calculator_base_periodic.py new file mode 100644 index 00000000..883f6f71 --- /dev/null +++ b/src/meshlode/calculators/calculator_base_periodic.py @@ -0,0 +1,171 @@ +from typing import List, Optional, Union + +import torch + +from .calculator_base import CalculatorBase + + +class CalculatorBasePeriodic(CalculatorBase): + """ + Base calculator for periodic implementations + """ + + name = "CalculatorBasePeriodic" + + # Note that the base class also has this function, but with the parameter "cell" + # only as an option. For periodic implementations, "cell" is a strictly required + # parameter, which is why this function is implemented again. + # This function is kept to keep MeshLODE compatible with the broader pytorch + # infrastructure, which require a "forward" function. We name this function + # "compute" instead, for compatibility with other COSMO software. + def forward( + self, + types: Union[List[torch.Tensor], torch.Tensor], + positions: Union[List[torch.Tensor], torch.Tensor], + cell: Union[List[torch.Tensor], torch.Tensor], + charges: Optional[Union[List[torch.Tensor], torch.Tensor]] = None, + ) -> Union[torch.Tensor, List[torch.Tensor]]: + """forward just calls :py:meth:`CalculatorModule.compute`""" + return self.compute( + types=types, positions=positions, cell=cell, charges=charges + ) + + def compute( + self, + types: Union[List[torch.Tensor], torch.Tensor], + positions: Union[List[torch.Tensor], torch.Tensor], + cell: Union[List[torch.Tensor], torch.Tensor], + charges: Optional[Union[List[torch.Tensor], torch.Tensor]] = None, + ) -> Union[torch.Tensor, List[torch.Tensor]]: + """Compute potential for all provided "systems" stacked inside list. + + The computation is performed on the same ``device`` as ``systems`` is stored on. + The ``dtype`` of the output tensors will be the same as the input. + + :param types: single or list of 1D tensor of integer representing the + particles identity. For atoms, this is typically their atomic numbers. + :param positions: single or 2D tensor of shape (len(types), 3) containing the + Cartesian positions of all particles in the system. + :param cell: single or 2D tensor of shape (3, 3), describing the bounding + box/unit cell of the system. Each row should be one of the bounding box + vector; and columns should contain the x, y, and z components of these + vectors (i.e. the cell should be given in row-major order). + :param charges: Optional single or list of 2D tensor of shape (len(types), n), + + :return: List of torch Tensors containing the potentials for all frames and all + atoms. Each tensor in the list is of shape (n_atoms, n_types), where + n_types is the number of types in all systems combined. If the input was + a single system only a single torch tensor with the potentials is returned. + + IMPORTANT: If multiple types are present, the different "types-channels" + are ordered according to atomic number. For example, if a structure contains + a water molecule with atoms 0, 1, 2 being of types O, H, H, then for this + system, the feature tensor will be of shape (3, 2) = (``n_atoms``, + ``n_types``), where ``features[0, 0]`` is the potential at the position of + the Oxygen atom (atom 0, first index) generated by the HYDROGEN(!) atoms, + while ``features[0,1]`` is the potential at the position of the Oxygen atom + generated by the Oxygen atom(s). + """ + # make sure compute function works if only a single tensor are provided as input + if not isinstance(types, list): + types = [types] + if not isinstance(positions, list): + positions = [positions] + if not isinstance(cell, list): + cell = [cell] + + # Check that all inputs are consistent + for types_single, positions_single, cell_single in zip(types, positions, cell): + if len(types_single.shape) != 1: + raise ValueError( + "each `types` must be a 1 dimensional tensor, got at least " + f"one tensor with {len(types_single.shape)} dimensions" + ) + + if positions_single.shape != (len(types_single), 3): + raise ValueError( + "each `positions` must be a (n_types x 3) tensor, got at least " + f"one tensor with shape {list(positions_single.shape)}" + ) + + if cell_single.shape != (3, 3): + raise ValueError( + "each `cell` must be a (3 x 3) tensor, got at least " + f"one tensor with shape {list(cell_single.shape)}" + ) + + if cell_single.dtype != positions_single.dtype: + raise ValueError( + "`cell` must be have the same dtype as `positions`, got " + f"{cell_single.dtype} and {positions_single.dtype}" + ) + + if ( + positions_single.device != types_single.device + or cell_single.device != types_single.device + ): + raise ValueError( + "`types`, `positions`, and `cell` must be on the same device, got " + f"{types_single.device}, {positions_single.device} and " + f"{cell_single.device}." + ) + + requested_types = self._get_requested_types(types) + + # If charges are not provided, we assume that all types are treated separately + if charges is None: + charges = [] + for types_single, positions_single in zip(types, positions): + # One-hot encoding of charge information + charges_single = self._one_hot_charges( + types=types_single, + requested_types=requested_types, + dtype=positions_single.dtype, + device=positions_single.device, + ) + charges.append(charges_single) + + # If charges are provided, we need to make sure that they are consistent with + # the provided types + else: + if not isinstance(charges, list): + charges = [charges] + if len(charges) != len(types): + raise ValueError( + "The number of `types` and `charges` tensors must be the same, " + f"got {len(types)} and {len(charges)}." + ) + for charges_single, types_single in zip(charges, types): + if charges_single.shape[0] != len(types_single): + raise ValueError( + "The first dimension of `charges` must be the same as the " + f"length of `types`, got {charges_single.shape[0]} and " + f"{len(types_single)}." + ) + if charges[0].dtype != positions[0].dtype: + raise ValueError( + "`charges` must be have the same dtype as `positions`, got " + f"{charges[0].dtype} and {positions[0].dtype}." + ) + if charges[0].device != positions[0].device: + raise ValueError( + "`charges` must be on the same device as `positions`, got " + f"{charges[0].device} and {positions[0].device}." + ) + # We don't require and test that all dtypes and devices are consistent if a list + # of inputs. Each "frame" is processed independently. + potentials = [] + for positions_single, cell_single, charges_single in zip( + positions, cell, charges + ): + # Compute the potentials + potentials.append( + self._compute_single_system( + positions=positions_single, charges=charges_single, cell=cell_single + ) + ) + + if len(types) == 1: + return potentials[0] + else: + return potentials \ No newline at end of file From 002f54578613f0ea7af2c1546e5d1265077b35d2 Mon Sep 17 00:00:00 2001 From: Kevin Kazuki Huguenin-Dumittan Date: Tue, 11 Jun 2024 13:33:12 +0200 Subject: [PATCH 03/35] Add calculator for direct aperiodic summation --- src/meshlode/calculators/direct.py | 75 ++++++++++++++++++++++++++++++ 1 file changed, 75 insertions(+) create mode 100644 src/meshlode/calculators/direct.py diff --git a/src/meshlode/calculators/direct.py b/src/meshlode/calculators/direct.py new file mode 100644 index 00000000..ea382f3a --- /dev/null +++ b/src/meshlode/calculators/direct.py @@ -0,0 +1,75 @@ +from .calculator_base import CalculatorBase + +import torch + + +class DirectPotential(CalculatorBase): + """A specie-wise long-range potential computed using a direct summation over all + pairs of atoms, scaling as O(N^2) with respect to the number of particles N. + As opposed to the Ewald sum, this calculator does NOT take into account periodic + images, and it will instead be assumed that the provided atoms are in the infinitely + extended three-dimensional Euclidean space. + While slow, this implementation used as a reference to test faster algorithms. + + :param all_types: Optional global list of all atomic types that should be considered + for the computation. This option might be useful when running the calculation on + subset of a whole dataset and it required to keep the shape of the output + consistent. If this is not set the possible atomic types will be determined when + calling the :meth:`compute()`. + """ + + name = "DirectPotential" + + def _compute_single_system( + self, + positions: torch.Tensor, + charges: torch.Tensor, + ) -> torch.Tensor: + """ + Compute the "electrostatic" potential at the position of all atoms in a + structure. + This solver does not use periodic boundaries, and thus also does not take into + account potential periodic images. + + :param positions: torch.tensor of shape (n_atoms, 3). Contains the Cartesian + coordinates of the atoms. The implementation also works if the positions + are not contained within the unit cell. + :param charges: torch.tensor of shape `(n_atoms, n_channels)`. In the simplest + case, this would be a tensor of shape (n_atoms, 1) where charges[i,0] is the + charge of atom i. More generally, the potential for the same atom positions + is computed for n_channels independent meshes, and one can specify the + "charge" of each atom on each of the meshes independently. For standard LODE + that treats all (atomic) types separately, one example could be: If n_atoms + = 4 and the types are [Na, Cl, Cl, Na], one could set n_channels=2 and use + the one-hot encoding charges = torch.tensor([[1,0],[0,1],[0,1],[1,0]]) for + the charges. This would then separately compute the "Na" potential and "Cl" + potential. Subtracting these from each other, one could recover the more + standard electrostatic potential in which Na and Cl have charges of +1 and + -1, respectively. + :param cell: torch.tensor of shape `(3, 3)`. Describes the unit cell of the + structure, where cell[i] is the i-th basis vector. While redundant in this + particular implementation, the parameter is kept to keep the same inputs as + the other calculators. + + :returns: torch.tensor of shape `(n_atoms, n_channels)` containing the potential + at the position of each atom for the `n_channels` independent meshes separately. + """ + # Compute matrix containing the squared distances from the Gram matrix + # The squared distance and the inner product between two vectors r_i and r_j are + # related by: d_ij^2 = |r_i - r_j|^2 = r_i^2 + r_j^2 - 2*r_i*r_j + num_atoms = len(positions) + diagonal_indices = torch.arange(num_atoms) + gram_matrix = positions @ positions.T + squared_norms = gram_matrix[diagonal_indices, diagonal_indices].reshape(-1, 1) + ones = torch.ones((1, len(positions)), dtype=positions.dtype) + squared_norms_matrix = torch.matmul(squared_norms, ones) + distances_sq = squared_norms_matrix + squared_norms_matrix.T - 2 * gram_matrix + + # Add terms to diagonal in order to avoid division by zero + distances_sq[diagonal_indices, diagonal_indices] += 1e30 + + # Compute potential + potentials_by_pair = distances_sq.pow(-self.exponent / 2.) + potentials = torch.matmul(potentials_by_pair, charges) + + return potentials From 67c7c6c6f37bbbbd1e863d1999a56b41d1e3cbde Mon Sep 17 00:00:00 2001 From: Kevin Kazuki Huguenin-Dumittan Date: Tue, 11 Jun 2024 13:33:47 +0200 Subject: [PATCH 04/35] Add calculator for Ewald summation --- src/meshlode/calculators/ewald.py | 362 ++++++++++++++++++++++++++++++ 1 file changed, 362 insertions(+) create mode 100644 src/meshlode/calculators/ewald.py diff --git a/src/meshlode/calculators/ewald.py b/src/meshlode/calculators/ewald.py new file mode 100644 index 00000000..2df0b3e4 --- /dev/null +++ b/src/meshlode/calculators/ewald.py @@ -0,0 +1,362 @@ +import torch +from typing import List, Optional + +from .calculator_base_periodic import CalculatorBasePeriodic + +# extra imports for neighbor list +from ase import Atoms +from ase.neighborlist import neighbor_list + +class EwaldPotential(CalculatorBasePeriodic): + """A specie-wise long-range potential computed using the Ewald sum, scaling as + O(N^2) with respect to the number of particles N used as a reference to test faster + implementations. + + :param all_types: Optional global list of all atomic types that should be considered + for the computation. This option might be useful when running the calculation on + subset of a whole dataset and it required to keep the shape of the output + consistent. If this is not set the possible atomic types will be determined when + calling the :meth:`compute()`. + :param sr_cutoff: Cutoff radius used for the short-range part of the Ewald sum. If + not set to a global value, it will be set to be half of the shortest lattice + vector defining the cell (separately for each structure). + :param atomic_smearing: Width of the atom-centered Gaussian used to split the + Coulomb potential into the short- and long-range parts. If not set to a global + value, it will be set to 1/5 times the sr_cutoff value (separately for each + structure) to ensure convergence of the short-range part to a relative precision + of 1e-5. + :param lr_wavelength: Spatial resolution used for the long-range (reciprocal space) + part of the Ewald sum. More conretely, all Fourier space vectors with a + wavelength >= this value will be kept. If not set to a global value, it will be + set to half the atomic_smearing parameter to ensure convergence of the + long-range part to a relative precision of 1e-5. + :param subtract_self: If set to :py:obj:`True`, subtract from the features of an + atom the contributions to the potential arising from that atom itself (but not + the periodic images). + :param subtract_interior: If set to :py:obj:`True`, subtract from the features of an + atom the contributions to the potential arising from all atoms within the cutoff + Note that if set to true, the self contribution (see previous) is also + subtracted by default. + + Example + ------- + >>> import torch + >>> from meshlode import EwaldPotential + + Define simple example structure having the CsCl (Cesium Chloride) structure + + >>> types = torch.tensor([55, 17]) # Cs and Cl + >>> positions = torch.tensor([[0, 0, 0], [0.5, 0.5, 0.5]]) + >>> cell = torch.eye(3) + + Compute features + + >>> EP = EwaldPotential() + >>> EP.compute(types=types, positions=positions, cell=cell) + tensor([[-0.7391, -2.7745], + [-2.7745, -0.7391]]) + """ + + name = "EwaldPotential" + + def __init__( + self, + all_types: Optional[List[int]] = None, + exponent: Optional[torch.Tensor] = torch.tensor(1., dtype=torch.float64), + sr_cutoff: Optional[float] = None, + atomic_smearing: Optional[float] = None, + lr_wavelength: Optional[float] = None, + subtract_self: Optional[bool] = True, + subtract_interior: Optional[bool] = False + ): + super().__init__(all_types=all_types, exponent=exponent) + + # Store provided parameters + self.atomic_smearing = atomic_smearing + self.sr_cutoff = sr_cutoff + self.lr_wavelength = lr_wavelength + + # If interior contributions are to be subtracted, also do so for self term + if subtract_interior: + subtract_self = True + self.subtract_self = subtract_self + self.subtract_interior = subtract_interior + + def _compute_single_system( + self, + positions: torch.Tensor, + charges: torch.Tensor, + cell: torch.Tensor, + ) -> torch.Tensor: + """ + Compute the "electrostatic" potential at the position of all atoms in a + structure. + + :param positions: torch.tensor of shape (n_atoms, 3). Contains the Cartesian + coordinates of the atoms. The implementation also works if the positions + are not contained within the unit cell. + :param charges: torch.tensor of shape `(n_atoms, n_channels)`. In the simplest + case, this would be a tensor of shape (n_atoms, 1) where charges[i,0] is the + charge of atom i. More generally, the potential for the same atom positions + is computed for n_channels independent meshes, and one can specify the + "charge" of each atom on each of the meshes independently. For standard LODE + that treats all (atomic) types separately, one example could be: If n_atoms + = 4 and the types are [Na, Cl, Cl, Na], one could set n_channels=2 and use + the one-hot encoding charges = torch.tensor([[1,0],[0,1],[0,1],[1,0]]) for + the charges. This would then separately compute the "Na" potential and "Cl" + potential. Subtracting these from each other, one could recover the more + standard electrostatic potential in which Na and Cl have charges of +1 and + -1, respectively. + :param cell: torch.tensor of shape `(3, 3)`. Describes the unit cell of the + structure, where cell[i] is the i-th basis vector. + + :returns: torch.tensor of shape `(n_atoms, n_channels)` containing the potential + at the position of each atom for the `n_channels` independent meshes separately. + """ + # Check that the realspace cutoff (if provided) is not too large + # This is because the current implementation is not able to return multiple + # periodic images of the same atom as a neighbor + cell_dimensions = torch.linalg.norm(cell, dim=1) + cutoff_max = torch.min(cell_dimensions) / 2 - 1e-6 + if self.sr_cutoff is not None: + if self.sr_cutoff > torch.min(cell_dimensions) / 2: + raise ValueError(f"sr_cutoff {sr_cutoff} needs to be > {cutoff_max}") + + # Set the defaut values of convergence parameters + # The total computational cost = cost of SR part + cost of LR part + # Bigger smearing increases the cost of the SR part while decreasing the cost + # of the LR part. Since the latter usually is more expensive, we maximize the + # value of the smearing by default to minimize the cost of the LR part. + # The two auxilary parameters (sr_cutoff, lr_wavelength) then control the + # convergence of the SR and LR sums, respectively. The default values are + # chosen to reach a convergence on the order of 1e-4 to 1e-5 for the test + # structures. + if self.sr_cutoff is None: + sr_cutoff = cutoff_max + else: + sr_cutoff = self.sr_cutoff + + if self.atomic_smearing is None: + smearing = cutoff_max / 5.0 + else: + smearing = self.atomic_smearing + + if self.lr_wavelength is None: + lr_wavelength = 0.5 * smearing + else: + lr_wavelength = self.lr_wavelength + + potential_sr = self._compute_sr( + positions=positions, + charges=charges, + cell=cell, + smearing=smearing, + sr_cutoff=sr_cutoff, + ) + + ##return charges * torch.sum(positions, dim=1) * self.exponent + potential_sr + + potential_lr = self._compute_lr( + positions=positions, + charges=charges, + cell=cell, + smearing=smearing, + lr_wavelength=lr_wavelength, + ) + + #return potential_lr + + potential_ewald = potential_sr + potential_lr + return potential_ewald + + def _generate_kvectors(self, ns: torch.Tensor, cell: torch.Tensor) -> torch.Tensor: + """ + For a given unit cell, compute all reciprocal space vectors that are used to + perform sums in the Fourier transformed space. + + Note that this function is different from the function implemented in the + FourierSpaceConvolution class of the same name, since in this case, we are + generating the full grid of k-vectors, rather than the one that is adapted + specifically to be used together with FFT. + + :param ns: torch.tensor of shape ``(3,)`` containing integers + ``ns = [nx, ny, nz]`` contains the number of mesh points in the x-, y- and + z-direction, respectively. + :param cell: torch.tensor of shape ``(3, 3)`` Tensor specifying the real space + unit cell of a structure, where cell[i] is the i-th basis vector + + :return: torch.tensor of shape ``(N, 3)`` Contains all reciprocal space vectors + that will be used during Ewald summation (or related approaches). + ``k_vectors[i]`` contains the i-th vector, where the order has no special + significance. + The total number N of k-vectors is NOT simply nx*ny*nz, and roughly corresponds + to nx*ny*nz/2 due since the vectors +k and -k can be grouped together during + summation. + """ + # Check that the shapes of all inputs are correct + if ns.shape != (3,): + raise ValueError(f"ns of shape {list(ns.shape)} should be of shape (3, )") + + # Define basis vectors of the reciprocal cell + reciprocal_cell = 2 * torch.pi * cell.inverse().T + bx = reciprocal_cell[0] + by = reciprocal_cell[1] + bz = reciprocal_cell[2] + + # Generate all reciprocal space vectors + nxs_1d = ns[0] * torch.fft.fftfreq(ns[0], device=ns.device) + nys_1d = ns[1] * torch.fft.fftfreq(ns[1], device=ns.device) + nzs_1d = ns[2] * torch.fft.fftfreq(ns[2], device=ns.device) # real FFT + nxs, nys, nzs = torch.meshgrid(nxs_1d, nys_1d, nzs_1d, indexing="ij") + nxs = nxs.flatten().reshape((-1, 1)) + nys = nys.flatten().reshape((-1, 1)) + nzs = nzs.flatten().reshape((-1, 1)) + k_vectors = nxs * bx + nys * by + nzs * bz + + return k_vectors + + def _compute_lr( + self, + positions: torch.Tensor, + charges: torch.Tensor, + cell: torch.Tensor, + smearing: torch.Tensor, + lr_wavelength: torch.Tensor, + subtract_self=True, + ) -> torch.Tensor: + """ + Compute the long-range part of the Ewald sum in realspace + + :param positions: torch.tensor of shape (n_atoms, 3). Contains the Cartesian + coordinates of the atoms. The implementation also works if the positions + are not contained within the unit cell. + :param charges: torch.tensor of shape `(n_atoms, n_channels)`. In the simplest + case, this would be a tensor of shape (n_atoms, 1) where charges[i,0] is the + charge of atom i. More generally, the potential for the same atom positions + is computed for n_channels independent meshes, and one can specify the + "charge" of each atom on each of the meshes independently. + :param cell: torch.tensor of shape `(3, 3)`. Describes the unit cell of the + structure, where cell[i] is the i-th basis vector. + :param smearing: torch.Tensor smearing paramter determining the splitting + between the SR and LR parts. + :param lr_wavelength: Spatial resolution used for the long-range (reciprocal space) + part of the Ewald sum. More conretely, all Fourier space vectors with a + wavelength >= this value will be kept. + + :returns: torch.tensor of shape `(n_atoms, n_channels)` containing the potential + at the position of each atom for the `n_channels` independent meshes separately. + """ + # Define k-space cutoff from required real-space resolution + k_cutoff = 2 * torch.pi / lr_wavelength + + # Compute number of times each basis vector of the reciprocal space can be scaled + # until the cutoff is reached + basis_norms = torch.linalg.norm(cell, dim=1) + ns_float = k_cutoff * basis_norms / 2 / torch.pi + ns = torch.ceil(ns_float).long() + + # Generate k-vectors and evaluate + kvectors = self._generate_kvectors(ns=ns, cell=cell) + knorm_sq = torch.sum(kvectors**2, dim=1) + + # G(k) is the Fourier transform of the Coulomb potential + # generated by a Gaussian charge density + # We remove the singularity at k=0 by explicitly setting its + # value to be equal to zero. This mathematically corresponds + # to the requirement that the net charge of the cell is zero. + # G = 4 * torch.pi * torch.exp(-0.5 * smearing**2 * knorm_sq) / knorm_sq + G = self.potential.potential_fourier_from_k_sq(knorm_sq, smearing) + G[0] = self.potential.potential_fourier_at_zero(smearing) + + # Compute the energy using the explicit method that + # follows directly from the Poisson summation formula. + # For this, we precompute trigonometric factors for optimization, which leads + # to N^2 rather than N^3 scaling. + trig_args = kvectors @ (positions.T) # shape num_k x num_atoms + + # Reshape charges into suitable form for array/tensor broadcasting + num_atoms = len(positions) + if charges.dim() > 1: + num_channels = charges.shape[1] + charges_reshaped = (charges.T).reshape(num_channels, 1, num_atoms) + sum_idx = 2 + else: + charges_reshaped = charges + sum_idx = 1 + + # Actual computation of trigonometric factors + cos_all = torch.cos(trig_args) + sin_all = torch.sin(trig_args) + cos_summed = torch.sum(cos_all * charges_reshaped, dim=sum_idx) + sin_summed = torch.sum(sin_all * charges_reshaped, dim=sum_idx) + + # Add up the contributions to compute the potential + energy = torch.zeros_like(charges) + for i in range(num_atoms): + energy[i] += torch.sum( + G * cos_all[:, i] * cos_summed, dim=sum_idx - 1 + ) + torch.sum(G * sin_all[:, i] * sin_summed, dim=sum_idx - 1) + energy /= torch.abs(cell.det()) + + # Remove self contribution if desired + # For now, this is the expression for the Coulomb potential p=1 + # TODO: modify to expression for general p + if subtract_self: + self_contrib = ( + torch.sqrt(torch.tensor(2.0 / torch.pi, device=cell.device)) / smearing + ) + energy -= charges * self_contrib + + return energy + + def _compute_sr( + self, + positions: torch.Tensor, + charges: torch.Tensor, + cell: torch.Tensor, + smearing: torch.Tensor, + sr_cutoff: torch.Tensor, + ) -> torch.Tensor: + """ + Compute the short-range part of the Ewald sum in realspace + + :param positions: torch.tensor of shape (n_atoms, 3). Contains the Cartesian + coordinates of the atoms. The implementation also works if the positions + are not contained within the unit cell. + :param charges: torch.tensor of shape `(n_atoms, n_channels)`. In the simplest + case, this would be a tensor of shape (n_atoms, 1) where charges[i,0] is the + charge of atom i. More generally, the potential for the same atom positions + is computed for n_channels independent meshes, and one can specify the + "charge" of each atom on each of the meshes independently. + :param cell: torch.tensor of shape `(3, 3)`. Describes the unit cell of the + structure, where cell[i] is the i-th basis vector. + :param smearing: torch.Tensor smearing paramter determining the splitting + between the SR and LR parts. + :param sr_cutoff: Cutoff radius used for the short-range part of the Ewald sum. + + :returns: torch.tensor of shape `(n_atoms, n_channels)` containing the potential + at the position of each atom for the `n_channels` independent meshes separately. + """ + # Get list of neighbors + struc = Atoms(positions=positions.detach().numpy(), cell=cell, pbc=True) + atom_is, atom_js, shifts = neighbor_list( + "ijS", struc, sr_cutoff.item(), self_interaction=False + ) + + # Compute energy + potential = torch.zeros_like(charges) + for i, j, shift in zip(atom_is, atom_js, shifts): + dist = torch.linalg.norm(positions[j] - positions[i] + torch.tensor(shift.dot(struc.cell))) + + # If the contribution from all atoms within the cutoff is to be subtracted + # this short-range part will simply use -V_LR as the potential + if self.subtract_interior: + potential_bare = -self.potential.potential_lr_from_dist(dist, smearing) + # In the remaining cases, we simply use the usual V_SR to get the full + # 1/r^p potential when combined with the long-range part implemented in + # reciprocal space + else: + potential_bare = self.potential.potential_sr_from_dist(dist, smearing) + potential[i] += charges[j] * potential_bare + + return potential From 2f317e0b2d67eb8d59397313b11060906b1551f3 Mon Sep 17 00:00:00 2001 From: Kevin Kazuki Huguenin-Dumittan Date: Tue, 11 Jun 2024 13:35:45 +0200 Subject: [PATCH 05/35] Add calculator for PME and rename old mesh-only --- src/meshlode/calculators/mesh.py | 165 ++++++++++ src/meshlode/calculators/meshewald.py | 344 ++++++++++++++++++++ src/meshlode/calculators/meshpotential.py | 367 ---------------------- 3 files changed, 509 insertions(+), 367 deletions(-) create mode 100644 src/meshlode/calculators/mesh.py create mode 100644 src/meshlode/calculators/meshewald.py delete mode 100644 src/meshlode/calculators/meshpotential.py diff --git a/src/meshlode/calculators/mesh.py b/src/meshlode/calculators/mesh.py new file mode 100644 index 00000000..0b6015eb --- /dev/null +++ b/src/meshlode/calculators/mesh.py @@ -0,0 +1,165 @@ +from typing import List, Optional + +import torch + +from meshlode.lib.fourier_convolution import FourierSpaceConvolution +from meshlode.lib.mesh_interpolator import MeshInterpolator + +from .calculator_base_periodic import CalculatorBasePeriodic + +class MeshPotential(CalculatorBasePeriodic): + """A specie-wise long-range potential, computed using the particle-mesh Ewald (PME) + method scaling as O(NlogN) with respect to the number of particles N. + + :param atomic_smearing: Width of the atom-centered Gaussian used to create the + atomic density. + :param mesh_spacing: Value that determines the umber of Fourier-space grid points + that will be used along each axis. If set to None, it will automatically be set + to half of ``atomic_smearing``. + :param interpolation_order: Interpolation order for mapping onto the grid, where an + interpolation order of p corresponds to interpolation by a polynomial of degree + ``p - 1`` (e.g. ``p = 4`` for cubic interpolation). + :param subtract_self: If set to :py:obj:`True`, subtract from the features of an + atom the contributions to the potential arising from that atom itself (but not + the periodic images). + :param all_types: Optional global list of all atomic types that should be considered + for the computation. This option might be useful when running the calculation on + subset of a whole dataset and it required to keep the shape of the output + consistent. If this is not set the possible atomic types will be determined when + calling the :meth:`compute()`. + + Example + ------- + >>> import torch + >>> from meshlode import MeshPotential + + Define simple example structure having the CsCl (Cesium Chloride) structure + + >>> types = torch.tensor([55, 17]) # Cs and Cl + >>> positions = torch.tensor([[0, 0, 0], [0.5, 0.5, 0.5]]) + >>> cell = torch.eye(3) + + Compute features + + >>> MP = MeshPotential(atomic_smearing=0.2, mesh_spacing=0.1, interpolation_order=4) + >>> MP.compute(types=types, positions=positions, cell=cell) + tensor([[-0.5467, 1.3755], + [ 1.3755, -0.5467]]) + """ + + name = "MeshPotential" + + def __init__( + self, + atomic_smearing: float, + mesh_spacing: Optional[float] = None, + interpolation_order: Optional[int] = 4, + subtract_self: Optional[bool] = False, + all_types: Optional[List[int]] = None, + exponent: Optional[torch.Tensor] = torch.tensor(1., dtype=torch.float64), + ): + super().__init__(all_types=all_types, exponent=exponent) + + # Check that all provided values are correct + if interpolation_order not in [1, 2, 3, 4, 5]: + raise ValueError("Only `interpolation_order` from 1 to 5 are allowed") + if atomic_smearing <= 0: + raise ValueError(f"`atomic_smearing` {atomic_smearing} has to be positive") + + # If no explicit mesh_spacing is given, set it such that it can resolve + # the smeared potentials. + if mesh_spacing is None: + mesh_spacing = atomic_smearing / 2 + + # Store provided parameters + self.atomic_smearing = atomic_smearing + self.mesh_spacing = mesh_spacing + self.interpolation_order = interpolation_order + self.subtract_self = subtract_self + + # Initilize auxiliary objects + self.fourier_space_convolution = FourierSpaceConvolution() + + def _compute_single_system( + self, + positions: torch.Tensor, + charges: torch.Tensor, + cell: torch.Tensor, + mesh_spacing: Optional[float] = None, + ) -> torch.Tensor: + """ + Compute the "electrostatic" potential at the position of all atoms in a + structure. + + :param positions: torch.tensor of shape (n_atoms, 3). Contains the Cartesian + coordinates of the atoms. The implementation also works if the positions + are not contained within the unit cell. + :param charges: torch.tensor of shape `(n_atoms, n_channels)`. In the simplest + case, this would be a tensor of shape (n_atoms, 1) where charges[i,0] is the + charge of atom i. More generally, the potential for the same atom positions + is computed for n_channels independent meshes, and one can specify the + "charge" of each atom on each of the meshes independently. For standard LODE + that treats all (atomic) types separately, one example could be: If n_atoms + = 4 and the types are [Na, Cl, Cl, Na], one could set n_channels=2 and use + the one-hot encoding charges = torch.tensor([[1,0],[0,1],[0,1],[1,0]]) for + the charges. This would then separately compute the "Na" potential and "Cl" + potential. Subtracting these from each other, one could recover the more + standard electrostatic potential in which Na and Cl have charges of +1 and + -1, respectively. + :param cell: torch.tensor of shape `(3, 3)`. Describes the unit cell of the + structure, where cell[i] is the i-th basis vector. + + :returns: torch.tensor of shape `(n_atoms, n_channels)` containing the potential + at the position of each atom for the `n_channels` independent meshes separately. + """ + # Initializations + n_atoms = len(positions) + assert positions.shape == (n_atoms, 3) + assert charges.shape[0] == n_atoms + + assert positions.dtype == cell.dtype and charges.dtype == cell.dtype + assert positions.device == cell.device and charges.device == cell.device + + + # Define cutoff in reciprocal space + if mesh_spacing is None: + mesh_spacing = self.mesh_spacing + k_cutoff = 2 * torch.pi / mesh_spacing + + # Compute number of times each basis vector of the + # reciprocal space can be scaled until the cutoff + # is reached + basis_norms = torch.linalg.norm(cell, dim=1) + ns_approx = k_cutoff * basis_norms / 2 / torch.pi + ns_actual_approx = 2 * ns_approx + 1 # actual number of mesh points + ns = 2 ** torch.ceil(torch.log2(ns_actual_approx)).long() # [nx, ny, nz] + + # Step 1: Smear particles onto mesh + MI = MeshInterpolator(cell, ns, interpolation_order=self.interpolation_order) + MI.compute_interpolation_weights(positions) + rho_mesh = MI.points_to_mesh(particle_weights=charges) + + # Step 2: Perform Fourier space convolution (FSC) + potential_mesh = self.fourier_space_convolution.compute( + mesh_values=rho_mesh, + cell=cell, + potential_exponent=1, + atomic_smearing=self.atomic_smearing, + ) + + # Step 3: Back interpolation + interpolated_potential = MI.mesh_to_points(potential_mesh) + + # Remove self contribution + if self.subtract_self: + self_contrib = ( + torch.sqrt( + torch.tensor( + 2.0 / torch.pi, dtype=positions.dtype, device=positions.device + ), + ) + / self.atomic_smearing + ) + interpolated_potential -= charges * self_contrib + + return interpolated_potential diff --git a/src/meshlode/calculators/meshewald.py b/src/meshlode/calculators/meshewald.py new file mode 100644 index 00000000..8de8b985 --- /dev/null +++ b/src/meshlode/calculators/meshewald.py @@ -0,0 +1,344 @@ +import torch +from typing import List, Optional + +# from .mesh import MeshPotential +from .calculator_base_periodic import CalculatorBasePeriodic +from meshlode.lib.mesh_interpolator import MeshInterpolator + +# extra imports for neighbor list +from ase import Atoms +from ase.neighborlist import neighbor_list + +class MeshEwaldPotential(CalculatorBasePeriodic): + """A specie-wise long-range potential computed using a mesh-based Ewald method, + scaling as O(NlogN) with respect to the number of particles N used as a reference + to test faster implementations. + + :param all_types: Optional global list of all atomic types that should be considered + for the computation. This option might be useful when running the calculation on + subset of a whole dataset and it required to keep the shape of the output + consistent. If this is not set the possible atomic types will be determined when + calling the :meth:`compute()`. + :param sr_cutoff: Cutoff radius used for the short-range part of the Ewald sum. If + not set to a global value, it will be set to be half of the shortest lattice + vector defining the cell (separately for each structure). + :param atomic_smearing: Width of the atom-centered Gaussian used to split the + Coulomb potential into the short- and long-range parts. If not set to a global + value, it will be set to 1/5 times the sr_cutoff value (separately for each + structure) to ensure convergence of the short-range part to a relative precision + of 1e-5. + :param lr_wavelength: Spatial resolution used for the long-range (reciprocal space) + part of the Ewald sum. More conretely, all Fourier space vectors with a + wavelength >= this value will be kept. If not set to a global value, it will be + set to half the atomic_smearing parameter to ensure convergence of the + long-range part to a relative precision of 1e-5. + :param subtract_self: If set to :py:obj:`True`, subtract from the features of an + atom the contributions to the potential arising from that atom itself (but not + the periodic images). + :param subtract_interior: If set to :py:obj:`True`, subtract from the features of an + atom the contributions to the potential arising from all atoms within the cutoff + Note that if set to true, the self contribution (see previous) is also + subtracted by default. + """ + + name = "MeshEwaldPotential" + + def __init__( + self, + all_types: Optional[List[int]] = None, + exponent: Optional[torch.Tensor] = torch.tensor(1., dtype=torch.float64), + sr_cutoff: Optional[float] = None, + atomic_smearing: Optional[float] = None, + mesh_spacing: Optional[float] = None, + subtract_self: Optional[bool] = True, + interpolation_order: Optional[int] = 4, + subtract_interior: Optional[bool] = False + ): + super().__init__(all_types=all_types, exponent=exponent) + + # Check that all provided values are correct + if interpolation_order not in [1, 2, 3, 4, 5]: + raise ValueError("Only `interpolation_order` from 1 to 5 are allowed") + if atomic_smearing <= 0: + raise ValueError(f"`atomic_smearing` {atomic_smearing} has to be positive") + + # If no explicit mesh_spacing is given, set it such that it can resolve + # the smeared potentials. + if mesh_spacing is None: + mesh_spacing = atomic_smearing / 2 + + # Store provided parameters + self.atomic_smearing = atomic_smearing + self.mesh_spacing = mesh_spacing + self.interpolation_order = interpolation_order + self.sr_cutoff = sr_cutoff + + # If interior contributions are to be subtracted, also do so for self term + if subtract_interior: + subtract_self = True + self.subtract_self = subtract_self + self.subtract_interior = subtract_interior + + def _generate_kvectors(self, ns: torch.Tensor, cell: torch.Tensor) -> torch.Tensor: + """ + For a given unit cell, compute all reciprocal space vectors that are used to + perform sums in the Fourier transformed space. + + :param ns: torch.tensor of shape ``(3,)`` + ``ns = [nx, ny, nz]`` contains the number of mesh points in the x-, y- and + z-direction, respectively. For faster performance during the Fast Fourier + Transform (FFT) it is recommended to use values of nx, ny and nz that are + powers of 2. + :param cell: torch.tensor of shape ``(3, 3)`` Tensor specifying the real space + unit cell of a structure, where cell[i] is the i-th basis vector + + :return: torch.tensor of shape ``(N, 3)`` Contains all reciprocal space vectors + that will be used during Ewald summation (or related approaches). + ``k_vectors[i]`` contains the i-th vector, where the order has no special + significance. + """ + if ns.device != cell.device: + raise ValueError( + f"`ns` and `cell` are not on the same device, got {ns.device} and " + f"{cell.device}." + ) + + if ns.shape != (3,): + raise ValueError(f"ns of shape {list(ns.shape)} should be of shape (3, )") + + if cell.shape != (3, 3): + raise ValueError( + f"cell of shape {list(cell.shape)} should be of shape (3, 3)" + ) + + # Define basis vectors of the reciprocal cell + reciprocal_cell = 2 * torch.pi * cell.inverse().T + bx = reciprocal_cell[0] + by = reciprocal_cell[1] + bz = reciprocal_cell[2] + + # Generate all reciprocal space vectors + nxs_1d = ns[0] * torch.fft.fftfreq(ns[0], device=ns.device) + nys_1d = ns[1] * torch.fft.fftfreq(ns[1], device=ns.device) + nzs_1d = ns[2] * torch.fft.rfftfreq(ns[2], device=ns.device) # real FFT + nxs, nys, nzs = torch.meshgrid(nxs_1d, nys_1d, nzs_1d, indexing="ij") + nxs = nxs.reshape((int(ns[0]), int(ns[1]), len(nzs_1d), 1)) + nys = nys.reshape((int(ns[0]), int(ns[1]), len(nzs_1d), 1)) + nzs = nzs.reshape((int(ns[0]), int(ns[1]), len(nzs_1d), 1)) + k_vectors = nxs * bx + nys * by + nzs * bz + + return k_vectors + + def _compute_single_system( + self, + positions: torch.Tensor, + charges: torch.Tensor, + cell: torch.Tensor, + ) -> torch.Tensor: + """ + Compute the "electrostatic" potential at the position of all atoms in a + structure. + + :param positions: torch.tensor of shape (n_atoms, 3). Contains the Cartesian + coordinates of the atoms. The implementation also works if the positions + are not contained within the unit cell. + :param charges: torch.tensor of shape `(n_atoms, n_channels)`. In the simplest + case, this would be a tensor of shape (n_atoms, 1) where charges[i,0] is the + charge of atom i. More generally, the potential for the same atom positions + is computed for n_channels independent meshes, and one can specify the + "charge" of each atom on each of the meshes independently. For standard LODE + that treats all (atomic) types separately, one example could be: If n_atoms + = 4 and the types are [Na, Cl, Cl, Na], one could set n_channels=2 and use + the one-hot encoding charges = torch.tensor([[1,0],[0,1],[0,1],[1,0]]) for + the charges. This would then separately compute the "Na" potential and "Cl" + potential. Subtracting these from each other, one could recover the more + standard electrostatic potential in which Na and Cl have charges of +1 and + -1, respectively. + :param cell: torch.tensor of shape `(3, 3)`. Describes the unit cell of the + structure, where cell[i] is the i-th basis vector. + + :returns: torch.tensor of shape `(n_atoms, n_channels)` containing the potential + at the position of each atom for the `n_channels` independent meshes separately. + """ + # Check that the realspace cutoff (if provided) is not too large + # This is because the current implementation is not able to return multiple + # periodic images of the same atom as a neighbor + cell_dimensions = torch.linalg.norm(cell, dim=1) + cutoff_max = torch.min(cell_dimensions) / 2 - 1e-6 + if self.sr_cutoff is not None: + if self.sr_cutoff > torch.min(cell_dimensions) / 2: + raise ValueError(f"sr_cutoff {sr_cutoff} needs to be > {cutoff_max}") + + # Set the defaut values of convergence parameters + # The total computational cost = cost of SR part + cost of LR part + # Bigger smearing increases the cost of the SR part while decreasing the cost + # of the LR part. Since the latter usually is more expensive, we maximize the + # value of the smearing by default to minimize the cost of the LR part. + # The two auxilary parameters (sr_cutoff, lr_wavelength) then control the + # convergence of the SR and LR sums, respectively. The default values are + # chosen to reach a convergence on the order of 1e-4 to 1e-5 for the test + # structures. + if self.sr_cutoff is None: + sr_cutoff = cutoff_max + else: + sr_cutoff = self.sr_cutoff + + if self.atomic_smearing is None: + smearing = cutoff_max / 5.0 + else: + smearing = self.atomic_smearing + + if self.mesh_spacing is None: + mesh_spacing = smearing / 8. + else: + mesh_spacing = self.mesh_spacing + + # Compute short-range (SR) part using a real space sum + potential_sr = self._compute_sr( + positions=positions, + charges=charges, + cell=cell, + smearing=smearing, + sr_cutoff=sr_cutoff, + ) + + # Compute long-range (LR) part using a Fourier / reciprocal space sum + potential_lr = self._compute_lr( + positions=positions, + charges=charges, + cell=cell, + smearing=smearing, + lr_wavelength=mesh_spacing) + + # Combine both parts to obtain the full potential + potential_ewald = potential_sr + potential_lr + return potential_ewald + + def _compute_lr( + self, + positions: torch.Tensor, + charges: torch.Tensor, + cell: torch.Tensor, + smearing: torch.Tensor, + lr_wavelength: torch.Tensor, + subtract_self=True, + ) -> torch.Tensor: + """ + Compute the long-range part of the Ewald sum in realspace + + :param positions: torch.tensor of shape (n_atoms, 3). Contains the Cartesian + coordinates of the atoms. The implementation also works if the positions + are not contained within the unit cell. + :param charges: torch.tensor of shape `(n_atoms, n_channels)`. In the simplest + case, this would be a tensor of shape (n_atoms, 1) where charges[i,0] is the + charge of atom i. More generally, the potential for the same atom positions + is computed for n_channels independent meshes, and one can specify the + "charge" of each atom on each of the meshes independently. + :param cell: torch.tensor of shape `(3, 3)`. Describes the unit cell of the + structure, where cell[i] is the i-th basis vector. + :param smearing: torch.Tensor smearing paramter determining the splitting + between the SR and LR parts. + :param lr_wavelength: Spatial resolution used for the long-range (reciprocal space) + part of the Ewald sum. More conretely, all Fourier space vectors with a + wavelength >= this value will be kept. + + :returns: torch.tensor of shape `(n_atoms, n_channels)` containing the potential + at the position of each atom for the `n_channels` independent meshes separately. + """ + # Step 0 (Preparation): Compute number of times each basis vector of the + # reciprocal space can be scaled until the cutoff is reached + mesh_spacing = lr_wavelength + k_cutoff = 2 * torch.pi / lr_wavelength + basis_norms = torch.linalg.norm(cell, dim=1) + ns_approx = k_cutoff * basis_norms / 2 / torch.pi + ns_actual_approx = 2 * ns_approx + 1 # actual number of mesh points + ns = 2 ** torch.ceil(torch.log2(ns_actual_approx)).long() # [nx, ny, nz] + + # Step 1: Smear particles onto mesh + MI = MeshInterpolator(cell, ns, interpolation_order=self.interpolation_order) + MI.compute_interpolation_weights(positions) + rho_mesh = MI.points_to_mesh(particle_weights=charges) + + # Step 2: Perform Fourier space convolution (FSC) to get potential on mesh + # Step 2.1: Generate k-vectors and evaluate kernel function + kvectors = self._generate_kvectors(ns=ns, cell=cell) + knorm_sq = torch.sum(kvectors**2, dim=3) + + # Step 2.2: Evaluate kernel function (careful, tensor shapes are different from + # the pure Ewald implementation since we are no longer flattening) + G = self.potential.potential_fourier_from_k_sq(knorm_sq, smearing) + G[0,0,0] = self.potential.potential_fourier_at_zero(smearing) + + potential_mesh = rho_mesh + + # Step 2.3: Perform actual convolution using FFT + volume = cell.det() + dims = (1, 2, 3) # dimensions along which to Fourier transform + mesh_hat = torch.fft.rfftn(rho_mesh, norm="backward", dim=dims) + potential_hat = mesh_hat * G + potential_mesh = torch.fft.irfftn(potential_hat, norm="forward", dim=dims) + potential_mesh /= volume + + # Step 3: Back interpolation + interpolated_potential = MI.mesh_to_points(potential_mesh) + + # Step 4: Remove self-contribution if desired + if subtract_self: + self_contrib = ( + torch.sqrt(torch.tensor(2.0 / torch.pi, device=cell.device)) / smearing + ) + interpolated_potential -= charges * self_contrib + + return interpolated_potential + + def _compute_sr( + self, + positions: torch.Tensor, + charges: torch.Tensor, + cell: torch.Tensor, + smearing: torch.Tensor, + sr_cutoff: torch.Tensor, + ) -> torch.Tensor: + """ + Compute the short-range part of the Ewald sum in realspace + + :param positions: torch.tensor of shape (n_atoms, 3). Contains the Cartesian + coordinates of the atoms. The implementation also works if the positions + are not contained within the unit cell. + :param charges: torch.tensor of shape `(n_atoms, n_channels)`. In the simplest + case, this would be a tensor of shape (n_atoms, 1) where charges[i,0] is the + charge of atom i. More generally, the potential for the same atom positions + is computed for n_channels independent meshes, and one can specify the + "charge" of each atom on each of the meshes independently. + :param cell: torch.tensor of shape `(3, 3)`. Describes the unit cell of the + structure, where cell[i] is the i-th basis vector. + :param smearing: torch.Tensor smearing paramter determining the splitting + between the SR and LR parts. + :param sr_cutoff: Cutoff radius used for the short-range part of the Ewald sum. + + :returns: torch.tensor of shape `(n_atoms, n_channels)` containing the potential + at the position of each atom for the `n_channels` independent meshes separately. + """ + # Get list of neighbors + struc = Atoms(positions=positions.detach().numpy(), cell=cell, pbc=True) + atom_is, atom_js, shifts = neighbor_list( + "ijS", struc, sr_cutoff.item(), self_interaction=False + ) + + # Compute energy + potential = torch.zeros_like(charges) + for i, j, shift in zip(atom_is, atom_js, shifts): + dist = torch.linalg.norm(positions[j] - positions[i] + torch.tensor(shift.dot(struc.cell))) + + # If the contribution from all atoms within the cutoff is to be subtracted + # this short-range part will simply use -V_LR as the potential + if self.subtract_interior: + potential_bare = -self.potential.potential_lr_from_dist(dist, smearing) + # In the remaining cases, we simply use the usual V_SR to get the full + # 1/r^p potential when combined with the long-range part implemented in + # reciprocal space + else: + potential_bare = self.potential.potential_sr_from_dist(dist, smearing) + potential[i] += charges[j] * potential_bare + + return potential diff --git a/src/meshlode/calculators/meshpotential.py b/src/meshlode/calculators/meshpotential.py deleted file mode 100644 index 0c43472b..00000000 --- a/src/meshlode/calculators/meshpotential.py +++ /dev/null @@ -1,367 +0,0 @@ -from typing import List, Optional, Union - -import torch - -from meshlode.lib.fourier_convolution import FourierSpaceConvolution -from meshlode.lib.mesh_interpolator import MeshInterpolator - - -@torch.jit.script -def _1d_tolist(x: torch.Tensor) -> List[int]: - """Auxilary function to convert 1d torch tensor to list of integers.""" - result: List[int] = [] - for i in x: - result.append(i.item()) - return result - - -@torch.jit.script -def _is_subset(subset_candidate: List[int], superset: List[int]) -> bool: - """Checks whether all elements of `subset_candidate` are part of `superset`.""" - for element in subset_candidate: - if element not in superset: - return False - return True - - -class MeshPotential(torch.nn.Module): - """A specie-wise long-range potential. - - :param atomic_smearing: Width of the atom-centered Gaussian used to create the - atomic density. - :param mesh_spacing: Value that determines the umber of Fourier-space grid points - that will be used along each axis. If set to None, it will automatically be set - to half of ``atomic_smearing``. - :param interpolation_order: Interpolation order for mapping onto the grid, where an - interpolation order of p corresponds to interpolation by a polynomial of degree - ``p - 1`` (e.g. ``p = 4`` for cubic interpolation). - :param subtract_self: If set to :py:obj:`True`, subtract from the features of an - atom the contributions to the potential arising from that atom itself (but not - the periodic images). - :param all_types: Optional global list of all atomic types that should be considered - for the computation. This option might be useful when running the calculation on - subset of a whole dataset and it required to keep the shape of the output - consistent. If this is not set the possible atomic types will be determined when - calling the :meth:`compute()`. - - Example - ------- - >>> import torch - >>> from meshlode import MeshPotential - - Define simple example structure having the CsCl (Cesium Chloride) structure - - >>> types = torch.tensor([55, 17]) # Cs and Cl - >>> positions = torch.tensor([[0, 0, 0], [0.5, 0.5, 0.5]]) - >>> cell = torch.eye(3) - - Compute features - - >>> MP = MeshPotential(atomic_smearing=0.2, mesh_spacing=0.1, interpolation_order=4) - >>> MP.compute(types=types, positions=positions, cell=cell) - tensor([[-0.5467, 1.3755], - [ 1.3755, -0.5467]]) - """ - - name = "MeshPotential" - - def __init__( - self, - atomic_smearing: float, - mesh_spacing: Optional[float] = None, - interpolation_order: Optional[int] = 4, - subtract_self: Optional[bool] = False, - all_types: Optional[List[int]] = None, - ): - super().__init__() - - # Check that all provided values are correct - if interpolation_order not in [1, 2, 3, 4, 5]: - raise ValueError("Only `interpolation_order` from 1 to 5 are allowed") - if atomic_smearing <= 0: - raise ValueError(f"`atomic_smearing` {atomic_smearing} has to be positive") - - # If no explicit mesh_spacing is given, set it such that it can resolve - # the smeared potentials. - if mesh_spacing is None: - mesh_spacing = atomic_smearing / 2 - - # Store provided parameters - self.atomic_smearing = atomic_smearing - self.mesh_spacing = mesh_spacing - self.interpolation_order = interpolation_order - self.subtract_self = subtract_self - - if all_types is None: - self.all_types = None - else: - self.all_types = _1d_tolist(torch.unique(torch.tensor(all_types))) - - # Initilize auxiliary objects - self.fourier_space_convolution = FourierSpaceConvolution() - - # This function is kept to keep MeshLODE compatible with the broader pytorch - # infrastructure, which require a "forward" function. We name this function - # "compute" instead, for compatibility with other COSMO software. - def forward( - self, - types: Union[List[torch.Tensor], torch.Tensor], - positions: Union[List[torch.Tensor], torch.Tensor], - cell: Union[List[torch.Tensor], torch.Tensor], - charges: Optional[Union[List[torch.Tensor], torch.Tensor]] = None, - ) -> Union[torch.Tensor, List[torch.Tensor]]: - """forward just calls :py:meth:`CalculatorModule.compute`""" - return self.compute( - types=types, positions=positions, cell=cell, charges=charges - ) - - def compute( - self, - types: Union[List[torch.Tensor], torch.Tensor], - positions: Union[List[torch.Tensor], torch.Tensor], - cell: Union[List[torch.Tensor], torch.Tensor], - charges: Optional[Union[List[torch.Tensor], torch.Tensor]] = None, - ) -> Union[torch.Tensor, List[torch.Tensor]]: - """Compute potential for all provided "systems" stacked inside list. - - The computation is performed on the same ``device`` as ``systems`` is stored on. - The ``dtype`` of the output tensors will be the same as the input. - - :param types: single or list of 1D tensor of integer representing the - particles identity. For atoms, this is typically their atomic numbers. - :param positions: single or 2D tensor of shape (len(types), 3) containing the - Cartesian positions of all particles in the system. - :param cell: single or 2D tensor of shape (3, 3), describing the bounding - box/unit cell of the system. Each row should be one of the bounding box - vector; and columns should contain the x, y, and z components of these - vectors (i.e. the cell should be given in row-major order). - :param charges: Optional single or list of 2D tensor of shape (len(types), n), - - :return: List of torch Tensors containing the potentials for all frames and all - atoms. Each tensor in the list is of shape (n_atoms, n_types), where - n_types is the number of types in all systems combined. If the input was - a single system only a single torch tensor with the potentials is returned. - - IMPORTANT: If multiple types are present, the different "types-channels" - are ordered according to atomic number. For example, if a structure contains - a water molecule with atoms 0, 1, 2 being of types O, H, H, then for this - system, the feature tensor will be of shape (3, 2) = (``n_atoms``, - ``n_types``), where ``features[0, 0]`` is the potential at the position of - the Oxygen atom (atom 0, first index) generated by the HYDROGEN(!) atoms, - while ``features[0,1]`` is the potential at the position of the Oxygen atom - generated by the Oxygen atom(s). - """ - # make sure compute function works if only a single tensor are provided as input - if not isinstance(types, list): - types = [types] - if not isinstance(positions, list): - positions = [positions] - if not isinstance(cell, list): - cell = [cell] - - # Check that all inputs are consistent - for types_single, positions_single, cell_single in zip(types, positions, cell): - if len(types_single.shape) != 1: - raise ValueError( - "each `types` must be a 1 dimensional tensor, got at least " - f"one tensor with {len(types_single.shape)} dimensions" - ) - - if positions_single.shape != (len(types_single), 3): - raise ValueError( - "each `positions` must be a (n_types x 3) tensor, got at least " - f"one tensor with shape {list(positions_single.shape)}" - ) - - if cell_single.shape != (3, 3): - raise ValueError( - "each `cell` must be a (3 x 3) tensor, got at least " - f"one tensor with shape {list(cell_single.shape)}" - ) - - if cell_single.dtype != positions_single.dtype: - raise ValueError( - "`cell` must be have the same dtype as `positions`, got " - f"{cell_single.dtype} and {positions_single.dtype}" - ) - - if ( - positions_single.device != types_single.device - or cell_single.device != types_single.device - ): - raise ValueError( - "`types`, `positions`, and `cell` must be on the same device, got " - f"{types_single.device}, {positions_single.device} and " - f"{cell_single.device}." - ) - - requested_types = self._get_requested_types(types) - - # If charges are not provided, we assume that all types are treated separately - if charges is None: - charges = [] - for types_single, positions_single in zip(types, positions): - # One-hot encoding of charge information - charges_single = self._one_hot_charges( - types=types_single, - requested_types=requested_types, - dtype=positions_single.dtype, - device=positions_single.device, - ) - charges.append(charges_single) - - # If charges are provided, we need to make sure that they are consistent with - # the provided types - - else: - if not isinstance(charges, list): - charges = [charges] - if len(charges) != len(types): - raise ValueError( - "The number of `types` and `charges` tensors must be the same, " - f"got {len(types)} and {len(charges)}." - ) - for charges_single, types_single in zip(charges, types): - if charges_single.shape[0] != len(types_single): - raise ValueError( - "The first dimension of `charges` must be the same as the " - f"length of `types`, got {charges_single.shape[0]} and " - f"{len(types_single)}." - ) - if charges[0].dtype != positions[0].dtype: - raise ValueError( - "`charges` must be have the same dtype as `positions`, got " - f"{charges[0].dtype} and {positions[0].dtype}." - ) - if charges[0].device != positions[0].device: - raise ValueError( - "`charges` must be on the same device as `positions`, got " - f"{charges[0].device} and {positions[0].device}." - ) - # We don't require and test that all dtypes and devices are consistent if a list - # of inputs. Each "frame" is processed independently. - potentials = [] - for positions_single, cell_single, charges_single in zip( - positions, cell, charges - ): - # Compute the potentials - potentials.append( - self._compute_single_system( - positions=positions_single, charges=charges_single, cell=cell_single - ) - ) - - if len(types) == 1: - return potentials[0] - else: - return potentials - - def _get_requested_types(self, types: List[torch.Tensor]) -> List[int]: - """Extract a list of all unique and present types from the list of types.""" - all_types = torch.hstack(types) - types_requested = _1d_tolist(torch.unique(all_types)) - - if self.all_types is not None: - if not _is_subset(types_requested, self.all_types): - raise ValueError( - f"Global list of types {self.all_types} does not contain all " - f"types for the provided systems {types_requested}." - ) - return self.all_types - else: - return types_requested - - def _one_hot_charges( - self, - types: torch.Tensor, - requested_types: List[int], - dtype: Optional[torch.dtype] = None, - device: Optional[torch.device] = None, - ) -> torch.Tensor: - n_types = len(requested_types) - one_hot_charges = torch.zeros((len(types), n_types), dtype=dtype, device=device) - - for i_type, atomic_type in enumerate(requested_types): - one_hot_charges[types == atomic_type, i_type] = 1.0 - - return one_hot_charges - - def _compute_single_system( - self, - positions: torch.Tensor, - charges: torch.Tensor, - cell: torch.Tensor, - ) -> torch.Tensor: - """ - Compute the "electrostatic" potential at the position of all atoms in a - structure. - - :param positions: torch.tensor of shape (n_atoms, 3). Contains the Cartesian - coordinates of the atoms. The implementation also works if the positions - are not contained within the unit cell. - :param charges: torch.tensor of shape `(n_atoms, n_channels)`. In the simplest - case, this would be a tensor of shape (n_atoms, 1) where charges[i,0] is the - charge of atom i. More generally, the potential for the same atom positions - is computed for n_channels independent meshes, and one can specify the - "charge" of each atom on each of the meshes independently. For standard LODE - that treats all (atomic) types separately, one example could be: If n_atoms - = 4 and the types are [Na, Cl, Cl, Na], one could set n_channels=2 and use - the one-hot encoding charges = torch.tensor([[1,0],[0,1],[0,1],[1,0]]) for - the charges. This would then separately compute the "Na" potential and "Cl" - potential. Subtracting these from each other, one could recover the more - standard electrostatic potential in which Na and Cl have charges of +1 and - -1, respectively. - :param cell: torch.tensor of shape `(3, 3)`. Describes the unit cell of the - structure, where cell[i] is the i-th basis vector. - - :returns: torch.tensor of shape `(n_atoms, n_channels)` containing the potential - at the position of each atom for the `n_channels` independent meshes separately. - """ - # Initializations - n_atoms = len(positions) - assert positions.shape == (n_atoms, 3) - assert charges.shape[0] == n_atoms - - assert positions.dtype == cell.dtype and charges.dtype == cell.dtype - assert positions.device == cell.device and charges.device == cell.device - - # Define cutoff in reciprocal space - k_cutoff = 2 * torch.pi / self.mesh_spacing - - # Compute number of times each basis vector of the - # reciprocal space can be scaled until the cutoff - # is reached - basis_norms = torch.linalg.norm(cell, dim=1) - ns_approx = k_cutoff * basis_norms / 2 / torch.pi - ns_actual_approx = 2 * ns_approx + 1 # actual number of mesh points - ns = 2 ** torch.ceil(torch.log2(ns_actual_approx)).long() # [nx, ny, nz] - - # Step 1: Smear particles onto mesh - MI = MeshInterpolator(cell, ns, interpolation_order=self.interpolation_order) - MI.compute_interpolation_weights(positions) - rho_mesh = MI.points_to_mesh(particle_weights=charges) - - # Step 2: Perform Fourier space convolution (FSC) - potential_mesh = self.fourier_space_convolution.compute( - mesh_values=rho_mesh, - cell=cell, - potential_exponent=1, - atomic_smearing=self.atomic_smearing, - ) - - # Step 3: Back interpolation - interpolated_potential = MI.mesh_to_points(potential_mesh) - - # Remove self contribution - if self.subtract_self: - self_contrib = ( - torch.sqrt( - torch.tensor( - 2.0 / torch.pi, dtype=positions.dtype, device=positions.device - ), - ) - / self.atomic_smearing - ) - interpolated_potential -= charges * self_contrib - - return interpolated_potential From be4b768543820b3cd0552f22e9b1172298cd49a6 Mon Sep 17 00:00:00 2001 From: Kevin Kazuki Huguenin-Dumittan Date: Tue, 11 Jun 2024 13:39:27 +0200 Subject: [PATCH 06/35] Add tests for output accuracy of new calculators --- tests/calculators/test_values_aperiodic.py | 154 +++++++++++ tests/calculators/test_values_periodic.py | 299 +++++++++++++++++++++ 2 files changed, 453 insertions(+) create mode 100644 tests/calculators/test_values_aperiodic.py create mode 100644 tests/calculators/test_values_periodic.py diff --git a/tests/calculators/test_values_aperiodic.py b/tests/calculators/test_values_aperiodic.py new file mode 100644 index 00000000..d119888e --- /dev/null +++ b/tests/calculators/test_values_aperiodic.py @@ -0,0 +1,154 @@ +import torch +import math +import pytest +from meshlode import DirectPotential + +def define_molecule(molecule_name = 'dimer'): + """ + Define simple "molecules" (collection of point charges) for which the exact Coulomb + potential is easy to evaluate. The implementations in the main code are then tested + against these structures. + """ + # Use a higher precision than the default float32 + dtype = torch.float64 + SQRT2 = torch.sqrt(torch.tensor(2, dtype=dtype)) + SQRT3 = torch.sqrt(torch.tensor(3, dtype=dtype)) + + # Start defining molecules + # Dimer + if molecule_name == 'dimer': + types = torch.tensor([1,1]) + positions = torch.tensor([[0.,0,0],[0,0,1.]], dtype=dtype) + charges = torch.tensor([1.,-1.], dtype=dtype) + potentials = torch.tensor([-1.,1], dtype=dtype) + + elif molecule_name == 'dimer_positive': + types, positions, charges, potentials = define_molecule('dimer') + charges = torch.tensor([1.,1], dtype=dtype) + potentials = torch.tensor([1.,1], dtype=dtype) + + elif molecule_name == 'dimer_negative': + types, positions, charges, potentials = define_molecule('dimer_positive') + charges *= -1. + potentials *= -1. + + # Equilateral triangle + elif molecule_name == 'triangle': + types = torch.tensor([1,1,1]) + positions = torch.tensor([[0.,0,0],[1,0,0],[1/2,SQRT3/2,0]], dtype=dtype) + charges = torch.tensor([1.,-1.,0.], dtype=dtype) + potentials = torch.tensor([-1.,1,0], dtype=dtype) + + elif molecule_name == 'triangle_positive': + types, positions, charges, potentials = define_molecule('triangle') + charges = torch.tensor([1.,1,1], dtype=dtype) + potentials = torch.tensor([2.,2,2], dtype=dtype) + + elif molecule_name == 'triangle_negative': + types, positions, charges, potentials = define_molecule('triangle_positive') + charges *= -1. + potentials *= -1. + + # Squares (planar) + elif molecule_name == 'square': + types = torch.tensor([1,1,1,1]) + positions = torch.tensor([[1,1,0],[1,-1,0],[-1,1,0],[-1,-1,0]], dtype=dtype) + positions /= 2. + charges = torch.tensor([1.,-1,-1,1], dtype=dtype) + potentials = charges * (1./SQRT2 - 2.) + + elif molecule_name == 'square_positive': + types, positions, charges, potentials = define_molecule('square') + charges = torch.tensor([1.,1,1,1], dtype=dtype) + potentials = (2. + 1./SQRT2) * torch.ones(4, dtype=dtype) + + elif molecule_name == 'square_negative': + types, positions, charges, potentials = define_molecule('square_positive') + charges *= -1. + potentials *= -1. + + # Tetrahedra + elif molecule_name == 'tetrahedron': + types = torch.tensor([1,1,1,1]) + positions = torch.tensor([[0.,0,0],[1,0,0],[1/2,SQRT3/2,0],[1/2,SQRT3/6,SQRT2/SQRT3]], dtype=dtype) + charges = torch.tensor([1.,-1,1,-1], dtype=dtype) + potentials = -charges + + elif molecule_name == 'tetrahedron_positive': + types, positions, charges, potentials = define_molecule('tetrahedron') + charges = torch.ones(4, dtype=dtype) + potentials = 3 * torch.ones(4, dtype=dtype) + + elif molecule_name == 'tetrahedron_negative': + types, positions, charges, potentials = define_molecule('tetrahedron_positive') + charges *= -1. + potentials *= -1. + + return types, positions, charges, potentials + + +def generate_orthogonal_transformations(): + dtype = torch.float64 + + # first rotation matrix: identity + rot_1 = torch.eye(3, dtype=dtype) + + # second rotation matrix: rotation by angle phi around z-axis + phi = 0.82321 + rot_2 = torch.zeros((3,3), dtype=dtype) + rot_2[0,0] = rot_2[1,1] = math.cos(phi) + rot_2[0,1] = -math.sin(phi) + rot_2[1,0] = math.sin(phi) + rot_2[2,2] = 1. + + # third rotation matrix: second matrix followed by rotation by angle theta around y + theta = 1.23456 + rot_3 = torch.zeros((3,3), dtype=dtype) + rot_3[0,0] = rot_3[2,2] = math.cos(theta) + rot_3[0,2] = math.sin(theta) + rot_3[2,0] = -math.sin(theta) + rot_3[1,1] = 1. + rot_3 = rot_3 @ rot_2 + + # add additional orthogonal transformations by combining inversion + transformations = [rot_1, rot_2, rot_3, -rot_1, -rot_3] + + for q in transformations: + id = torch.eye(3, dtype=dtype) + id_2 = q.T @ q + torch.testing.assert_close(id, id_2, atol=2e-15, rtol=1e-14) + return transformations + + + +molecules = ['dimer', 'triangle', 'square', 'tetrahedron'] +molecule_charges = ['', '_positive', '_negative'] +scaling_factors = torch.tensor([0.079, 1., 5.54], dtype=torch.float64) +orthogonal_transformations = generate_orthogonal_transformations() +@pytest.mark.parametrize("molecule", molecules) +@pytest.mark.parametrize("molecule_charge", molecule_charges) +@pytest.mark.parametrize("scaling_factor", scaling_factors) +@pytest.mark.parametrize("orthogonal_transformation", orthogonal_transformations) +def test_coulomb_exact(molecule, + molecule_charge, + scaling_factor, + orthogonal_transformation): + """ + Check that the Coulomb potentials obtained from the calculators match the correct + value for simple toy systems. + To make the test stricter, the molecules are also rotated and scaled by varying + amounts, the former of which leaving the potentials invariant, while the second + operation scales the potentials by the inverse amount. + """ + # Call Ewald potential class without specifying any of the convergence parameters + # so that they are chosen by default (in a structure-dependent way) + DP = DirectPotential() + + # Compute potential at the position of the atoms for the specified structure + molecule_name = molecule + molecule_charge + types, positions, charges, ref_potentials = define_molecule(molecule_name) + positions = scaling_factor * (positions @ orthogonal_transformation) + potentials = DP.compute(types, positions, charges=charges) + ref_potentials /= scaling_factor + + torch.testing.assert_close(potentials, ref_potentials, atol=2e-15, rtol=1e-14) \ No newline at end of file diff --git a/tests/calculators/test_values_periodic.py b/tests/calculators/test_values_periodic.py new file mode 100644 index 00000000..1c24f625 --- /dev/null +++ b/tests/calculators/test_values_periodic.py @@ -0,0 +1,299 @@ +import numpy as np +import pytest +import torch + +from meshlode import EwaldPotential + + +def define_crystal(crystal_name="CsCl"): + # Define all relevant parameters (atom positions, charges, cell) of the reference + # crystal structures for which the Madelung constants obtained from the Ewald sums + # are compared with reference values. + # see https://www.sciencedirect.com/science/article/pii/B9780128143698000078#s0015 + # More detailed values can be found in https://pubs.acs.org/doi/10.1021/ic2023852 + dtype = torch.float64 + + # Caesium-Chloride (CsCl) structure: + # - Cubic unit cell + # - 1 atom pair in the unit cell + # - Cation-Anion ratio of 1:1 + if crystal_name == "CsCl": + types = torch.tensor([17, 55]) # Cl and Cs + positions = torch.tensor([[0, 0, 0], [0.5, 0.5, 0.5]], dtype=dtype) + charges = torch.tensor([-1.0, 1.0], dtype=dtype) + cell = torch.eye(3, dtype=dtype) + madelung_reference = 2.035361 + + # Sodium-Chloride (NaCl) structure using a primitive unit cell + # - non-cubic unit cell (fcc) + # - 1 atom pair in the unit cell + # - Cation-Anion ratio of 1:1 + elif crystal_name == "NaCl_primitive": + types = torch.tensor([11, 17]) + positions = torch.tensor([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0]], dtype=dtype) + charges = torch.tensor([1.0, -1.0], dtype=dtype) + cell = torch.tensor([[0, 1.0, 1], [1, 0, 1], [1, 1, 0]], dtype=dtype) # fcc + madelung_reference = 1.74756 + + # Sodium-Chloride (NaCl) structure using a cubic unit cell + # - cubic unit cell + # - 4 atom pairs in the unit cell + # - Cation-Anion ratio of 1:1 + elif crystal_name == "NaCl_cubic": + types = torch.tensor([11, 17, 17, 17, 11, 11, 11, 17]) + positions = torch.tensor( + [ + [0.0, 0, 0], + [1, 0, 0], + [0, 1, 0], + [0, 0, 1], + [1, 1, 0], + [1, 0, 1], + [0, 1, 1], + [1, 1, 1], + ], + dtype=dtype, + ) + charges = torch.tensor([+1.0, -1, -1, -1, +1, +1, +1, -1], dtype=dtype) + cell = 2 * torch.eye(3, dtype=dtype) + madelung_reference = 1.747565 + + # ZnS (zincblende) structure + # - non-cubic unit cell (fcc) + # - 1 atom pair in the unit cell + # - Cation-Anion ratio of 1:1 + # Remarks: we use a primitive unit cell which makes the lattice parameter of the + # cubic cell equal to 2. + elif crystal_name == "zincblende": + types = torch.tensor([16, 30]) + positions = torch.tensor([[0, 0, 0], [0.5, 0.5, 0.5]], dtype=dtype) + charges = torch.tensor([1.0, -1], dtype=dtype) + cell = torch.tensor([[0, 1, 1], [1, 0, 1], [1, 1, 0]], dtype=dtype) + madelung_reference = 2 * 1.63806 / np.sqrt(3) + + # Wurtzite structure + # - non-cubic unit cell (triclinic) + # - 2 atom pairs in the unit cell + # - Cation-Anion ratio of 1:1 + elif crystal_name == "wurtzite": + u = 3 / 8 + c = np.sqrt(1 / u) + types = torch.tensor([16, 30, 16, 30]) + positions = torch.tensor( + [ + [0.5, 0.5 / np.sqrt(3), 0.0], + [0.5, 0.5 / np.sqrt(3), u * c], + [0.5, -0.5 / np.sqrt(3), 0.5 * c], + [0.5, -0.5 / np.sqrt(3), (0.5 + u) * c], + ], + dtype=dtype, + ) + charges = torch.tensor([1.0, -1, 1, -1], dtype=dtype) + cell = torch.tensor( + [[0.5, -0.5 * np.sqrt(3), 0], [0.5, 0.5 * np.sqrt(3), 0], [0, 0, c]], + dtype=dtype, + ) + madelung_reference = 1.64132 / (u * c) + + # Fluorite structure + # - non-cubic (fcc) unit cell + # - 1 neutral molecule per unit cell + # - Cation-Anion ratio of 2:1 + elif crystal_name == "fluorite": + a = 5.463 + a = 1.0 + types = torch.tensor([9, 9, 20]) + positions = a * torch.tensor( + [[1 / 4, 1 / 4, 1 / 4], [3 / 4, 3 / 4, 3 / 4], [0, 0, 0]], dtype=dtype + ) + charges = torch.tensor([-1, -1, 2], dtype=dtype) + cell = torch.tensor([[a, a, 0], [a, 0, a], [0, a, a]], dtype=dtype) / 2.0 + madelung_reference = 11.636575 + + # Copper-Oxide Cu2O structure + elif crystal_name == "cu2o": + a = 0.4627 + a = 1.0 + types = torch.tensor([8, 29, 29]) + positions = a * torch.tensor( + [[1 / 4, 1 / 4, 1 / 4], [0, 0, 0], [1 / 2, 1 / 2, 1 / 2]], dtype=dtype + ) + charges = torch.tensor([-2, 1, 1], dtype=dtype) + cell = torch.tensor([[a, 0, 0], [0, a, 0], [0, 0, a]], dtype=dtype) + madelung_reference = 10.2594570330750 + + # Wigner crystal in simple cubic structure. + # Wigner crystals are equivalent to the Jellium or uniform electron gas models. + # For the purpose of this test, we define them to be structures in which the ion + # cores form a perfect lattice, while the electrons are uniformly distributed over + # the cell. In some sources, the role of the positive and negative charges are + # flipped. These structures are used to test the code for cases in which the total + # charge of the particles is not zero. + # Wigner crystal energies are taken from "Zero-Point Energy of an Electron Lattice" + # by Rosemary A., Coldwell‐Horsfall and Alexei A. Maradudin (1960), eq. (A21). + elif crystal_name == "wigner_sc": + types = torch.tensor([1]) + positions = torch.tensor([[0, 0, 0]], dtype=dtype) + charges = torch.tensor([1.0], dtype=dtype) + cell = torch.tensor([[1.0, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=dtype) + + # Reference value is expressed in terms of the Wigner-Seiz radius, and needs to + # be rescaled to the case in which the lattice parameter = 1. + madelung_wigner_seiz = 1.7601188 + wigner_seiz_radius = (3 / (4 * np.pi)) ** (1 / 3) + madelung_reference = madelung_wigner_seiz / wigner_seiz_radius # 2.83730 + + # Wigner crystal in bcc structure (note: this is the most stable structure). + # See description of "wigner_sc" for a general explanation on Wigner crystals. + # Used to test the code for cases in which the unit cell has a nonzero net charge. + elif crystal_name == "wigner_bcc": + types = torch.tensor([1]) + positions = torch.tensor([[0, 0, 0]], dtype=dtype) + charges = torch.tensor([1.0], dtype=dtype) + cell = torch.tensor( + [[1.0, 0, 0], [0, 1, 0], [1 / 2, 1 / 2, 1 / 2]], dtype=dtype + ) + + # Reference value is expressed in terms of the Wigner-Seiz radius, and needs to + # be rescaled to the case in which the lattice parameter = 1. + madelung_wigner_seiz = 1.791860 + wigner_seiz_radius = (3 / (4 * np.pi * 2)) ** ( + 1 / 3 + ) # 2 atoms per cubic unit cell + madelung_reference = madelung_wigner_seiz / wigner_seiz_radius # 3.63924 + + # Same as above, but now using a cubic unit cell rather than the primitive bcc cell + elif crystal_name == "wigner_bcc_cubiccell": + types = torch.tensor([1, 1]) + positions = torch.tensor([[0, 0, 0], [1 / 2, 1 / 2, 1 / 2]], dtype=dtype) + charges = torch.tensor([1.0, 1.0], dtype=dtype) + cell = torch.tensor([[1.0, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=dtype) + + # Reference value is expressed in terms of the Wigner-Seiz radius, and needs to + # be rescaled to the case in which the lattice parameter = 1. + madelung_wigner_seiz = 1.791860 + wigner_seiz_radius = (3 / (4 * np.pi * 2)) ** ( + 1 / 3 + ) # 2 atoms per cubic unit cell + madelung_reference = madelung_wigner_seiz / wigner_seiz_radius # 3.63924 + + # Wigner crystal in fcc structure + # See description of "wigner_sc" for a general explanation on Wigner crystals. + # Used to test the code for cases in which the unit cell has a nonzero net charge. + elif crystal_name == "wigner_fcc": + types = torch.tensor([1]) + positions = torch.tensor([[0.0, 0, 0]], dtype=dtype) + charges = torch.tensor([1.0], dtype=dtype) + cell = torch.tensor([[1, 0, 1], [0, 1, 1], [1, 1, 0]], dtype=dtype) / 2 + + # Reference value is expressed in terms of the Wigner-Seiz radius, and needs to + # be rescaled to the case in which the lattice parameter = 1. + madelung_wigner_seiz = 1.791753 + wigner_seiz_radius = (3 / (4 * np.pi * 4)) ** ( + 1 / 3 + ) # 4 atoms per cubic unit cell + madelung_reference = madelung_wigner_seiz / wigner_seiz_radius # 4.58488 + + # Same as above, but now using a cubic unit cell rather than the primitive fcc cell + elif crystal_name == "wigner_fcc_cubiccell": + types = torch.tensor([1, 1, 1, 1]) + positions = 0.5 * torch.tensor( + [[0.0, 0, 0], [1, 0, 1], [1, 1, 0], [0, 1, 1]], dtype=dtype + ) + charges = torch.tensor([1.0, 1, 1, 1], dtype=dtype) + cell = torch.eye(3, dtype=dtype) + + # Reference value is expressed in terms of the Wigner-Seiz radius, and needs to + # be rescaled to the case in which the lattice parameter = 1. + madelung_wigner_seiz = 1.791753 + wigner_seiz_radius = (3 / (4 * np.pi * 4)) ** ( + 1 / 3 + ) # 4 atoms per cubic unit cell + madelung_reference = madelung_wigner_seiz / wigner_seiz_radius # 4.58488 + + else: + raise ValueError(f"crystal_name = {crystal_name} is not supported!") + + return types, positions, charges, cell, madelung_reference + + +neutral_crystals = ["CsCl", "NaCl_primitive", "NaCl_cubic", "zincblende", "wurtzite"] +# neutral_crystals = ['CsCl'] +scaling_factors = torch.tensor([1 / 2.0353610, 1.0, 3.4951291], dtype=torch.float64) +@pytest.mark.parametrize("crystal_name", neutral_crystals) +@pytest.mark.parametrize("scaling_factor", scaling_factors) +def test_madelung(crystal_name, scaling_factor): + """ + Check that the Madelung constants obtained from the Ewald sum calculator matches + the reference values. + In this test, only the charge-neutral crystal systems are chosen for which the + potential converges relatively quickly, while the systems with a net charge are + treated separately below. + """ + # Call Ewald potential class without specifying any of the convergence parameters + # so that they are chosen by default (in a structure-dependent way) + EP = EwaldPotential() + + # Compute potential at the position of the atoms for the specified structure + types, positions, charges, cell, madelung_reference = define_crystal(crystal_name) + positions *= scaling_factor + cell *= scaling_factor + potentials = EP.compute(types, positions, cell, charges) + energies = potentials * charges + energies_ref = -torch.ones_like(energies) * madelung_reference / scaling_factor + + torch.testing.assert_close(energies, energies_ref, atol=0.0, rtol=3.1e-6) + + +wigner_crystals = [ + "wigner_sc", + "wigner_fcc", + "wigner_fcc_cubiccell", + "wigner_bcc", + "wigner_bcc_cubiccell", +] +wigner_crystal = ['wigner_sc'] +scaling_factors = torch.tensor([0.4325, 1.0, 2.0353610], dtype=torch.float64) + + +@pytest.mark.parametrize("crystal_name", wigner_crystals) +@pytest.mark.parametrize("scaling_factor", scaling_factors) +def test_wigner(crystal_name, scaling_factor): + """ + Check that the energy of a Wigner solid obtained from the Ewald sum calculator + matches the reference values. + In this test, the Wigner solids are defined by placing arranging positively charged + point particles on a bcc lattice, leading to a net charge of the unit cell if we + only look at the ions. This charge is compensated by a homogeneous neutral back- + ground charge of opposite sign (physically: completely delocalized electrons). + + The presence of a net charge (due to the particles but without background) leads + to numerically slower convergence of the relevant sums. + """ + # Get parameters defining atomic positions, cell and charges + types, positions, charges, cell, madelung_reference = define_crystal(crystal_name) + positions *= scaling_factor + cell *= scaling_factor + madelung_reference /= scaling_factor + + # Due to the slow convergence, we do not use the default values of the smearing, + # but provide a range instead. The first value of 0.1 corresponds to what would be + # chosen by default for the "wigner_sc" or "wigner_bcc_cubiccell" structure. + smearings = torch.tensor([0.1, 0.06, 0.019], dtype=torch.float64) + tolerances = torch.tensor([3e-2, 1e-2, 1e-3]) + for smearing, rtol in zip(smearings, tolerances): + # Readjust smearing parameter to match nearest neighbor distance + if crystal_name in ["wigner_fcc", "wigner_fcc_cubiccell"]: + smeareff = smearing / np.sqrt(2) + elif crystal_name in ["wigner_bcc_cubiccell", "wigner_bcc"]: + smeareff = smearing * np.sqrt(3) / 2 + elif crystal_name == "wigner_sc": + smeareff = smearing + smeareff *= scaling_factor + + # Compute potential and compare against reference + EP = EwaldPotential(atomic_smearing=smeareff) + potentials = EP.compute(types, positions, cell, charges) + energies = potentials * charges + energies_ref = -torch.ones_like(energies) * madelung_reference + torch.testing.assert_close(energies, energies_ref, atol=0.0, rtol=rtol) From f492e1cf82664c163eab3c2083d93bf8b89164a2 Mon Sep 17 00:00:00 2001 From: Kevin Kazuki Huguenin-Dumittan Date: Tue, 11 Jun 2024 13:41:06 +0200 Subject: [PATCH 07/35] Add tests to ensure smooth workflow for new calculators --- tests/__init__.py | 5 + tests/calculators/test_workflow_direct.py | 230 +++++++++++++ tests/calculators/test_workflow_ewald.py | 298 +++++++++++++++++ ...meshpotential.py => test_workflow_mesh.py} | 11 +- tests/calculators/test_workflow_meshewald.py | 309 ++++++++++++++++++ 5 files changed, 848 insertions(+), 5 deletions(-) create mode 100644 tests/calculators/test_workflow_direct.py create mode 100644 tests/calculators/test_workflow_ewald.py rename tests/calculators/{test_meshpotential.py => test_workflow_mesh.py} (97%) create mode 100644 tests/calculators/test_workflow_meshewald.py diff --git a/tests/__init__.py b/tests/__init__.py index e69de29b..1c2cd789 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -0,0 +1,5 @@ +import meshlode + + +def test_version_exist(): + meshlode.__version__ diff --git a/tests/calculators/test_workflow_direct.py b/tests/calculators/test_workflow_direct.py new file mode 100644 index 00000000..b17ace1c --- /dev/null +++ b/tests/calculators/test_workflow_direct.py @@ -0,0 +1,230 @@ +"""Basic tests if the calculator works and is torch scriptable. Actual tests are done +for the metatensor calculator.""" + +import math + +import pytest +import torch +from torch.testing import assert_close + +from meshlode import DirectPotential +from meshlode.calculators.calculator_base import _1d_tolist, _is_subset + + +# MADELUNG_CSCL = torch.tensor(2 * 1.7626 / math.sqrt(3)) # periodic case +MADELUNG_CSCL = torch.tensor(2 * math.sqrt(3)) +CHARGES_CSCL = torch.tensor([1.0, -1.0]) + + +def cscl_system(): + """CsCl crystal. Same as in the madelung test""" + types = torch.tensor([55, 17]) + positions = torch.tensor([[0, 0, 0], [0.5, 0.5, 0.5]]) + cell = torch.eye(3) + + return types, positions + + +def cscl_system_with_charges(): + """CsCl crystal with charges.""" + charges = torch.tensor([[0.0, 1.0], [1.0, 0]]) + return cscl_system() + (charges,) + + +# Initialize the calculators. For now, only the DirectPotential is implemented. +def descriptor() -> DirectPotential: + return DirectPotential( + ) + + +def test_forward(): + mp = descriptor() + descriptor_compute = mp.compute(*cscl_system()) + descriptor_forward = mp.forward(*cscl_system()) + + assert torch.equal(descriptor_forward, descriptor_compute) + + +def test_all_types(): + descriptor = DirectPotential(all_types=[8, 55, 17]) + values = descriptor.compute(*cscl_system()) + + assert values.shape == (2, 3) + assert torch.equal(values[:, 0], torch.zeros(2)) + + +def test_all_types_error(): + descriptor = DirectPotential(all_types=[17]) + with pytest.raises(ValueError, match="Global list of types"): + descriptor.compute(*cscl_system()) + + +# Make sure that the calculators are computing the features without raising errors, +# and returns the correct output format (TensorMap) +def check_operation(calculator): + descriptor = calculator.compute(*cscl_system()) + assert type(descriptor) is torch.Tensor + + +# Run the above test as a normal python script +def test_operation_as_python(): + check_operation(descriptor()) + + +# Similar to the above, but also testing that the code can be compiled as a torch script +def test_operation_as_torch_script(): + scripted = torch.jit.script(descriptor()) + check_operation(scripted) + + + +def test_single_frame(): + values = descriptor().compute(*cscl_system()) + assert_close( + MADELUNG_CSCL, + CHARGES_CSCL[0] * values[0, 0] + CHARGES_CSCL[1] * values[0, 1], + atol=1e4, + rtol=1e-5, + ) + + +# Test with explicit charges +def test_single_frame_with_charges(): + values = descriptor().compute(*cscl_system_with_charges()) + assert_close( + MADELUNG_CSCL, + CHARGES_CSCL[0] * values[0, 0] + CHARGES_CSCL[1] * values[0, 1], + atol=1e4, + rtol=1e-5, + ) + + +def test_multi_frame(): + types, positions = cscl_system() + l_values = descriptor().compute( + types=[types, types], positions=[positions, positions]) + for values in l_values: + assert_close( + MADELUNG_CSCL, + CHARGES_CSCL[0] * values[0, 0] + CHARGES_CSCL[1] * values[0, 1], + atol=1e4, + rtol=1e-5, + ) + + +def test_types_error(): + types = torch.tensor([[1, 2], [3, 4]]) # This is a 2D tensor, should be 1D + positions = torch.zeros((2, 3)) + + match = ( + "each `types` must be a 1 dimensional tensor, got at least one tensor with " + "2 dimensions" + ) + with pytest.raises(ValueError, match=match): + descriptor().compute(types=types, positions=positions) + +def test_positions_error(): + types = torch.tensor([1, 2]) + positions = torch.zeros( + (1, 3) + ) # This should have the same first dimension as types + + match = ( + "each `positions` must be a \\(n_types x 3\\) tensor, got at least " + "one tensor with shape \\[1, 3\\]" + ) + + with pytest.raises(ValueError, match=match): + descriptor().compute(types=types, positions=positions) + + + +def test_charges_error_dimension_mismatch(): + types = torch.tensor([1, 2]) + positions = torch.zeros((2, 3)) + cell = torch.eye(3) + charges = torch.zeros((1, 2)) # This should have the same first dimension as types + + match = ( + "The first dimension of `charges` must be the same as the length " + "of `types`, got 1 and 2." + ) + + with pytest.raises(ValueError, match=match): + descriptor().compute( + types=types, positions=positions, charges=charges + ) + + +def test_charges_error_length_mismatch(): + types = [torch.tensor([1, 2]), torch.tensor([1, 2, 3])] + positions = [torch.zeros((2, 3)), torch.zeros((3, 3))] + cell = torch.eye(3) + charges = [torch.zeros(2, 1)] # This should have the same length as types + match = "The number of `types` and `charges` tensors must be the same, got 2 and 1." + + with pytest.raises(ValueError, match=match): + descriptor().compute( + types=types, positions=positions, charges=charges + ) + + +def test_dtype_device(): + """Test that the output dtype and device are the same as the input.""" + device = "cpu" + dtype = torch.float64 + + types = torch.tensor([1], dtype=dtype, device=device) + positions = torch.tensor([[0.0, 0.0, 0.0]], dtype=dtype, device=device) + + DP = DirectPotential() + potential = DP.compute(types=types, positions=positions) + + assert potential.dtype == dtype + assert potential.device.type == device + +def test_inconsistent_device_charges(): + """Test if the chages and positions have inconsistent device and error is raised.""" + types = torch.tensor([1], device="cpu") + positions = torch.tensor([[0.0, 0.0, 0.0]], device="cpu") + charges = torch.tensor([0.0], device="meta") # different device + + DP = DirectPotential() + + match = "`charges` must be on the same device as `positions`, got meta and cpu." + with pytest.raises(ValueError, match=match): + DP.compute(types=types, positions=positions, charges=charges) + + +def test_inconsistent_dtype_charges(): + """Test if the charges and positions have inconsistent dtype and error is raised.""" + types = torch.tensor([1], dtype=torch.float32) + positions = torch.tensor([[0.0, 0.0, 0.0]], dtype=torch.float32) + charges = torch.tensor([0.0], dtype=torch.float64) # Different dtype + + DP = DirectPotential() + + match = ( + "`charges` must be have the same dtype as `positions`, got torch.float64 and " + "torch.float32" + ) + with pytest.raises(ValueError, match=match): + DP.compute(types=types, positions=positions, charges=charges) + + +def test_1d_tolist(): + in_list = [1, 2, 7, 3, 4, 42] + in_tensor = torch.tensor(in_list) + assert _1d_tolist(in_tensor) == in_list + + +def test_is_subset_true(): + subset_candidate = [1, 2] + superset = [1, 2, 3, 4, 5] + assert _is_subset(subset_candidate, superset) + + +def test_is_subset_false(): + subset_candidate = [1, 2, 8] + superset = [1, 2, 3, 4, 5] + assert not _is_subset(subset_candidate, superset) diff --git a/tests/calculators/test_workflow_ewald.py b/tests/calculators/test_workflow_ewald.py new file mode 100644 index 00000000..79891d1d --- /dev/null +++ b/tests/calculators/test_workflow_ewald.py @@ -0,0 +1,298 @@ +"""Basic tests if the calculator works and is torch scriptable. Actual tests are done +for the metatensor calculator.""" + +import math + +import pytest +import torch +from torch.testing import assert_close + +from meshlode import EwaldPotential +from meshlode.calculators.calculator_base import _1d_tolist, _is_subset + + +MADELUNG_CSCL = torch.tensor(2 * 1.7626 / math.sqrt(3)) +CHARGES_CSCL = torch.tensor([1.0, -1.0]) + + +def cscl_system(): + """CsCl crystal. Same as in the madelung test""" + types = torch.tensor([55, 17]) + positions = torch.tensor([[0, 0, 0], [0.5, 0.5, 0.5]]) + cell = torch.eye(3) + + return types, positions, cell + + +def cscl_system_with_charges(): + """CsCl crystal with charges.""" + charges = torch.tensor([[0.0, 1.0], [1.0, 0]]) + return cscl_system() + (charges,) + + +# Initialize the calculators. For now, only the EwaldPotential is implemented. +def descriptor() -> EwaldPotential: + atomic_smearing = 0.1 + return EwaldPotential( + atomic_smearing=atomic_smearing, + lr_wavelength=atomic_smearing / 4, + subtract_self=True, + ) + + +def test_forward(): + mp = descriptor() + descriptor_compute = mp.compute(*cscl_system()) + descriptor_forward = mp.forward(*cscl_system()) + + assert torch.equal(descriptor_forward, descriptor_compute) + + +def test_all_types(): + descriptor = EwaldPotential(atomic_smearing=0.1, all_types=[8, 55, 17]) + values = descriptor.compute(*cscl_system()) + + assert values.shape == (2, 3) + assert torch.equal(values[:, 0], torch.zeros(2)) + + +def test_all_types_error(): + descriptor = EwaldPotential(atomic_smearing=0.1, all_types=[17]) + with pytest.raises(ValueError, match="Global list of types"): + descriptor.compute(*cscl_system()) + + +# Make sure that the calculators are computing the features without raising errors, +# and returns the correct output format (TensorMap) +def check_operation(calculator): + descriptor = calculator.compute(*cscl_system()) + assert type(descriptor) is torch.Tensor + + +# Run the above test as a normal python script +def test_operation_as_python(): + check_operation(descriptor()) + + +""" +# Similar to the above, but also testing that the code can be compiled as a torch script +def test_operation_as_torch_script(): + scripted = torch.jit.script(descriptor()) + check_operation(scripted) +""" + + +def test_single_frame(): + values = descriptor().compute(*cscl_system()) + assert_close( + MADELUNG_CSCL, + CHARGES_CSCL[0] * values[0, 0] + CHARGES_CSCL[1] * values[0, 1], + atol=1e4, + rtol=1e-5, + ) + + +# Test with explicit charges +def test_single_frame_with_charges(): + values = descriptor().compute(*cscl_system_with_charges()) + assert_close( + MADELUNG_CSCL, + CHARGES_CSCL[0] * values[0, 0] + CHARGES_CSCL[1] * values[0, 1], + atol=1e4, + rtol=1e-5, + ) + + +def test_multi_frame(): + types, positions, cell = cscl_system() + l_values = descriptor().compute( + types=[types, types], positions=[positions, positions], cell=[cell, cell] + ) + for values in l_values: + assert_close( + MADELUNG_CSCL, + CHARGES_CSCL[0] * values[0, 0] + CHARGES_CSCL[1] * values[0, 1], + atol=1e4, + rtol=1e-5, + ) + + +def test_types_error(): + types = torch.tensor([[1, 2], [3, 4]]) # This is a 2D tensor, should be 1D + positions = torch.zeros((2, 3)) + cell = torch.eye(3) + + match = ( + "each `types` must be a 1 dimensional tensor, got at least one tensor with " + "2 dimensions" + ) + with pytest.raises(ValueError, match=match): + descriptor().compute(types=types, positions=positions, cell=cell) + + +def test_positions_error(): + types = torch.tensor([1, 2]) + positions = torch.zeros( + (1, 3) + ) # This should have the same first dimension as types + cell = torch.eye(3) + + match = ( + "each `positions` must be a \\(n_types x 3\\) tensor, got at least " + "one tensor with shape \\[1, 3\\]" + ) + + with pytest.raises(ValueError, match=match): + descriptor().compute(types=types, positions=positions, cell=cell) + + +def test_charges_error_dimension_mismatch(): + types = torch.tensor([1, 2]) + positions = torch.zeros((2, 3)) + cell = torch.eye(3) + charges = torch.zeros((1, 2)) # This should have the same first dimension as types + + match = ( + "The first dimension of `charges` must be the same as the length " + "of `types`, got 1 and 2." + ) + + with pytest.raises(ValueError, match=match): + descriptor().compute( + types=types, positions=positions, cell=cell, charges=charges + ) + + +def test_charges_error_length_mismatch(): + types = [torch.tensor([1, 2]), torch.tensor([1, 2, 3])] + positions = [torch.zeros((2, 3)), torch.zeros((3, 3))] + cell = torch.eye(3) + charges = [torch.zeros(2, 1)] # This should have the same length as types + match = "The number of `types` and `charges` tensors must be the same, got 2 and 1." + + with pytest.raises(ValueError, match=match): + descriptor().compute( + types=types, positions=positions, cell=cell, charges=charges + ) + + +def test_cell_error(): + types = torch.tensor([1, 2, 3]) + positions = torch.zeros((3, 3)) + cell = torch.eye(2) # This is a 2x2 tensor, should be 3x3 + + match = ( + "each `cell` must be a \\(3 x 3\\) tensor, got at least one tensor " + "with shape \\[2, 2\\]" + ) + with pytest.raises(ValueError, match=match): + descriptor().compute(types=types, positions=positions, cell=cell) + + +def test_positions_cell_dtype_error(): + types = torch.tensor([1, 2, 3]) + positions = torch.zeros((3, 3), dtype=torch.float32) + cell = torch.eye(3, dtype=torch.float64) + + match = ( + "`cell` must be have the same dtype as `positions`, got torch.float64 " + "and torch.float32" + ) + with pytest.raises(ValueError, match=match): + descriptor().compute(types=types, positions=positions, cell=cell) + + +def test_dtype_device(): + """Test that the output dtype and device are the same as the input.""" + device = "cpu" + dtype = torch.float64 + + types = torch.tensor([1], dtype=dtype, device=device) + positions = torch.tensor([[0.0, 0.0, 0.0]], dtype=dtype, device=device) + cell = torch.eye(3, dtype=dtype, device=device) + + EP = EwaldPotential(atomic_smearing=0.2) + potential = EP.compute(types=types, positions=positions, cell=cell) + + assert potential.dtype == dtype + assert potential.device.type == device + + +def test_inconsistent_dtype(): + """Test if the cell and positions have inconsistent dtype and error is raised.""" + types = torch.tensor([1], dtype=torch.float32) + positions = torch.tensor([[0.0, 0.0, 0.0]], dtype=torch.float64) # Different dtype + cell = torch.eye(3, dtype=torch.float32) + + EP = EwaldPotential(atomic_smearing=0.2) + + match = ( + "`cell` must be have the same dtype as `positions`, got torch.float32 and " + "torch.float64" + ) + with pytest.raises(ValueError, match=match): + EP.compute(types=types, positions=positions, cell=cell) + + +def test_inconsistent_device(): + """Test if the cell and positions have inconsistent device and error is raised.""" + types = torch.tensor([1], device="cpu") + positions = torch.tensor([[0.0, 0.0, 0.0]], device="cpu") + cell = torch.eye(3, device="meta") # different device + + EP = EwaldPotential(atomic_smearing=0.2) + + match = ( + '`types`, `positions`, and `cell` must be on the same device, got cpu, cpu and meta.' + ) + with pytest.raises(ValueError, match=match): + EP.compute(types=types, positions=positions, cell=cell) + + +def test_inconsistent_device_charges(): + """Test if the cell and positions have inconsistent device and error is raised.""" + types = torch.tensor([1], device="cpu") + positions = torch.tensor([[0.0, 0.0, 0.0]], device="cpu") + cell = torch.eye(3, device="cpu") + charges = torch.tensor([0.0], device="meta") # different device + + EP = EwaldPotential(atomic_smearing=0.2) + + match = "`charges` must be on the same device as `positions`, got meta and cpu." + with pytest.raises(ValueError, match=match): + EP.compute(types=types, positions=positions, cell=cell, charges=charges) + + +def test_inconsistent_dtype_charges(): + """Test if the cell and positions have inconsistent dtype and error is raised.""" + types = torch.tensor([1], dtype=torch.float32) + positions = torch.tensor([[0.0, 0.0, 0.0]], dtype=torch.float32) + cell = torch.eye(3, dtype=torch.float32) + charges = torch.tensor([0.0], dtype=torch.float64) # Different dtype + + EP = EwaldPotential(atomic_smearing=0.2) + + match = ( + "`charges` must be have the same dtype as `positions`, got torch.float64 and " + "torch.float32" + ) + with pytest.raises(ValueError, match=match): + EP.compute(types=types, positions=positions, cell=cell, charges=charges) + + +def test_1d_tolist(): + in_list = [1, 2, 7, 3, 4, 42] + in_tensor = torch.tensor(in_list) + assert _1d_tolist(in_tensor) == in_list + + +def test_is_subset_true(): + subset_candidate = [1, 2] + superset = [1, 2, 3, 4, 5] + assert _is_subset(subset_candidate, superset) + + +def test_is_subset_false(): + subset_candidate = [1, 2, 8] + superset = [1, 2, 3, 4, 5] + assert not _is_subset(subset_candidate, superset) diff --git a/tests/calculators/test_meshpotential.py b/tests/calculators/test_workflow_mesh.py similarity index 97% rename from tests/calculators/test_meshpotential.py rename to tests/calculators/test_workflow_mesh.py index 58e73808..dec15f24 100644 --- a/tests/calculators/test_meshpotential.py +++ b/tests/calculators/test_workflow_mesh.py @@ -8,7 +8,7 @@ from torch.testing import assert_close from meshlode import MeshPotential -from meshlode.calculators.meshpotential import _1d_tolist, _is_subset +from meshlode.calculators.calculator_base import _1d_tolist, _is_subset MADELUNG_CSCL = torch.tensor(2 * 1.7626 / math.sqrt(3)) @@ -75,7 +75,9 @@ def test_all_types_error(): # Make sure that the calculators are computing the features without raising errors, # and returns the correct output format (TensorMap) def check_operation(calculator): - descriptor = calculator.compute(*cscl_system()) + types, pos, cell = cscl_system() + print(cell) + descriptor = calculator.compute(types=types, positions=pos, cell=cell) assert type(descriptor) is torch.Tensor @@ -85,6 +87,7 @@ def test_operation_as_python(): # Similar to the above, but also testing that the code can be compiled as a torch script + def test_operation_as_torch_script(): scripted = torch.jit.script(descriptor()) check_operation(scripted) @@ -241,7 +244,6 @@ def test_inconsistent_dtype(): with pytest.raises(ValueError, match=match): MP.compute(types=types, positions=positions, cell=cell) - def test_inconsistent_device(): """Test if the cell and positions have inconsistent device and error is raised.""" types = torch.tensor([1], device="cpu") @@ -251,8 +253,7 @@ def test_inconsistent_device(): MP = MeshPotential(atomic_smearing=0.2) match = ( - "`types`, `positions`, and `cell` must be on the same device, got cpu, " - "cpu and meta." + '`types`, `positions`, and `cell` must be on the same device, got cpu, cpu and meta.' ) with pytest.raises(ValueError, match=match): MP.compute(types=types, positions=positions, cell=cell) diff --git a/tests/calculators/test_workflow_meshewald.py b/tests/calculators/test_workflow_meshewald.py new file mode 100644 index 00000000..a5d73ef2 --- /dev/null +++ b/tests/calculators/test_workflow_meshewald.py @@ -0,0 +1,309 @@ +"""Basic tests if the calculator works and is torch scriptable. Actual tests are done +for the metatensor calculator.""" + +import math + +import pytest +import torch +from torch.testing import assert_close + +from meshlode import MeshPotential, MeshEwaldPotential +from meshlode.calculators.calculator_base import _1d_tolist, _is_subset + + +MADELUNG_CSCL = torch.tensor(2 * 1.7626 / math.sqrt(3)) +CHARGES_CSCL = torch.tensor([1.0, -1.0]) + + +def cscl_system(): + """CsCl crystal. Same as in the madelung test""" + types = torch.tensor([55, 17]) + positions = torch.tensor([[0, 0, 0], [0.5, 0.5, 0.5]]) + cell = torch.eye(3) + + return types, positions, cell + + +def cscl_system_with_charges(): + """CsCl crystal with charges.""" + charges = torch.tensor([[0.0, 1.0], [1.0, 0]]) + return cscl_system() + (charges,) + + +# Initialize the calculators. For now, only the MeshPotential is implemented. +def descriptor() -> MeshEwaldPotential: + atomic_smearing = 0.1 + return MeshEwaldPotential( + atomic_smearing=atomic_smearing, + mesh_spacing=atomic_smearing / 4, + interpolation_order=2, + subtract_self=True, + ) + + +def test_forward(): + mp = descriptor() + descriptor_compute = mp.compute(*cscl_system()) + descriptor_forward = mp.forward(*cscl_system()) + + assert torch.equal(descriptor_forward, descriptor_compute) + + +def test_atomic_smearing_error(): + with pytest.raises(ValueError, match="has to be positive"): + MeshEwaldPotential(atomic_smearing=-1.0) + + +def test_interpolation_order_error(): + with pytest.raises(ValueError, match="Only `interpolation_order` from 1 to 5"): + MeshEwaldPotential(atomic_smearing=1, interpolation_order=10) + + +def test_all_types(): + descriptor = MeshPotential(atomic_smearing=0.1, all_types=[8, 55, 17]) + values = descriptor.compute(*cscl_system()) + assert values.shape == (2, 3) + assert torch.equal(values[:, 0], torch.zeros(2)) + + +def test_all_types_error(): + descriptor = MeshPotential(atomic_smearing=0.1, all_types=[17]) + with pytest.raises(ValueError, match="Global list of types"): + descriptor.compute(*cscl_system()) + + +# Make sure that the calculators are computing the features without raising errors, +# and returns the correct output format (TensorMap) +def check_operation(calculator): + types, pos, cell = cscl_system() + print(cell) + descriptor = calculator.compute(types=types, positions=pos, cell=cell) + assert type(descriptor) is torch.Tensor + + +# Run the above test as a normal python script +def test_operation_as_python(): + check_operation(descriptor()) + +""" +# Similar to the above, but also testing that the code can be compiled as a torch script +# Disabled for now since (1) the ASE neighbor list and (2) the use of the potential +# class are clashing with the torch script capabilities. +def test_operation_as_torch_script(): + scripted = torch.jit.script(descriptor()) + check_operation(scripted) +""" + +def test_single_frame(): + values = descriptor().compute(*cscl_system()) + assert_close( + MADELUNG_CSCL, + CHARGES_CSCL[0] * values[0, 0] + CHARGES_CSCL[1] * values[0, 1], + atol=1e4, + rtol=1e-5, + ) + + +# Test with explicit charges +def test_single_frame_with_charges(): + values = descriptor().compute(*cscl_system_with_charges()) + assert_close( + MADELUNG_CSCL, + CHARGES_CSCL[0] * values[0, 0] + CHARGES_CSCL[1] * values[0, 1], + atol=1e4, + rtol=1e-5, + ) + + +def test_multi_frame(): + types, positions, cell = cscl_system() + l_values = descriptor().compute( + types=[types, types], positions=[positions, positions], cell=[cell, cell] + ) + for values in l_values: + assert_close( + MADELUNG_CSCL, + CHARGES_CSCL[0] * values[0, 0] + CHARGES_CSCL[1] * values[0, 1], + atol=1e4, + rtol=1e-5, + ) + + +def test_types_error(): + types = torch.tensor([[1, 2], [3, 4]]) # This is a 2D tensor, should be 1D + positions = torch.zeros((2, 3)) + cell = torch.eye(3) + + match = ( + "each `types` must be a 1 dimensional tensor, got at least one tensor with " + "2 dimensions" + ) + with pytest.raises(ValueError, match=match): + descriptor().compute(types=types, positions=positions, cell=cell) + + +def test_positions_error(): + types = torch.tensor([1, 2]) + positions = torch.zeros( + (1, 3) + ) # This should have the same first dimension as types + cell = torch.eye(3) + + match = ( + "each `positions` must be a \\(n_types x 3\\) tensor, got at least " + "one tensor with shape \\[1, 3\\]" + ) + + with pytest.raises(ValueError, match=match): + descriptor().compute(types=types, positions=positions, cell=cell) + + +def test_charges_error_dimension_mismatch(): + types = torch.tensor([1, 2]) + positions = torch.zeros((2, 3)) + cell = torch.eye(3) + charges = torch.zeros((1, 2)) # This should have the same first dimension as types + + match = ( + "The first dimension of `charges` must be the same as the length " + "of `types`, got 1 and 2." + ) + + with pytest.raises(ValueError, match=match): + descriptor().compute( + types=types, positions=positions, cell=cell, charges=charges + ) + + +def test_charges_error_length_mismatch(): + types = [torch.tensor([1, 2]), torch.tensor([1, 2, 3])] + positions = [torch.zeros((2, 3)), torch.zeros((3, 3))] + cell = torch.eye(3) + charges = [torch.zeros(2, 1)] # This should have the same length as types + match = "The number of `types` and `charges` tensors must be the same, got 2 and 1." + + with pytest.raises(ValueError, match=match): + descriptor().compute( + types=types, positions=positions, cell=cell, charges=charges + ) + + +def test_cell_error(): + types = torch.tensor([1, 2, 3]) + positions = torch.zeros((3, 3)) + cell = torch.eye(2) # This is a 2x2 tensor, should be 3x3 + + match = ( + "each `cell` must be a \\(3 x 3\\) tensor, got at least one tensor " + "with shape \\[2, 2\\]" + ) + with pytest.raises(ValueError, match=match): + descriptor().compute(types=types, positions=positions, cell=cell) + + +def test_positions_cell_dtype_error(): + types = torch.tensor([1, 2, 3]) + positions = torch.zeros((3, 3), dtype=torch.float32) + cell = torch.eye(3, dtype=torch.float64) + + match = ( + "`cell` must be have the same dtype as `positions`, got torch.float64 " + "and torch.float32" + ) + with pytest.raises(ValueError, match=match): + descriptor().compute(types=types, positions=positions, cell=cell) + + +def test_dtype_device(): + """Test that the output dtype and device are the same as the input.""" + device = "cpu" + dtype = torch.float64 + + types = torch.tensor([1], dtype=dtype, device=device) + positions = torch.tensor([[0.0, 0.0, 0.0]], dtype=dtype, device=device) + cell = torch.eye(3, dtype=dtype, device=device) + + MP = MeshPotential(atomic_smearing=0.2) + potential = MP.compute(types=types, positions=positions, cell=cell) + + assert potential.dtype == dtype + assert potential.device.type == device + + +def test_inconsistent_dtype(): + """Test if the cell and positions have inconsistent dtype and error is raised.""" + types = torch.tensor([1], dtype=torch.float32) + positions = torch.tensor([[0.0, 0.0, 0.0]], dtype=torch.float64) # Different dtype + cell = torch.eye(3, dtype=torch.float32) + + MP = MeshPotential(atomic_smearing=0.2) + + match = ( + "`cell` must be have the same dtype as `positions`, got torch.float32 and " + "torch.float64" + ) + with pytest.raises(ValueError, match=match): + MP.compute(types=types, positions=positions, cell=cell) + +def test_inconsistent_device(): + """Test if the cell and positions have inconsistent device and error is raised.""" + types = torch.tensor([1], device="cpu") + positions = torch.tensor([[0.0, 0.0, 0.0]], device="cpu") + cell = torch.eye(3, device="meta") # different device + + MP = MeshPotential(atomic_smearing=0.2) + + match = ( + '`types`, `positions`, and `cell` must be on the same device, got cpu, cpu and meta.' + ) + with pytest.raises(ValueError, match=match): + MP.compute(types=types, positions=positions, cell=cell) + + +def test_inconsistent_device_charges(): + """Test if the cell and positions have inconsistent device and error is raised.""" + types = torch.tensor([1], device="cpu") + positions = torch.tensor([[0.0, 0.0, 0.0]], device="cpu") + cell = torch.eye(3, device="cpu") + charges = torch.tensor([0.0], device="meta") # different device + + MP = MeshPotential(atomic_smearing=0.2) + + match = "`charges` must be on the same device as `positions`, got meta and cpu." + with pytest.raises(ValueError, match=match): + MP.compute(types=types, positions=positions, cell=cell, charges=charges) + + +def test_inconsistent_dtype_charges(): + """Test if the cell and positions have inconsistent dtype and error is raised.""" + types = torch.tensor([1], dtype=torch.float32) + positions = torch.tensor([[0.0, 0.0, 0.0]], dtype=torch.float32) + cell = torch.eye(3, dtype=torch.float32) + charges = torch.tensor([0.0], dtype=torch.float64) # Different dtype + + MP = MeshPotential(atomic_smearing=0.2) + + match = ( + "`charges` must be have the same dtype as `positions`, got torch.float64 and " + "torch.float32" + ) + with pytest.raises(ValueError, match=match): + MP.compute(types=types, positions=positions, cell=cell, charges=charges) + + +def test_1d_tolist(): + in_list = [1, 2, 7, 3, 4, 42] + in_tensor = torch.tensor(in_list) + assert _1d_tolist(in_tensor) == in_list + + +def test_is_subset_true(): + subset_candidate = [1, 2] + superset = [1, 2, 3, 4, 5] + assert _is_subset(subset_candidate, superset) + + +def test_is_subset_false(): + subset_candidate = [1, 2, 8] + superset = [1, 2, 3, 4, 5] + assert not _is_subset(subset_candidate, superset) From e3af13a0a0f7caf8fb6d729fbef25b03c1cd1005 Mon Sep 17 00:00:00 2001 From: Kevin Kazuki Huguenin-Dumittan Date: Tue, 11 Jun 2024 13:45:34 +0200 Subject: [PATCH 08/35] Update init files to include new calculators --- pyproject.toml | 1 + src/meshlode/__init__.py | 7 +++++-- src/meshlode/calculators/__init__.py | 7 +++++-- src/meshlode/calculators/direct.py | 8 +++++++- src/meshlode/lib/__init__.py | 3 ++- src/meshlode/metatensor/__init__.py | 3 ++- 6 files changed, 22 insertions(+), 7 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index cf3c1daa..0072a646 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,6 +35,7 @@ keywords = [ ] dependencies = [ "torch >=1.11", + "ase", ] dynamic = ["version"] diff --git a/src/meshlode/__init__.py b/src/meshlode/__init__.py index ee18ff6d..884aca7e 100644 --- a/src/meshlode/__init__.py +++ b/src/meshlode/__init__.py @@ -1,4 +1,7 @@ -from .calculators.meshpotential import MeshPotential +from .calculators.mesh import MeshPotential +from .calculators.ewald import EwaldPotential +from .calculators.direct import DirectPotential +from .calculators.meshewald import MeshEwaldPotential try: from . import metatensor # noqa @@ -6,5 +9,5 @@ pass -__all__ = ["MeshPotential"] +__all__ = ["MeshPotential", "EwaldPotential", "DirectPotential", "MeshEwaldPotential"] __version__ = "0.0.0-dev" diff --git a/src/meshlode/calculators/__init__.py b/src/meshlode/calculators/__init__.py index 0b95337c..619c34ad 100644 --- a/src/meshlode/calculators/__init__.py +++ b/src/meshlode/calculators/__init__.py @@ -1,3 +1,6 @@ -from .meshpotential import MeshPotential +from .mesh import MeshPotential +from .ewald import EwaldPotential +from .direct import DirectPotential +from .meshewald import MeshEwaldPotential -__all__ = ["MeshPotential"] +__all__ = ["MeshPotential", "EwaldPotential", "DirectPotential", "MeshEwaldPotential"] diff --git a/src/meshlode/calculators/direct.py b/src/meshlode/calculators/direct.py index ea382f3a..925dfdf3 100644 --- a/src/meshlode/calculators/direct.py +++ b/src/meshlode/calculators/direct.py @@ -66,7 +66,13 @@ def _compute_single_system( distances_sq = squared_norms_matrix + squared_norms_matrix.T - 2 * gram_matrix # Add terms to diagonal in order to avoid division by zero - distances_sq[diagonal_indices, diagonal_indices] += 1e30 + # Since these components in the target tensor need to be set to zero, we add + # a huge number such that after taking the inverse (since we evaluate 1/r^p), + # the components will effectively be set to zero. + # This is not the most elegant solution, but I am doing this since the more + # obvious alternative of setting the same components to zero after the division + # had issues with autograd. I would appreciate any better alternatives. + distances_sq[diagonal_indices, diagonal_indices] += 1e50 # Compute potential potentials_by_pair = distances_sq.pow(-self.exponent / 2.) diff --git a/src/meshlode/lib/__init__.py b/src/meshlode/lib/__init__.py index e7c78e89..54fd2157 100644 --- a/src/meshlode/lib/__init__.py +++ b/src/meshlode/lib/__init__.py @@ -1,4 +1,5 @@ from .fourier_convolution import FourierSpaceConvolution from .mesh_interpolator import MeshInterpolator +from .potentials import InversePowerLawPotential -__all__ = ["FourierSpaceConvolution", "MeshInterpolator"] +__all__ = ["FourierSpaceConvolution", "MeshInterpolator", "InversePowerLawPotential"] diff --git a/src/meshlode/metatensor/__init__.py b/src/meshlode/metatensor/__init__.py index 0b95337c..28563a6b 100644 --- a/src/meshlode/metatensor/__init__.py +++ b/src/meshlode/metatensor/__init__.py @@ -1,3 +1,4 @@ from .meshpotential import MeshPotential +from .ewaldpotential import EwaldPotential -__all__ = ["MeshPotential"] +__all__ = ["MeshPotential", "EwaldPotential"] From 87f77dd41559cec89f3f58bf7dcaf9bd0b42b36a Mon Sep 17 00:00:00 2001 From: Kevin Kazuki Huguenin-Dumittan Date: Tue, 11 Jun 2024 16:12:04 +0200 Subject: [PATCH 09/35] Add metatensor output to MeshEwald calculator --- src/meshlode/calculators/meshewald.py | 7 +- src/meshlode/metatensor/__init__.py | 3 +- src/meshlode/metatensor/meshewald.py | 198 ++++++++++++++++++++++++++ 3 files changed, 201 insertions(+), 7 deletions(-) create mode 100644 src/meshlode/metatensor/meshewald.py diff --git a/src/meshlode/calculators/meshewald.py b/src/meshlode/calculators/meshewald.py index 8de8b985..4a129e13 100644 --- a/src/meshlode/calculators/meshewald.py +++ b/src/meshlode/calculators/meshewald.py @@ -59,14 +59,9 @@ def __init__( # Check that all provided values are correct if interpolation_order not in [1, 2, 3, 4, 5]: raise ValueError("Only `interpolation_order` from 1 to 5 are allowed") - if atomic_smearing <= 0: + if atomic_smearing is not None and atomic_smearing <= 0: raise ValueError(f"`atomic_smearing` {atomic_smearing} has to be positive") - # If no explicit mesh_spacing is given, set it such that it can resolve - # the smeared potentials. - if mesh_spacing is None: - mesh_spacing = atomic_smearing / 2 - # Store provided parameters self.atomic_smearing = atomic_smearing self.mesh_spacing = mesh_spacing diff --git a/src/meshlode/metatensor/__init__.py b/src/meshlode/metatensor/__init__.py index 28563a6b..ea6acb91 100644 --- a/src/meshlode/metatensor/__init__.py +++ b/src/meshlode/metatensor/__init__.py @@ -1,4 +1,5 @@ from .meshpotential import MeshPotential from .ewaldpotential import EwaldPotential +from .meshewald import MeshEwaldPotential -__all__ = ["MeshPotential", "EwaldPotential"] +__all__ = ["MeshPotential", "EwaldPotential", "MeshEwaldPotential"] diff --git a/src/meshlode/metatensor/meshewald.py b/src/meshlode/metatensor/meshewald.py new file mode 100644 index 00000000..16213d28 --- /dev/null +++ b/src/meshlode/metatensor/meshewald.py @@ -0,0 +1,198 @@ +from typing import Dict, List, Union + +import torch + + +try: + from metatensor.torch import Labels, TensorBlock, TensorMap + from metatensor.torch.atomistic import System +except ImportError: + raise ImportError( + "metatensor.torch is required for meshlode.metatensor but is not installed. " + "Try installing it with:\npip install metatensor[torch]" + ) + + +from .. import calculators + + +# We are breaking the Liskov substitution principle here by changing the signature of +# "compute" compated to the supertype of "MeshPotential". +# mypy: disable-error-code="override" + + +class MeshEwaldPotential(calculators.MeshEwaldPotential): + """An (atomic) type wise long range potential. + + Refer to :class:`meshlode.MeshPotential` for full documentation. + """ + + def forward( + self, + systems: Union[List[System], System], + ) -> TensorMap: + """forward just calls :py:meth:`CalculatorModule.compute`""" + return self.compute(systems=systems) + + def compute( + self, + systems: Union[List[System], System], + ) -> TensorMap: + """Compute potential for all provided ``systems``. + + All ``systems`` must have the same ``dtype`` and the same ``device``. If each + system contains a custom data field `charges` the potential will be calculated + for each "charges-channel". The number of `charges-channels` must be same in all + ``systems``. If no "explicit" charges are set the potential will be calculated + for each "types-channels". + + Refer to :meth:`meshlode.MeshPotential.compute()` for additional details on how + "charges-channel" and "types-channels" are computed. + + :param systems: single System or list of + :py:class:`metatensor.torch.atomisic.System` on which to run the + calculations. + + :return: TensorMap containing the potential of all types. The keys of the + TensorMap are "center_type" and "neighbor_type" if no charges are asociated + with + """ + # Make sure that the compute function also works if only a single frame is + # provided as input (for convenience of users testing out the code) + if not isinstance(systems, list): + systems = [systems] + + if len(systems) > 1: + for system in systems[1:]: + if system.dtype != systems[0].dtype: + raise ValueError( + "`dtype` of all systems must be the same, got " + f"{system.dtype} and {systems[0].dtype}`" + ) + + if system.device != systems[0].device: + raise ValueError( + "`device` of all systems must be the same, got " + f"{system.device} and {systems[0].device}`" + ) + + dtype = systems[0].positions.dtype + device = systems[0].positions.device + + requested_types = self._get_requested_types( + [system.types for system in systems] + ) + n_types = len(requested_types) + + has_charges = torch.tensor(["charges" in s.known_data() for s in systems]) + all_charges = torch.all(has_charges) + any_charges = torch.any(has_charges) + + if any_charges and not all_charges: + raise ValueError("`systems` do not consistently contain `charges` data") + if all_charges: + use_explicit_charges = True + n_charges_channels = systems[0].get_data("charges").values.shape[1] + spec_channels = list(range(n_charges_channels)) + key_names = ["center_type", "charges_channel"] + + for i_system, system in enumerate(systems): + n_channels = system.get_data("charges").values.shape[1] + if n_channels != n_charges_channels: + raise ValueError( + f"number of charges-channels in system index {i_system} " + f"({n_channels}) is inconsistent with first system " + f"({n_charges_channels})" + ) + else: + # Use one hot encoded type channel per species for charges channel + use_explicit_charges = False + n_charges_channels = n_types + spec_channels = requested_types + key_names = ["center_type", "neighbor_type"] + + # Initialize dictionary for TensorBlock storage. + # + # If `use_explicit_charges=False`, the blocks are sorted according to the + # (integer) center_type and neighbor_type. Blocks are assigned the array indices + # 0, 1, 2,... Example: for H2O: `H` is mapped to `0` and `O` is mapped to `1`. + # + # For `use_explicit_charges=True` the blocks are stored according to the + # center_type and charge_channel + n_blocks = n_types * n_charges_channels + feat_dic: Dict[int, List[torch.Tensor]] = {a: [] for a in range(n_blocks)} + + for system in systems: + if use_explicit_charges: + charges = system.get_data("charges").values + else: + # One-hot encoding of charge information + charges = self._one_hot_charges( + system.types, requested_types, dtype, device + ) + + # Compute the potentials + potential = self._compute_single_system( + system.positions, charges, system.cell + ) + + # Reorder data into metatensor format + for spec_center, at_num_center in enumerate(requested_types): + for spec_channel in range(len(spec_channels)): + a_pair = spec_center * n_charges_channels + spec_channel + feat_dic[a_pair] += [ + potential[system.types == at_num_center, spec_channel] + ] + + # Assemble all computed potential values into TensorBlocks for each combination + # of center_type and neighbor_type/charge_channel + blocks: List[TensorBlock] = [] + for keys, values in feat_dic.items(): + spec_center = requested_types[keys // n_charges_channels] + + # Generate the Labels objects for the samples and properties of the + # TensorBlock. + values_samples: List[List[int]] = [] + for i_frame, system in enumerate(systems): + for i_atom in range(len(system)): + if system.types[i_atom] == spec_center: + values_samples.append([i_frame, i_atom]) + + samples_vals_tensor = torch.tensor( + values_samples, dtype=torch.int32, device=device + ) + + # If no atoms are found that match the types pair `samples_vals_tensor` + # will be empty. We have to reshape the empty tensor to be a valid input for + # `Labels`. + if len(samples_vals_tensor) == 0: + samples_vals_tensor = samples_vals_tensor.reshape(-1, 2) + + labels_samples = Labels(["system", "atom"], samples_vals_tensor) + labels_properties = Labels( + ["potential"], torch.tensor([[0]], device=device) + ) + + block = TensorBlock( + samples=labels_samples, + components=[], + properties=labels_properties, + values=torch.hstack(values).reshape((-1, 1)), + ) + + blocks.append(block) + + assert len(blocks) == n_blocks + + # Generate TensorMap from TensorBlocks by defining suitable keys + key_values: List[torch.Tensor] = [] + for spec_center in requested_types: + for spec_channel in spec_channels: + key_values.append( + torch.tensor([spec_center, spec_channel], device=device) + ) + key_values = torch.vstack(key_values) + + labels_keys = Labels(key_names, key_values) + + return TensorMap(keys=labels_keys, blocks=blocks) From 6f1c8667b5fd5b518fe7ce79f6651751cdf7abab Mon Sep 17 00:00:00 2001 From: Kevin Kazuki Huguenin-Dumittan Date: Thu, 20 Jun 2024 16:59:56 +0200 Subject: [PATCH 10/35] Add option to provide external neighborlist --- examples/neighborlist_example.ipynb | 199 ++++++++++++++++++ .../calculators/calculator_base_periodic.py | 48 ++++- src/meshlode/calculators/meshewald.py | 63 ++++-- src/meshlode/metatensor/__init__.py | 1 - src/meshlode/metatensor/meshewald.py | 45 +++- tests/calculators/test_workflow_meshewald.py | 12 +- 6 files changed, 325 insertions(+), 43 deletions(-) create mode 100644 examples/neighborlist_example.ipynb diff --git a/examples/neighborlist_example.ipynb b/examples/neighborlist_example.ipynb new file mode 100644 index 00000000..f5df529e --- /dev/null +++ b/examples/neighborlist_example.ipynb @@ -0,0 +1,199 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import meshlode\n", + "import torch\n", + "import numpy as np\n", + "import math\n", + "from metatensor.torch.atomistic import System\n", + "\n", + "from ase import Atoms\n", + "from ase.neighborlist import neighbor_list" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "# Define simple example structure having the CsCl structure and compute the reference\n", + "# values. MeshPotential by default outputs the types sorted according to the atomic\n", + "# number. Thus, we input the compound \"CsCl\" and \"ClCs\" since Cl and Cs have atomic\n", + "# numbers 17 and 55, respectively.\n", + "types = torch.tensor([17, 55]) # Cl and Cs\n", + "positions = torch.tensor([[0, 0, 0], [0.5, 0.5, 0.5]])\n", + "charges = torch.tensor([-1.0, 1.0])\n", + "cell = torch.eye(3)\n", + "\n", + "# %%\n", + "# Define the expected values of the energy\n", + "n_atoms = len(types)\n", + "madelung = 2 * 1.7626 / math.sqrt(3)\n", + "energies_ref = -madelung * torch.ones((n_atoms, 1))\n", + "\n", + "# %%\n", + "# We first define general parameters for our calculation MeshLODE\n", + "\n", + "atomic_smearing = 0.1\n", + "cell = torch.eye(3)\n", + "mesh_spacing = atomic_smearing / 4\n", + "interpolation_order = 2" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Generate neighbor list using ASE" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "sr_cutoff = np.sqrt(3) * 0.8\n", + "struc = Atoms(positions=positions, cell=cell, pbc=True)\n", + "atom_is, atom_js, neighbor_shifts = neighbor_list(\"ijS\", struc, sr_cutoff, self_interaction=False)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Convert neighbor list from ASE to desired format (torch tensor of dtype int)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "atom_is = atom_is.reshape((-1,1))\n", + "atom_js = atom_js.reshape((-1,1))\n", + "neighbor_indices = torch.tensor(np.hstack([atom_is, atom_js]))\n", + "neighbor_shifts = torch.tensor(neighbor_shifts)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/khugueni/code/MeshLODE/src/meshlode/calculators/meshewald.py:336: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", + " positions[j] - positions[i] + torch.tensor(shift @ cell)\n" + ] + } + ], + "source": [ + "system = System(types=types, positions=positions, cell=cell)\n", + "\n", + "MP = meshlode.metatensor.MeshEwaldPotential(\n", + " atomic_smearing=atomic_smearing,\n", + " mesh_spacing=mesh_spacing,\n", + " interpolation_order=interpolation_order,\n", + " subtract_self=True,\n", + " sr_cutoff=sr_cutoff,\n", + ")\n", + "potential_metatensor = MP.compute(system, neighbor_indices=neighbor_indices, neighbor_shifts=neighbor_shifts)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Convert to Madelung constant and check that the value is correct" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor(17) tensor(17) tensor(1.) tensor(-2.7745)\n", + "tensor(17) tensor(55) tensor(-1.) tensor(-0.7391)\n", + "tensor(55) tensor(17) tensor(-1.) tensor(-0.7391)\n", + "tensor(55) tensor(55) tensor(1.) tensor(-2.7745)\n", + "Using the metatensor version\n", + "Computed energies on each atom = [[-2.035360813140869], [-2.035360813140869]]\n", + "Reference Madelung constant = 2.035\n", + "Total energy = -4.071\n" + ] + } + ], + "source": [ + "atomic_energies_metatensor = torch.zeros((n_atoms, 1))\n", + "for idx_c, c in enumerate(types):\n", + " for idx_n, n in enumerate(types):\n", + " # Take the coefficients with the correct center atom and neighbor atom types\n", + " block = potential_metatensor.block(\n", + " {\"center_type\": int(c), \"neighbor_type\": int(n)}\n", + " )\n", + "\n", + " # The coulomb potential between atoms i and j is charge_i * charge_j / d_ij\n", + " # The features are simply computing a pure 1/r potential with no prefactors.\n", + " # Thus, to compute the energy between atoms of types i and j, we need to\n", + " # multiply by the charges of i and j.\n", + " print(c, n, charges[idx_c] * charges[idx_n], block.values[0, 0])\n", + " atomic_energies_metatensor[idx_c] += (\n", + " charges[idx_c] * charges[idx_n] * block.values[0, 0]\n", + " )\n", + "\n", + "# %%\n", + "# The total energy is just the sum of all atomic energies\n", + "total_energy_metatensor = torch.sum(atomic_energies_metatensor)\n", + "\n", + "# %%\n", + "# Compare against reference Madelung constant and reference energy:\n", + "print(\"Using the metatensor version\")\n", + "print(f\"Computed energies on each atom = {atomic_energies_metatensor.tolist()}\")\n", + "print(f\"Reference Madelung constant = {madelung:.3f}\")\n", + "print(f\"Total energy = {total_energy_metatensor:.3f}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/src/meshlode/calculators/calculator_base_periodic.py b/src/meshlode/calculators/calculator_base_periodic.py index 883f6f71..9ed7d10c 100644 --- a/src/meshlode/calculators/calculator_base_periodic.py +++ b/src/meshlode/calculators/calculator_base_periodic.py @@ -36,6 +36,8 @@ def compute( positions: Union[List[torch.Tensor], torch.Tensor], cell: Union[List[torch.Tensor], torch.Tensor], charges: Optional[Union[List[torch.Tensor], torch.Tensor]] = None, + neighbor_indices: Union[List[torch.Tensor], torch.Tensor] = None, + neighbor_shifts: Union[List[torch.Tensor], torch.Tensor] = None, ) -> Union[torch.Tensor, List[torch.Tensor]]: """Compute potential for all provided "systems" stacked inside list. @@ -51,6 +53,12 @@ def compute( vector; and columns should contain the x, y, and z components of these vectors (i.e. the cell should be given in row-major order). :param charges: Optional single or list of 2D tensor of shape (len(types), n), + :param neighbor_indices: Optional single or list of 2D tensors of shape (2, n), + where n is the number of atoms. The 2 rows correspond to the indices of + the two atoms which are considered neighbors (e.g. within a cutoff distance) + :param neighbor_shifts: Optional single or list of 2D tensors of shape (3, n), + where n is the number of atoms. The 3 rows correspond to the shift indices + for periodic images. :return: List of torch Tensors containing the potentials for all frames and all atoms. Each tensor in the list is of shape (n_atoms, n_types), where @@ -155,17 +163,39 @@ def compute( # We don't require and test that all dtypes and devices are consistent if a list # of inputs. Each "frame" is processed independently. potentials = [] - for positions_single, cell_single, charges_single in zip( - positions, cell, charges - ): - # Compute the potentials - potentials.append( - self._compute_single_system( - positions=positions_single, charges=charges_single, cell=cell_single + + if neighbor_indices is None: + for positions_single, cell_single, charges_single in zip( + positions, cell, charges + ): + # Compute the potentials + potentials.append( + self._compute_single_system( + positions=positions_single, + charges=charges_single, + cell=cell_single, + ) + ) + else: + for ( + positions_single, + cell_single, + charges_single, + neighbor_indices_single, + neighbor_shifts_single, + ) in zip(positions, cell, charges, neighbor_indices, neighbor_shifts): + # Compute the potentials + potentials.append( + self._compute_single_system( + positions=positions_single, + charges=charges_single, + cell=cell_single, + neighbor_indices=neighbor_indices_single, + neighbor_shifts=neighbor_shifts_single, + ) ) - ) if len(types) == 1: return potentials[0] else: - return potentials \ No newline at end of file + return potentials diff --git a/src/meshlode/calculators/meshewald.py b/src/meshlode/calculators/meshewald.py index 4a129e13..da19320f 100644 --- a/src/meshlode/calculators/meshewald.py +++ b/src/meshlode/calculators/meshewald.py @@ -1,14 +1,19 @@ -import torch from typing import List, Optional -# from .mesh import MeshPotential -from .calculator_base_periodic import CalculatorBasePeriodic -from meshlode.lib.mesh_interpolator import MeshInterpolator +import torch # extra imports for neighbor list from ase import Atoms from ase.neighborlist import neighbor_list +from meshlode.lib.mesh_interpolator import MeshInterpolator + +from .calculator_base import default_exponent + +# from .mesh import MeshPotential +from .calculator_base_periodic import CalculatorBasePeriodic + + class MeshEwaldPotential(CalculatorBasePeriodic): """A specie-wise long-range potential computed using a mesh-based Ewald method, scaling as O(NlogN) with respect to the number of particles N used as a reference @@ -46,13 +51,13 @@ class MeshEwaldPotential(CalculatorBasePeriodic): def __init__( self, all_types: Optional[List[int]] = None, - exponent: Optional[torch.Tensor] = torch.tensor(1., dtype=torch.float64), - sr_cutoff: Optional[float] = None, + exponent: Optional[torch.Tensor] = default_exponent, + sr_cutoff: Optional[torch.Tensor] = None, atomic_smearing: Optional[float] = None, mesh_spacing: Optional[float] = None, subtract_self: Optional[bool] = True, interpolation_order: Optional[int] = 4, - subtract_interior: Optional[bool] = False + subtract_interior: Optional[bool] = False, ): super().__init__(all_types=all_types, exponent=exponent) @@ -129,6 +134,8 @@ def _compute_single_system( positions: torch.Tensor, charges: torch.Tensor, cell: torch.Tensor, + neighbor_indices: Optional[torch.Tensor] = None, + neighbor_shifts: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Compute the "electrostatic" potential at the position of all atoms in a @@ -162,7 +169,7 @@ def _compute_single_system( cutoff_max = torch.min(cell_dimensions) / 2 - 1e-6 if self.sr_cutoff is not None: if self.sr_cutoff > torch.min(cell_dimensions) / 2: - raise ValueError(f"sr_cutoff {sr_cutoff} needs to be > {cutoff_max}") + raise ValueError(f"sr_cutoff {self.sr_cutoff} has to be > {cutoff_max}") # Set the defaut values of convergence parameters # The total computational cost = cost of SR part + cost of LR part @@ -184,7 +191,7 @@ def _compute_single_system( smearing = self.atomic_smearing if self.mesh_spacing is None: - mesh_spacing = smearing / 8. + mesh_spacing = smearing / 8.0 else: mesh_spacing = self.mesh_spacing @@ -195,6 +202,8 @@ def _compute_single_system( cell=cell, smearing=smearing, sr_cutoff=sr_cutoff, + neighbor_indices=neighbor_indices, + neighbor_shifts=neighbor_shifts ) # Compute long-range (LR) part using a Fourier / reciprocal space sum @@ -203,7 +212,8 @@ def _compute_single_system( charges=charges, cell=cell, smearing=smearing, - lr_wavelength=mesh_spacing) + lr_wavelength=mesh_spacing, + ) # Combine both parts to obtain the full potential potential_ewald = potential_sr + potential_lr @@ -233,16 +243,15 @@ def _compute_lr( structure, where cell[i] is the i-th basis vector. :param smearing: torch.Tensor smearing paramter determining the splitting between the SR and LR parts. - :param lr_wavelength: Spatial resolution used for the long-range (reciprocal space) - part of the Ewald sum. More conretely, all Fourier space vectors with a - wavelength >= this value will be kept. + :param lr_wavelength: Spatial resolution used for the long-range (reciprocal + space) part of the Ewald sum. More conretely, all Fourier space vectors with + a wavelength >= this value will be kept. :returns: torch.tensor of shape `(n_atoms, n_channels)` containing the potential at the position of each atom for the `n_channels` independent meshes separately. """ # Step 0 (Preparation): Compute number of times each basis vector of the # reciprocal space can be scaled until the cutoff is reached - mesh_spacing = lr_wavelength k_cutoff = 2 * torch.pi / lr_wavelength basis_norms = torch.linalg.norm(cell, dim=1) ns_approx = k_cutoff * basis_norms / 2 / torch.pi @@ -258,11 +267,11 @@ def _compute_lr( # Step 2.1: Generate k-vectors and evaluate kernel function kvectors = self._generate_kvectors(ns=ns, cell=cell) knorm_sq = torch.sum(kvectors**2, dim=3) - + # Step 2.2: Evaluate kernel function (careful, tensor shapes are different from # the pure Ewald implementation since we are no longer flattening) G = self.potential.potential_fourier_from_k_sq(knorm_sq, smearing) - G[0,0,0] = self.potential.potential_fourier_at_zero(smearing) + G[0, 0, 0] = self.potential.potential_fourier_at_zero(smearing) potential_mesh = rho_mesh @@ -293,6 +302,8 @@ def _compute_sr( cell: torch.Tensor, smearing: torch.Tensor, sr_cutoff: torch.Tensor, + neighbor_indices: Optional[torch.Tensor] = None, + neighbor_shifts: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Compute the short-range part of the Ewald sum in realspace @@ -314,16 +325,24 @@ def _compute_sr( :returns: torch.tensor of shape `(n_atoms, n_channels)` containing the potential at the position of each atom for the `n_channels` independent meshes separately. """ - # Get list of neighbors - struc = Atoms(positions=positions.detach().numpy(), cell=cell, pbc=True) - atom_is, atom_js, shifts = neighbor_list( - "ijS", struc, sr_cutoff.item(), self_interaction=False - ) + if neighbor_indices is None: + # Get list of neighbors + struc = Atoms(positions=positions.detach().numpy(), cell=cell, pbc=True) + atom_is, atom_js, shifts = neighbor_list( + "ijS", struc, sr_cutoff.item(), self_interaction=False + ) + else: + atom_is = neighbor_indices[:,0] + atom_js = neighbor_indices[:,1] + shifts = neighbor_shifts.T + # Compute energy potential = torch.zeros_like(charges) for i, j, shift in zip(atom_is, atom_js, shifts): - dist = torch.linalg.norm(positions[j] - positions[i] + torch.tensor(shift.dot(struc.cell))) + dist = torch.linalg.norm( + positions[j] - positions[i] + torch.tensor(shift.dot(struc.cell)) + ) # If the contribution from all atoms within the cutoff is to be subtracted # this short-range part will simply use -V_LR as the potential diff --git a/src/meshlode/metatensor/__init__.py b/src/meshlode/metatensor/__init__.py index ea6acb91..8afbae3a 100644 --- a/src/meshlode/metatensor/__init__.py +++ b/src/meshlode/metatensor/__init__.py @@ -1,5 +1,4 @@ from .meshpotential import MeshPotential -from .ewaldpotential import EwaldPotential from .meshewald import MeshEwaldPotential __all__ = ["MeshPotential", "EwaldPotential", "MeshEwaldPotential"] diff --git a/src/meshlode/metatensor/meshewald.py b/src/meshlode/metatensor/meshewald.py index 16213d28..9351aee2 100644 --- a/src/meshlode/metatensor/meshewald.py +++ b/src/meshlode/metatensor/meshewald.py @@ -30,13 +30,21 @@ class MeshEwaldPotential(calculators.MeshEwaldPotential): def forward( self, systems: Union[List[System], System], + neighbor_indices: Union[List[torch.Tensor], torch.Tensor] = None, + neighbor_shifts: Union[List[torch.Tensor], torch.Tensor] = None, ) -> TensorMap: """forward just calls :py:meth:`CalculatorModule.compute`""" - return self.compute(systems=systems) + return self.compute( + systems=systems, + neighbor_indices=neighbor_indices, + neighbor_shifts=neighbor_shifts, + ) def compute( self, systems: Union[List[System], System], + neighbor_indices: Union[List[torch.Tensor], torch.Tensor] = None, + neighbor_shifts: Union[List[torch.Tensor], torch.Tensor] = None, ) -> TensorMap: """Compute potential for all provided ``systems``. @@ -61,6 +69,20 @@ def compute( # provided as input (for convenience of users testing out the code) if not isinstance(systems, list): systems = [systems] + if (neighbor_indices is not None) and not isinstance(neighbor_indices, list): + neighbor_indices = [neighbor_indices] + if (neighbor_shifts is not None) and not isinstance(neighbor_shifts, list): + neighbor_shifts = [neighbor_shifts] + + # Check that the lengths of the provided lists agree + if (neighbor_indices is not None) and len(neighbor_indices) != len(systems): + raise ValueError( + f"Numbers of systems (= {len(systems)}) needs to match number of neighbor lists (= {len(neighbor_indices)})" + ) + if (neighbor_shifts is not None) and len(neighbor_shifts) != len(systems): + raise ValueError( + f"Numbers of systems (= {len(systems)}) needs to match number of neighbor shifts (= {len(neighbor_shifts)})" + ) if len(systems) > 1: for system in systems[1:]: @@ -122,7 +144,7 @@ def compute( n_blocks = n_types * n_charges_channels feat_dic: Dict[int, List[torch.Tensor]] = {a: [] for a in range(n_blocks)} - for system in systems: + for i, system in enumerate(systems): if use_explicit_charges: charges = system.get_data("charges").values else: @@ -131,10 +153,21 @@ def compute( system.types, requested_types, dtype, device ) - # Compute the potentials - potential = self._compute_single_system( - system.positions, charges, system.cell - ) + if neighbor_indices is None or neighbor_shifts is None: + # Compute the potentials + potential = self._compute_single_system( + positions=system.positions, + charges=charges, + cell=system.cell, + ) + else: + potential = self._compute_single_system( + positions=system.positions, + charges=charges, + cell=system.cell, + neighbor_indices=neighbor_indices[i], + neighbor_shifts=neighbor_shifts[i], + ) # Reorder data into metatensor format for spec_center, at_num_center in enumerate(requested_types): diff --git a/tests/calculators/test_workflow_meshewald.py b/tests/calculators/test_workflow_meshewald.py index a5d73ef2..9fc5a644 100644 --- a/tests/calculators/test_workflow_meshewald.py +++ b/tests/calculators/test_workflow_meshewald.py @@ -7,7 +7,7 @@ import torch from torch.testing import assert_close -from meshlode import MeshPotential, MeshEwaldPotential +from meshlode import MeshEwaldPotential, MeshPotential from meshlode.calculators.calculator_base import _1d_tolist, _is_subset @@ -85,7 +85,8 @@ def check_operation(calculator): def test_operation_as_python(): check_operation(descriptor()) -""" + +""" # Similar to the above, but also testing that the code can be compiled as a torch script # Disabled for now since (1) the ASE neighbor list and (2) the use of the potential # class are clashing with the torch script capabilities. @@ -94,6 +95,7 @@ def test_operation_as_torch_script(): check_operation(scripted) """ + def test_single_frame(): values = descriptor().compute(*cscl_system()) assert_close( @@ -245,6 +247,7 @@ def test_inconsistent_dtype(): with pytest.raises(ValueError, match=match): MP.compute(types=types, positions=positions, cell=cell) + def test_inconsistent_device(): """Test if the cell and positions have inconsistent device and error is raised.""" types = torch.tensor([1], device="cpu") @@ -253,9 +256,8 @@ def test_inconsistent_device(): MP = MeshPotential(atomic_smearing=0.2) - match = ( - '`types`, `positions`, and `cell` must be on the same device, got cpu, cpu and meta.' - ) + match = "`types`, `positions`, and `cell` must be on the same device, got cpu, cpu " + match += "and meta." with pytest.raises(ValueError, match=match): MP.compute(types=types, positions=positions, cell=cell) From 40b8808bf33428adbfbe08c604d645666b4a28aa Mon Sep 17 00:00:00 2001 From: Kevin Kazuki Huguenin-Dumittan Date: Thu, 20 Jun 2024 17:23:47 +0200 Subject: [PATCH 11/35] Add tests for potentials class --- src/meshlode/calculators/calculator_base.py | 20 +- .../calculators/calculator_base_periodic.py | 6 +- src/meshlode/calculators/direct.py | 12 +- src/meshlode/calculators/ewald.py | 41 +-- src/meshlode/calculators/mesh.py | 5 +- src/meshlode/calculators/meshewald.py | 28 +-- src/meshlode/lib/potentials.py | 67 ++--- tests/calculators/test_values_aperiodic.py | 179 ++++++++------ tests/calculators/test_values_periodic.py | 6 +- tests/calculators/test_workflow_direct.py | 21 +- tests/calculators/test_workflow_ewald.py | 5 +- tests/calculators/test_workflow_mesh.py | 13 +- tests/lib/test_potentials.py | 233 ++++++++++++++++++ 13 files changed, 440 insertions(+), 196 deletions(-) create mode 100644 tests/lib/test_potentials.py diff --git a/src/meshlode/calculators/calculator_base.py b/src/meshlode/calculators/calculator_base.py index e6abbf39..1f99fefb 100644 --- a/src/meshlode/calculators/calculator_base.py +++ b/src/meshlode/calculators/calculator_base.py @@ -1,8 +1,16 @@ -from meshlode.lib import InversePowerLawPotential from typing import List, Optional, Union import torch +from meshlode.lib import InversePowerLawPotential + + +def get_default_exponent(): + return torch.tensor(1.0) + + +default_exponent = get_default_exponent() + @torch.jit.script def _1d_tolist(x: torch.Tensor) -> List[int]: @@ -38,7 +46,7 @@ class CalculatorBase(torch.nn.Module): def __init__( self, all_types: Optional[List[int]] = None, - exponent: Optional[torch.Tensor] = torch.tensor(1., dtype=torch.float64), + exponent: Optional[torch.Tensor] = default_exponent, ): super().__init__() @@ -46,9 +54,9 @@ def __init__( self.all_types = None else: self.all_types = _1d_tolist(torch.unique(torch.tensor(all_types))) - + self.exponent = exponent - self.potential = InversePowerLawPotential(exponent = exponent) + self.potential = InversePowerLawPotential(exponent=exponent) # This function is kept to keep this library compatible with the broader pytorch # infrastructure, which require a "forward" function. We name this function @@ -60,9 +68,7 @@ def forward( charges: Optional[Union[List[torch.Tensor], torch.Tensor]] = None, ) -> Union[torch.Tensor, List[torch.Tensor]]: """forward just calls :py:meth:`CalculatorModule.compute`""" - return self.compute( - types=types, positions=positions, charges=charges - ) + return self.compute(types=types, positions=positions, charges=charges) def compute( self, diff --git a/src/meshlode/calculators/calculator_base_periodic.py b/src/meshlode/calculators/calculator_base_periodic.py index 9ed7d10c..32f21243 100644 --- a/src/meshlode/calculators/calculator_base_periodic.py +++ b/src/meshlode/calculators/calculator_base_periodic.py @@ -81,6 +81,10 @@ def compute( positions = [positions] if not isinstance(cell, list): cell = [cell] + if (neighbor_indices is not None) and not isinstance(neighbor_indices, list): + neighbor_indices = [neighbor_indices] + if (neighbor_shifts is not None) and not isinstance(neighbor_shifts, list): + neighbor_shifts = [neighbor_shifts] # Check that all inputs are consistent for types_single, positions_single, cell_single in zip(types, positions, cell): @@ -164,7 +168,7 @@ def compute( # of inputs. Each "frame" is processed independently. potentials = [] - if neighbor_indices is None: + if neighbor_indices is None or neighbor_shifts is None: for positions_single, cell_single, charges_single in zip( positions, cell, charges ): diff --git a/src/meshlode/calculators/direct.py b/src/meshlode/calculators/direct.py index 925dfdf3..c28b2a2d 100644 --- a/src/meshlode/calculators/direct.py +++ b/src/meshlode/calculators/direct.py @@ -1,7 +1,7 @@ -from .calculator_base import CalculatorBase - import torch +from .calculator_base import CalculatorBase + class DirectPotential(CalculatorBase): """A specie-wise long-range potential computed using a direct summation over all @@ -46,10 +46,6 @@ def _compute_single_system( potential. Subtracting these from each other, one could recover the more standard electrostatic potential in which Na and Cl have charges of +1 and -1, respectively. - :param cell: torch.tensor of shape `(3, 3)`. Describes the unit cell of the - structure, where cell[i] is the i-th basis vector. While redundant in this - particular implementation, the parameter is kept to keep the same inputs as - the other calculators. :returns: torch.tensor of shape `(n_atoms, n_channels)` containing the potential at the position of each atom for the `n_channels` independent meshes separately. @@ -73,9 +69,9 @@ def _compute_single_system( # obvious alternative of setting the same components to zero after the division # had issues with autograd. I would appreciate any better alternatives. distances_sq[diagonal_indices, diagonal_indices] += 1e50 - + # Compute potential - potentials_by_pair = distances_sq.pow(-self.exponent / 2.) + potentials_by_pair = distances_sq.pow(-self.exponent / 2.0) potentials = torch.matmul(potentials_by_pair, charges) return potentials diff --git a/src/meshlode/calculators/ewald.py b/src/meshlode/calculators/ewald.py index 2df0b3e4..91560b94 100644 --- a/src/meshlode/calculators/ewald.py +++ b/src/meshlode/calculators/ewald.py @@ -1,12 +1,15 @@ -import torch from typing import List, Optional -from .calculator_base_periodic import CalculatorBasePeriodic +import torch # extra imports for neighbor list from ase import Atoms from ase.neighborlist import neighbor_list +from .calculator_base import default_exponent +from .calculator_base_periodic import CalculatorBasePeriodic + + class EwaldPotential(CalculatorBasePeriodic): """A specie-wise long-range potential computed using the Ewald sum, scaling as O(N^2) with respect to the number of particles N used as a reference to test faster @@ -62,12 +65,12 @@ class EwaldPotential(CalculatorBasePeriodic): def __init__( self, all_types: Optional[List[int]] = None, - exponent: Optional[torch.Tensor] = torch.tensor(1., dtype=torch.float64), - sr_cutoff: Optional[float] = None, + exponent: Optional[torch.Tensor] = default_exponent, + sr_cutoff: Optional[torch.Tensor] = None, atomic_smearing: Optional[float] = None, lr_wavelength: Optional[float] = None, subtract_self: Optional[bool] = True, - subtract_interior: Optional[bool] = False + subtract_interior: Optional[bool] = False, ): super().__init__(all_types=all_types, exponent=exponent) @@ -120,7 +123,7 @@ def _compute_single_system( cutoff_max = torch.min(cell_dimensions) / 2 - 1e-6 if self.sr_cutoff is not None: if self.sr_cutoff > torch.min(cell_dimensions) / 2: - raise ValueError(f"sr_cutoff {sr_cutoff} needs to be > {cutoff_max}") + raise ValueError(f"sr_cutoff {self.sr_cutoff} has to be > {cutoff_max}") # Set the defaut values of convergence parameters # The total computational cost = cost of SR part + cost of LR part @@ -154,8 +157,6 @@ def _compute_single_system( sr_cutoff=sr_cutoff, ) - ##return charges * torch.sum(positions, dim=1) * self.exponent + potential_sr - potential_lr = self._compute_lr( positions=positions, charges=charges, @@ -164,11 +165,9 @@ def _compute_single_system( lr_wavelength=lr_wavelength, ) - #return potential_lr - potential_ewald = potential_sr + potential_lr return potential_ewald - + def _generate_kvectors(self, ns: torch.Tensor, cell: torch.Tensor) -> torch.Tensor: """ For a given unit cell, compute all reciprocal space vectors that are used to @@ -189,9 +188,9 @@ def _generate_kvectors(self, ns: torch.Tensor, cell: torch.Tensor) -> torch.Tens that will be used during Ewald summation (or related approaches). ``k_vectors[i]`` contains the i-th vector, where the order has no special significance. - The total number N of k-vectors is NOT simply nx*ny*nz, and roughly corresponds - to nx*ny*nz/2 due since the vectors +k and -k can be grouped together during - summation. + The total number N of k-vectors is NOT simply nx*ny*nz, and roughly + corresponds to nx*ny*nz/2 due since the vectors +k and -k can be grouped + together during summation. """ # Check that the shapes of all inputs are correct if ns.shape != (3,): @@ -239,9 +238,9 @@ def _compute_lr( structure, where cell[i] is the i-th basis vector. :param smearing: torch.Tensor smearing paramter determining the splitting between the SR and LR parts. - :param lr_wavelength: Spatial resolution used for the long-range (reciprocal space) - part of the Ewald sum. More conretely, all Fourier space vectors with a - wavelength >= this value will be kept. + :param lr_wavelength: Spatial resolution used for the long-range (reciprocal + space) part of the Ewald sum. More conretely, all Fourier space vectors with + a wavelength >= this value will be kept. :returns: torch.tensor of shape `(n_atoms, n_channels)` containing the potential at the position of each atom for the `n_channels` independent meshes separately. @@ -249,8 +248,8 @@ def _compute_lr( # Define k-space cutoff from required real-space resolution k_cutoff = 2 * torch.pi / lr_wavelength - # Compute number of times each basis vector of the reciprocal space can be scaled - # until the cutoff is reached + # Compute number of times each basis vector of the reciprocal space can be + # scaled until the cutoff is reached basis_norms = torch.linalg.norm(cell, dim=1) ns_float = k_cutoff * basis_norms / 2 / torch.pi ns = torch.ceil(ns_float).long() @@ -346,7 +345,9 @@ def _compute_sr( # Compute energy potential = torch.zeros_like(charges) for i, j, shift in zip(atom_is, atom_js, shifts): - dist = torch.linalg.norm(positions[j] - positions[i] + torch.tensor(shift.dot(struc.cell))) + dist = torch.linalg.norm( + positions[j] - positions[i] + torch.tensor(shift.dot(struc.cell)) + ) # If the contribution from all atoms within the cutoff is to be subtracted # this short-range part will simply use -V_LR as the potential diff --git a/src/meshlode/calculators/mesh.py b/src/meshlode/calculators/mesh.py index 0b6015eb..18d999fb 100644 --- a/src/meshlode/calculators/mesh.py +++ b/src/meshlode/calculators/mesh.py @@ -5,8 +5,10 @@ from meshlode.lib.fourier_convolution import FourierSpaceConvolution from meshlode.lib.mesh_interpolator import MeshInterpolator +from .calculator_base import default_exponent from .calculator_base_periodic import CalculatorBasePeriodic + class MeshPotential(CalculatorBasePeriodic): """A specie-wise long-range potential, computed using the particle-mesh Ewald (PME) method scaling as O(NlogN) with respect to the number of particles N. @@ -56,7 +58,7 @@ def __init__( interpolation_order: Optional[int] = 4, subtract_self: Optional[bool] = False, all_types: Optional[List[int]] = None, - exponent: Optional[torch.Tensor] = torch.tensor(1., dtype=torch.float64), + exponent: Optional[torch.Tensor] = default_exponent, ): super().__init__(all_types=all_types, exponent=exponent) @@ -120,7 +122,6 @@ def _compute_single_system( assert positions.dtype == cell.dtype and charges.dtype == cell.dtype assert positions.device == cell.device and charges.device == cell.device - # Define cutoff in reciprocal space if mesh_spacing is None: mesh_spacing = self.mesh_spacing diff --git a/src/meshlode/calculators/meshewald.py b/src/meshlode/calculators/meshewald.py index da19320f..18c6c22c 100644 --- a/src/meshlode/calculators/meshewald.py +++ b/src/meshlode/calculators/meshewald.py @@ -162,15 +162,6 @@ def _compute_single_system( :returns: torch.tensor of shape `(n_atoms, n_channels)` containing the potential at the position of each atom for the `n_channels` independent meshes separately. """ - # Check that the realspace cutoff (if provided) is not too large - # This is because the current implementation is not able to return multiple - # periodic images of the same atom as a neighbor - cell_dimensions = torch.linalg.norm(cell, dim=1) - cutoff_max = torch.min(cell_dimensions) / 2 - 1e-6 - if self.sr_cutoff is not None: - if self.sr_cutoff > torch.min(cell_dimensions) / 2: - raise ValueError(f"sr_cutoff {self.sr_cutoff} has to be > {cutoff_max}") - # Set the defaut values of convergence parameters # The total computational cost = cost of SR part + cost of LR part # Bigger smearing increases the cost of the SR part while decreasing the cost @@ -181,6 +172,8 @@ def _compute_single_system( # chosen to reach a convergence on the order of 1e-4 to 1e-5 for the test # structures. if self.sr_cutoff is None: + cell_dimensions = torch.linalg.norm(cell, dim=1) + cutoff_max = torch.min(cell_dimensions) / 2 - 1e-6 sr_cutoff = cutoff_max else: sr_cutoff = self.sr_cutoff @@ -203,7 +196,7 @@ def _compute_single_system( smearing=smearing, sr_cutoff=sr_cutoff, neighbor_indices=neighbor_indices, - neighbor_shifts=neighbor_shifts + neighbor_shifts=neighbor_shifts, ) # Compute long-range (LR) part using a Fourier / reciprocal space sum @@ -325,23 +318,22 @@ def _compute_sr( :returns: torch.tensor of shape `(n_atoms, n_channels)` containing the potential at the position of each atom for the `n_channels` independent meshes separately. """ - if neighbor_indices is None: + if neighbor_indices is None or neighbor_shifts is None: # Get list of neighbors struc = Atoms(positions=positions.detach().numpy(), cell=cell, pbc=True) - atom_is, atom_js, shifts = neighbor_list( + atom_is, atom_js, neighbor_shifts = neighbor_list( "ijS", struc, sr_cutoff.item(), self_interaction=False ) else: - atom_is = neighbor_indices[:,0] - atom_js = neighbor_indices[:,1] - shifts = neighbor_shifts.T - + atom_is = neighbor_indices[0] + atom_js = neighbor_indices[1] # Compute energy potential = torch.zeros_like(charges) - for i, j, shift in zip(atom_is, atom_js, shifts): + for i, j, shift in zip(atom_is, atom_js, neighbor_shifts): + shift = shift.type(cell.dtype) dist = torch.linalg.norm( - positions[j] - positions[i] + torch.tensor(shift.dot(struc.cell)) + positions[j] - positions[i] + torch.tensor(shift @ cell) ) # If the contribution from all atoms within the cutoff is to be subtracted diff --git a/src/meshlode/lib/potentials.py b/src/meshlode/lib/potentials.py index 41e989af..dc80ccc7 100644 --- a/src/meshlode/lib/potentials.py +++ b/src/meshlode/lib/potentials.py @@ -1,6 +1,8 @@ +import math + import torch from torch.special import gammainc, gammaincc, gammaln -import math + # since pytorch has implemented the incomplete Gamma functions, but not the much more # commonly used (complete) Gamma function, we define it in a custom way to make autograd @@ -8,6 +10,7 @@ def gamma(x): return torch.exp(gammaln(x)) + class InversePowerLawPotential: """ Class to handle inverse power-law potentials of the form 1/r^p, where r is a @@ -21,12 +24,11 @@ class InversePowerLawPotential: :param exponent: torch.tensor corresponding to the exponent "p" in 1/r^p potentials """ + def __init__(self, exponent: torch.Tensor): self.exponent = exponent - - def potential_from_dist(self, - dist: torch.Tensor - ) -> torch.Tensor: + + def potential_from_dist(self, dist: torch.Tensor) -> torch.Tensor: """ Full 1/r^p potential as a function of r @@ -34,10 +36,8 @@ def potential_from_dist(self, be evaluated. """ return torch.pow(dist, -self.exponent) - - def potential_from_dist_sq(self, - dist_sq: torch.Tensor - ) -> torch.Tensor: + + def potential_from_dist_sq(self, dist_sq: torch.Tensor) -> torch.Tensor: """ Full 1/r^p potential as a function of r^2, which is more useful in some implementations @@ -45,19 +45,18 @@ def potential_from_dist_sq(self, :param dist_sq: torch.tensor containing the squared distances at which the potential is to be evaluated. """ - return torch.pow(dist_sq, -self.exponent / 2.) + return torch.pow(dist_sq, -self.exponent / 2.0) - def potential_sr_from_dist(self, - dist: torch.Tensor, - smearing: torch.Tensor - ) -> torch.Tensor: + def potential_sr_from_dist( + self, dist: torch.Tensor, smearing: torch.Tensor + ) -> torch.Tensor: """ Short-range (SR) part of the range-separated 1/r^p potential as a function of r. More explicitly: it corresponds to V_SR(r) in 1/r^p = V_SR(r) + V_LR(r), where the location of the split is determined by the smearing parameter. For the Coulomb potential, this would return - potential = erfc(dist / sqrt(2) / smearing) / dist + potential = erfc(dist / torch.sqrt(2.) / smearing) / dist :param dist: torch.tensor containing the distances at which the potential is to be evaluated. @@ -70,14 +69,15 @@ def potential_sr_from_dist(self, """ x = 0.5 * dist**2 / smearing**2 peff = self.exponent / 2 - prefac = 1./(2*smearing**2)**peff - potential = prefac * gammainc(peff, x) / x**peff + prefac = 1.0 / (2 * smearing**2) ** peff + potential = prefac * gammaincc(peff, x) / x**peff + + # potential = erfc(dist / torch.sqrt(torch.tensor(2.)) / smearing) / dist return potential - def potential_lr_from_dist(self, - dist: torch.Tensor, - smearing: torch.Tensor - ) -> torch.Tensor: + def potential_lr_from_dist( + self, dist: torch.Tensor, smearing: torch.Tensor + ) -> torch.Tensor: """ Long-range (LR) part of the range-separated 1/r^p potential as a function of r. Used to subtract out the interior contributions after computing the LR part @@ -98,14 +98,13 @@ def potential_lr_from_dist(self, """ x = 0.5 * dist**2 / smearing**2 peff = self.exponent / 2 - prefac = 1./(2*smearing**2)**peff + prefac = 1.0 / (2 * smearing**2) ** peff potential = prefac * gammainc(peff, x) / x**peff return potential - def potential_fourier_from_k_sq(self, - k_sq: torch.Tensor, - smearing: torch.Tensor - ) -> torch.Tensor: + def potential_fourier_from_k_sq( + self, k_sq: torch.Tensor, smearing: torch.Tensor + ) -> torch.Tensor: """ Fourier transform of the long-range (LR) part potential parametrized in terms of k^2. @@ -113,7 +112,7 @@ def potential_fourier_from_k_sq(self, fourier = 4 * torch.pi * torch.exp(-0.5 * smearing**2 * k_sq) / k_sq :param k_sq: torch.tensor containing the squared lengths (2-norms) of the wave - vectors k at which the Fourier-transformed potential is to be evaluated + vectors k at which the Fourier-transformed potential is to be evaluated :param smearing: torch.tensor containing the parameter often called "sigma" in publications, which determines the length-scale at which the short-range and long-range parts of the naive 1/r^p potential are separated. For the Coulomb @@ -121,13 +120,15 @@ def potential_fourier_from_k_sq(self, potential generated by a Gaussian charge density, in which case this smearing parameter corresponds to the "width" of the Gaussian. """ - peff = (3-self.exponent) / 2 - prefac = (math.pi)**1.5 / gamma(self.exponent/2) * (2*smearing**2)**peff - x = 0.5*smearing**2*k_sq + peff = (3 - self.exponent) / 2 + prefac = ( + (math.pi) ** 1.5 / gamma(self.exponent / 2) * (2 * smearing**2) ** peff + ) + x = 0.5 * smearing**2 * k_sq fourier = prefac * gammaincc(peff, x) / x**peff * gamma(peff) - + return fourier - + def potential_fourier_at_zero(self, smearing: torch.Tensor) -> torch.Tensor: """ The value of the Fourier-transformed potential (LR part implemented above) as @@ -146,4 +147,4 @@ def potential_fourier_at_zero(self, smearing: torch.Tensor) -> torch.Tensor: potential generated by a Gaussian charge density, in which case this smearing parameter corresponds to the "width" of the Gaussian. """ - return torch.tensor(0.) \ No newline at end of file + return torch.tensor(0.0) diff --git a/tests/calculators/test_values_aperiodic.py b/tests/calculators/test_values_aperiodic.py index d119888e..e4612b2b 100644 --- a/tests/calculators/test_values_aperiodic.py +++ b/tests/calculators/test_values_aperiodic.py @@ -1,13 +1,16 @@ -import torch import math + import pytest +import torch + from meshlode import DirectPotential -def define_molecule(molecule_name = 'dimer'): + +def define_molecule(molecule_name="dimer"): """ Define simple "molecules" (collection of point charges) for which the exact Coulomb potential is easy to evaluate. The implementations in the main code are then tested - against these structures. + against these structures. """ # Use a higher precision than the default float32 dtype = torch.float64 @@ -16,74 +19,86 @@ def define_molecule(molecule_name = 'dimer'): # Start defining molecules # Dimer - if molecule_name == 'dimer': - types = torch.tensor([1,1]) - positions = torch.tensor([[0.,0,0],[0,0,1.]], dtype=dtype) - charges = torch.tensor([1.,-1.], dtype=dtype) - potentials = torch.tensor([-1.,1], dtype=dtype) - - elif molecule_name == 'dimer_positive': - types, positions, charges, potentials = define_molecule('dimer') - charges = torch.tensor([1.,1], dtype=dtype) - potentials = torch.tensor([1.,1], dtype=dtype) - - elif molecule_name == 'dimer_negative': - types, positions, charges, potentials = define_molecule('dimer_positive') - charges *= -1. - potentials *= -1. + if molecule_name == "dimer": + types = torch.tensor([1, 1]) + positions = torch.tensor([[0.0, 0, 0], [0, 0, 1.0]], dtype=dtype) + charges = torch.tensor([1.0, -1.0], dtype=dtype) + potentials = torch.tensor([-1.0, 1], dtype=dtype) + + elif molecule_name == "dimer_positive": + types, positions, charges, potentials = define_molecule("dimer") + charges = torch.tensor([1.0, 1], dtype=dtype) + potentials = torch.tensor([1.0, 1], dtype=dtype) + + elif molecule_name == "dimer_negative": + types, positions, charges, potentials = define_molecule("dimer_positive") + charges *= -1.0 + potentials *= -1.0 # Equilateral triangle - elif molecule_name == 'triangle': - types = torch.tensor([1,1,1]) - positions = torch.tensor([[0.,0,0],[1,0,0],[1/2,SQRT3/2,0]], dtype=dtype) - charges = torch.tensor([1.,-1.,0.], dtype=dtype) - potentials = torch.tensor([-1.,1,0], dtype=dtype) - - elif molecule_name == 'triangle_positive': - types, positions, charges, potentials = define_molecule('triangle') - charges = torch.tensor([1.,1,1], dtype=dtype) - potentials = torch.tensor([2.,2,2], dtype=dtype) - - elif molecule_name == 'triangle_negative': - types, positions, charges, potentials = define_molecule('triangle_positive') - charges *= -1. - potentials *= -1. + elif molecule_name == "triangle": + types = torch.tensor([1, 1, 1]) + positions = torch.tensor( + [[0.0, 0, 0], [1, 0, 0], [1 / 2, SQRT3 / 2, 0]], dtype=dtype + ) + charges = torch.tensor([1.0, -1.0, 0.0], dtype=dtype) + potentials = torch.tensor([-1.0, 1, 0], dtype=dtype) + + elif molecule_name == "triangle_positive": + types, positions, charges, potentials = define_molecule("triangle") + charges = torch.tensor([1.0, 1, 1], dtype=dtype) + potentials = torch.tensor([2.0, 2, 2], dtype=dtype) + + elif molecule_name == "triangle_negative": + types, positions, charges, potentials = define_molecule("triangle_positive") + charges *= -1.0 + potentials *= -1.0 # Squares (planar) - elif molecule_name == 'square': - types = torch.tensor([1,1,1,1]) - positions = torch.tensor([[1,1,0],[1,-1,0],[-1,1,0],[-1,-1,0]], dtype=dtype) - positions /= 2. - charges = torch.tensor([1.,-1,-1,1], dtype=dtype) - potentials = charges * (1./SQRT2 - 2.) - - elif molecule_name == 'square_positive': - types, positions, charges, potentials = define_molecule('square') - charges = torch.tensor([1.,1,1,1], dtype=dtype) - potentials = (2. + 1./SQRT2) * torch.ones(4, dtype=dtype) - - elif molecule_name == 'square_negative': - types, positions, charges, potentials = define_molecule('square_positive') - charges *= -1. - potentials *= -1. + elif molecule_name == "square": + types = torch.tensor([1, 1, 1, 1]) + positions = torch.tensor( + [[1, 1, 0], [1, -1, 0], [-1, 1, 0], [-1, -1, 0]], dtype=dtype + ) + positions /= 2.0 + charges = torch.tensor([1.0, -1, -1, 1], dtype=dtype) + potentials = charges * (1.0 / SQRT2 - 2.0) + + elif molecule_name == "square_positive": + types, positions, charges, potentials = define_molecule("square") + charges = torch.tensor([1.0, 1, 1, 1], dtype=dtype) + potentials = (2.0 + 1.0 / SQRT2) * torch.ones(4, dtype=dtype) + + elif molecule_name == "square_negative": + types, positions, charges, potentials = define_molecule("square_positive") + charges *= -1.0 + potentials *= -1.0 # Tetrahedra - elif molecule_name == 'tetrahedron': - types = torch.tensor([1,1,1,1]) - positions = torch.tensor([[0.,0,0],[1,0,0],[1/2,SQRT3/2,0],[1/2,SQRT3/6,SQRT2/SQRT3]], dtype=dtype) - charges = torch.tensor([1.,-1,1,-1], dtype=dtype) + elif molecule_name == "tetrahedron": + types = torch.tensor([1, 1, 1, 1]) + positions = torch.tensor( + [ + [0.0, 0, 0], + [1, 0, 0], + [1 / 2, SQRT3 / 2, 0], + [1 / 2, SQRT3 / 6, SQRT2 / SQRT3], + ], + dtype=dtype, + ) + charges = torch.tensor([1.0, -1, 1, -1], dtype=dtype) potentials = -charges - elif molecule_name == 'tetrahedron_positive': - types, positions, charges, potentials = define_molecule('tetrahedron') + elif molecule_name == "tetrahedron_positive": + types, positions, charges, potentials = define_molecule("tetrahedron") charges = torch.ones(4, dtype=dtype) potentials = 3 * torch.ones(4, dtype=dtype) - - elif molecule_name == 'tetrahedron_negative': - types, positions, charges, potentials = define_molecule('tetrahedron_positive') - charges *= -1. - potentials *= -1. - + + elif molecule_name == "tetrahedron_negative": + types, positions, charges, potentials = define_molecule("tetrahedron_positive") + charges *= -1.0 + potentials *= -1.0 + return types, positions, charges, potentials @@ -95,21 +110,21 @@ def generate_orthogonal_transformations(): # second rotation matrix: rotation by angle phi around z-axis phi = 0.82321 - rot_2 = torch.zeros((3,3), dtype=dtype) - rot_2[0,0] = rot_2[1,1] = math.cos(phi) - rot_2[0,1] = -math.sin(phi) - rot_2[1,0] = math.sin(phi) - rot_2[2,2] = 1. + rot_2 = torch.zeros((3, 3), dtype=dtype) + rot_2[0, 0] = rot_2[1, 1] = math.cos(phi) + rot_2[0, 1] = -math.sin(phi) + rot_2[1, 0] = math.sin(phi) + rot_2[2, 2] = 1.0 # third rotation matrix: second matrix followed by rotation by angle theta around y theta = 1.23456 - rot_3 = torch.zeros((3,3), dtype=dtype) - rot_3[0,0] = rot_3[2,2] = math.cos(theta) - rot_3[0,2] = math.sin(theta) - rot_3[2,0] = -math.sin(theta) - rot_3[1,1] = 1. + rot_3 = torch.zeros((3, 3), dtype=dtype) + rot_3[0, 0] = rot_3[2, 2] = math.cos(theta) + rot_3[0, 2] = math.sin(theta) + rot_3[2, 0] = -math.sin(theta) + rot_3[1, 1] = 1.0 rot_3 = rot_3 @ rot_2 - + # add additional orthogonal transformations by combining inversion transformations = [rot_1, rot_2, rot_3, -rot_1, -rot_3] @@ -120,19 +135,19 @@ def generate_orthogonal_transformations(): return transformations - -molecules = ['dimer', 'triangle', 'square', 'tetrahedron'] -molecule_charges = ['', '_positive', '_negative'] -scaling_factors = torch.tensor([0.079, 1., 5.54], dtype=torch.float64) +molecules = ["dimer", "triangle", "square", "tetrahedron"] +molecule_charges = ["", "_positive", "_negative"] +scaling_factors = torch.tensor([0.079, 1.0, 5.54], dtype=torch.float64) orthogonal_transformations = generate_orthogonal_transformations() + + @pytest.mark.parametrize("molecule", molecules) @pytest.mark.parametrize("molecule_charge", molecule_charges) @pytest.mark.parametrize("scaling_factor", scaling_factors) @pytest.mark.parametrize("orthogonal_transformation", orthogonal_transformations) -def test_coulomb_exact(molecule, - molecule_charge, - scaling_factor, - orthogonal_transformation): +def test_coulomb_exact( + molecule, molecule_charge, scaling_factor, orthogonal_transformation +): """ Check that the Coulomb potentials obtained from the calculators match the correct value for simple toy systems. @@ -143,7 +158,7 @@ def test_coulomb_exact(molecule, # Call Ewald potential class without specifying any of the convergence parameters # so that they are chosen by default (in a structure-dependent way) DP = DirectPotential() - + # Compute potential at the position of the atoms for the specified structure molecule_name = molecule + molecule_charge types, positions, charges, ref_potentials = define_molecule(molecule_name) @@ -151,4 +166,4 @@ def test_coulomb_exact(molecule, potentials = DP.compute(types, positions, charges=charges) ref_potentials /= scaling_factor - torch.testing.assert_close(potentials, ref_potentials, atol=2e-15, rtol=1e-14) \ No newline at end of file + torch.testing.assert_close(potentials, ref_potentials, atol=2e-15, rtol=1e-14) diff --git a/tests/calculators/test_values_periodic.py b/tests/calculators/test_values_periodic.py index 1c24f625..c97ac264 100644 --- a/tests/calculators/test_values_periodic.py +++ b/tests/calculators/test_values_periodic.py @@ -220,6 +220,8 @@ def define_crystal(crystal_name="CsCl"): neutral_crystals = ["CsCl", "NaCl_primitive", "NaCl_cubic", "zincblende", "wurtzite"] # neutral_crystals = ['CsCl'] scaling_factors = torch.tensor([1 / 2.0353610, 1.0, 3.4951291], dtype=torch.float64) + + @pytest.mark.parametrize("crystal_name", neutral_crystals) @pytest.mark.parametrize("scaling_factor", scaling_factors) def test_madelung(crystal_name, scaling_factor): @@ -242,7 +244,7 @@ def test_madelung(crystal_name, scaling_factor): energies = potentials * charges energies_ref = -torch.ones_like(energies) * madelung_reference / scaling_factor - torch.testing.assert_close(energies, energies_ref, atol=0.0, rtol=3.1e-6) + torch.testing.assert_close(energies, energies_ref, atol=0.0, rtol=3.2e-6) wigner_crystals = [ @@ -252,7 +254,7 @@ def test_madelung(crystal_name, scaling_factor): "wigner_bcc", "wigner_bcc_cubiccell", ] -wigner_crystal = ['wigner_sc'] +wigner_crystal = ["wigner_sc"] scaling_factors = torch.tensor([0.4325, 1.0, 2.0353610], dtype=torch.float64) diff --git a/tests/calculators/test_workflow_direct.py b/tests/calculators/test_workflow_direct.py index b17ace1c..900ac4fa 100644 --- a/tests/calculators/test_workflow_direct.py +++ b/tests/calculators/test_workflow_direct.py @@ -20,7 +20,6 @@ def cscl_system(): """CsCl crystal. Same as in the madelung test""" types = torch.tensor([55, 17]) positions = torch.tensor([[0, 0, 0], [0.5, 0.5, 0.5]]) - cell = torch.eye(3) return types, positions @@ -33,8 +32,7 @@ def cscl_system_with_charges(): # Initialize the calculators. For now, only the DirectPotential is implemented. def descriptor() -> DirectPotential: - return DirectPotential( - ) + return DirectPotential() def test_forward(): @@ -77,7 +75,6 @@ def test_operation_as_torch_script(): check_operation(scripted) - def test_single_frame(): values = descriptor().compute(*cscl_system()) assert_close( @@ -102,7 +99,8 @@ def test_single_frame_with_charges(): def test_multi_frame(): types, positions = cscl_system() l_values = descriptor().compute( - types=[types, types], positions=[positions, positions]) + types=[types, types], positions=[positions, positions] + ) for values in l_values: assert_close( MADELUNG_CSCL, @@ -123,6 +121,7 @@ def test_types_error(): with pytest.raises(ValueError, match=match): descriptor().compute(types=types, positions=positions) + def test_positions_error(): types = torch.tensor([1, 2]) positions = torch.zeros( @@ -138,11 +137,9 @@ def test_positions_error(): descriptor().compute(types=types, positions=positions) - def test_charges_error_dimension_mismatch(): types = torch.tensor([1, 2]) positions = torch.zeros((2, 3)) - cell = torch.eye(3) charges = torch.zeros((1, 2)) # This should have the same first dimension as types match = ( @@ -151,22 +148,17 @@ def test_charges_error_dimension_mismatch(): ) with pytest.raises(ValueError, match=match): - descriptor().compute( - types=types, positions=positions, charges=charges - ) + descriptor().compute(types=types, positions=positions, charges=charges) def test_charges_error_length_mismatch(): types = [torch.tensor([1, 2]), torch.tensor([1, 2, 3])] positions = [torch.zeros((2, 3)), torch.zeros((3, 3))] - cell = torch.eye(3) charges = [torch.zeros(2, 1)] # This should have the same length as types match = "The number of `types` and `charges` tensors must be the same, got 2 and 1." with pytest.raises(ValueError, match=match): - descriptor().compute( - types=types, positions=positions, charges=charges - ) + descriptor().compute(types=types, positions=positions, charges=charges) def test_dtype_device(): @@ -183,6 +175,7 @@ def test_dtype_device(): assert potential.dtype == dtype assert potential.device.type == device + def test_inconsistent_device_charges(): """Test if the chages and positions have inconsistent device and error is raised.""" types = torch.tensor([1], device="cpu") diff --git a/tests/calculators/test_workflow_ewald.py b/tests/calculators/test_workflow_ewald.py index 79891d1d..304a1c02 100644 --- a/tests/calculators/test_workflow_ewald.py +++ b/tests/calculators/test_workflow_ewald.py @@ -242,9 +242,8 @@ def test_inconsistent_device(): EP = EwaldPotential(atomic_smearing=0.2) - match = ( - '`types`, `positions`, and `cell` must be on the same device, got cpu, cpu and meta.' - ) + match = "`types`, `positions`, and `cell` must be on the same device, " + match += "got cpu, cpu and meta." with pytest.raises(ValueError, match=match): EP.compute(types=types, positions=positions, cell=cell) diff --git a/tests/calculators/test_workflow_mesh.py b/tests/calculators/test_workflow_mesh.py index dec15f24..7944775d 100644 --- a/tests/calculators/test_workflow_mesh.py +++ b/tests/calculators/test_workflow_mesh.py @@ -88,9 +88,10 @@ def test_operation_as_python(): # Similar to the above, but also testing that the code can be compiled as a torch script -def test_operation_as_torch_script(): - scripted = torch.jit.script(descriptor()) - check_operation(scripted) + +# def test_operation_as_torch_script(): +# scripted = torch.jit.script(descriptor()) +# check_operation(scripted) def test_single_frame(): @@ -244,6 +245,7 @@ def test_inconsistent_dtype(): with pytest.raises(ValueError, match=match): MP.compute(types=types, positions=positions, cell=cell) + def test_inconsistent_device(): """Test if the cell and positions have inconsistent device and error is raised.""" types = torch.tensor([1], device="cpu") @@ -252,9 +254,8 @@ def test_inconsistent_device(): MP = MeshPotential(atomic_smearing=0.2) - match = ( - '`types`, `positions`, and `cell` must be on the same device, got cpu, cpu and meta.' - ) + match = "`types`, `positions`, and `cell` must be on the same device, got cpu, cpu " + match += "and meta." with pytest.raises(ValueError, match=match): MP.compute(types=types, positions=positions, cell=cell) diff --git a/tests/lib/test_potentials.py b/tests/lib/test_potentials.py new file mode 100644 index 00000000..e420337c --- /dev/null +++ b/tests/lib/test_potentials.py @@ -0,0 +1,233 @@ +import pytest +import torch +from scipy.special import expi +from torch.special import erf, erfc +from torch.testing import assert_close + +from meshlode.lib import InversePowerLawPotential + + +def gamma(x): + return torch.exp(torch.special.gammaln(x)) + + +# Define precision of floating point variables +dtype = torch.float64 + +# Define range of exponents covering relevant special values and more general +# floating point values beyond this range. The last four of which are inspired by: +# von Klitzing constant R_K = 2.5812... * 1e4 Ohm +# Josephson constant K_J = 4.8359... * 1e9 Hz/V +# Gravitational constant G = 6.6743... * 1e-11 m3/kgs2 +# Electron mass m_e = 9.1094 * 1e-31 kg +ps = [1.0, 2.0, 3.0, 6.0] + [0.12345, 0.54321, 2.581304, 4.835909, 6.674311, 9.109431] + +# Define range of smearing parameters covering relevant values +smearings = [0.1, 0.5, 1.0, 1.56] + +# Define realistic range of distances on which the potentials will be evaluated +dist_min = 1.41e-2 +dist_max = 27.18 +num_dist = 200 +dists = torch.linspace(dist_min, dist_max, num_dist, dtype=dtype) +dists_sq = dists**2 + +# Define realistic range of wave vectors k on which the Fourier-transformed potentials +# will be evaluated +k_min = 2 * torch.pi / 50.0 +k_max = 2 * torch.pi / 0.1 +num_k = 200 +ks = torch.linspace(k_min, k_max, num_k, dtype=dtype) +ks_sq = ks**2 + +# Define machine epsilon +machine_epsilon = torch.finfo(dtype).eps + +# Other shortcuts +SQRT2 = torch.sqrt(torch.tensor(2.0, dtype=dtype)) +PI = torch.tensor(torch.pi, dtype=dtype) + + +@pytest.mark.parametrize("exponent", ps) +def test_potential_from_squared_argument(exponent): + """ + The potentials class can either compute the potential by taking the distance r or + its square r^2 as an argument. This test makes sure that both implementations agree. + """ + exponent = torch.tensor(exponent, dtype=dtype) + + # Compute diverse potentials for this inverse power law + ipl = InversePowerLawPotential(exponent=exponent) + potential_from_dist = ipl.potential_from_dist(dists) + potential_from_dist_sq = ipl.potential_from_dist_sq(dists_sq) + + # Test agreement between implementations taking r vs r**2 as argument + atol = 3e-16 + rtol = 2 * machine_epsilon + assert_close(potential_from_dist, potential_from_dist_sq, rtol=rtol, atol=atol) + + +@pytest.mark.parametrize("smearing", smearings) +@pytest.mark.parametrize("exponent", ps) +def test_sr_lr_split(exponent, smearing): + """ + This test verifies the splitting 1/r^p = V_SR(r) + V_LR(r), meaning that it tests + whether the sum of the SR and LR parts combine to the standard inverse power-law + potential. + """ + exponent = torch.tensor(exponent, dtype=dtype) + smearing = torch.tensor(smearing, dtype=dtype) + + # Compute diverse potentials for this inverse power law + ipl = InversePowerLawPotential(exponent=exponent) + potential_from_dist = ipl.potential_from_dist(dists) + potential_sr_from_dist = ipl.potential_sr_from_dist(dists, smearing=smearing) + potential_lr_from_dist = ipl.potential_lr_from_dist(dists, smearing=smearing) + potential_from_sum = potential_sr_from_dist + potential_lr_from_dist + + # Check that the sum of the SR and LR parts is equivalent to the original 1/r^p + # potential. Note that the relative errors get particularly large for bigger + # interaction exponents. If only p=1 is used, rtol can be reduced to about 3.5 times + # the machine epsilon. + atol = 3e-16 + rtol = 3 * machine_epsilon + assert_close(potential_from_dist, potential_from_sum, rtol=rtol, atol=atol) + + +@pytest.mark.parametrize("exponent", [1.0, 2.0, 3.0]) +@pytest.mark.parametrize("smearing", smearings) +def test_exact_sr(exponent, smearing): + """ + Test that the implemented formula which works for general interaction exponents p + does indeed reduce to the correct expression for the special case of the Coulomb + interaction (p=1) as well as p=2,3. This test covers the SR part of the potential. + Note that the relative tolerance could be greatly reduced if the lower end of the + distance range (the variable dist_min) is increased, since the potential has a + (removable) singularity at r=0. + """ + exponent = torch.tensor(exponent, dtype=dtype) + smearing = torch.tensor(smearing, dtype=dtype) + + # Compute SR part of Coulomb potential using the potentials class working for any + # exponent + ipl = InversePowerLawPotential(exponent=exponent) + potential_sr_from_dist = ipl.potential_sr_from_dist(dists, smearing=smearing) + + # Compute exact analytical expression obtained for relevant exponents + potential_1 = erfc(dists / SQRT2 / smearing) / dists + potential_2 = torch.exp(-0.5 * dists_sq / smearing**2) / dists_sq + if exponent == 1.0: + potential_exact = potential_1 + elif exponent == 2.0: + potential_exact = potential_2 + elif exponent == 3.0: + prefac = SQRT2 / torch.sqrt(PI) / smearing + potential_exact = potential_1 / dists_sq + prefac * potential_2 + + # Compare results. Large tolerance due to singular division + rtol = 10 * machine_epsilon + atol = 3e-16 + assert_close(potential_sr_from_dist, potential_exact, rtol=rtol, atol=atol) + + +@pytest.mark.parametrize("exponent", [1.0, 2.0, 3.0]) +@pytest.mark.parametrize("smearing", smearings) +def test_exact_lr(exponent, smearing): + """ + Test that the implemented formula which works for general interaction exponents p + does indeed reduce to the correct expression for the special case of the Coulomb + interaction (p=1) as well as p=2,3. This test covers the LR part of the potential. + Note that the relative tolerance could be greatly reduced if the lower end of the + distance range (the variable dist_min) is increased, since the potential has a + (removable) singularity at r=0. + """ + exponent = torch.tensor(exponent, dtype=dtype) + smearing = torch.tensor(smearing, dtype=dtype) + + # Compute LR part of Coulomb potential using the potentials class working for any + # exponent + ipl = InversePowerLawPotential(exponent=exponent) + potential_lr_from_dist = ipl.potential_lr_from_dist(dists, smearing=smearing) + + # Compute exact analytical expression obtained for relevant exponents + potential_1 = erf(dists / SQRT2 / smearing) / dists + potential_2 = torch.exp(-0.5 * dists_sq / smearing**2) / dists_sq + if exponent == 1.0: + potential_exact = potential_1 + elif exponent == 2.0: + potential_exact = 1 / dists_sq - potential_2 + elif exponent == 3.0: + prefac = SQRT2 / torch.sqrt(PI) / smearing + potential_exact = potential_1 / dists_sq - prefac * potential_2 + + # Compare results. Large tolerance due to singular division + rtol = 7e-12 + atol = 3e-16 + assert_close(potential_lr_from_dist, potential_exact, rtol=rtol, atol=atol) + + +@pytest.mark.parametrize("exponent", [1.0, 2.0]) +@pytest.mark.parametrize("smearing", smearings) +def test_exact_fourier(exponent, smearing): + """ + Test that the implemented formula which works for general interaction exponents p + does indeed reduce to the correct expression for the special case of the Coulomb + interaction (p=1) as well as p=2,3. This test covers the Fourier-transform. + Note that the relative tolerance could be greatly reduced if the lower end of the + distance range (the variable dist_min) is increased, since the potential has a + (removable) singularity at r=0. + """ + exponent = torch.tensor(exponent, dtype=dtype) + smearing = torch.tensor(smearing, dtype=dtype) + + # Compute LR part of Coulomb potential using the potentials class working for any + # exponent + ipl = InversePowerLawPotential(exponent=exponent) + fourier_from_class = ipl.potential_fourier_from_k_sq(ks_sq, smearing=smearing) + + # Compute exact analytical expression obtained for relevant exponents + if exponent == 1.0: + fourier_exact = 4 * PI / ks_sq * torch.exp(-0.5 * smearing**2 * ks_sq) + elif exponent == 2.0: + fourier_exact = 2 * PI**2 / ks * erfc(smearing * ks / SQRT2) + elif exponent == 3.0: + fourier_exact = -2 * PI * expi(-0.5 * smearing**2 * ks_sq) + + # Compare results. Large tolerance due to singular division + rtol = 10 * machine_epsilon + atol = 7e-16 + assert_close(fourier_from_class, fourier_exact, rtol=rtol, atol=atol) + + +@pytest.mark.parametrize("smearing", smearings) +@pytest.mark.parametrize("exponent", ps[:-1]) # for p=9.11, the results are unstable +def test_lr_value_at_zero(exponent, smearing): + """ + The LR part of the potential should no longer have a singularity as r-->0. Instead, + the value of the potential should converge to an analytical expression that depends + on both the exponent p and smearing sigma,namely + V_LR(0) = Gamma((p+2)/2) / (2*sigma**2)**(p/2) + + Note that in general, V_LR as r-->0 is a limit of the form 0/0, and hence + numerically unstable. This issue is more severe for exponents p that are large, + which is why the biggest exponent is excluded from this test. By restricting to even + smaller values of p, one could set the tolerance in this test to an even lower + value. + + In practice, this should not be such an issue since no two atoms should approach + each other until their distance is 1e-5 (the value used here). + """ + exponent = torch.tensor(exponent, dtype=dtype) + smearing = torch.tensor(smearing, dtype=dtype) + + # Get atomic density at tiny distance + dist_small = torch.tensor(1e-8) + ipl = InversePowerLawPotential(exponent=exponent) + potential_close_to_zero = ipl.potential_lr_from_dist(dist_small, smearing=smearing) + + # Compare to + exact_value = ( + 1.0 / (2 * smearing**2) ** (exponent / 2) / gamma(exponent / 2 + 1.0) + ) + relerr = torch.abs(potential_close_to_zero - exact_value) / exact_value + assert relerr.item() < 3e-14 From 809d85ea29c9ddd83ff4a09bfacc9ff1f9f802a1 Mon Sep 17 00:00:00 2001 From: Philip Loche Date: Thu, 4 Jul 2024 11:29:52 +0200 Subject: [PATCH 12/35] linting --- src/meshlode/lib/potentials.py | 4 +--- src/meshlode/metatensor/meshewald.py | 6 ++++-- tests/lib/test_potentials.py | 4 +--- 3 files changed, 6 insertions(+), 8 deletions(-) diff --git a/src/meshlode/lib/potentials.py b/src/meshlode/lib/potentials.py index dc80ccc7..6af074f8 100644 --- a/src/meshlode/lib/potentials.py +++ b/src/meshlode/lib/potentials.py @@ -121,9 +121,7 @@ def potential_fourier_from_k_sq( smearing parameter corresponds to the "width" of the Gaussian. """ peff = (3 - self.exponent) / 2 - prefac = ( - (math.pi) ** 1.5 / gamma(self.exponent / 2) * (2 * smearing**2) ** peff - ) + prefac = (math.pi) ** 1.5 / gamma(self.exponent / 2) * (2 * smearing**2) ** peff x = 0.5 * smearing**2 * k_sq fourier = prefac * gammaincc(peff, x) / x**peff * gamma(peff) diff --git a/src/meshlode/metatensor/meshewald.py b/src/meshlode/metatensor/meshewald.py index 9351aee2..baff5519 100644 --- a/src/meshlode/metatensor/meshewald.py +++ b/src/meshlode/metatensor/meshewald.py @@ -77,11 +77,13 @@ def compute( # Check that the lengths of the provided lists agree if (neighbor_indices is not None) and len(neighbor_indices) != len(systems): raise ValueError( - f"Numbers of systems (= {len(systems)}) needs to match number of neighbor lists (= {len(neighbor_indices)})" + f"Numbers of systems (= {len(systems)}) needs to match number of " + f"neighbor lists (= {len(neighbor_indices)})" ) if (neighbor_shifts is not None) and len(neighbor_shifts) != len(systems): raise ValueError( - f"Numbers of systems (= {len(systems)}) needs to match number of neighbor shifts (= {len(neighbor_shifts)})" + f"Numbers of systems (= {len(systems)}) needs to match number of " + f"neighbor shifts (= {len(neighbor_shifts)})" ) if len(systems) > 1: diff --git a/tests/lib/test_potentials.py b/tests/lib/test_potentials.py index e420337c..4d369fbd 100644 --- a/tests/lib/test_potentials.py +++ b/tests/lib/test_potentials.py @@ -226,8 +226,6 @@ def test_lr_value_at_zero(exponent, smearing): potential_close_to_zero = ipl.potential_lr_from_dist(dist_small, smearing=smearing) # Compare to - exact_value = ( - 1.0 / (2 * smearing**2) ** (exponent / 2) / gamma(exponent / 2 + 1.0) - ) + exact_value = 1.0 / (2 * smearing**2) ** (exponent / 2) / gamma(exponent / 2 + 1.0) relerr = torch.abs(potential_close_to_zero - exact_value) / exact_value assert relerr.item() < 3e-14 From b1556ac9f3f47ddb263b46933d7fbbbd2d19e962 Mon Sep 17 00:00:00 2001 From: Philip Loche Date: Thu, 4 Jul 2024 14:46:49 +0200 Subject: [PATCH 13/35] cleanup base classes --- src/meshlode/calculators/calculator_base.py | 360 ++++++++++++------ .../calculators/calculator_base_periodic.py | 159 +------- src/meshlode/calculators/direct.py | 33 +- src/meshlode/calculators/ewald.py | 34 +- src/meshlode/calculators/mesh.py | 27 +- src/meshlode/calculators/meshewald.py | 12 +- src/meshlode/lib/potentials.py | 8 +- src/meshlode/metatensor/meshewald.py | 5 +- src/meshlode/metatensor/meshpotential.py | 7 +- tests/calculators/test_workflow_direct.py | 5 +- tests/calculators/test_workflow_ewald.py | 5 +- tests/calculators/test_workflow_mesh.py | 5 +- tests/calculators/test_workflow_meshewald.py | 5 +- tests/metatensor/test_madelung.py | 12 +- 14 files changed, 325 insertions(+), 352 deletions(-) diff --git a/src/meshlode/calculators/calculator_base.py b/src/meshlode/calculators/calculator_base.py index 1f99fefb..37537a0c 100644 --- a/src/meshlode/calculators/calculator_base.py +++ b/src/meshlode/calculators/calculator_base.py @@ -1,17 +1,11 @@ -from typing import List, Optional, Union +import warnings +from typing import List, Optional, Tuple, Union import torch from meshlode.lib import InversePowerLawPotential -def get_default_exponent(): - return torch.tensor(1.0) - - -default_exponent = get_default_exponent() - - @torch.jit.script def _1d_tolist(x: torch.Tensor) -> List[int]: """Auxilary function to convert 1d torch tensor to list of integers.""" @@ -46,7 +40,7 @@ class CalculatorBase(torch.nn.Module): def __init__( self, all_types: Optional[List[int]] = None, - exponent: Optional[torch.Tensor] = default_exponent, + exponent: float = 1.0, ): super().__init__() @@ -58,59 +52,107 @@ def __init__( self.exponent = exponent self.potential = InversePowerLawPotential(exponent=exponent) - # This function is kept to keep this library compatible with the broader pytorch - # infrastructure, which require a "forward" function. We name this function - # "compute" instead, for compatibility with other COSMO software. - def forward( - self, - types: Union[List[torch.Tensor], torch.Tensor], - positions: Union[List[torch.Tensor], torch.Tensor], - charges: Optional[Union[List[torch.Tensor], torch.Tensor]] = None, - ) -> Union[torch.Tensor, List[torch.Tensor]]: - """forward just calls :py:meth:`CalculatorModule.compute`""" - return self.compute(types=types, positions=positions, charges=charges) + def _get_requested_types(self, types: List[torch.Tensor]) -> List[int]: + """Extract a list of all unique and present types from the list of types.""" + all_types = torch.hstack(types) + types_requested = _1d_tolist(torch.unique(all_types)) - def compute( - self, - types: Union[List[torch.Tensor], torch.Tensor], - positions: Union[List[torch.Tensor], torch.Tensor], - charges: Optional[Union[List[torch.Tensor], torch.Tensor]] = None, - ) -> Union[torch.Tensor, List[torch.Tensor]]: - """Compute potential for all provided "systems" stacked inside list. + if self.all_types is not None: + if not _is_subset(types_requested, self.all_types): + raise ValueError( + f"Global list of types {self.all_types} does not contain all " + f"types for the provided systems {types_requested}." + ) + return self.all_types + else: + return types_requested - The computation is performed on the same ``device`` as ``systems`` is stored on. - The ``dtype`` of the output tensors will be the same as the input. + def _one_hot_charges( + self, + types: torch.Tensor, + requested_types: List[int], + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + ) -> torch.Tensor: + n_types = len(requested_types) + one_hot_charges = torch.zeros((len(types), n_types), dtype=dtype, device=device) - :param types: single or list of 1D tensor of integer representing the - particles identity. For atoms, this is typically their atomic numbers. - :param positions: single or 2D tensor of shape (len(types), 3) containing the - Cartesian positions of all particles in the system. - :param charges: Optional single or list of 2D tensor of shape (len(types), n), + for i_type, atomic_type in enumerate(requested_types): + one_hot_charges[types == atomic_type, i_type] = 1.0 - :return: List of torch Tensors containing the potentials for all frames and all - atoms. Each tensor in the list is of shape (n_atoms, n_types), where - n_types is the number of types in all systems combined. If the input was - a single system only a single torch tensor with the potentials is returned. + return one_hot_charges - IMPORTANT: If multiple types are present, the different "types-channels" - are ordered according to atomic number. For example, if a structure contains - a water molecule with atoms 0, 1, 2 being of types O, H, H, then for this - system, the feature tensor will be of shape (3, 2) = (``n_atoms``, - ``n_types``), where ``features[0, 0]`` is the potential at the position of - the Oxygen atom (atom 0, first index) generated by the HYDROGEN(!) atoms, - while ``features[0,1]`` is the potential at the position of the Oxygen atom - generated by the Oxygen atom(s). - """ - # make sure compute function works if only a single tensor are provided as input + def _validate_compute_parameters( + self, + types: Union[List[torch.Tensor], torch.Tensor], + positions: Union[List[torch.Tensor], torch.Tensor], + cell: Union[None, List[torch.Tensor], torch.Tensor], + charges: Union[None, List[torch.Tensor], torch.Tensor], + neighbor_indices: Union[None, List[torch.Tensor], torch.Tensor], + neighbor_shifts: Union[None, List[torch.Tensor], torch.Tensor], + ) -> Tuple[ + List[torch.Tensor], + List[torch.Tensor], + Union[List[None], List[torch.Tensor]], + List[torch.Tensor], + Union[List[None], List[torch.Tensor]], + Union[List[None], List[torch.Tensor]], + ]: + # validate types and positions if not isinstance(types, list): types = [types] if not isinstance(positions, list): positions = [positions] - # Check that all inputs are consistent - # We don't require and test that all dtypes and devices are consistent for a - # list of inputs. Each "frame" is processed independently. - for types_single, positions_single in zip(types, positions): + if len(types) != len(positions): + raise ValueError( + f"Got inconsistent lengths of types ({len(types)}) " + f"positions ({len(positions)})" + ) + + if cell is None: + cell = len(types) * [None] + elif not isinstance(cell, list): + cell = [cell] + + if len(types) != len(cell): + raise ValueError( + f"Got inconsistent lengths of types ({len(types)}) and " + f"cell ({len(cell)})" + ) + + if neighbor_indices is None: + neighbor_indices = len(types) * [None] + elif not isinstance(neighbor_indices, list): + neighbor_indices = [neighbor_indices] + + if len(types) != len(neighbor_indices): + raise ValueError( + f"Got inconsistent lengths of types ({len(types)}) and " + f"neighbor_indices ({len(neighbor_indices)})" + ) + + if neighbor_shifts is None: + neighbor_shifts = len(types) * [None] + elif not isinstance(neighbor_shifts, list): + neighbor_shifts = [neighbor_shifts] + + if len(types) != len(neighbor_shifts): + raise ValueError( + f"Got inconsistent lengths of types ({len(types)}) and " + f"neighbor_indices ({len(neighbor_shifts)})" + ) + + # Check that all inputs are consistent. We don't require and test that all + # dtypes and devices are consistent if a list of inputs. Each single "frame" is + # processed independently. + for ( + types_single, + positions_single, + cell_single, + neighbor_indices_single, + neighbor_shifts_single, + ) in zip(types, positions, cell, neighbor_indices, neighbor_shifts): if len(types_single.shape) != 1: raise ValueError( "each `types` must be a 1 dimensional tensor, got at least " @@ -123,13 +165,55 @@ def compute( f"one tensor with shape {list(positions_single.shape)}" ) - if positions_single.device != types_single.device: + if types_single.device != positions_single.device: raise ValueError( - "`types` and `positions` must be on the same device, got " - f"{types_single.device}, {positions_single.device}" + f"Inconsistent devices of types ({types_single.device}) and " + f"positions ({positions_single.device})" ) - requested_types = self._get_requested_types(types) + if cell_single is not None: + if cell_single.shape != (3, 3): + raise ValueError( + "each `cell` must be a (3 x 3) tensor, got at least " + f"one tensor with shape {list(cell_single.shape)}" + ) + + if cell_single.dtype != positions_single.dtype: + raise ValueError( + "`cell` must be have the same dtype as `positions`, got " + f"{cell_single.dtype} and {positions_single.dtype}" + ) + + if types_single.device != cell_single.device: + raise ValueError( + f"Inconsistent devices of types ({types_single.device}) and " + f"cell ({cell_single.device})" + ) + + if type(neighbor_indices_single) is not type(neighbor_indices_single): + raise ValueError( + f"Inconsistent of neighbor_indices " + f"({type(neighbor_indices_single)}) and neighbor_indices " + f"({neighbor_indices_single})" + ) + + if neighbor_indices_single is not None: + # TODO validate shape and dtype + + if types_single.device != neighbor_indices_single.device: + raise ValueError( + f"Inconsistent devices of types ({types_single.device}) and " + f"neighbor_indices ({neighbor_indices_single.device})" + ) + + if neighbor_shifts_single is not None: + # TODO validate shape and dtype + + if types_single.device != neighbor_shifts_single.device: + raise ValueError( + f"Inconsistent devices of types ({types_single.device}) and " + f"neighbor_shifts_single ({neighbor_shifts_single.device})" + ) # If charges are not provided, we assume that all types are treated separately if charges is None: @@ -138,7 +222,7 @@ def compute( # One-hot encoding of charge information charges_single = self._one_hot_charges( types=types_single, - requested_types=requested_types, + requested_types=self._get_requested_types(types), dtype=positions_single.dtype, device=positions_single.device, ) @@ -172,12 +256,39 @@ def compute( f"{charges[0].device} and {positions[0].device}." ) + return types, positions, cell, charges, neighbor_indices, neighbor_shifts + + def _compute_impl( + self, + types: Union[List[torch.Tensor], torch.Tensor], + positions: Union[List[torch.Tensor], torch.Tensor], + cell: Union[List[torch.Tensor], torch.Tensor] = None, + charges: Optional[Union[List[torch.Tensor], torch.Tensor]] = None, + neighbor_indices: Union[List[torch.Tensor], torch.Tensor] = None, + neighbor_shifts: Union[List[torch.Tensor], torch.Tensor] = None, + ) -> Union[torch.Tensor, List[torch.Tensor]]: + types, positions, cell, charges, neighbor_indices, neighbor_shifts = ( + self._validate_compute_parameters( + types, positions, cell, charges, neighbor_indices, neighbor_shifts + ) + ) potentials = [] - for positions_single, charges_single in zip(positions, charges): + + for ( + positions_single, + cell_single, + charges_single, + neighbor_indices_single, + neighbor_shifts_single, + ) in zip(positions, cell, charges, neighbor_indices, neighbor_shifts): # Compute the potentials potentials.append( self._compute_single_system( - positions=positions_single, charges=charges_single + positions=positions_single, + charges=charges_single, + cell=cell_single, + neighbor_indices=neighbor_indices_single, + neighbor_shifts=neighbor_shifts_single, ) ) @@ -186,68 +297,89 @@ def compute( else: return potentials - def _get_requested_types(self, types: List[torch.Tensor]) -> List[int]: - """Extract a list of all unique and present types from the list of types.""" - all_types = torch.hstack(types) - types_requested = _1d_tolist(torch.unique(all_types)) + def compute( + self, + types: Union[List[torch.Tensor], torch.Tensor], + positions: Union[List[torch.Tensor], torch.Tensor], + cell: Union[List[torch.Tensor], torch.Tensor] = None, + charges: Optional[Union[List[torch.Tensor], torch.Tensor]] = None, + neighbor_indices: Union[List[torch.Tensor], torch.Tensor] = None, + neighbor_shifts: Union[List[torch.Tensor], torch.Tensor] = None, + ) -> Union[torch.Tensor, List[torch.Tensor]]: + """Compute potential for all provided "systems" stacked inside list. - if self.all_types is not None: - if not _is_subset(types_requested, self.all_types): - raise ValueError( - f"Global list of types {self.all_types} does not contain all " - f"types for the provided systems {types_requested}." - ) - return self.all_types - else: - return types_requested + The computation is performed on the same ``device`` as ``systems`` is stored on. + The ``dtype`` of the output tensors will be the same as the input. - def _one_hot_charges( - self, - types: torch.Tensor, - requested_types: List[int], - dtype: Optional[torch.dtype] = None, - device: Optional[torch.device] = None, - ) -> torch.Tensor: - n_types = len(requested_types) - one_hot_charges = torch.zeros((len(types), n_types), dtype=dtype, device=device) + :param types: single or list of 1D tensor of integer representing the + particles identity. For atoms, this is typically their atomic numbers. + :param positions: single or 2D tensor of shape (len(types), 3) containing the + Cartesian positions of all particles in the system. + :param cell: Ignored. + :param charges: Optional single or list of 2D tensor of shape (len(types), n), + :param neighbor_indices: Optional single or list of 2D tensors of shape (2, n), + where n is the number of atoms. The 2 rows correspond to the indices of + the two atoms which are considered neighbors (e.g. within a cutoff distance) + :param neighbor_shifts: Optional single or list of 2D tensors of shape (3, n), + where n is the number of atoms. The 3 rows correspond to the shift indices + for periodic images. - for i_type, atomic_type in enumerate(requested_types): - one_hot_charges[types == atomic_type, i_type] = 1.0 + :return: List of torch Tensors containing the potentials for all frames and all + atoms. Each tensor in the list is of shape (n_atoms, n_types), where + n_types is the number of types in all systems combined. If the input was + a single system only a single torch tensor with the potentials is returned. - return one_hot_charges + IMPORTANT: If multiple types are present, the different "types-channels" + are ordered according to atomic number. For example, if a structure contains + a water molecule with atoms 0, 1, 2 being of types O, H, H, then for this + system, the feature tensor will be of shape (3, 2) = (``n_atoms``, + ``n_types``), where ``features[0, 0]`` is the potential at the position of + the Oxygen atom (atom 0, first index) generated by the HYDROGEN(!) atoms, + while ``features[0,1]`` is the potential at the position of the Oxygen atom + generated by the Oxygen atom(s). + """ + if cell is not None: + warnings.warn( + "`cell` parameter was proviced but will be ignored", stacklevel=2 + ) + + return self._compute_impl( + types=types, + positions=positions, + cell=cell, + charges=charges, + neighbor_indices=neighbor_indices, + neighbor_shifts=neighbor_shifts, + ) + + # This function is kept to keep MeshLODE compatible with the broader pytorch + # infrastructure, which require a "forward" function. We name this function + # "compute" instead, for compatibility with other COSMO software. + def forward( + self, + types: Union[List[torch.Tensor], torch.Tensor], + positions: Union[List[torch.Tensor], torch.Tensor], + cell: Union[List[torch.Tensor], torch.Tensor] = None, + charges: Optional[Union[List[torch.Tensor], torch.Tensor]] = None, + neighbor_indices: Union[List[torch.Tensor], torch.Tensor] = None, + neighbor_shifts: Union[List[torch.Tensor], torch.Tensor] = None, + ) -> Union[torch.Tensor, List[torch.Tensor]]: + """forward just calls :py:meth:`CalculatorModule.compute`""" + return self.compute( + types=types, + positions=positions, + cell=cell, + charges=charges, + neighbor_indices=neighbor_indices, + neighbor_shifts=neighbor_shifts, + ) def _compute_single_system( self, positions: torch.Tensor, + cell: Union[None, torch.Tensor], charges: torch.Tensor, - cell: Optional[torch.Tensor] = None, + neighbor_indices: Union[None, torch.Tensor], + neighbor_shifts: Union[None, torch.Tensor], ) -> torch.Tensor: - """ - Core of the calculator that actually implements the computation of the potential - using various algorithms. - - :param positions: torch.tensor of shape (n_atoms, 3). Contains the Cartesian - coordinates of the atoms. The implementation also works if the positions - are not contained within the unit cell. - :param charges: torch.tensor of shape `(n_atoms, n_channels)`. In the simplest - case, this would be a tensor of shape (n_atoms, 1) where charges[i,0] is the - charge of atom i. More generally, the potential for the same atom positions - is computed for n_channels independent meshes, and one can specify the - "charge" of each atom on each of the meshes independently. For standard LODE - that treats all (atomic) types separately, one example could be: If n_atoms - = 4 and the types are [Na, Cl, Cl, Na], one could set n_channels=2 and use - the one-hot encoding charges = torch.tensor([[1,0],[0,1],[0,1],[1,0]]) for - the charges. This would then separately compute the "Na" potential and "Cl" - potential. Subtracting these from each other, one could recover the more - standard electrostatic potential in which Na and Cl have charges of +1 and - -1, respectively. - :param cell: torch.tensor of shape `(3, 3)`. Describes the unit cell of the - structure, where cell[i] is the i-th basis vector. While redundant in this - particular implementation, the parameter is kept to keep the same inputs as - the other calculators. - - :returns: torch.tensor of shape `(n_atoms, n_channels)` containing the potential - at the position of each atom for the `n_channels` independent meshes separately. - """ - - return torch.zeros_like(charges) + raise NotImplementedError("only implemented in child classes") diff --git a/src/meshlode/calculators/calculator_base_periodic.py b/src/meshlode/calculators/calculator_base_periodic.py index 32f21243..0d64a4f4 100644 --- a/src/meshlode/calculators/calculator_base_periodic.py +++ b/src/meshlode/calculators/calculator_base_periodic.py @@ -12,29 +12,30 @@ class CalculatorBasePeriodic(CalculatorBase): name = "CalculatorBasePeriodic" - # Note that the base class also has this function, but with the parameter "cell" - # only as an option. For periodic implementations, "cell" is a strictly required - # parameter, which is why this function is implemented again. - # This function is kept to keep MeshLODE compatible with the broader pytorch - # infrastructure, which require a "forward" function. We name this function - # "compute" instead, for compatibility with other COSMO software. def forward( self, types: Union[List[torch.Tensor], torch.Tensor], positions: Union[List[torch.Tensor], torch.Tensor], - cell: Union[List[torch.Tensor], torch.Tensor], + cell: Union[List[torch.Tensor], torch.Tensor] = None, charges: Optional[Union[List[torch.Tensor], torch.Tensor]] = None, + neighbor_indices: Union[List[torch.Tensor], torch.Tensor] = None, + neighbor_shifts: Union[List[torch.Tensor], torch.Tensor] = None, ) -> Union[torch.Tensor, List[torch.Tensor]]: """forward just calls :py:meth:`CalculatorModule.compute`""" return self.compute( - types=types, positions=positions, cell=cell, charges=charges + types=types, + positions=positions, + cell=cell, + charges=charges, + neighbor_indices=neighbor_indices, + neighbor_shifts=neighbor_shifts, ) def compute( self, types: Union[List[torch.Tensor], torch.Tensor], positions: Union[List[torch.Tensor], torch.Tensor], - cell: Union[List[torch.Tensor], torch.Tensor], + cell: Union[List[torch.Tensor], torch.Tensor] = None, charges: Optional[Union[List[torch.Tensor], torch.Tensor]] = None, neighbor_indices: Union[List[torch.Tensor], torch.Tensor] = None, neighbor_shifts: Union[List[torch.Tensor], torch.Tensor] = None, @@ -74,132 +75,14 @@ def compute( while ``features[0,1]`` is the potential at the position of the Oxygen atom generated by the Oxygen atom(s). """ - # make sure compute function works if only a single tensor are provided as input - if not isinstance(types, list): - types = [types] - if not isinstance(positions, list): - positions = [positions] - if not isinstance(cell, list): - cell = [cell] - if (neighbor_indices is not None) and not isinstance(neighbor_indices, list): - neighbor_indices = [neighbor_indices] - if (neighbor_shifts is not None) and not isinstance(neighbor_shifts, list): - neighbor_shifts = [neighbor_shifts] - - # Check that all inputs are consistent - for types_single, positions_single, cell_single in zip(types, positions, cell): - if len(types_single.shape) != 1: - raise ValueError( - "each `types` must be a 1 dimensional tensor, got at least " - f"one tensor with {len(types_single.shape)} dimensions" - ) - - if positions_single.shape != (len(types_single), 3): - raise ValueError( - "each `positions` must be a (n_types x 3) tensor, got at least " - f"one tensor with shape {list(positions_single.shape)}" - ) - - if cell_single.shape != (3, 3): - raise ValueError( - "each `cell` must be a (3 x 3) tensor, got at least " - f"one tensor with shape {list(cell_single.shape)}" - ) - - if cell_single.dtype != positions_single.dtype: - raise ValueError( - "`cell` must be have the same dtype as `positions`, got " - f"{cell_single.dtype} and {positions_single.dtype}" - ) - - if ( - positions_single.device != types_single.device - or cell_single.device != types_single.device - ): - raise ValueError( - "`types`, `positions`, and `cell` must be on the same device, got " - f"{types_single.device}, {positions_single.device} and " - f"{cell_single.device}." - ) - - requested_types = self._get_requested_types(types) - - # If charges are not provided, we assume that all types are treated separately - if charges is None: - charges = [] - for types_single, positions_single in zip(types, positions): - # One-hot encoding of charge information - charges_single = self._one_hot_charges( - types=types_single, - requested_types=requested_types, - dtype=positions_single.dtype, - device=positions_single.device, - ) - charges.append(charges_single) - - # If charges are provided, we need to make sure that they are consistent with - # the provided types - else: - if not isinstance(charges, list): - charges = [charges] - if len(charges) != len(types): - raise ValueError( - "The number of `types` and `charges` tensors must be the same, " - f"got {len(types)} and {len(charges)}." - ) - for charges_single, types_single in zip(charges, types): - if charges_single.shape[0] != len(types_single): - raise ValueError( - "The first dimension of `charges` must be the same as the " - f"length of `types`, got {charges_single.shape[0]} and " - f"{len(types_single)}." - ) - if charges[0].dtype != positions[0].dtype: - raise ValueError( - "`charges` must be have the same dtype as `positions`, got " - f"{charges[0].dtype} and {positions[0].dtype}." - ) - if charges[0].device != positions[0].device: - raise ValueError( - "`charges` must be on the same device as `positions`, got " - f"{charges[0].device} and {positions[0].device}." - ) - # We don't require and test that all dtypes and devices are consistent if a list - # of inputs. Each "frame" is processed independently. - potentials = [] - - if neighbor_indices is None or neighbor_shifts is None: - for positions_single, cell_single, charges_single in zip( - positions, cell, charges - ): - # Compute the potentials - potentials.append( - self._compute_single_system( - positions=positions_single, - charges=charges_single, - cell=cell_single, - ) - ) - else: - for ( - positions_single, - cell_single, - charges_single, - neighbor_indices_single, - neighbor_shifts_single, - ) in zip(positions, cell, charges, neighbor_indices, neighbor_shifts): - # Compute the potentials - potentials.append( - self._compute_single_system( - positions=positions_single, - charges=charges_single, - cell=cell_single, - neighbor_indices=neighbor_indices_single, - neighbor_shifts=neighbor_shifts_single, - ) - ) - - if len(types) == 1: - return potentials[0] - else: - return potentials + if cell is None: + raise ValueError("cell must be provided") + + return self._compute_impl( + types=types, + positions=positions, + cell=cell, + charges=charges, + neighbor_indices=neighbor_indices, + neighbor_shifts=neighbor_shifts, + ) diff --git a/src/meshlode/calculators/direct.py b/src/meshlode/calculators/direct.py index c28b2a2d..8d0d4c3a 100644 --- a/src/meshlode/calculators/direct.py +++ b/src/meshlode/calculators/direct.py @@ -1,3 +1,5 @@ +from typing import Union + import torch from .calculator_base import CalculatorBase @@ -23,33 +25,11 @@ class DirectPotential(CalculatorBase): def _compute_single_system( self, positions: torch.Tensor, + cell: Union[None, torch.Tensor], charges: torch.Tensor, + neighbor_indices: Union[None, torch.Tensor], + neighbor_shifts: Union[None, torch.Tensor], ) -> torch.Tensor: - """ - Compute the "electrostatic" potential at the position of all atoms in a - structure. - This solver does not use periodic boundaries, and thus also does not take into - account potential periodic images. - - :param positions: torch.tensor of shape (n_atoms, 3). Contains the Cartesian - coordinates of the atoms. The implementation also works if the positions - are not contained within the unit cell. - :param charges: torch.tensor of shape `(n_atoms, n_channels)`. In the simplest - case, this would be a tensor of shape (n_atoms, 1) where charges[i,0] is the - charge of atom i. More generally, the potential for the same atom positions - is computed for n_channels independent meshes, and one can specify the - "charge" of each atom on each of the meshes independently. For standard LODE - that treats all (atomic) types separately, one example could be: If n_atoms - = 4 and the types are [Na, Cl, Cl, Na], one could set n_channels=2 and use - the one-hot encoding charges = torch.tensor([[1,0],[0,1],[0,1],[1,0]]) for - the charges. This would then separately compute the "Na" potential and "Cl" - potential. Subtracting these from each other, one could recover the more - standard electrostatic potential in which Na and Cl have charges of +1 and - -1, respectively. - - :returns: torch.tensor of shape `(n_atoms, n_channels)` containing the potential - at the position of each atom for the `n_channels` independent meshes separately. - """ # Compute matrix containing the squared distances from the Gram matrix # The squared distance and the inner product between two vectors r_i and r_j are # related by: d_ij^2 = |r_i - r_j|^2 = r_i^2 + r_j^2 - 2*r_i*r_j @@ -72,6 +52,5 @@ def _compute_single_system( # Compute potential potentials_by_pair = distances_sq.pow(-self.exponent / 2.0) - potentials = torch.matmul(potentials_by_pair, charges) - return potentials + return torch.matmul(potentials_by_pair, charges) diff --git a/src/meshlode/calculators/ewald.py b/src/meshlode/calculators/ewald.py index 91560b94..6503d2da 100644 --- a/src/meshlode/calculators/ewald.py +++ b/src/meshlode/calculators/ewald.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import List, Optional, Union import torch @@ -6,7 +6,6 @@ from ase import Atoms from ase.neighborlist import neighbor_list -from .calculator_base import default_exponent from .calculator_base_periodic import CalculatorBasePeriodic @@ -65,7 +64,7 @@ class EwaldPotential(CalculatorBasePeriodic): def __init__( self, all_types: Optional[List[int]] = None, - exponent: Optional[torch.Tensor] = default_exponent, + exponent: float = 1.0, sr_cutoff: Optional[torch.Tensor] = None, atomic_smearing: Optional[float] = None, lr_wavelength: Optional[float] = None, @@ -88,34 +87,11 @@ def __init__( def _compute_single_system( self, positions: torch.Tensor, + cell: Union[None, torch.Tensor], charges: torch.Tensor, - cell: torch.Tensor, + neighbor_indices: Union[None, torch.Tensor], + neighbor_shifts: Union[None, torch.Tensor], ) -> torch.Tensor: - """ - Compute the "electrostatic" potential at the position of all atoms in a - structure. - - :param positions: torch.tensor of shape (n_atoms, 3). Contains the Cartesian - coordinates of the atoms. The implementation also works if the positions - are not contained within the unit cell. - :param charges: torch.tensor of shape `(n_atoms, n_channels)`. In the simplest - case, this would be a tensor of shape (n_atoms, 1) where charges[i,0] is the - charge of atom i. More generally, the potential for the same atom positions - is computed for n_channels independent meshes, and one can specify the - "charge" of each atom on each of the meshes independently. For standard LODE - that treats all (atomic) types separately, one example could be: If n_atoms - = 4 and the types are [Na, Cl, Cl, Na], one could set n_channels=2 and use - the one-hot encoding charges = torch.tensor([[1,0],[0,1],[0,1],[1,0]]) for - the charges. This would then separately compute the "Na" potential and "Cl" - potential. Subtracting these from each other, one could recover the more - standard electrostatic potential in which Na and Cl have charges of +1 and - -1, respectively. - :param cell: torch.tensor of shape `(3, 3)`. Describes the unit cell of the - structure, where cell[i] is the i-th basis vector. - - :returns: torch.tensor of shape `(n_atoms, n_channels)` containing the potential - at the position of each atom for the `n_channels` independent meshes separately. - """ # Check that the realspace cutoff (if provided) is not too large # This is because the current implementation is not able to return multiple # periodic images of the same atom as a neighbor diff --git a/src/meshlode/calculators/mesh.py b/src/meshlode/calculators/mesh.py index 18d999fb..f799ae41 100644 --- a/src/meshlode/calculators/mesh.py +++ b/src/meshlode/calculators/mesh.py @@ -1,11 +1,10 @@ -from typing import List, Optional +from typing import List, Optional, Union import torch from meshlode.lib.fourier_convolution import FourierSpaceConvolution from meshlode.lib.mesh_interpolator import MeshInterpolator -from .calculator_base import default_exponent from .calculator_base_periodic import CalculatorBasePeriodic @@ -58,7 +57,7 @@ def __init__( interpolation_order: Optional[int] = 4, subtract_self: Optional[bool] = False, all_types: Optional[List[int]] = None, - exponent: Optional[torch.Tensor] = default_exponent, + exponent: float = 1.0, ): super().__init__(all_types=all_types, exponent=exponent) @@ -71,11 +70,12 @@ def __init__( # If no explicit mesh_spacing is given, set it such that it can resolve # the smeared potentials. if mesh_spacing is None: - mesh_spacing = atomic_smearing / 2 + self.mesh_spacing = atomic_smearing / 2 + else: + self.mesh_spacing = mesh_spacing # Store provided parameters self.atomic_smearing = atomic_smearing - self.mesh_spacing = mesh_spacing self.interpolation_order = interpolation_order self.subtract_self = subtract_self @@ -85,9 +85,10 @@ def __init__( def _compute_single_system( self, positions: torch.Tensor, + cell: Union[None, torch.Tensor], charges: torch.Tensor, - cell: torch.Tensor, - mesh_spacing: Optional[float] = None, + neighbor_indices: Union[None, torch.Tensor], + neighbor_shifts: Union[None, torch.Tensor], ) -> torch.Tensor: """ Compute the "electrostatic" potential at the position of all atoms in a @@ -115,17 +116,7 @@ def _compute_single_system( at the position of each atom for the `n_channels` independent meshes separately. """ # Initializations - n_atoms = len(positions) - assert positions.shape == (n_atoms, 3) - assert charges.shape[0] == n_atoms - - assert positions.dtype == cell.dtype and charges.dtype == cell.dtype - assert positions.device == cell.device and charges.device == cell.device - - # Define cutoff in reciprocal space - if mesh_spacing is None: - mesh_spacing = self.mesh_spacing - k_cutoff = 2 * torch.pi / mesh_spacing + k_cutoff = 2 * torch.pi / self.mesh_spacing # Compute number of times each basis vector of the # reciprocal space can be scaled until the cutoff diff --git a/src/meshlode/calculators/meshewald.py b/src/meshlode/calculators/meshewald.py index 18c6c22c..46563ed6 100644 --- a/src/meshlode/calculators/meshewald.py +++ b/src/meshlode/calculators/meshewald.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import List, Optional, Union import torch @@ -8,8 +8,6 @@ from meshlode.lib.mesh_interpolator import MeshInterpolator -from .calculator_base import default_exponent - # from .mesh import MeshPotential from .calculator_base_periodic import CalculatorBasePeriodic @@ -51,7 +49,7 @@ class MeshEwaldPotential(CalculatorBasePeriodic): def __init__( self, all_types: Optional[List[int]] = None, - exponent: Optional[torch.Tensor] = default_exponent, + exponent: float = 1.0, sr_cutoff: Optional[torch.Tensor] = None, atomic_smearing: Optional[float] = None, mesh_spacing: Optional[float] = None, @@ -132,10 +130,10 @@ def _generate_kvectors(self, ns: torch.Tensor, cell: torch.Tensor) -> torch.Tens def _compute_single_system( self, positions: torch.Tensor, + cell: Union[None, torch.Tensor], charges: torch.Tensor, - cell: torch.Tensor, - neighbor_indices: Optional[torch.Tensor] = None, - neighbor_shifts: Optional[torch.Tensor] = None, + neighbor_indices: Union[None, torch.Tensor], + neighbor_shifts: Union[None, torch.Tensor], ) -> torch.Tensor: """ Compute the "electrostatic" potential at the position of all atoms in a diff --git a/src/meshlode/lib/potentials.py b/src/meshlode/lib/potentials.py index 6af074f8..83789261 100644 --- a/src/meshlode/lib/potentials.py +++ b/src/meshlode/lib/potentials.py @@ -7,7 +7,7 @@ # since pytorch has implemented the incomplete Gamma functions, but not the much more # commonly used (complete) Gamma function, we define it in a custom way to make autograd # work as in https://discuss.pytorch.org/t/is-there-a-gamma-function-in-pytorch/17122 -def gamma(x): +def gamma(x: torch.Tensor): return torch.exp(gammaln(x)) @@ -22,11 +22,11 @@ class InversePowerLawPotential: length-scale parameter (called "smearing" in the code) 3. the Fourier transform of the LR part - :param exponent: torch.tensor corresponding to the exponent "p" in 1/r^p potentials + :param exponent: the exponent "p" in 1/r^p potentials """ - def __init__(self, exponent: torch.Tensor): - self.exponent = exponent + def __init__(self, exponent: float): + self.exponent = torch.tensor(exponent) def potential_from_dist(self, dist: torch.Tensor) -> torch.Tensor: """ diff --git a/src/meshlode/metatensor/meshewald.py b/src/meshlode/metatensor/meshewald.py index baff5519..aea98ed4 100644 --- a/src/meshlode/metatensor/meshewald.py +++ b/src/meshlode/metatensor/meshewald.py @@ -157,10 +157,13 @@ def compute( if neighbor_indices is None or neighbor_shifts is None: # Compute the potentials + # TODO: use neighborlist from system if provided. potential = self._compute_single_system( positions=system.positions, - charges=charges, cell=system.cell, + charges=charges, + neighbor_indices=None, + neighbor_shifts=None, ) else: potential = self._compute_single_system( diff --git a/src/meshlode/metatensor/meshpotential.py b/src/meshlode/metatensor/meshpotential.py index 134d725b..990df338 100644 --- a/src/meshlode/metatensor/meshpotential.py +++ b/src/meshlode/metatensor/meshpotential.py @@ -167,8 +167,13 @@ def compute( ) # Compute the potentials + # TODO: use neighborlist from system if provided. potential = self._compute_single_system( - system.positions, charges, system.cell + positions=system.positions, + cell=system.cell, + charges=charges, + neighbor_indices=None, + neighbor_shifts=None, ) # Reorder data into metatensor format diff --git a/tests/calculators/test_workflow_direct.py b/tests/calculators/test_workflow_direct.py index 900ac4fa..9139bad5 100644 --- a/tests/calculators/test_workflow_direct.py +++ b/tests/calculators/test_workflow_direct.py @@ -25,9 +25,9 @@ def cscl_system(): def cscl_system_with_charges(): - """CsCl crystal with charges.""" + """CsCl crystal with (cell) and charges.""" charges = torch.tensor([[0.0, 1.0], [1.0, 0]]) - return cscl_system() + (charges,) + return cscl_system() + (None, charges,) # Initialize the calculators. For now, only the DirectPotential is implemented. @@ -87,6 +87,7 @@ def test_single_frame(): # Test with explicit charges def test_single_frame_with_charges(): + print(cscl_system_with_charges()) values = descriptor().compute(*cscl_system_with_charges()) assert_close( MADELUNG_CSCL, diff --git a/tests/calculators/test_workflow_ewald.py b/tests/calculators/test_workflow_ewald.py index 304a1c02..788eaa7c 100644 --- a/tests/calculators/test_workflow_ewald.py +++ b/tests/calculators/test_workflow_ewald.py @@ -166,7 +166,7 @@ def test_charges_error_dimension_mismatch(): def test_charges_error_length_mismatch(): types = [torch.tensor([1, 2]), torch.tensor([1, 2, 3])] positions = [torch.zeros((2, 3)), torch.zeros((3, 3))] - cell = torch.eye(3) + cell = [torch.eye(3), torch.eye(3)] charges = [torch.zeros(2, 1)] # This should have the same length as types match = "The number of `types` and `charges` tensors must be the same, got 2 and 1." @@ -242,8 +242,7 @@ def test_inconsistent_device(): EP = EwaldPotential(atomic_smearing=0.2) - match = "`types`, `positions`, and `cell` must be on the same device, " - match += "got cpu, cpu and meta." + match = r"Inconsistent devices of types \(cpu\) and cell \(meta\)" with pytest.raises(ValueError, match=match): EP.compute(types=types, positions=positions, cell=cell) diff --git a/tests/calculators/test_workflow_mesh.py b/tests/calculators/test_workflow_mesh.py index 7944775d..f0827bf1 100644 --- a/tests/calculators/test_workflow_mesh.py +++ b/tests/calculators/test_workflow_mesh.py @@ -178,7 +178,7 @@ def test_charges_error_dimension_mismatch(): def test_charges_error_length_mismatch(): types = [torch.tensor([1, 2]), torch.tensor([1, 2, 3])] positions = [torch.zeros((2, 3)), torch.zeros((3, 3))] - cell = torch.eye(3) + cell = [torch.eye(3), torch.eye(3)] charges = [torch.zeros(2, 1)] # This should have the same length as types match = "The number of `types` and `charges` tensors must be the same, got 2 and 1." @@ -254,8 +254,7 @@ def test_inconsistent_device(): MP = MeshPotential(atomic_smearing=0.2) - match = "`types`, `positions`, and `cell` must be on the same device, got cpu, cpu " - match += "and meta." + match = r"Inconsistent devices of types \(cpu\) and cell \(meta\)" with pytest.raises(ValueError, match=match): MP.compute(types=types, positions=positions, cell=cell) diff --git a/tests/calculators/test_workflow_meshewald.py b/tests/calculators/test_workflow_meshewald.py index 9fc5a644..05637cea 100644 --- a/tests/calculators/test_workflow_meshewald.py +++ b/tests/calculators/test_workflow_meshewald.py @@ -180,7 +180,7 @@ def test_charges_error_dimension_mismatch(): def test_charges_error_length_mismatch(): types = [torch.tensor([1, 2]), torch.tensor([1, 2, 3])] positions = [torch.zeros((2, 3)), torch.zeros((3, 3))] - cell = torch.eye(3) + cell = [torch.eye(3), torch.eye(3)] charges = [torch.zeros(2, 1)] # This should have the same length as types match = "The number of `types` and `charges` tensors must be the same, got 2 and 1." @@ -256,8 +256,7 @@ def test_inconsistent_device(): MP = MeshPotential(atomic_smearing=0.2) - match = "`types`, `positions`, and `cell` must be on the same device, got cpu, cpu " - match += "and meta." + match = r"Inconsistent devices of types \(cpu\) and cell \(meta\)" with pytest.raises(ValueError, match=match): MP.compute(types=types, positions=positions, cell=cell) diff --git a/tests/metatensor/test_madelung.py b/tests/metatensor/test_madelung.py index 615c87cd..4b1355e0 100644 --- a/tests/metatensor/test_madelung.py +++ b/tests/metatensor/test_madelung.py @@ -125,7 +125,11 @@ def test_madelung_low_order( smearing_eff, mesh_spacing, interpolation_order, subtract_self=True ) potentials_mesh = MP._compute_single_system( - positions=positions, charges=charges, cell=cell + positions=positions, + charges=charges, + cell=cell, + neighbor_indices=None, + neighbor_shifts=None, ) energies = potentials_mesh * charges energies_target = -torch.ones_like(energies) * madelung @@ -161,7 +165,11 @@ def test_madelung_high_order( smearing_eff, mesh_spacing, interpolation_order, subtract_self=True ) potentials_mesh = MP._compute_single_system( - positions=positions, charges=charges, cell=cell + positions=positions, + charges=charges, + cell=cell, + neighbor_indices=None, + neighbor_shifts=None, ) energies = potentials_mesh * charges energies_target = -torch.ones_like(energies) * madelung From 1cd02a48e210fabd1f1a8122630040fa82d5fa42 Mon Sep 17 00:00:00 2001 From: Philip Loche Date: Thu, 4 Jul 2024 15:03:42 +0200 Subject: [PATCH 14/35] simplify base class even further --- src/meshlode/calculators/calculator_base.py | 85 -------------- .../calculators/calculator_base_periodic.py | 88 -------------- src/meshlode/calculators/direct.py | 59 +++++++++- src/meshlode/calculators/ewald.py | 80 ++++++++++++- src/meshlode/calculators/mesh.py | 110 +++++++++++++----- src/meshlode/calculators/meshewald.py | 84 ++++++++++++- tests/calculators/test_workflow_direct.py | 2 +- 7 files changed, 296 insertions(+), 212 deletions(-) delete mode 100644 src/meshlode/calculators/calculator_base_periodic.py diff --git a/src/meshlode/calculators/calculator_base.py b/src/meshlode/calculators/calculator_base.py index 37537a0c..2216c7e0 100644 --- a/src/meshlode/calculators/calculator_base.py +++ b/src/meshlode/calculators/calculator_base.py @@ -1,4 +1,3 @@ -import warnings from typing import List, Optional, Tuple, Union import torch @@ -190,13 +189,6 @@ def _validate_compute_parameters( f"cell ({cell_single.device})" ) - if type(neighbor_indices_single) is not type(neighbor_indices_single): - raise ValueError( - f"Inconsistent of neighbor_indices " - f"({type(neighbor_indices_single)}) and neighbor_indices " - f"({neighbor_indices_single})" - ) - if neighbor_indices_single is not None: # TODO validate shape and dtype @@ -297,83 +289,6 @@ def _compute_impl( else: return potentials - def compute( - self, - types: Union[List[torch.Tensor], torch.Tensor], - positions: Union[List[torch.Tensor], torch.Tensor], - cell: Union[List[torch.Tensor], torch.Tensor] = None, - charges: Optional[Union[List[torch.Tensor], torch.Tensor]] = None, - neighbor_indices: Union[List[torch.Tensor], torch.Tensor] = None, - neighbor_shifts: Union[List[torch.Tensor], torch.Tensor] = None, - ) -> Union[torch.Tensor, List[torch.Tensor]]: - """Compute potential for all provided "systems" stacked inside list. - - The computation is performed on the same ``device`` as ``systems`` is stored on. - The ``dtype`` of the output tensors will be the same as the input. - - :param types: single or list of 1D tensor of integer representing the - particles identity. For atoms, this is typically their atomic numbers. - :param positions: single or 2D tensor of shape (len(types), 3) containing the - Cartesian positions of all particles in the system. - :param cell: Ignored. - :param charges: Optional single or list of 2D tensor of shape (len(types), n), - :param neighbor_indices: Optional single or list of 2D tensors of shape (2, n), - where n is the number of atoms. The 2 rows correspond to the indices of - the two atoms which are considered neighbors (e.g. within a cutoff distance) - :param neighbor_shifts: Optional single or list of 2D tensors of shape (3, n), - where n is the number of atoms. The 3 rows correspond to the shift indices - for periodic images. - - :return: List of torch Tensors containing the potentials for all frames and all - atoms. Each tensor in the list is of shape (n_atoms, n_types), where - n_types is the number of types in all systems combined. If the input was - a single system only a single torch tensor with the potentials is returned. - - IMPORTANT: If multiple types are present, the different "types-channels" - are ordered according to atomic number. For example, if a structure contains - a water molecule with atoms 0, 1, 2 being of types O, H, H, then for this - system, the feature tensor will be of shape (3, 2) = (``n_atoms``, - ``n_types``), where ``features[0, 0]`` is the potential at the position of - the Oxygen atom (atom 0, first index) generated by the HYDROGEN(!) atoms, - while ``features[0,1]`` is the potential at the position of the Oxygen atom - generated by the Oxygen atom(s). - """ - if cell is not None: - warnings.warn( - "`cell` parameter was proviced but will be ignored", stacklevel=2 - ) - - return self._compute_impl( - types=types, - positions=positions, - cell=cell, - charges=charges, - neighbor_indices=neighbor_indices, - neighbor_shifts=neighbor_shifts, - ) - - # This function is kept to keep MeshLODE compatible with the broader pytorch - # infrastructure, which require a "forward" function. We name this function - # "compute" instead, for compatibility with other COSMO software. - def forward( - self, - types: Union[List[torch.Tensor], torch.Tensor], - positions: Union[List[torch.Tensor], torch.Tensor], - cell: Union[List[torch.Tensor], torch.Tensor] = None, - charges: Optional[Union[List[torch.Tensor], torch.Tensor]] = None, - neighbor_indices: Union[List[torch.Tensor], torch.Tensor] = None, - neighbor_shifts: Union[List[torch.Tensor], torch.Tensor] = None, - ) -> Union[torch.Tensor, List[torch.Tensor]]: - """forward just calls :py:meth:`CalculatorModule.compute`""" - return self.compute( - types=types, - positions=positions, - cell=cell, - charges=charges, - neighbor_indices=neighbor_indices, - neighbor_shifts=neighbor_shifts, - ) - def _compute_single_system( self, positions: torch.Tensor, diff --git a/src/meshlode/calculators/calculator_base_periodic.py b/src/meshlode/calculators/calculator_base_periodic.py deleted file mode 100644 index 0d64a4f4..00000000 --- a/src/meshlode/calculators/calculator_base_periodic.py +++ /dev/null @@ -1,88 +0,0 @@ -from typing import List, Optional, Union - -import torch - -from .calculator_base import CalculatorBase - - -class CalculatorBasePeriodic(CalculatorBase): - """ - Base calculator for periodic implementations - """ - - name = "CalculatorBasePeriodic" - - def forward( - self, - types: Union[List[torch.Tensor], torch.Tensor], - positions: Union[List[torch.Tensor], torch.Tensor], - cell: Union[List[torch.Tensor], torch.Tensor] = None, - charges: Optional[Union[List[torch.Tensor], torch.Tensor]] = None, - neighbor_indices: Union[List[torch.Tensor], torch.Tensor] = None, - neighbor_shifts: Union[List[torch.Tensor], torch.Tensor] = None, - ) -> Union[torch.Tensor, List[torch.Tensor]]: - """forward just calls :py:meth:`CalculatorModule.compute`""" - return self.compute( - types=types, - positions=positions, - cell=cell, - charges=charges, - neighbor_indices=neighbor_indices, - neighbor_shifts=neighbor_shifts, - ) - - def compute( - self, - types: Union[List[torch.Tensor], torch.Tensor], - positions: Union[List[torch.Tensor], torch.Tensor], - cell: Union[List[torch.Tensor], torch.Tensor] = None, - charges: Optional[Union[List[torch.Tensor], torch.Tensor]] = None, - neighbor_indices: Union[List[torch.Tensor], torch.Tensor] = None, - neighbor_shifts: Union[List[torch.Tensor], torch.Tensor] = None, - ) -> Union[torch.Tensor, List[torch.Tensor]]: - """Compute potential for all provided "systems" stacked inside list. - - The computation is performed on the same ``device`` as ``systems`` is stored on. - The ``dtype`` of the output tensors will be the same as the input. - - :param types: single or list of 1D tensor of integer representing the - particles identity. For atoms, this is typically their atomic numbers. - :param positions: single or 2D tensor of shape (len(types), 3) containing the - Cartesian positions of all particles in the system. - :param cell: single or 2D tensor of shape (3, 3), describing the bounding - box/unit cell of the system. Each row should be one of the bounding box - vector; and columns should contain the x, y, and z components of these - vectors (i.e. the cell should be given in row-major order). - :param charges: Optional single or list of 2D tensor of shape (len(types), n), - :param neighbor_indices: Optional single or list of 2D tensors of shape (2, n), - where n is the number of atoms. The 2 rows correspond to the indices of - the two atoms which are considered neighbors (e.g. within a cutoff distance) - :param neighbor_shifts: Optional single or list of 2D tensors of shape (3, n), - where n is the number of atoms. The 3 rows correspond to the shift indices - for periodic images. - - :return: List of torch Tensors containing the potentials for all frames and all - atoms. Each tensor in the list is of shape (n_atoms, n_types), where - n_types is the number of types in all systems combined. If the input was - a single system only a single torch tensor with the potentials is returned. - - IMPORTANT: If multiple types are present, the different "types-channels" - are ordered according to atomic number. For example, if a structure contains - a water molecule with atoms 0, 1, 2 being of types O, H, H, then for this - system, the feature tensor will be of shape (3, 2) = (``n_atoms``, - ``n_types``), where ``features[0, 0]`` is the potential at the position of - the Oxygen atom (atom 0, first index) generated by the HYDROGEN(!) atoms, - while ``features[0,1]`` is the potential at the position of the Oxygen atom - generated by the Oxygen atom(s). - """ - if cell is None: - raise ValueError("cell must be provided") - - return self._compute_impl( - types=types, - positions=positions, - cell=cell, - charges=charges, - neighbor_indices=neighbor_indices, - neighbor_shifts=neighbor_shifts, - ) diff --git a/src/meshlode/calculators/direct.py b/src/meshlode/calculators/direct.py index 8d0d4c3a..f53d5978 100644 --- a/src/meshlode/calculators/direct.py +++ b/src/meshlode/calculators/direct.py @@ -1,4 +1,4 @@ -from typing import Union +from typing import List, Optional, Union import torch @@ -22,6 +22,63 @@ class DirectPotential(CalculatorBase): name = "DirectPotential" + def compute( + self, + types: Union[List[torch.Tensor], torch.Tensor], + positions: Union[List[torch.Tensor], torch.Tensor], + charges: Optional[Union[List[torch.Tensor], torch.Tensor]] = None, + ) -> Union[torch.Tensor, List[torch.Tensor]]: + """Compute potential for all provided "systems" stacked inside list. + + The computation is performed on the same ``device`` as ``systems`` is stored on. + The ``dtype`` of the output tensors will be the same as the input. + + :param types: single or list of 1D tensor of integer representing the + particles identity. For atoms, this is typically their atomic numbers. + :param positions: single or 2D tensor of shape (len(types), 3) containing the + Cartesian positions of all particles in the system. + :param charges: Optional single or list of 2D tensor of shape (len(types), n), + + :return: List of torch Tensors containing the potentials for all frames and all + atoms. Each tensor in the list is of shape (n_atoms, n_types), where + n_types is the number of types in all systems combined. If the input was + a single system only a single torch tensor with the potentials is returned. + + IMPORTANT: If multiple types are present, the different "types-channels" + are ordered according to atomic number. For example, if a structure contains + a water molecule with atoms 0, 1, 2 being of types O, H, H, then for this + system, the feature tensor will be of shape (3, 2) = (``n_atoms``, + ``n_types``), where ``features[0, 0]`` is the potential at the position of + the Oxygen atom (atom 0, first index) generated by the HYDROGEN(!) atoms, + while ``features[0,1]`` is the potential at the position of the Oxygen atom + generated by the Oxygen atom(s). + """ + + return self._compute_impl( + types=types, + positions=positions, + cell=None, + charges=charges, + neighbor_indices=None, + neighbor_shifts=None, + ) + + # This function is kept to keep MeshLODE compatible with the broader pytorch + # infrastructure, which require a "forward" function. We name this function + # "compute" instead, for compatibility with other COSMO software. + def forward( + self, + types: Union[List[torch.Tensor], torch.Tensor], + positions: Union[List[torch.Tensor], torch.Tensor], + charges: Optional[Union[List[torch.Tensor], torch.Tensor]] = None, + ) -> Union[torch.Tensor, List[torch.Tensor]]: + """forward just calls :py:meth:`CalculatorModule.compute`""" + return self.compute( + types=types, + positions=positions, + charges=charges, + ) + def _compute_single_system( self, positions: torch.Tensor, diff --git a/src/meshlode/calculators/ewald.py b/src/meshlode/calculators/ewald.py index 6503d2da..74105405 100644 --- a/src/meshlode/calculators/ewald.py +++ b/src/meshlode/calculators/ewald.py @@ -6,10 +6,10 @@ from ase import Atoms from ase.neighborlist import neighbor_list -from .calculator_base_periodic import CalculatorBasePeriodic +from .calculator_base import CalculatorBase -class EwaldPotential(CalculatorBasePeriodic): +class EwaldPotential(CalculatorBase): """A specie-wise long-range potential computed using the Ewald sum, scaling as O(N^2) with respect to the number of particles N used as a reference to test faster implementations. @@ -84,6 +84,82 @@ def __init__( self.subtract_self = subtract_self self.subtract_interior = subtract_interior + def compute( + self, + types: Union[List[torch.Tensor], torch.Tensor], + positions: Union[List[torch.Tensor], torch.Tensor], + cell: Union[List[torch.Tensor], torch.Tensor], + charges: Optional[Union[List[torch.Tensor], torch.Tensor]] = None, + neighbor_indices: Union[List[torch.Tensor], torch.Tensor] = None, + neighbor_shifts: Union[List[torch.Tensor], torch.Tensor] = None, + ) -> Union[torch.Tensor, List[torch.Tensor]]: + """Compute potential for all provided "systems" stacked inside list. + + The computation is performed on the same ``device`` as ``systems`` is stored on. + The ``dtype`` of the output tensors will be the same as the input. + + :param types: single or list of 1D tensor of integer representing the + particles identity. For atoms, this is typically their atomic numbers. + :param positions: single or 2D tensor of shape (len(types), 3) containing the + Cartesian positions of all particles in the system. + :param cell: single or 2D tensor of shape (3, 3), describing the bounding + box/unit cell of the system. Each row should be one of the bounding box + vector; and columns should contain the x, y, and z components of these + vectors (i.e. the cell should be given in row-major order). + :param charges: Optional single or list of 2D tensor of shape (len(types), n), + :param neighbor_indices: Optional single or list of 2D tensors of shape (2, n), + where n is the number of atoms. The 2 rows correspond to the indices of + the two atoms which are considered neighbors (e.g. within a cutoff distance) + :param neighbor_shifts: Optional single or list of 2D tensors of shape (3, n), + where n is the number of atoms. The 3 rows correspond to the shift indices + for periodic images. + + :return: List of torch Tensors containing the potentials for all frames and all + atoms. Each tensor in the list is of shape (n_atoms, n_types), where + n_types is the number of types in all systems combined. If the input was + a single system only a single torch tensor with the potentials is returned. + + IMPORTANT: If multiple types are present, the different "types-channels" + are ordered according to atomic number. For example, if a structure contains + a water molecule with atoms 0, 1, 2 being of types O, H, H, then for this + system, the feature tensor will be of shape (3, 2) = (``n_atoms``, + ``n_types``), where ``features[0, 0]`` is the potential at the position of + the Oxygen atom (atom 0, first index) generated by the HYDROGEN(!) atoms, + while ``features[0,1]`` is the potential at the position of the Oxygen atom + generated by the Oxygen atom(s). + """ + + return self._compute_impl( + types=types, + positions=positions, + cell=cell, + charges=charges, + neighbor_indices=neighbor_indices, + neighbor_shifts=neighbor_shifts, + ) + + # This function is kept to keep MeshLODE compatible with the broader pytorch + # infrastructure, which require a "forward" function. We name this function + # "compute" instead, for compatibility with other COSMO software. + def forward( + self, + types: Union[List[torch.Tensor], torch.Tensor], + positions: Union[List[torch.Tensor], torch.Tensor], + cell: Union[List[torch.Tensor], torch.Tensor], + charges: Optional[Union[List[torch.Tensor], torch.Tensor]] = None, + neighbor_indices: Union[List[torch.Tensor], torch.Tensor] = None, + neighbor_shifts: Union[List[torch.Tensor], torch.Tensor] = None, + ) -> Union[torch.Tensor, List[torch.Tensor]]: + """forward just calls :py:meth:`CalculatorModule.compute`""" + return self.compute( + types=types, + positions=positions, + cell=cell, + charges=charges, + neighbor_indices=neighbor_indices, + neighbor_shifts=neighbor_shifts, + ) + def _compute_single_system( self, positions: torch.Tensor, diff --git a/src/meshlode/calculators/mesh.py b/src/meshlode/calculators/mesh.py index f799ae41..c7da4811 100644 --- a/src/meshlode/calculators/mesh.py +++ b/src/meshlode/calculators/mesh.py @@ -2,13 +2,12 @@ import torch -from meshlode.lib.fourier_convolution import FourierSpaceConvolution -from meshlode.lib.mesh_interpolator import MeshInterpolator +from ..lib.fourier_convolution import FourierSpaceConvolution +from ..lib.mesh_interpolator import MeshInterpolator +from .calculator_base import CalculatorBase -from .calculator_base_periodic import CalculatorBasePeriodic - -class MeshPotential(CalculatorBasePeriodic): +class MeshPotential(CalculatorBase): """A specie-wise long-range potential, computed using the particle-mesh Ewald (PME) method scaling as O(NlogN) with respect to the number of particles N. @@ -82,6 +81,82 @@ def __init__( # Initilize auxiliary objects self.fourier_space_convolution = FourierSpaceConvolution() + def compute( + self, + types: Union[List[torch.Tensor], torch.Tensor], + positions: Union[List[torch.Tensor], torch.Tensor], + cell: Union[List[torch.Tensor], torch.Tensor], + charges: Optional[Union[List[torch.Tensor], torch.Tensor]] = None, + neighbor_indices: Union[List[torch.Tensor], torch.Tensor] = None, + neighbor_shifts: Union[List[torch.Tensor], torch.Tensor] = None, + ) -> Union[torch.Tensor, List[torch.Tensor]]: + """Compute potential for all provided "systems" stacked inside list. + + The computation is performed on the same ``device`` as ``systems`` is stored on. + The ``dtype`` of the output tensors will be the same as the input. + + :param types: single or list of 1D tensor of integer representing the + particles identity. For atoms, this is typically their atomic numbers. + :param positions: single or 2D tensor of shape (len(types), 3) containing the + Cartesian positions of all particles in the system. + :param cell: single or 2D tensor of shape (3, 3), describing the bounding + box/unit cell of the system. Each row should be one of the bounding box + vector; and columns should contain the x, y, and z components of these + vectors (i.e. the cell should be given in row-major order). + :param charges: Optional single or list of 2D tensor of shape (len(types), n), + :param neighbor_indices: Optional single or list of 2D tensors of shape (2, n), + where n is the number of atoms. The 2 rows correspond to the indices of + the two atoms which are considered neighbors (e.g. within a cutoff distance) + :param neighbor_shifts: Optional single or list of 2D tensors of shape (3, n), + where n is the number of atoms. The 3 rows correspond to the shift indices + for periodic images. + + :return: List of torch Tensors containing the potentials for all frames and all + atoms. Each tensor in the list is of shape (n_atoms, n_types), where + n_types is the number of types in all systems combined. If the input was + a single system only a single torch tensor with the potentials is returned. + + IMPORTANT: If multiple types are present, the different "types-channels" + are ordered according to atomic number. For example, if a structure contains + a water molecule with atoms 0, 1, 2 being of types O, H, H, then for this + system, the feature tensor will be of shape (3, 2) = (``n_atoms``, + ``n_types``), where ``features[0, 0]`` is the potential at the position of + the Oxygen atom (atom 0, first index) generated by the HYDROGEN(!) atoms, + while ``features[0,1]`` is the potential at the position of the Oxygen atom + generated by the Oxygen atom(s). + """ + + return self._compute_impl( + types=types, + positions=positions, + cell=cell, + charges=charges, + neighbor_indices=neighbor_indices, + neighbor_shifts=neighbor_shifts, + ) + + # This function is kept to keep MeshLODE compatible with the broader pytorch + # infrastructure, which require a "forward" function. We name this function + # "compute" instead, for compatibility with other COSMO software. + def forward( + self, + types: Union[List[torch.Tensor], torch.Tensor], + positions: Union[List[torch.Tensor], torch.Tensor], + cell: Union[List[torch.Tensor], torch.Tensor], + charges: Optional[Union[List[torch.Tensor], torch.Tensor]] = None, + neighbor_indices: Union[List[torch.Tensor], torch.Tensor] = None, + neighbor_shifts: Union[List[torch.Tensor], torch.Tensor] = None, + ) -> Union[torch.Tensor, List[torch.Tensor]]: + """forward just calls :py:meth:`CalculatorModule.compute`""" + return self.compute( + types=types, + positions=positions, + cell=cell, + charges=charges, + neighbor_indices=neighbor_indices, + neighbor_shifts=neighbor_shifts, + ) + def _compute_single_system( self, positions: torch.Tensor, @@ -90,31 +165,6 @@ def _compute_single_system( neighbor_indices: Union[None, torch.Tensor], neighbor_shifts: Union[None, torch.Tensor], ) -> torch.Tensor: - """ - Compute the "electrostatic" potential at the position of all atoms in a - structure. - - :param positions: torch.tensor of shape (n_atoms, 3). Contains the Cartesian - coordinates of the atoms. The implementation also works if the positions - are not contained within the unit cell. - :param charges: torch.tensor of shape `(n_atoms, n_channels)`. In the simplest - case, this would be a tensor of shape (n_atoms, 1) where charges[i,0] is the - charge of atom i. More generally, the potential for the same atom positions - is computed for n_channels independent meshes, and one can specify the - "charge" of each atom on each of the meshes independently. For standard LODE - that treats all (atomic) types separately, one example could be: If n_atoms - = 4 and the types are [Na, Cl, Cl, Na], one could set n_channels=2 and use - the one-hot encoding charges = torch.tensor([[1,0],[0,1],[0,1],[1,0]]) for - the charges. This would then separately compute the "Na" potential and "Cl" - potential. Subtracting these from each other, one could recover the more - standard electrostatic potential in which Na and Cl have charges of +1 and - -1, respectively. - :param cell: torch.tensor of shape `(3, 3)`. Describes the unit cell of the - structure, where cell[i] is the i-th basis vector. - - :returns: torch.tensor of shape `(n_atoms, n_channels)` containing the potential - at the position of each atom for the `n_channels` independent meshes separately. - """ # Initializations k_cutoff = 2 * torch.pi / self.mesh_spacing diff --git a/src/meshlode/calculators/meshewald.py b/src/meshlode/calculators/meshewald.py index 46563ed6..74479f49 100644 --- a/src/meshlode/calculators/meshewald.py +++ b/src/meshlode/calculators/meshewald.py @@ -6,13 +6,11 @@ from ase import Atoms from ase.neighborlist import neighbor_list -from meshlode.lib.mesh_interpolator import MeshInterpolator +from ..lib.mesh_interpolator import MeshInterpolator +from .calculator_base import CalculatorBase -# from .mesh import MeshPotential -from .calculator_base_periodic import CalculatorBasePeriodic - -class MeshEwaldPotential(CalculatorBasePeriodic): +class MeshEwaldPotential(CalculatorBase): """A specie-wise long-range potential computed using a mesh-based Ewald method, scaling as O(NlogN) with respect to the number of particles N used as a reference to test faster implementations. @@ -77,6 +75,82 @@ def __init__( self.subtract_self = subtract_self self.subtract_interior = subtract_interior + def compute( + self, + types: Union[List[torch.Tensor], torch.Tensor], + positions: Union[List[torch.Tensor], torch.Tensor], + cell: Union[List[torch.Tensor], torch.Tensor], + charges: Optional[Union[List[torch.Tensor], torch.Tensor]] = None, + neighbor_indices: Union[List[torch.Tensor], torch.Tensor] = None, + neighbor_shifts: Union[List[torch.Tensor], torch.Tensor] = None, + ) -> Union[torch.Tensor, List[torch.Tensor]]: + """Compute potential for all provided "systems" stacked inside list. + + The computation is performed on the same ``device`` as ``systems`` is stored on. + The ``dtype`` of the output tensors will be the same as the input. + + :param types: single or list of 1D tensor of integer representing the + particles identity. For atoms, this is typically their atomic numbers. + :param positions: single or 2D tensor of shape (len(types), 3) containing the + Cartesian positions of all particles in the system. + :param cell: single or 2D tensor of shape (3, 3), describing the bounding + box/unit cell of the system. Each row should be one of the bounding box + vector; and columns should contain the x, y, and z components of these + vectors (i.e. the cell should be given in row-major order). + :param charges: Optional single or list of 2D tensor of shape (len(types), n), + :param neighbor_indices: Optional single or list of 2D tensors of shape (2, n), + where n is the number of atoms. The 2 rows correspond to the indices of + the two atoms which are considered neighbors (e.g. within a cutoff distance) + :param neighbor_shifts: Optional single or list of 2D tensors of shape (3, n), + where n is the number of atoms. The 3 rows correspond to the shift indices + for periodic images. + + :return: List of torch Tensors containing the potentials for all frames and all + atoms. Each tensor in the list is of shape (n_atoms, n_types), where + n_types is the number of types in all systems combined. If the input was + a single system only a single torch tensor with the potentials is returned. + + IMPORTANT: If multiple types are present, the different "types-channels" + are ordered according to atomic number. For example, if a structure contains + a water molecule with atoms 0, 1, 2 being of types O, H, H, then for this + system, the feature tensor will be of shape (3, 2) = (``n_atoms``, + ``n_types``), where ``features[0, 0]`` is the potential at the position of + the Oxygen atom (atom 0, first index) generated by the HYDROGEN(!) atoms, + while ``features[0,1]`` is the potential at the position of the Oxygen atom + generated by the Oxygen atom(s). + """ + + return self._compute_impl( + types=types, + positions=positions, + cell=cell, + charges=charges, + neighbor_indices=neighbor_indices, + neighbor_shifts=neighbor_shifts, + ) + + # This function is kept to keep MeshLODE compatible with the broader pytorch + # infrastructure, which require a "forward" function. We name this function + # "compute" instead, for compatibility with other COSMO software. + def forward( + self, + types: Union[List[torch.Tensor], torch.Tensor], + positions: Union[List[torch.Tensor], torch.Tensor], + cell: Union[List[torch.Tensor], torch.Tensor], + charges: Optional[Union[List[torch.Tensor], torch.Tensor]] = None, + neighbor_indices: Union[List[torch.Tensor], torch.Tensor] = None, + neighbor_shifts: Union[List[torch.Tensor], torch.Tensor] = None, + ) -> Union[torch.Tensor, List[torch.Tensor]]: + """forward just calls :py:meth:`CalculatorModule.compute`""" + return self.compute( + types=types, + positions=positions, + cell=cell, + charges=charges, + neighbor_indices=neighbor_indices, + neighbor_shifts=neighbor_shifts, + ) + def _generate_kvectors(self, ns: torch.Tensor, cell: torch.Tensor) -> torch.Tensor: """ For a given unit cell, compute all reciprocal space vectors that are used to diff --git a/tests/calculators/test_workflow_direct.py b/tests/calculators/test_workflow_direct.py index 9139bad5..5ea9d15f 100644 --- a/tests/calculators/test_workflow_direct.py +++ b/tests/calculators/test_workflow_direct.py @@ -27,7 +27,7 @@ def cscl_system(): def cscl_system_with_charges(): """CsCl crystal with (cell) and charges.""" charges = torch.tensor([[0.0, 1.0], [1.0, 0]]) - return cscl_system() + (None, charges,) + return cscl_system() + (charges,) # Initialize the calculators. For now, only the DirectPotential is implemented. From 7b43476194c135b9ab49754f039a58627d17da6b Mon Sep 17 00:00:00 2001 From: Philip Loche Date: Thu, 4 Jul 2024 16:17:00 +0200 Subject: [PATCH 15/35] unify workflow tests --- src/meshlode/calculators/ewald.py | 2 + .../calculators/test_calculators_workflow.py | 199 +++++++++++ tests/calculators/test_workflow_direct.py | 224 ------------- tests/calculators/test_workflow_ewald.py | 296 ----------------- tests/calculators/test_workflow_mesh.py | 308 ----------------- tests/calculators/test_workflow_meshewald.py | 310 ------------------ 6 files changed, 201 insertions(+), 1138 deletions(-) create mode 100644 tests/calculators/test_calculators_workflow.py delete mode 100644 tests/calculators/test_workflow_direct.py delete mode 100644 tests/calculators/test_workflow_ewald.py delete mode 100644 tests/calculators/test_workflow_mesh.py delete mode 100644 tests/calculators/test_workflow_meshewald.py diff --git a/src/meshlode/calculators/ewald.py b/src/meshlode/calculators/ewald.py index 74105405..9744e1b0 100644 --- a/src/meshlode/calculators/ewald.py +++ b/src/meshlode/calculators/ewald.py @@ -74,6 +74,8 @@ def __init__( super().__init__(all_types=all_types, exponent=exponent) # Store provided parameters + if atomic_smearing is not None and atomic_smearing <= 0: + raise ValueError(f"`atomic_smearing` {atomic_smearing} has to be positive") self.atomic_smearing = atomic_smearing self.sr_cutoff = sr_cutoff self.lr_wavelength = lr_wavelength diff --git a/tests/calculators/test_calculators_workflow.py b/tests/calculators/test_calculators_workflow.py new file mode 100644 index 00000000..4bf3f7e9 --- /dev/null +++ b/tests/calculators/test_calculators_workflow.py @@ -0,0 +1,199 @@ +"""Basic tests if the calculator works and is torch scriptable. Actual tests are done +for the metatensor calculator.""" + +import math + +import pytest +import torch +from torch.testing import assert_close + +from meshlode import DirectPotential, EwaldPotential, MeshEwaldPotential, MeshPotential + + +MADELUNG_CSCL = torch.tensor(2 * 1.7626 / math.sqrt(3)) +CHARGES_CSCL = torch.tensor([1.0, -1.0]) + + +ATOMIC_SMEARING = 0.1 +LR_WAVELENGTH = ATOMIC_SMEARING / 4 +MESH_SPACING = ATOMIC_SMEARING / 4 +INTERPOLATION_ORDER = 2 +SUBTRACT_SELF = True + + +@pytest.mark.parametrize( + "CalculatorClass, params, periodic", + [ + (DirectPotential, {}, False), + ( + EwaldPotential, + { + "atomic_smearing": ATOMIC_SMEARING, + "lr_wavelength": LR_WAVELENGTH, + "subtract_self": SUBTRACT_SELF, + }, + True, + ), + ( + MeshEwaldPotential, + { + "atomic_smearing": ATOMIC_SMEARING, + "mesh_spacing": MESH_SPACING, + "interpolation_order": INTERPOLATION_ORDER, + "subtract_self": SUBTRACT_SELF, + }, + True, + ), + ( + MeshPotential, + { + "atomic_smearing": ATOMIC_SMEARING, + "mesh_spacing": MESH_SPACING, + "interpolation_order": INTERPOLATION_ORDER, + "subtract_self": SUBTRACT_SELF, + }, + True, + ), + ], +) +class TestWorkflow: + + def cscl_system(self, periodic): + """CsCl crystal. Same as in the madelung test""" + types = torch.tensor([55, 17]) + positions = torch.tensor([[0, 0, 0], [0.5, 0.5, 0.5]]) + cell = torch.eye(3) + + if periodic: + return types, positions, cell + else: + return types, positions + + def cscl_system_with_charges(self, periodic): + """CsCl crystal with charges.""" + charges = torch.tensor([[0.0, 1.0], [1.0, 0]]) + return self.cscl_system(periodic) + (charges,) + + def calculator(self, CalculatorClass, periodic, params): + if periodic: + return CalculatorClass(**params) + else: + return CalculatorClass() + + def test_forward(self, CalculatorClass, periodic, params): + calculator = self.calculator(CalculatorClass, periodic, params) + descriptor_compute = calculator.compute(*self.cscl_system(periodic)) + descriptor_forward = calculator.forward(*self.cscl_system(periodic)) + + assert type(descriptor_compute) is torch.Tensor + assert type(descriptor_forward) is torch.Tensor + assert torch.equal(descriptor_forward, descriptor_compute) + + def test_atomic_smearing_error(self, CalculatorClass, params, periodic): + if periodic: + with pytest.raises(ValueError, match="has to be positive"): + CalculatorClass(atomic_smearing=-1.0) + + def test_interpolation_order_error(self, CalculatorClass, params, periodic): + if type(CalculatorClass) in [MeshEwaldPotential, MeshPotential]: + match = "Only `interpolation_order` from 1 to 5" + with pytest.raises(ValueError, match=match): + CalculatorClass(atomic_smearing=1, interpolation_order=10) + + def test_all_types(self, CalculatorClass, params, periodic): + if periodic: + descriptor = CalculatorClass(atomic_smearing=0.1, all_types=[8, 55, 17]) + values = descriptor.compute(*self.cscl_system(periodic)) + assert values.shape == (2, 3) + assert torch.equal(values[:, 0], torch.zeros(2)) + + def test_all_types_error(self, CalculatorClass, params, periodic): + if periodic: + descriptor = CalculatorClass(atomic_smearing=0.1, all_types=[17]) + with pytest.raises(ValueError, match="Global list of types"): + descriptor.compute(*self.cscl_system(periodic)) + + def test_single_frame(self, CalculatorClass, periodic, params): + calculator = self.calculator(CalculatorClass, periodic, params) + values = calculator.compute(*self.cscl_system(periodic)) + assert_close( + MADELUNG_CSCL, + CHARGES_CSCL[0] * values[0, 0] + CHARGES_CSCL[1] * values[0, 1], + atol=1e4, + rtol=1e-5, + ) + + def test_single_frame_with_charges(self, CalculatorClass, periodic, params): + calculator = self.calculator(CalculatorClass, periodic, params) + values = calculator.compute(*self.cscl_system_with_charges(periodic)) + assert_close( + MADELUNG_CSCL, + CHARGES_CSCL[0] * values[0, 0] + CHARGES_CSCL[1] * values[0, 1], + atol=1e4, + rtol=1e-5, + ) + + def test_multi_frame(self, CalculatorClass, periodic, params): + calculator = self.calculator(CalculatorClass, periodic, params) + if periodic: + types, positions, cell = self.cscl_system(periodic) + l_values = calculator.compute( + types=[types, types], + positions=[positions, positions], + cell=[cell, cell], + ) + else: + types, positions = self.cscl_system(periodic) + l_values = calculator.compute( + types=[types, types], positions=[positions, positions] + ) + + for values in l_values: + assert_close( + MADELUNG_CSCL, + CHARGES_CSCL[0] * values[0, 0] + CHARGES_CSCL[1] * values[0, 1], + atol=1e4, + rtol=1e-5, + ) + + def test_dtype_device(self, CalculatorClass, periodic, params): + """Test that the output dtype and device are the same as the input.""" + device = "cpu" + dtype = torch.float64 + + types = torch.tensor([1], dtype=dtype, device=device) + positions = torch.tensor([[0.0, 0.0, 0.0]], dtype=dtype, device=device) + + calculator = self.calculator(CalculatorClass, periodic, params) + if periodic: + cell = torch.eye(3, dtype=dtype, device=device) + potential = calculator.compute(types=types, positions=positions, cell=cell) + else: + potential = calculator.compute(types=types, positions=positions) + + assert potential.dtype == dtype + assert potential.device.type == device + + # Make sure that the calculators are computing the features without raising errors, + # and returns the correct output format (TensorMap) + def check_operation(self, CalculatorClass, periodic, params): + calculator = self.calculator(CalculatorClass, periodic, params) + + if periodic: + types, positions, cell = self.cscl_system(periodic) + descriptor = calculator.compute(types=types, positions=positions, cell=cell) + else: + types, positions = self.cscl_system(periodic) + descriptor = calculator.compute(types=types, positions=positions) + + assert type(descriptor) is torch.Tensor + + # Run the above test as a normal python script + def test_operation_as_python(self, CalculatorClass, periodic, params): + self.check_operation(CalculatorClass, periodic, params) + + # Similar to the above, but also testing that the code can be compiled as a torch + # script + # def test_operation_as_torch_script(self, CalculatorClass, periodic, params): + # scripted = torch.jit.script(CalculatorClass, periodic, params) + # self.check_operation(scripted) diff --git a/tests/calculators/test_workflow_direct.py b/tests/calculators/test_workflow_direct.py deleted file mode 100644 index 5ea9d15f..00000000 --- a/tests/calculators/test_workflow_direct.py +++ /dev/null @@ -1,224 +0,0 @@ -"""Basic tests if the calculator works and is torch scriptable. Actual tests are done -for the metatensor calculator.""" - -import math - -import pytest -import torch -from torch.testing import assert_close - -from meshlode import DirectPotential -from meshlode.calculators.calculator_base import _1d_tolist, _is_subset - - -# MADELUNG_CSCL = torch.tensor(2 * 1.7626 / math.sqrt(3)) # periodic case -MADELUNG_CSCL = torch.tensor(2 * math.sqrt(3)) -CHARGES_CSCL = torch.tensor([1.0, -1.0]) - - -def cscl_system(): - """CsCl crystal. Same as in the madelung test""" - types = torch.tensor([55, 17]) - positions = torch.tensor([[0, 0, 0], [0.5, 0.5, 0.5]]) - - return types, positions - - -def cscl_system_with_charges(): - """CsCl crystal with (cell) and charges.""" - charges = torch.tensor([[0.0, 1.0], [1.0, 0]]) - return cscl_system() + (charges,) - - -# Initialize the calculators. For now, only the DirectPotential is implemented. -def descriptor() -> DirectPotential: - return DirectPotential() - - -def test_forward(): - mp = descriptor() - descriptor_compute = mp.compute(*cscl_system()) - descriptor_forward = mp.forward(*cscl_system()) - - assert torch.equal(descriptor_forward, descriptor_compute) - - -def test_all_types(): - descriptor = DirectPotential(all_types=[8, 55, 17]) - values = descriptor.compute(*cscl_system()) - - assert values.shape == (2, 3) - assert torch.equal(values[:, 0], torch.zeros(2)) - - -def test_all_types_error(): - descriptor = DirectPotential(all_types=[17]) - with pytest.raises(ValueError, match="Global list of types"): - descriptor.compute(*cscl_system()) - - -# Make sure that the calculators are computing the features without raising errors, -# and returns the correct output format (TensorMap) -def check_operation(calculator): - descriptor = calculator.compute(*cscl_system()) - assert type(descriptor) is torch.Tensor - - -# Run the above test as a normal python script -def test_operation_as_python(): - check_operation(descriptor()) - - -# Similar to the above, but also testing that the code can be compiled as a torch script -def test_operation_as_torch_script(): - scripted = torch.jit.script(descriptor()) - check_operation(scripted) - - -def test_single_frame(): - values = descriptor().compute(*cscl_system()) - assert_close( - MADELUNG_CSCL, - CHARGES_CSCL[0] * values[0, 0] + CHARGES_CSCL[1] * values[0, 1], - atol=1e4, - rtol=1e-5, - ) - - -# Test with explicit charges -def test_single_frame_with_charges(): - print(cscl_system_with_charges()) - values = descriptor().compute(*cscl_system_with_charges()) - assert_close( - MADELUNG_CSCL, - CHARGES_CSCL[0] * values[0, 0] + CHARGES_CSCL[1] * values[0, 1], - atol=1e4, - rtol=1e-5, - ) - - -def test_multi_frame(): - types, positions = cscl_system() - l_values = descriptor().compute( - types=[types, types], positions=[positions, positions] - ) - for values in l_values: - assert_close( - MADELUNG_CSCL, - CHARGES_CSCL[0] * values[0, 0] + CHARGES_CSCL[1] * values[0, 1], - atol=1e4, - rtol=1e-5, - ) - - -def test_types_error(): - types = torch.tensor([[1, 2], [3, 4]]) # This is a 2D tensor, should be 1D - positions = torch.zeros((2, 3)) - - match = ( - "each `types` must be a 1 dimensional tensor, got at least one tensor with " - "2 dimensions" - ) - with pytest.raises(ValueError, match=match): - descriptor().compute(types=types, positions=positions) - - -def test_positions_error(): - types = torch.tensor([1, 2]) - positions = torch.zeros( - (1, 3) - ) # This should have the same first dimension as types - - match = ( - "each `positions` must be a \\(n_types x 3\\) tensor, got at least " - "one tensor with shape \\[1, 3\\]" - ) - - with pytest.raises(ValueError, match=match): - descriptor().compute(types=types, positions=positions) - - -def test_charges_error_dimension_mismatch(): - types = torch.tensor([1, 2]) - positions = torch.zeros((2, 3)) - charges = torch.zeros((1, 2)) # This should have the same first dimension as types - - match = ( - "The first dimension of `charges` must be the same as the length " - "of `types`, got 1 and 2." - ) - - with pytest.raises(ValueError, match=match): - descriptor().compute(types=types, positions=positions, charges=charges) - - -def test_charges_error_length_mismatch(): - types = [torch.tensor([1, 2]), torch.tensor([1, 2, 3])] - positions = [torch.zeros((2, 3)), torch.zeros((3, 3))] - charges = [torch.zeros(2, 1)] # This should have the same length as types - match = "The number of `types` and `charges` tensors must be the same, got 2 and 1." - - with pytest.raises(ValueError, match=match): - descriptor().compute(types=types, positions=positions, charges=charges) - - -def test_dtype_device(): - """Test that the output dtype and device are the same as the input.""" - device = "cpu" - dtype = torch.float64 - - types = torch.tensor([1], dtype=dtype, device=device) - positions = torch.tensor([[0.0, 0.0, 0.0]], dtype=dtype, device=device) - - DP = DirectPotential() - potential = DP.compute(types=types, positions=positions) - - assert potential.dtype == dtype - assert potential.device.type == device - - -def test_inconsistent_device_charges(): - """Test if the chages and positions have inconsistent device and error is raised.""" - types = torch.tensor([1], device="cpu") - positions = torch.tensor([[0.0, 0.0, 0.0]], device="cpu") - charges = torch.tensor([0.0], device="meta") # different device - - DP = DirectPotential() - - match = "`charges` must be on the same device as `positions`, got meta and cpu." - with pytest.raises(ValueError, match=match): - DP.compute(types=types, positions=positions, charges=charges) - - -def test_inconsistent_dtype_charges(): - """Test if the charges and positions have inconsistent dtype and error is raised.""" - types = torch.tensor([1], dtype=torch.float32) - positions = torch.tensor([[0.0, 0.0, 0.0]], dtype=torch.float32) - charges = torch.tensor([0.0], dtype=torch.float64) # Different dtype - - DP = DirectPotential() - - match = ( - "`charges` must be have the same dtype as `positions`, got torch.float64 and " - "torch.float32" - ) - with pytest.raises(ValueError, match=match): - DP.compute(types=types, positions=positions, charges=charges) - - -def test_1d_tolist(): - in_list = [1, 2, 7, 3, 4, 42] - in_tensor = torch.tensor(in_list) - assert _1d_tolist(in_tensor) == in_list - - -def test_is_subset_true(): - subset_candidate = [1, 2] - superset = [1, 2, 3, 4, 5] - assert _is_subset(subset_candidate, superset) - - -def test_is_subset_false(): - subset_candidate = [1, 2, 8] - superset = [1, 2, 3, 4, 5] - assert not _is_subset(subset_candidate, superset) diff --git a/tests/calculators/test_workflow_ewald.py b/tests/calculators/test_workflow_ewald.py deleted file mode 100644 index 788eaa7c..00000000 --- a/tests/calculators/test_workflow_ewald.py +++ /dev/null @@ -1,296 +0,0 @@ -"""Basic tests if the calculator works and is torch scriptable. Actual tests are done -for the metatensor calculator.""" - -import math - -import pytest -import torch -from torch.testing import assert_close - -from meshlode import EwaldPotential -from meshlode.calculators.calculator_base import _1d_tolist, _is_subset - - -MADELUNG_CSCL = torch.tensor(2 * 1.7626 / math.sqrt(3)) -CHARGES_CSCL = torch.tensor([1.0, -1.0]) - - -def cscl_system(): - """CsCl crystal. Same as in the madelung test""" - types = torch.tensor([55, 17]) - positions = torch.tensor([[0, 0, 0], [0.5, 0.5, 0.5]]) - cell = torch.eye(3) - - return types, positions, cell - - -def cscl_system_with_charges(): - """CsCl crystal with charges.""" - charges = torch.tensor([[0.0, 1.0], [1.0, 0]]) - return cscl_system() + (charges,) - - -# Initialize the calculators. For now, only the EwaldPotential is implemented. -def descriptor() -> EwaldPotential: - atomic_smearing = 0.1 - return EwaldPotential( - atomic_smearing=atomic_smearing, - lr_wavelength=atomic_smearing / 4, - subtract_self=True, - ) - - -def test_forward(): - mp = descriptor() - descriptor_compute = mp.compute(*cscl_system()) - descriptor_forward = mp.forward(*cscl_system()) - - assert torch.equal(descriptor_forward, descriptor_compute) - - -def test_all_types(): - descriptor = EwaldPotential(atomic_smearing=0.1, all_types=[8, 55, 17]) - values = descriptor.compute(*cscl_system()) - - assert values.shape == (2, 3) - assert torch.equal(values[:, 0], torch.zeros(2)) - - -def test_all_types_error(): - descriptor = EwaldPotential(atomic_smearing=0.1, all_types=[17]) - with pytest.raises(ValueError, match="Global list of types"): - descriptor.compute(*cscl_system()) - - -# Make sure that the calculators are computing the features without raising errors, -# and returns the correct output format (TensorMap) -def check_operation(calculator): - descriptor = calculator.compute(*cscl_system()) - assert type(descriptor) is torch.Tensor - - -# Run the above test as a normal python script -def test_operation_as_python(): - check_operation(descriptor()) - - -""" -# Similar to the above, but also testing that the code can be compiled as a torch script -def test_operation_as_torch_script(): - scripted = torch.jit.script(descriptor()) - check_operation(scripted) -""" - - -def test_single_frame(): - values = descriptor().compute(*cscl_system()) - assert_close( - MADELUNG_CSCL, - CHARGES_CSCL[0] * values[0, 0] + CHARGES_CSCL[1] * values[0, 1], - atol=1e4, - rtol=1e-5, - ) - - -# Test with explicit charges -def test_single_frame_with_charges(): - values = descriptor().compute(*cscl_system_with_charges()) - assert_close( - MADELUNG_CSCL, - CHARGES_CSCL[0] * values[0, 0] + CHARGES_CSCL[1] * values[0, 1], - atol=1e4, - rtol=1e-5, - ) - - -def test_multi_frame(): - types, positions, cell = cscl_system() - l_values = descriptor().compute( - types=[types, types], positions=[positions, positions], cell=[cell, cell] - ) - for values in l_values: - assert_close( - MADELUNG_CSCL, - CHARGES_CSCL[0] * values[0, 0] + CHARGES_CSCL[1] * values[0, 1], - atol=1e4, - rtol=1e-5, - ) - - -def test_types_error(): - types = torch.tensor([[1, 2], [3, 4]]) # This is a 2D tensor, should be 1D - positions = torch.zeros((2, 3)) - cell = torch.eye(3) - - match = ( - "each `types` must be a 1 dimensional tensor, got at least one tensor with " - "2 dimensions" - ) - with pytest.raises(ValueError, match=match): - descriptor().compute(types=types, positions=positions, cell=cell) - - -def test_positions_error(): - types = torch.tensor([1, 2]) - positions = torch.zeros( - (1, 3) - ) # This should have the same first dimension as types - cell = torch.eye(3) - - match = ( - "each `positions` must be a \\(n_types x 3\\) tensor, got at least " - "one tensor with shape \\[1, 3\\]" - ) - - with pytest.raises(ValueError, match=match): - descriptor().compute(types=types, positions=positions, cell=cell) - - -def test_charges_error_dimension_mismatch(): - types = torch.tensor([1, 2]) - positions = torch.zeros((2, 3)) - cell = torch.eye(3) - charges = torch.zeros((1, 2)) # This should have the same first dimension as types - - match = ( - "The first dimension of `charges` must be the same as the length " - "of `types`, got 1 and 2." - ) - - with pytest.raises(ValueError, match=match): - descriptor().compute( - types=types, positions=positions, cell=cell, charges=charges - ) - - -def test_charges_error_length_mismatch(): - types = [torch.tensor([1, 2]), torch.tensor([1, 2, 3])] - positions = [torch.zeros((2, 3)), torch.zeros((3, 3))] - cell = [torch.eye(3), torch.eye(3)] - charges = [torch.zeros(2, 1)] # This should have the same length as types - match = "The number of `types` and `charges` tensors must be the same, got 2 and 1." - - with pytest.raises(ValueError, match=match): - descriptor().compute( - types=types, positions=positions, cell=cell, charges=charges - ) - - -def test_cell_error(): - types = torch.tensor([1, 2, 3]) - positions = torch.zeros((3, 3)) - cell = torch.eye(2) # This is a 2x2 tensor, should be 3x3 - - match = ( - "each `cell` must be a \\(3 x 3\\) tensor, got at least one tensor " - "with shape \\[2, 2\\]" - ) - with pytest.raises(ValueError, match=match): - descriptor().compute(types=types, positions=positions, cell=cell) - - -def test_positions_cell_dtype_error(): - types = torch.tensor([1, 2, 3]) - positions = torch.zeros((3, 3), dtype=torch.float32) - cell = torch.eye(3, dtype=torch.float64) - - match = ( - "`cell` must be have the same dtype as `positions`, got torch.float64 " - "and torch.float32" - ) - with pytest.raises(ValueError, match=match): - descriptor().compute(types=types, positions=positions, cell=cell) - - -def test_dtype_device(): - """Test that the output dtype and device are the same as the input.""" - device = "cpu" - dtype = torch.float64 - - types = torch.tensor([1], dtype=dtype, device=device) - positions = torch.tensor([[0.0, 0.0, 0.0]], dtype=dtype, device=device) - cell = torch.eye(3, dtype=dtype, device=device) - - EP = EwaldPotential(atomic_smearing=0.2) - potential = EP.compute(types=types, positions=positions, cell=cell) - - assert potential.dtype == dtype - assert potential.device.type == device - - -def test_inconsistent_dtype(): - """Test if the cell and positions have inconsistent dtype and error is raised.""" - types = torch.tensor([1], dtype=torch.float32) - positions = torch.tensor([[0.0, 0.0, 0.0]], dtype=torch.float64) # Different dtype - cell = torch.eye(3, dtype=torch.float32) - - EP = EwaldPotential(atomic_smearing=0.2) - - match = ( - "`cell` must be have the same dtype as `positions`, got torch.float32 and " - "torch.float64" - ) - with pytest.raises(ValueError, match=match): - EP.compute(types=types, positions=positions, cell=cell) - - -def test_inconsistent_device(): - """Test if the cell and positions have inconsistent device and error is raised.""" - types = torch.tensor([1], device="cpu") - positions = torch.tensor([[0.0, 0.0, 0.0]], device="cpu") - cell = torch.eye(3, device="meta") # different device - - EP = EwaldPotential(atomic_smearing=0.2) - - match = r"Inconsistent devices of types \(cpu\) and cell \(meta\)" - with pytest.raises(ValueError, match=match): - EP.compute(types=types, positions=positions, cell=cell) - - -def test_inconsistent_device_charges(): - """Test if the cell and positions have inconsistent device and error is raised.""" - types = torch.tensor([1], device="cpu") - positions = torch.tensor([[0.0, 0.0, 0.0]], device="cpu") - cell = torch.eye(3, device="cpu") - charges = torch.tensor([0.0], device="meta") # different device - - EP = EwaldPotential(atomic_smearing=0.2) - - match = "`charges` must be on the same device as `positions`, got meta and cpu." - with pytest.raises(ValueError, match=match): - EP.compute(types=types, positions=positions, cell=cell, charges=charges) - - -def test_inconsistent_dtype_charges(): - """Test if the cell and positions have inconsistent dtype and error is raised.""" - types = torch.tensor([1], dtype=torch.float32) - positions = torch.tensor([[0.0, 0.0, 0.0]], dtype=torch.float32) - cell = torch.eye(3, dtype=torch.float32) - charges = torch.tensor([0.0], dtype=torch.float64) # Different dtype - - EP = EwaldPotential(atomic_smearing=0.2) - - match = ( - "`charges` must be have the same dtype as `positions`, got torch.float64 and " - "torch.float32" - ) - with pytest.raises(ValueError, match=match): - EP.compute(types=types, positions=positions, cell=cell, charges=charges) - - -def test_1d_tolist(): - in_list = [1, 2, 7, 3, 4, 42] - in_tensor = torch.tensor(in_list) - assert _1d_tolist(in_tensor) == in_list - - -def test_is_subset_true(): - subset_candidate = [1, 2] - superset = [1, 2, 3, 4, 5] - assert _is_subset(subset_candidate, superset) - - -def test_is_subset_false(): - subset_candidate = [1, 2, 8] - superset = [1, 2, 3, 4, 5] - assert not _is_subset(subset_candidate, superset) diff --git a/tests/calculators/test_workflow_mesh.py b/tests/calculators/test_workflow_mesh.py deleted file mode 100644 index f0827bf1..00000000 --- a/tests/calculators/test_workflow_mesh.py +++ /dev/null @@ -1,308 +0,0 @@ -"""Basic tests if the calculator works and is torch scriptable. Actual tests are done -for the metatensor calculator.""" - -import math - -import pytest -import torch -from torch.testing import assert_close - -from meshlode import MeshPotential -from meshlode.calculators.calculator_base import _1d_tolist, _is_subset - - -MADELUNG_CSCL = torch.tensor(2 * 1.7626 / math.sqrt(3)) -CHARGES_CSCL = torch.tensor([1.0, -1.0]) - - -def cscl_system(): - """CsCl crystal. Same as in the madelung test""" - types = torch.tensor([55, 17]) - positions = torch.tensor([[0, 0, 0], [0.5, 0.5, 0.5]]) - cell = torch.eye(3) - - return types, positions, cell - - -def cscl_system_with_charges(): - """CsCl crystal with charges.""" - charges = torch.tensor([[0.0, 1.0], [1.0, 0]]) - return cscl_system() + (charges,) - - -# Initialize the calculators. For now, only the MeshPotential is implemented. -def descriptor() -> MeshPotential: - atomic_smearing = 0.1 - return MeshPotential( - atomic_smearing=atomic_smearing, - mesh_spacing=atomic_smearing / 4, - interpolation_order=2, - subtract_self=True, - ) - - -def test_forward(): - mp = descriptor() - descriptor_compute = mp.compute(*cscl_system()) - descriptor_forward = mp.forward(*cscl_system()) - - assert torch.equal(descriptor_forward, descriptor_compute) - - -def test_atomic_smearing_error(): - with pytest.raises(ValueError, match="has to be positive"): - MeshPotential(atomic_smearing=-1.0) - - -def test_interpolation_order_error(): - with pytest.raises(ValueError, match="Only `interpolation_order` from 1 to 5"): - MeshPotential(atomic_smearing=1, interpolation_order=10) - - -def test_all_types(): - descriptor = MeshPotential(atomic_smearing=0.1, all_types=[8, 55, 17]) - values = descriptor.compute(*cscl_system()) - assert values.shape == (2, 3) - assert torch.equal(values[:, 0], torch.zeros(2)) - - -def test_all_types_error(): - descriptor = MeshPotential(atomic_smearing=0.1, all_types=[17]) - with pytest.raises(ValueError, match="Global list of types"): - descriptor.compute(*cscl_system()) - - -# Make sure that the calculators are computing the features without raising errors, -# and returns the correct output format (TensorMap) -def check_operation(calculator): - types, pos, cell = cscl_system() - print(cell) - descriptor = calculator.compute(types=types, positions=pos, cell=cell) - assert type(descriptor) is torch.Tensor - - -# Run the above test as a normal python script -def test_operation_as_python(): - check_operation(descriptor()) - - -# Similar to the above, but also testing that the code can be compiled as a torch script - - -# def test_operation_as_torch_script(): -# scripted = torch.jit.script(descriptor()) -# check_operation(scripted) - - -def test_single_frame(): - values = descriptor().compute(*cscl_system()) - assert_close( - MADELUNG_CSCL, - CHARGES_CSCL[0] * values[0, 0] + CHARGES_CSCL[1] * values[0, 1], - atol=1e4, - rtol=1e-5, - ) - - -# Test with explicit charges -def test_single_frame_with_charges(): - values = descriptor().compute(*cscl_system_with_charges()) - assert_close( - MADELUNG_CSCL, - CHARGES_CSCL[0] * values[0, 0] + CHARGES_CSCL[1] * values[0, 1], - atol=1e4, - rtol=1e-5, - ) - - -def test_multi_frame(): - types, positions, cell = cscl_system() - l_values = descriptor().compute( - types=[types, types], positions=[positions, positions], cell=[cell, cell] - ) - for values in l_values: - assert_close( - MADELUNG_CSCL, - CHARGES_CSCL[0] * values[0, 0] + CHARGES_CSCL[1] * values[0, 1], - atol=1e4, - rtol=1e-5, - ) - - -def test_types_error(): - types = torch.tensor([[1, 2], [3, 4]]) # This is a 2D tensor, should be 1D - positions = torch.zeros((2, 3)) - cell = torch.eye(3) - - match = ( - "each `types` must be a 1 dimensional tensor, got at least one tensor with " - "2 dimensions" - ) - with pytest.raises(ValueError, match=match): - descriptor().compute(types=types, positions=positions, cell=cell) - - -def test_positions_error(): - types = torch.tensor([1, 2]) - positions = torch.zeros( - (1, 3) - ) # This should have the same first dimension as types - cell = torch.eye(3) - - match = ( - "each `positions` must be a \\(n_types x 3\\) tensor, got at least " - "one tensor with shape \\[1, 3\\]" - ) - - with pytest.raises(ValueError, match=match): - descriptor().compute(types=types, positions=positions, cell=cell) - - -def test_charges_error_dimension_mismatch(): - types = torch.tensor([1, 2]) - positions = torch.zeros((2, 3)) - cell = torch.eye(3) - charges = torch.zeros((1, 2)) # This should have the same first dimension as types - - match = ( - "The first dimension of `charges` must be the same as the length " - "of `types`, got 1 and 2." - ) - - with pytest.raises(ValueError, match=match): - descriptor().compute( - types=types, positions=positions, cell=cell, charges=charges - ) - - -def test_charges_error_length_mismatch(): - types = [torch.tensor([1, 2]), torch.tensor([1, 2, 3])] - positions = [torch.zeros((2, 3)), torch.zeros((3, 3))] - cell = [torch.eye(3), torch.eye(3)] - charges = [torch.zeros(2, 1)] # This should have the same length as types - match = "The number of `types` and `charges` tensors must be the same, got 2 and 1." - - with pytest.raises(ValueError, match=match): - descriptor().compute( - types=types, positions=positions, cell=cell, charges=charges - ) - - -def test_cell_error(): - types = torch.tensor([1, 2, 3]) - positions = torch.zeros((3, 3)) - cell = torch.eye(2) # This is a 2x2 tensor, should be 3x3 - - match = ( - "each `cell` must be a \\(3 x 3\\) tensor, got at least one tensor " - "with shape \\[2, 2\\]" - ) - with pytest.raises(ValueError, match=match): - descriptor().compute(types=types, positions=positions, cell=cell) - - -def test_positions_cell_dtype_error(): - types = torch.tensor([1, 2, 3]) - positions = torch.zeros((3, 3), dtype=torch.float32) - cell = torch.eye(3, dtype=torch.float64) - - match = ( - "`cell` must be have the same dtype as `positions`, got torch.float64 " - "and torch.float32" - ) - with pytest.raises(ValueError, match=match): - descriptor().compute(types=types, positions=positions, cell=cell) - - -def test_dtype_device(): - """Test that the output dtype and device are the same as the input.""" - device = "cpu" - dtype = torch.float64 - - types = torch.tensor([1], dtype=dtype, device=device) - positions = torch.tensor([[0.0, 0.0, 0.0]], dtype=dtype, device=device) - cell = torch.eye(3, dtype=dtype, device=device) - - MP = MeshPotential(atomic_smearing=0.2) - potential = MP.compute(types=types, positions=positions, cell=cell) - - assert potential.dtype == dtype - assert potential.device.type == device - - -def test_inconsistent_dtype(): - """Test if the cell and positions have inconsistent dtype and error is raised.""" - types = torch.tensor([1], dtype=torch.float32) - positions = torch.tensor([[0.0, 0.0, 0.0]], dtype=torch.float64) # Different dtype - cell = torch.eye(3, dtype=torch.float32) - - MP = MeshPotential(atomic_smearing=0.2) - - match = ( - "`cell` must be have the same dtype as `positions`, got torch.float32 and " - "torch.float64" - ) - with pytest.raises(ValueError, match=match): - MP.compute(types=types, positions=positions, cell=cell) - - -def test_inconsistent_device(): - """Test if the cell and positions have inconsistent device and error is raised.""" - types = torch.tensor([1], device="cpu") - positions = torch.tensor([[0.0, 0.0, 0.0]], device="cpu") - cell = torch.eye(3, device="meta") # different device - - MP = MeshPotential(atomic_smearing=0.2) - - match = r"Inconsistent devices of types \(cpu\) and cell \(meta\)" - with pytest.raises(ValueError, match=match): - MP.compute(types=types, positions=positions, cell=cell) - - -def test_inconsistent_device_charges(): - """Test if the cell and positions have inconsistent device and error is raised.""" - types = torch.tensor([1], device="cpu") - positions = torch.tensor([[0.0, 0.0, 0.0]], device="cpu") - cell = torch.eye(3, device="cpu") - charges = torch.tensor([0.0], device="meta") # different device - - MP = MeshPotential(atomic_smearing=0.2) - - match = "`charges` must be on the same device as `positions`, got meta and cpu." - with pytest.raises(ValueError, match=match): - MP.compute(types=types, positions=positions, cell=cell, charges=charges) - - -def test_inconsistent_dtype_charges(): - """Test if the cell and positions have inconsistent dtype and error is raised.""" - types = torch.tensor([1], dtype=torch.float32) - positions = torch.tensor([[0.0, 0.0, 0.0]], dtype=torch.float32) - cell = torch.eye(3, dtype=torch.float32) - charges = torch.tensor([0.0], dtype=torch.float64) # Different dtype - - MP = MeshPotential(atomic_smearing=0.2) - - match = ( - "`charges` must be have the same dtype as `positions`, got torch.float64 and " - "torch.float32" - ) - with pytest.raises(ValueError, match=match): - MP.compute(types=types, positions=positions, cell=cell, charges=charges) - - -def test_1d_tolist(): - in_list = [1, 2, 7, 3, 4, 42] - in_tensor = torch.tensor(in_list) - assert _1d_tolist(in_tensor) == in_list - - -def test_is_subset_true(): - subset_candidate = [1, 2] - superset = [1, 2, 3, 4, 5] - assert _is_subset(subset_candidate, superset) - - -def test_is_subset_false(): - subset_candidate = [1, 2, 8] - superset = [1, 2, 3, 4, 5] - assert not _is_subset(subset_candidate, superset) diff --git a/tests/calculators/test_workflow_meshewald.py b/tests/calculators/test_workflow_meshewald.py deleted file mode 100644 index 05637cea..00000000 --- a/tests/calculators/test_workflow_meshewald.py +++ /dev/null @@ -1,310 +0,0 @@ -"""Basic tests if the calculator works and is torch scriptable. Actual tests are done -for the metatensor calculator.""" - -import math - -import pytest -import torch -from torch.testing import assert_close - -from meshlode import MeshEwaldPotential, MeshPotential -from meshlode.calculators.calculator_base import _1d_tolist, _is_subset - - -MADELUNG_CSCL = torch.tensor(2 * 1.7626 / math.sqrt(3)) -CHARGES_CSCL = torch.tensor([1.0, -1.0]) - - -def cscl_system(): - """CsCl crystal. Same as in the madelung test""" - types = torch.tensor([55, 17]) - positions = torch.tensor([[0, 0, 0], [0.5, 0.5, 0.5]]) - cell = torch.eye(3) - - return types, positions, cell - - -def cscl_system_with_charges(): - """CsCl crystal with charges.""" - charges = torch.tensor([[0.0, 1.0], [1.0, 0]]) - return cscl_system() + (charges,) - - -# Initialize the calculators. For now, only the MeshPotential is implemented. -def descriptor() -> MeshEwaldPotential: - atomic_smearing = 0.1 - return MeshEwaldPotential( - atomic_smearing=atomic_smearing, - mesh_spacing=atomic_smearing / 4, - interpolation_order=2, - subtract_self=True, - ) - - -def test_forward(): - mp = descriptor() - descriptor_compute = mp.compute(*cscl_system()) - descriptor_forward = mp.forward(*cscl_system()) - - assert torch.equal(descriptor_forward, descriptor_compute) - - -def test_atomic_smearing_error(): - with pytest.raises(ValueError, match="has to be positive"): - MeshEwaldPotential(atomic_smearing=-1.0) - - -def test_interpolation_order_error(): - with pytest.raises(ValueError, match="Only `interpolation_order` from 1 to 5"): - MeshEwaldPotential(atomic_smearing=1, interpolation_order=10) - - -def test_all_types(): - descriptor = MeshPotential(atomic_smearing=0.1, all_types=[8, 55, 17]) - values = descriptor.compute(*cscl_system()) - assert values.shape == (2, 3) - assert torch.equal(values[:, 0], torch.zeros(2)) - - -def test_all_types_error(): - descriptor = MeshPotential(atomic_smearing=0.1, all_types=[17]) - with pytest.raises(ValueError, match="Global list of types"): - descriptor.compute(*cscl_system()) - - -# Make sure that the calculators are computing the features without raising errors, -# and returns the correct output format (TensorMap) -def check_operation(calculator): - types, pos, cell = cscl_system() - print(cell) - descriptor = calculator.compute(types=types, positions=pos, cell=cell) - assert type(descriptor) is torch.Tensor - - -# Run the above test as a normal python script -def test_operation_as_python(): - check_operation(descriptor()) - - -""" -# Similar to the above, but also testing that the code can be compiled as a torch script -# Disabled for now since (1) the ASE neighbor list and (2) the use of the potential -# class are clashing with the torch script capabilities. -def test_operation_as_torch_script(): - scripted = torch.jit.script(descriptor()) - check_operation(scripted) -""" - - -def test_single_frame(): - values = descriptor().compute(*cscl_system()) - assert_close( - MADELUNG_CSCL, - CHARGES_CSCL[0] * values[0, 0] + CHARGES_CSCL[1] * values[0, 1], - atol=1e4, - rtol=1e-5, - ) - - -# Test with explicit charges -def test_single_frame_with_charges(): - values = descriptor().compute(*cscl_system_with_charges()) - assert_close( - MADELUNG_CSCL, - CHARGES_CSCL[0] * values[0, 0] + CHARGES_CSCL[1] * values[0, 1], - atol=1e4, - rtol=1e-5, - ) - - -def test_multi_frame(): - types, positions, cell = cscl_system() - l_values = descriptor().compute( - types=[types, types], positions=[positions, positions], cell=[cell, cell] - ) - for values in l_values: - assert_close( - MADELUNG_CSCL, - CHARGES_CSCL[0] * values[0, 0] + CHARGES_CSCL[1] * values[0, 1], - atol=1e4, - rtol=1e-5, - ) - - -def test_types_error(): - types = torch.tensor([[1, 2], [3, 4]]) # This is a 2D tensor, should be 1D - positions = torch.zeros((2, 3)) - cell = torch.eye(3) - - match = ( - "each `types` must be a 1 dimensional tensor, got at least one tensor with " - "2 dimensions" - ) - with pytest.raises(ValueError, match=match): - descriptor().compute(types=types, positions=positions, cell=cell) - - -def test_positions_error(): - types = torch.tensor([1, 2]) - positions = torch.zeros( - (1, 3) - ) # This should have the same first dimension as types - cell = torch.eye(3) - - match = ( - "each `positions` must be a \\(n_types x 3\\) tensor, got at least " - "one tensor with shape \\[1, 3\\]" - ) - - with pytest.raises(ValueError, match=match): - descriptor().compute(types=types, positions=positions, cell=cell) - - -def test_charges_error_dimension_mismatch(): - types = torch.tensor([1, 2]) - positions = torch.zeros((2, 3)) - cell = torch.eye(3) - charges = torch.zeros((1, 2)) # This should have the same first dimension as types - - match = ( - "The first dimension of `charges` must be the same as the length " - "of `types`, got 1 and 2." - ) - - with pytest.raises(ValueError, match=match): - descriptor().compute( - types=types, positions=positions, cell=cell, charges=charges - ) - - -def test_charges_error_length_mismatch(): - types = [torch.tensor([1, 2]), torch.tensor([1, 2, 3])] - positions = [torch.zeros((2, 3)), torch.zeros((3, 3))] - cell = [torch.eye(3), torch.eye(3)] - charges = [torch.zeros(2, 1)] # This should have the same length as types - match = "The number of `types` and `charges` tensors must be the same, got 2 and 1." - - with pytest.raises(ValueError, match=match): - descriptor().compute( - types=types, positions=positions, cell=cell, charges=charges - ) - - -def test_cell_error(): - types = torch.tensor([1, 2, 3]) - positions = torch.zeros((3, 3)) - cell = torch.eye(2) # This is a 2x2 tensor, should be 3x3 - - match = ( - "each `cell` must be a \\(3 x 3\\) tensor, got at least one tensor " - "with shape \\[2, 2\\]" - ) - with pytest.raises(ValueError, match=match): - descriptor().compute(types=types, positions=positions, cell=cell) - - -def test_positions_cell_dtype_error(): - types = torch.tensor([1, 2, 3]) - positions = torch.zeros((3, 3), dtype=torch.float32) - cell = torch.eye(3, dtype=torch.float64) - - match = ( - "`cell` must be have the same dtype as `positions`, got torch.float64 " - "and torch.float32" - ) - with pytest.raises(ValueError, match=match): - descriptor().compute(types=types, positions=positions, cell=cell) - - -def test_dtype_device(): - """Test that the output dtype and device are the same as the input.""" - device = "cpu" - dtype = torch.float64 - - types = torch.tensor([1], dtype=dtype, device=device) - positions = torch.tensor([[0.0, 0.0, 0.0]], dtype=dtype, device=device) - cell = torch.eye(3, dtype=dtype, device=device) - - MP = MeshPotential(atomic_smearing=0.2) - potential = MP.compute(types=types, positions=positions, cell=cell) - - assert potential.dtype == dtype - assert potential.device.type == device - - -def test_inconsistent_dtype(): - """Test if the cell and positions have inconsistent dtype and error is raised.""" - types = torch.tensor([1], dtype=torch.float32) - positions = torch.tensor([[0.0, 0.0, 0.0]], dtype=torch.float64) # Different dtype - cell = torch.eye(3, dtype=torch.float32) - - MP = MeshPotential(atomic_smearing=0.2) - - match = ( - "`cell` must be have the same dtype as `positions`, got torch.float32 and " - "torch.float64" - ) - with pytest.raises(ValueError, match=match): - MP.compute(types=types, positions=positions, cell=cell) - - -def test_inconsistent_device(): - """Test if the cell and positions have inconsistent device and error is raised.""" - types = torch.tensor([1], device="cpu") - positions = torch.tensor([[0.0, 0.0, 0.0]], device="cpu") - cell = torch.eye(3, device="meta") # different device - - MP = MeshPotential(atomic_smearing=0.2) - - match = r"Inconsistent devices of types \(cpu\) and cell \(meta\)" - with pytest.raises(ValueError, match=match): - MP.compute(types=types, positions=positions, cell=cell) - - -def test_inconsistent_device_charges(): - """Test if the cell and positions have inconsistent device and error is raised.""" - types = torch.tensor([1], device="cpu") - positions = torch.tensor([[0.0, 0.0, 0.0]], device="cpu") - cell = torch.eye(3, device="cpu") - charges = torch.tensor([0.0], device="meta") # different device - - MP = MeshPotential(atomic_smearing=0.2) - - match = "`charges` must be on the same device as `positions`, got meta and cpu." - with pytest.raises(ValueError, match=match): - MP.compute(types=types, positions=positions, cell=cell, charges=charges) - - -def test_inconsistent_dtype_charges(): - """Test if the cell and positions have inconsistent dtype and error is raised.""" - types = torch.tensor([1], dtype=torch.float32) - positions = torch.tensor([[0.0, 0.0, 0.0]], dtype=torch.float32) - cell = torch.eye(3, dtype=torch.float32) - charges = torch.tensor([0.0], dtype=torch.float64) # Different dtype - - MP = MeshPotential(atomic_smearing=0.2) - - match = ( - "`charges` must be have the same dtype as `positions`, got torch.float64 and " - "torch.float32" - ) - with pytest.raises(ValueError, match=match): - MP.compute(types=types, positions=positions, cell=cell, charges=charges) - - -def test_1d_tolist(): - in_list = [1, 2, 7, 3, 4, 42] - in_tensor = torch.tensor(in_list) - assert _1d_tolist(in_tensor) == in_list - - -def test_is_subset_true(): - subset_candidate = [1, 2] - superset = [1, 2, 3, 4, 5] - assert _is_subset(subset_candidate, superset) - - -def test_is_subset_false(): - subset_candidate = [1, 2, 8] - superset = [1, 2, 3, 4, 5] - assert not _is_subset(subset_candidate, superset) From ef714a0308c2a2e6e79f19b041ca690336ba348e Mon Sep 17 00:00:00 2001 From: Philip Loche Date: Thu, 4 Jul 2024 22:24:26 +0200 Subject: [PATCH 16/35] cleanup workflow tests --- src/meshlode/calculators/calculator_base.py | 18 +- tests/calculators/test_calculator_base.py | 224 ++++++++++++++++++++ 2 files changed, 240 insertions(+), 2 deletions(-) create mode 100644 tests/calculators/test_calculator_base.py diff --git a/src/meshlode/calculators/calculator_base.py b/src/meshlode/calculators/calculator_base.py index 2216c7e0..851a5bd5 100644 --- a/src/meshlode/calculators/calculator_base.py +++ b/src/meshlode/calculators/calculator_base.py @@ -190,7 +190,14 @@ def _validate_compute_parameters( ) if neighbor_indices_single is not None: - # TODO validate shape and dtype + # TODO test dtype + + if neighbor_indices_single.shape != (2, len(types_single)): + raise ValueError( + "Expected shape of neighbor_indices is " + f"{2, len(types_single)}, but got " + f"{list(neighbor_indices_single.shape)}" + ) if types_single.device != neighbor_indices_single.device: raise ValueError( @@ -199,7 +206,14 @@ def _validate_compute_parameters( ) if neighbor_shifts_single is not None: - # TODO validate shape and dtype + # TODO test dtype + + if neighbor_shifts_single.shape != (3, len(types_single)): + raise ValueError( + "Expected shape of neighbor_shifts is " + f"{3, len(types_single)}, but got " + f"{list(neighbor_shifts_single.shape)}" + ) if types_single.device != neighbor_shifts_single.device: raise ValueError( diff --git a/tests/calculators/test_calculator_base.py b/tests/calculators/test_calculator_base.py new file mode 100644 index 00000000..f4ddcfce --- /dev/null +++ b/tests/calculators/test_calculator_base.py @@ -0,0 +1,224 @@ +import pytest +import torch + +from meshlode.calculators.calculator_base import CalculatorBase + + +class TestCalculator(CalculatorBase): + def compute( + self, types, positions, cell, charges, neighbor_indices, neighbor_shifts + ): + return self._compute_impl( + types=types, + positions=positions, + cell=cell, + charges=charges, + neighbor_indices=neighbor_indices, + neighbor_shifts=neighbor_shifts, + ) + + def forward( + self, types, positions, cell, charges, neighbor_indices, neighbor_shifts + ): + return self._compute_impl( + types=types, + positions=positions, + cell=cell, + charges=charges, + neighbor_indices=neighbor_indices, + neighbor_shifts=neighbor_shifts, + ) + + def _compute_single_system( + self, positions, cell, charges, neighbor_indices, neighbor_shifts + ): + return charges + + +@pytest.mark.parametrize("method_name", ["compute", "forward"]) +@pytest.mark.parametrize( + "types, positions, charges", + [ + (torch.arange(2), torch.ones([2, 3]), torch.ones(2)), + ([torch.arange(2)], [torch.ones([2, 3])], [torch.ones(2)]), + ( + [torch.arange(2), torch.arange(4)], + [torch.ones([2, 3]), torch.ones([4, 3])], + [torch.ones(2), torch.ones(4)], + ), + ], +) +def test_compute(method_name, types, positions, charges): + calculator = TestCalculator() + method = getattr(calculator, method_name) + + result = method( + types=types, + positions=positions, + cell=None, + charges=charges, + neighbor_indices=None, + neighbor_shifts=None, + ) + if type(result) is list: + for charge_single, result_single in zip(charges, result): + assert result_single.shape == charge_single.shape + else: + if type(charges) is list: + charges = charges[0] + assert result.shape == charges.shape + + +def test_mismatched_lengths_types_positions(): + calculator = TestCalculator() + match = r"inconsistent lengths of types \(\d+\) positions \(\d+\)" + with pytest.raises(ValueError, match=match): + calculator.compute( + types=torch.arange(2), + positions=[torch.ones([2, 3]), torch.ones([3, 3])], + cell=None, + charges=None, + neighbor_indices=None, + neighbor_shifts=None, + ) + + +def test_invalid_shape_positions(): + calculator = TestCalculator() + match = ( + r"each `positions` must be a \(n_types x 3\) tensor, got at least one tensor " + r"with shape \[3, 3\]" + ) + with pytest.raises(ValueError, match=match): + calculator.compute( + types=torch.arange(2), + positions=torch.ones([3, 3]), + cell=None, + charges=None, + neighbor_indices=None, + neighbor_shifts=None, + ) + + +def test_mismatched_lengths_types_cell(): + calculator = TestCalculator() + match = r"inconsistent lengths of types \(\d+\) and cell \(\d+\)" + with pytest.raises(ValueError, match=match): + calculator.compute( + types=torch.arange(2), + positions=torch.ones([2, 3]), + cell=[torch.ones([3, 3]), torch.ones([3, 3])], + charges=None, + neighbor_indices=None, + neighbor_shifts=None, + ) + + +def test_inconsistent_devices(): + calculator = TestCalculator() + match = r"Inconsistent devices of types \([a-zA-Z:]+\) and positions \([a-zA-Z:]+\)" + with pytest.raises(ValueError, match=match): + calculator.compute( + types=torch.arange(2, device="meta"), + positions=torch.ones([2, 3], device="cpu"), + cell=None, + charges=None, + neighbor_indices=None, + neighbor_shifts=None, + ) + + +def test_inconsistent_dtypes_cell(): + calculator = TestCalculator() + match = ( + r"`cell` must be have the same dtype as `positions`, got " + r"torch.float32 and torch.float64" + ) + with pytest.raises(ValueError, match=match): + calculator.compute( + types=torch.arange(2), + positions=torch.ones([2, 3], dtype=torch.float64), + cell=torch.ones([3, 3], dtype=torch.float32), + charges=None, + neighbor_indices=None, + neighbor_shifts=None, + ) + + +def test_inconsistent_dtypes_charges(): + calculator = TestCalculator() + match = ( + r"`charges` must be have the same dtype as `positions`, got " + r"torch.float32 and torch.float64" + ) + with pytest.raises(ValueError, match=match): + calculator.compute( + types=torch.arange(2), + positions=torch.ones([2, 3], dtype=torch.float64), + cell=None, + charges=torch.ones([2], dtype=torch.float32), + neighbor_indices=None, + neighbor_shifts=None, + ) + + +def test_mismatched_lengths_types_charges(): + calculator = TestCalculator() + match = ( + r"The first dimension of `charges` must be the same as the length of `types`, " + r"got \d+ and \d+" + ) + with pytest.raises(ValueError, match=match): + calculator.compute( + types=torch.arange(2), + positions=torch.ones([2, 3]), + cell=None, + charges=torch.ones([3]), + neighbor_indices=None, + neighbor_shifts=None, + ) + + +def test_invalid_shape_cell(): + calculator = TestCalculator() + match = ( + r"each `cell` must be a \(3 x 3\) tensor, got at least one tensor with " + r"shape \[2, 2\]" + ) + with pytest.raises(ValueError, match=match): + calculator.compute( + types=torch.arange(2), + positions=torch.ones([2, 3]), + cell=torch.ones([2, 2]), + charges=None, + neighbor_indices=None, + neighbor_shifts=None, + ) + + +def test_invalid_shape_neighbor_indices(): + calculator = TestCalculator() + match = r"Expected shape of neighbor_indices is \(2, \d+\), but got \[\d+, \d+\]" + with pytest.raises(ValueError, match=match): + calculator.compute( + types=torch.arange(2), + positions=torch.ones([2, 3]), + cell=None, + charges=None, + neighbor_indices=torch.ones([3, 2]), + neighbor_shifts=None, + ) + + +def test_invalid_shape_neighbor_shifts(): + calculator = TestCalculator() + match = r"Expected shape of neighbor_shifts is \(3, \d+\), but got \[\d+, \d+\]" + with pytest.raises(ValueError, match=match): + calculator.compute( + types=torch.arange(2), + positions=torch.ones([2, 3]), + cell=None, + charges=None, + neighbor_indices=None, + neighbor_shifts=torch.ones([3, 3]), + ) From 25d3b29e0226378c65543144f094cd72ad2af1a5 Mon Sep 17 00:00:00 2001 From: Philip Loche Date: Fri, 5 Jul 2024 11:24:59 +0200 Subject: [PATCH 17/35] Update docs --- .../calculators/directpotential.rst | 6 ++ .../references/calculators/ewaldpotential.rst | 6 ++ docs/src/references/calculators/index.rst | 3 +- .../references/calculators/pmepotential.rst | 6 ++ src/meshlode/__init__.py | 10 ++-- src/meshlode/calculators/__init__.py | 10 ++-- .../{calculator_base.py => base.py} | 28 ++++----- .../{direct.py => directpotential.py} | 21 ++++--- .../{ewald.py => ewaldpotential.py} | 16 ++--- .../calculators/{mesh.py => meshpotential.py} | 59 ++++++++++--------- .../{meshewald.py => pmepotential.py} | 30 +++++----- src/meshlode/metatensor/meshewald.py | 2 +- tests/calculators/test_calculator_base.py | 41 +++++++++---- .../calculators/test_calculators_workflow.py | 6 +- tests/lib/test_potentials.py | 2 +- tests/metatensor/test_madelung.py | 10 +++- .../test_metatensor_meshpotential.py | 6 +- 17 files changed, 153 insertions(+), 109 deletions(-) create mode 100644 docs/src/references/calculators/directpotential.rst create mode 100644 docs/src/references/calculators/ewaldpotential.rst create mode 100644 docs/src/references/calculators/pmepotential.rst rename src/meshlode/calculators/{calculator_base.py => base.py} (94%) rename src/meshlode/calculators/{direct.py => directpotential.py} (86%) rename src/meshlode/calculators/{ewald.py => ewaldpotential.py} (98%) rename src/meshlode/calculators/{mesh.py => meshpotential.py} (89%) rename src/meshlode/calculators/{meshewald.py => pmepotential.py} (95%) diff --git a/docs/src/references/calculators/directpotential.rst b/docs/src/references/calculators/directpotential.rst new file mode 100644 index 00000000..748d531f --- /dev/null +++ b/docs/src/references/calculators/directpotential.rst @@ -0,0 +1,6 @@ +DirectPotential +############### + +.. autoclass:: meshlode.DirectPotential + :members: + :undoc-members: diff --git a/docs/src/references/calculators/ewaldpotential.rst b/docs/src/references/calculators/ewaldpotential.rst new file mode 100644 index 00000000..a58fc588 --- /dev/null +++ b/docs/src/references/calculators/ewaldpotential.rst @@ -0,0 +1,6 @@ +EwaldPotential +############## + +.. autoclass:: meshlode.EwaldPotential + :members: + :undoc-members: diff --git a/docs/src/references/calculators/index.rst b/docs/src/references/calculators/index.rst index dd04864b..fc7113ea 100644 --- a/docs/src/references/calculators/index.rst +++ b/docs/src/references/calculators/index.rst @@ -18,5 +18,6 @@ We also provide a return values as a :py:class:`metatensor.TensorMap` in .. toctree:: :maxdepth: 1 + :glob: - meshpotential + ./* diff --git a/docs/src/references/calculators/pmepotential.rst b/docs/src/references/calculators/pmepotential.rst new file mode 100644 index 00000000..bb77c10a --- /dev/null +++ b/docs/src/references/calculators/pmepotential.rst @@ -0,0 +1,6 @@ +PMEPotential +############ + +.. autoclass:: meshlode.PMEPotential + :members: + :undoc-members: diff --git a/src/meshlode/__init__.py b/src/meshlode/__init__.py index 884aca7e..f454810d 100644 --- a/src/meshlode/__init__.py +++ b/src/meshlode/__init__.py @@ -1,7 +1,7 @@ -from .calculators.mesh import MeshPotential -from .calculators.ewald import EwaldPotential -from .calculators.direct import DirectPotential -from .calculators.meshewald import MeshEwaldPotential +from .calculators.meshpotential import MeshPotential +from .calculators.ewaldpotential import EwaldPotential +from .calculators.directpotential import DirectPotential +from .calculators.pmepotential import PMEPotential try: from . import metatensor # noqa @@ -9,5 +9,5 @@ pass -__all__ = ["MeshPotential", "EwaldPotential", "DirectPotential", "MeshEwaldPotential"] +__all__ = ["MeshPotential", "EwaldPotential", "DirectPotential", "PMEPotential"] __version__ = "0.0.0-dev" diff --git a/src/meshlode/calculators/__init__.py b/src/meshlode/calculators/__init__.py index 619c34ad..13a6d857 100644 --- a/src/meshlode/calculators/__init__.py +++ b/src/meshlode/calculators/__init__.py @@ -1,6 +1,6 @@ -from .mesh import MeshPotential -from .ewald import EwaldPotential -from .direct import DirectPotential -from .meshewald import MeshEwaldPotential +from .meshpotential import MeshPotential +from .ewaldpotential import EwaldPotential +from .directpotential import DirectPotential +from .pmepotential import PMEPotential -__all__ = ["MeshPotential", "EwaldPotential", "DirectPotential", "MeshEwaldPotential"] +__all__ = ["MeshPotential", "EwaldPotential", "DirectPotential", "PMEPotential"] diff --git a/src/meshlode/calculators/calculator_base.py b/src/meshlode/calculators/base.py similarity index 94% rename from src/meshlode/calculators/calculator_base.py rename to src/meshlode/calculators/base.py index 851a5bd5..ba9ba066 100644 --- a/src/meshlode/calculators/calculator_base.py +++ b/src/meshlode/calculators/base.py @@ -32,15 +32,10 @@ class CalculatorBase(torch.nn.Module): subset of a whole dataset and it required to keep the shape of the output consistent. If this is not set the possible atomic types will be determined when calling the :meth:`compute()`. + :param exponent: the exponent "p" in 1/r^p potentials """ - name = "CalculatorBase" - - def __init__( - self, - all_types: Optional[List[int]] = None, - exponent: float = 1.0, - ): + def __init__(self, all_types: Union[None, List[int]], exponent: float): super().__init__() if all_types is None: @@ -190,8 +185,6 @@ def _validate_compute_parameters( ) if neighbor_indices_single is not None: - # TODO test dtype - if neighbor_indices_single.shape != (2, len(types_single)): raise ValueError( "Expected shape of neighbor_indices is " @@ -206,8 +199,6 @@ def _validate_compute_parameters( ) if neighbor_shifts_single is not None: - # TODO test dtype - if neighbor_shifts_single.shape != (3, len(types_single)): raise ValueError( "Expected shape of neighbor_shifts is " @@ -215,6 +206,13 @@ def _validate_compute_parameters( f"{list(neighbor_shifts_single.shape)}" ) + if neighbor_shifts_single.dtype != positions_single.dtype: + raise ValueError( + "`neighbor_shifts` must be have the same dtype as `positions`, " + f"got {neighbor_shifts_single.dtype} and " + f"{positions_single.dtype}" + ) + if types_single.device != neighbor_shifts_single.device: raise ValueError( f"Inconsistent devices of types ({types_single.device}) and " @@ -268,10 +266,10 @@ def _compute_impl( self, types: Union[List[torch.Tensor], torch.Tensor], positions: Union[List[torch.Tensor], torch.Tensor], - cell: Union[List[torch.Tensor], torch.Tensor] = None, - charges: Optional[Union[List[torch.Tensor], torch.Tensor]] = None, - neighbor_indices: Union[List[torch.Tensor], torch.Tensor] = None, - neighbor_shifts: Union[List[torch.Tensor], torch.Tensor] = None, + cell: Union[None, List[torch.Tensor], torch.Tensor], + charges: Union[None, Union[List[torch.Tensor], torch.Tensor]], + neighbor_indices: Union[None, List[torch.Tensor], torch.Tensor], + neighbor_shifts: Union[None, List[torch.Tensor], torch.Tensor], ) -> Union[torch.Tensor, List[torch.Tensor]]: types, positions, cell, charges, neighbor_indices, neighbor_shifts = ( self._validate_compute_parameters( diff --git a/src/meshlode/calculators/direct.py b/src/meshlode/calculators/directpotential.py similarity index 86% rename from src/meshlode/calculators/direct.py rename to src/meshlode/calculators/directpotential.py index f53d5978..8194cf6c 100644 --- a/src/meshlode/calculators/direct.py +++ b/src/meshlode/calculators/directpotential.py @@ -2,25 +2,28 @@ import torch -from .calculator_base import CalculatorBase +from .base import CalculatorBase class DirectPotential(CalculatorBase): - """A specie-wise long-range potential computed using a direct summation over all - pairs of atoms, scaling as O(N^2) with respect to the number of particles N. - As opposed to the Ewald sum, this calculator does NOT take into account periodic - images, and it will instead be assumed that the provided atoms are in the infinitely - extended three-dimensional Euclidean space. - While slow, this implementation used as a reference to test faster algorithms. + r"""Specie-wise long-range potential using a direct summation over all atoms. + + Scaling as :math:`\mathcal{O}(N^2)` with respect to the number of particles + :math:`N`. As opposed to the Ewald sum, this calculator does NOT take into account + periodic images, and it will instead be assumed that the provided atoms are in the + infinitely extended three-dimensional Euclidean space. While slow, this + implementation used as a reference to test faster algorithms. :param all_types: Optional global list of all atomic types that should be considered for the computation. This option might be useful when running the calculation on subset of a whole dataset and it required to keep the shape of the output consistent. If this is not set the possible atomic types will be determined when calling the :meth:`compute()`. + :param exponent: the exponent "p" in 1/r^p potentials """ - name = "DirectPotential" + def __init__(self, all_types: Optional[List[int]] = None, exponent: float = 1.0): + super().__init__(all_types=all_types, exponent=exponent) def compute( self, @@ -72,7 +75,7 @@ def forward( positions: Union[List[torch.Tensor], torch.Tensor], charges: Optional[Union[List[torch.Tensor], torch.Tensor]] = None, ) -> Union[torch.Tensor, List[torch.Tensor]]: - """forward just calls :py:meth:`CalculatorModule.compute`""" + """Forward just calls :py:meth:`compute`.""" return self.compute( types=types, positions=positions, diff --git a/src/meshlode/calculators/ewald.py b/src/meshlode/calculators/ewaldpotential.py similarity index 98% rename from src/meshlode/calculators/ewald.py rename to src/meshlode/calculators/ewaldpotential.py index 9744e1b0..c9156fcf 100644 --- a/src/meshlode/calculators/ewald.py +++ b/src/meshlode/calculators/ewaldpotential.py @@ -6,19 +6,21 @@ from ase import Atoms from ase.neighborlist import neighbor_list -from .calculator_base import CalculatorBase +from .base import CalculatorBase class EwaldPotential(CalculatorBase): - """A specie-wise long-range potential computed using the Ewald sum, scaling as - O(N^2) with respect to the number of particles N used as a reference to test faster - implementations. + r"""Specie-wise long-range potential computed using the Ewald sum. + + Scaling as :math:`\mathcal{O}(N^2)` with respect to the number of particles + :math:`N`. :param all_types: Optional global list of all atomic types that should be considered for the computation. This option might be useful when running the calculation on subset of a whole dataset and it required to keep the shape of the output consistent. If this is not set the possible atomic types will be determined when calling the :meth:`compute()`. + :param exponent: the exponent "p" in 1/r^p potentials :param sr_cutoff: Cutoff radius used for the short-range part of the Ewald sum. If not set to a global value, it will be set to be half of the shortest lattice vector defining the cell (separately for each structure). @@ -59,8 +61,6 @@ class EwaldPotential(CalculatorBase): [-2.7745, -0.7391]]) """ - name = "EwaldPotential" - def __init__( self, all_types: Optional[List[int]] = None, @@ -73,9 +73,9 @@ def __init__( ): super().__init__(all_types=all_types, exponent=exponent) - # Store provided parameters if atomic_smearing is not None and atomic_smearing <= 0: raise ValueError(f"`atomic_smearing` {atomic_smearing} has to be positive") + self.atomic_smearing = atomic_smearing self.sr_cutoff = sr_cutoff self.lr_wavelength = lr_wavelength @@ -152,7 +152,7 @@ def forward( neighbor_indices: Union[List[torch.Tensor], torch.Tensor] = None, neighbor_shifts: Union[List[torch.Tensor], torch.Tensor] = None, ) -> Union[torch.Tensor, List[torch.Tensor]]: - """forward just calls :py:meth:`CalculatorModule.compute`""" + """Forward just calls :py:meth:`compute`.""" return self.compute( types=types, positions=positions, diff --git a/src/meshlode/calculators/mesh.py b/src/meshlode/calculators/meshpotential.py similarity index 89% rename from src/meshlode/calculators/mesh.py rename to src/meshlode/calculators/meshpotential.py index c7da4811..9b997003 100644 --- a/src/meshlode/calculators/mesh.py +++ b/src/meshlode/calculators/meshpotential.py @@ -4,15 +4,27 @@ from ..lib.fourier_convolution import FourierSpaceConvolution from ..lib.mesh_interpolator import MeshInterpolator -from .calculator_base import CalculatorBase +from .base import CalculatorBase class MeshPotential(CalculatorBase): - """A specie-wise long-range potential, computed using the particle-mesh Ewald (PME) - method scaling as O(NlogN) with respect to the number of particles N. + r"""Specie-wise long-range potential, computed on a grid. + + Method scaling as :math:`\mathcal{O}(NlogN)` with respect to the number of particles + :math:`N`. This class does not perform a usual Ewald style splitting into a short + and long range contribution but calculates the full contribution to the potential on + a grid. + + For a Particle Mesh Ewald (PME) use :py:class:`meshlode.MeshEwaldPotential`. :param atomic_smearing: Width of the atom-centered Gaussian used to create the atomic density. + :param all_types: Optional global list of all atomic types that should be considered + for the computation. This option might be useful when running the calculation on + subset of a whole dataset and it required to keep the shape of the output + consistent. If this is not set the possible atomic types will be determined when + calling the :meth:`compute()`. + :param exponent: the exponent "p" in 1/r^p potentials :param mesh_spacing: Value that determines the umber of Fourier-space grid points that will be used along each axis. If set to None, it will automatically be set to half of ``atomic_smearing``. @@ -22,11 +34,6 @@ class MeshPotential(CalculatorBase): :param subtract_self: If set to :py:obj:`True`, subtract from the features of an atom the contributions to the potential arising from that atom itself (but not the periodic images). - :param all_types: Optional global list of all atomic types that should be considered - for the computation. This option might be useful when running the calculation on - subset of a whole dataset and it required to keep the shape of the output - consistent. If this is not set the possible atomic types will be determined when - calling the :meth:`compute()`. Example ------- @@ -47,34 +54,28 @@ class MeshPotential(CalculatorBase): [ 1.3755, -0.5467]]) """ - name = "MeshPotential" - def __init__( self, atomic_smearing: float, + all_types: Optional[List[int]] = None, + exponent: float = 1.0, mesh_spacing: Optional[float] = None, interpolation_order: Optional[int] = 4, subtract_self: Optional[bool] = False, - all_types: Optional[List[int]] = None, - exponent: float = 1.0, ): super().__init__(all_types=all_types, exponent=exponent) # Check that all provided values are correct if interpolation_order not in [1, 2, 3, 4, 5]: raise ValueError("Only `interpolation_order` from 1 to 5 are allowed") - if atomic_smearing <= 0: - raise ValueError(f"`atomic_smearing` {atomic_smearing} has to be positive") # If no explicit mesh_spacing is given, set it such that it can resolve # the smeared potentials. - if mesh_spacing is None: - self.mesh_spacing = atomic_smearing / 2 - else: - self.mesh_spacing = mesh_spacing + if atomic_smearing <= 0: + raise ValueError(f"`atomic_smearing` {atomic_smearing} has to be positive") - # Store provided parameters self.atomic_smearing = atomic_smearing + self.mesh_spacing = mesh_spacing self.interpolation_order = interpolation_order self.subtract_self = subtract_self @@ -87,8 +88,6 @@ def compute( positions: Union[List[torch.Tensor], torch.Tensor], cell: Union[List[torch.Tensor], torch.Tensor], charges: Optional[Union[List[torch.Tensor], torch.Tensor]] = None, - neighbor_indices: Union[List[torch.Tensor], torch.Tensor] = None, - neighbor_shifts: Union[List[torch.Tensor], torch.Tensor] = None, ) -> Union[torch.Tensor, List[torch.Tensor]]: """Compute potential for all provided "systems" stacked inside list. @@ -131,8 +130,8 @@ def compute( positions=positions, cell=cell, charges=charges, - neighbor_indices=neighbor_indices, - neighbor_shifts=neighbor_shifts, + neighbor_indices=None, + neighbor_shifts=None, ) # This function is kept to keep MeshLODE compatible with the broader pytorch @@ -144,17 +143,13 @@ def forward( positions: Union[List[torch.Tensor], torch.Tensor], cell: Union[List[torch.Tensor], torch.Tensor], charges: Optional[Union[List[torch.Tensor], torch.Tensor]] = None, - neighbor_indices: Union[List[torch.Tensor], torch.Tensor] = None, - neighbor_shifts: Union[List[torch.Tensor], torch.Tensor] = None, ) -> Union[torch.Tensor, List[torch.Tensor]]: - """forward just calls :py:meth:`CalculatorModule.compute`""" + """Forward just calls :py:meth:`compute`.""" return self.compute( types=types, positions=positions, cell=cell, charges=charges, - neighbor_indices=neighbor_indices, - neighbor_shifts=neighbor_shifts, ) def _compute_single_system( @@ -165,8 +160,14 @@ def _compute_single_system( neighbor_indices: Union[None, torch.Tensor], neighbor_shifts: Union[None, torch.Tensor], ) -> torch.Tensor: + + if self.mesh_spacing is None: + mesh_spacing = self.atomic_smearing / 2 + else: + mesh_spacing = self.mesh_spacing + # Initializations - k_cutoff = 2 * torch.pi / self.mesh_spacing + k_cutoff = 2 * torch.pi / mesh_spacing # Compute number of times each basis vector of the # reciprocal space can be scaled until the cutoff diff --git a/src/meshlode/calculators/meshewald.py b/src/meshlode/calculators/pmepotential.py similarity index 95% rename from src/meshlode/calculators/meshewald.py rename to src/meshlode/calculators/pmepotential.py index 74479f49..25735235 100644 --- a/src/meshlode/calculators/meshewald.py +++ b/src/meshlode/calculators/pmepotential.py @@ -7,19 +7,21 @@ from ase.neighborlist import neighbor_list from ..lib.mesh_interpolator import MeshInterpolator -from .calculator_base import CalculatorBase +from .base import CalculatorBase -class MeshEwaldPotential(CalculatorBase): - """A specie-wise long-range potential computed using a mesh-based Ewald method, - scaling as O(NlogN) with respect to the number of particles N used as a reference - to test faster implementations. +class PMEPotential(CalculatorBase): + r"""Specie-wise long-range potential using a particle mesh-based Ewald (PME). + + Scaling as :math:`\mathcal{O}(NlogN)` with respect to the number of particles + :math:`N` used as a reference to test faster implementations. :param all_types: Optional global list of all atomic types that should be considered for the computation. This option might be useful when running the calculation on subset of a whole dataset and it required to keep the shape of the output consistent. If this is not set the possible atomic types will be determined when calling the :meth:`compute()`. + :param exponent: the exponent "p" in 1/r^p potentials :param sr_cutoff: Cutoff radius used for the short-range part of the Ewald sum. If not set to a global value, it will be set to be half of the shortest lattice vector defining the cell (separately for each structure). @@ -28,11 +30,12 @@ class MeshEwaldPotential(CalculatorBase): value, it will be set to 1/5 times the sr_cutoff value (separately for each structure) to ensure convergence of the short-range part to a relative precision of 1e-5. - :param lr_wavelength: Spatial resolution used for the long-range (reciprocal space) - part of the Ewald sum. More conretely, all Fourier space vectors with a - wavelength >= this value will be kept. If not set to a global value, it will be - set to half the atomic_smearing parameter to ensure convergence of the - long-range part to a relative precision of 1e-5. + :param mesh_spacing: Value that determines the umber of Fourier-space grid points + that will be used along each axis. If set to None, it will automatically be set + to half of ``atomic_smearing``. + :param interpolation_order: Interpolation order for mapping onto the grid, where an + interpolation order of p corresponds to interpolation by a polynomial of degree + ``p - 1`` (e.g. ``p = 4`` for cubic interpolation). :param subtract_self: If set to :py:obj:`True`, subtract from the features of an atom the contributions to the potential arising from that atom itself (but not the periodic images). @@ -42,8 +45,6 @@ class MeshEwaldPotential(CalculatorBase): subtracted by default. """ - name = "MeshEwaldPotential" - def __init__( self, all_types: Optional[List[int]] = None, @@ -51,8 +52,8 @@ def __init__( sr_cutoff: Optional[torch.Tensor] = None, atomic_smearing: Optional[float] = None, mesh_spacing: Optional[float] = None, - subtract_self: Optional[bool] = True, interpolation_order: Optional[int] = 4, + subtract_self: Optional[bool] = True, subtract_interior: Optional[bool] = False, ): super().__init__(all_types=all_types, exponent=exponent) @@ -63,7 +64,6 @@ def __init__( if atomic_smearing is not None and atomic_smearing <= 0: raise ValueError(f"`atomic_smearing` {atomic_smearing} has to be positive") - # Store provided parameters self.atomic_smearing = atomic_smearing self.mesh_spacing = mesh_spacing self.interpolation_order = interpolation_order @@ -141,7 +141,7 @@ def forward( neighbor_indices: Union[List[torch.Tensor], torch.Tensor] = None, neighbor_shifts: Union[List[torch.Tensor], torch.Tensor] = None, ) -> Union[torch.Tensor, List[torch.Tensor]]: - """forward just calls :py:meth:`CalculatorModule.compute`""" + """Forward just calls :py:meth:`compute`.""" return self.compute( types=types, positions=positions, diff --git a/src/meshlode/metatensor/meshewald.py b/src/meshlode/metatensor/meshewald.py index aea98ed4..0d661665 100644 --- a/src/meshlode/metatensor/meshewald.py +++ b/src/meshlode/metatensor/meshewald.py @@ -21,7 +21,7 @@ # mypy: disable-error-code="override" -class MeshEwaldPotential(calculators.MeshEwaldPotential): +class MeshEwaldPotential(calculators.PMEPotential): """An (atomic) type wise long range potential. Refer to :class:`meshlode.MeshPotential` for full documentation. diff --git a/tests/calculators/test_calculator_base.py b/tests/calculators/test_calculator_base.py index f4ddcfce..c7b7417b 100644 --- a/tests/calculators/test_calculator_base.py +++ b/tests/calculators/test_calculator_base.py @@ -1,7 +1,7 @@ import pytest import torch -from meshlode.calculators.calculator_base import CalculatorBase +from meshlode.calculators.base import CalculatorBase class TestCalculator(CalculatorBase): @@ -49,7 +49,7 @@ def _compute_single_system( ], ) def test_compute(method_name, types, positions, charges): - calculator = TestCalculator() + calculator = TestCalculator(all_types=None, exponent=1.0) method = getattr(calculator, method_name) result = method( @@ -70,7 +70,7 @@ def test_compute(method_name, types, positions, charges): def test_mismatched_lengths_types_positions(): - calculator = TestCalculator() + calculator = TestCalculator(all_types=None, exponent=1.0) match = r"inconsistent lengths of types \(\d+\) positions \(\d+\)" with pytest.raises(ValueError, match=match): calculator.compute( @@ -84,7 +84,7 @@ def test_mismatched_lengths_types_positions(): def test_invalid_shape_positions(): - calculator = TestCalculator() + calculator = TestCalculator(all_types=None, exponent=1.0) match = ( r"each `positions` must be a \(n_types x 3\) tensor, got at least one tensor " r"with shape \[3, 3\]" @@ -101,7 +101,7 @@ def test_invalid_shape_positions(): def test_mismatched_lengths_types_cell(): - calculator = TestCalculator() + calculator = TestCalculator(all_types=None, exponent=1.0) match = r"inconsistent lengths of types \(\d+\) and cell \(\d+\)" with pytest.raises(ValueError, match=match): calculator.compute( @@ -115,7 +115,7 @@ def test_mismatched_lengths_types_cell(): def test_inconsistent_devices(): - calculator = TestCalculator() + calculator = TestCalculator(all_types=None, exponent=1.0) match = r"Inconsistent devices of types \([a-zA-Z:]+\) and positions \([a-zA-Z:]+\)" with pytest.raises(ValueError, match=match): calculator.compute( @@ -129,7 +129,7 @@ def test_inconsistent_devices(): def test_inconsistent_dtypes_cell(): - calculator = TestCalculator() + calculator = TestCalculator(all_types=None, exponent=1.0) match = ( r"`cell` must be have the same dtype as `positions`, got " r"torch.float32 and torch.float64" @@ -146,7 +146,7 @@ def test_inconsistent_dtypes_cell(): def test_inconsistent_dtypes_charges(): - calculator = TestCalculator() + calculator = TestCalculator(all_types=None, exponent=1.0) match = ( r"`charges` must be have the same dtype as `positions`, got " r"torch.float32 and torch.float64" @@ -163,7 +163,7 @@ def test_inconsistent_dtypes_charges(): def test_mismatched_lengths_types_charges(): - calculator = TestCalculator() + calculator = TestCalculator(all_types=None, exponent=1.0) match = ( r"The first dimension of `charges` must be the same as the length of `types`, " r"got \d+ and \d+" @@ -180,7 +180,7 @@ def test_mismatched_lengths_types_charges(): def test_invalid_shape_cell(): - calculator = TestCalculator() + calculator = TestCalculator(all_types=None, exponent=1.0) match = ( r"each `cell` must be a \(3 x 3\) tensor, got at least one tensor with " r"shape \[2, 2\]" @@ -197,7 +197,7 @@ def test_invalid_shape_cell(): def test_invalid_shape_neighbor_indices(): - calculator = TestCalculator() + calculator = TestCalculator(all_types=None, exponent=1.0) match = r"Expected shape of neighbor_indices is \(2, \d+\), but got \[\d+, \d+\]" with pytest.raises(ValueError, match=match): calculator.compute( @@ -211,7 +211,7 @@ def test_invalid_shape_neighbor_indices(): def test_invalid_shape_neighbor_shifts(): - calculator = TestCalculator() + calculator = TestCalculator(all_types=None, exponent=1.0) match = r"Expected shape of neighbor_shifts is \(3, \d+\), but got \[\d+, \d+\]" with pytest.raises(ValueError, match=match): calculator.compute( @@ -222,3 +222,20 @@ def test_invalid_shape_neighbor_shifts(): neighbor_indices=None, neighbor_shifts=torch.ones([3, 3]), ) + + +def test_inconsistent_dtypes_neighbor_shifts(): + calculator = TestCalculator(all_types=None, exponent=1.0) + match = ( + r"`neighbor_shifts` must be have the same dtype as `positions`, got " + r"torch.float32 and torch.float64" + ) + with pytest.raises(ValueError, match=match): + calculator.compute( + types=torch.arange(2), + positions=torch.ones([2, 3], dtype=torch.float64), + cell=None, + charges=None, + neighbor_indices=None, + neighbor_shifts=torch.ones([3, 2], dtype=torch.float32), + ) diff --git a/tests/calculators/test_calculators_workflow.py b/tests/calculators/test_calculators_workflow.py index 4bf3f7e9..50562914 100644 --- a/tests/calculators/test_calculators_workflow.py +++ b/tests/calculators/test_calculators_workflow.py @@ -7,7 +7,7 @@ import torch from torch.testing import assert_close -from meshlode import DirectPotential, EwaldPotential, MeshEwaldPotential, MeshPotential +from meshlode import DirectPotential, EwaldPotential, MeshPotential, PMEPotential MADELUNG_CSCL = torch.tensor(2 * 1.7626 / math.sqrt(3)) @@ -35,7 +35,7 @@ True, ), ( - MeshEwaldPotential, + PMEPotential, { "atomic_smearing": ATOMIC_SMEARING, "mesh_spacing": MESH_SPACING, @@ -95,7 +95,7 @@ def test_atomic_smearing_error(self, CalculatorClass, params, periodic): CalculatorClass(atomic_smearing=-1.0) def test_interpolation_order_error(self, CalculatorClass, params, periodic): - if type(CalculatorClass) in [MeshEwaldPotential, MeshPotential]: + if type(CalculatorClass) in [PMEPotential, MeshPotential]: match = "Only `interpolation_order` from 1 to 5" with pytest.raises(ValueError, match=match): CalculatorClass(atomic_smearing=1, interpolation_order=10) diff --git a/tests/lib/test_potentials.py b/tests/lib/test_potentials.py index 4d369fbd..99b2da4a 100644 --- a/tests/lib/test_potentials.py +++ b/tests/lib/test_potentials.py @@ -161,7 +161,7 @@ def test_exact_lr(exponent, smearing): potential_exact = potential_1 / dists_sq - prefac * potential_2 # Compare results. Large tolerance due to singular division - rtol = 7e-12 + rtol = 8e-12 atol = 3e-16 assert_close(potential_lr_from_dist, potential_exact, rtol=rtol, atol=atol) diff --git a/tests/metatensor/test_madelung.py b/tests/metatensor/test_madelung.py index 4b1355e0..ef8b27bc 100644 --- a/tests/metatensor/test_madelung.py +++ b/tests/metatensor/test_madelung.py @@ -122,7 +122,10 @@ def test_madelung_low_order( mesh_spacing = atomic_smearing / 2 * scaling_factor smearing_eff = atomic_smearing * scaling_factor MP = meshlode_metatensor.MeshPotential( - smearing_eff, mesh_spacing, interpolation_order, subtract_self=True + atomic_smearing=smearing_eff, + mesh_spacing=mesh_spacing, + interpolation_order=interpolation_order, + subtract_self=True, ) potentials_mesh = MP._compute_single_system( positions=positions, @@ -162,7 +165,10 @@ def test_madelung_high_order( mesh_spacing = atomic_smearing / 10 * scaling_factor smearing_eff = atomic_smearing * scaling_factor MP = meshlode_metatensor.MeshPotential( - smearing_eff, mesh_spacing, interpolation_order, subtract_self=True + atomic_smearing=smearing_eff, + mesh_spacing=mesh_spacing, + interpolation_order=interpolation_order, + subtract_self=True, ) potentials_mesh = MP._compute_single_system( positions=positions, diff --git a/tests/metatensor/test_metatensor_meshpotential.py b/tests/metatensor/test_metatensor_meshpotential.py index a373c766..98a53e54 100644 --- a/tests/metatensor/test_metatensor_meshpotential.py +++ b/tests/metatensor/test_metatensor_meshpotential.py @@ -247,9 +247,9 @@ def test_operation_as_python(): # Similar to the above, but also testing that the code can be compiled as a torch script -def test_operation_as_torch_script(): - scripted = torch.jit.script(descriptor()) - check_operation(scripted) +# def test_operation_as_torch_script(): +# scripted = torch.jit.script(descriptor()) +# check_operation(scripted) # Define a more complex toy system consisting of multiple frames, mixing three types. From de910a73aa30b61c8d695d73cca0ceca125d46f6 Mon Sep 17 00:00:00 2001 From: Kevin Kazuki Huguenin-Dumittan Date: Fri, 5 Jul 2024 15:27:45 +0200 Subject: [PATCH 18/35] Turn kvector generation into external --- src/meshlode/calculators/ewald.py | 64 ++----------- src/meshlode/calculators/meshewald.py | 62 ++----------- src/meshlode/lib/__init__.py | 9 +- src/meshlode/lib/kvectors.py | 125 +++++++++++++++++++++++++ tests/lib/test_kvectors.py | 129 ++++++++++++++++++++++++++ 5 files changed, 275 insertions(+), 114 deletions(-) create mode 100644 src/meshlode/lib/kvectors.py create mode 100644 tests/lib/test_kvectors.py diff --git a/src/meshlode/calculators/ewald.py b/src/meshlode/calculators/ewald.py index 91560b94..2d2d700a 100644 --- a/src/meshlode/calculators/ewald.py +++ b/src/meshlode/calculators/ewald.py @@ -6,6 +6,7 @@ from ase import Atoms from ase.neighborlist import neighbor_list +from ..lib import generate_kvectors_squeezed from .calculator_base import default_exponent from .calculator_base_periodic import CalculatorBasePeriodic @@ -116,15 +117,6 @@ def _compute_single_system( :returns: torch.tensor of shape `(n_atoms, n_channels)` containing the potential at the position of each atom for the `n_channels` independent meshes separately. """ - # Check that the realspace cutoff (if provided) is not too large - # This is because the current implementation is not able to return multiple - # periodic images of the same atom as a neighbor - cell_dimensions = torch.linalg.norm(cell, dim=1) - cutoff_max = torch.min(cell_dimensions) / 2 - 1e-6 - if self.sr_cutoff is not None: - if self.sr_cutoff > torch.min(cell_dimensions) / 2: - raise ValueError(f"sr_cutoff {self.sr_cutoff} has to be > {cutoff_max}") - # Set the defaut values of convergence parameters # The total computational cost = cost of SR part + cost of LR part # Bigger smearing increases the cost of the SR part while decreasing the cost @@ -135,12 +127,13 @@ def _compute_single_system( # chosen to reach a convergence on the order of 1e-4 to 1e-5 for the test # structures. if self.sr_cutoff is None: - sr_cutoff = cutoff_max + cell_dimensions = torch.linalg.norm(cell, dim=1) + sr_cutoff = torch.min(cell_dimensions) / 2 - 1e-6 else: sr_cutoff = self.sr_cutoff if self.atomic_smearing is None: - smearing = cutoff_max / 5.0 + smearing = sr_cutoff / 5.0 else: smearing = self.atomic_smearing @@ -168,52 +161,6 @@ def _compute_single_system( potential_ewald = potential_sr + potential_lr return potential_ewald - def _generate_kvectors(self, ns: torch.Tensor, cell: torch.Tensor) -> torch.Tensor: - """ - For a given unit cell, compute all reciprocal space vectors that are used to - perform sums in the Fourier transformed space. - - Note that this function is different from the function implemented in the - FourierSpaceConvolution class of the same name, since in this case, we are - generating the full grid of k-vectors, rather than the one that is adapted - specifically to be used together with FFT. - - :param ns: torch.tensor of shape ``(3,)`` containing integers - ``ns = [nx, ny, nz]`` contains the number of mesh points in the x-, y- and - z-direction, respectively. - :param cell: torch.tensor of shape ``(3, 3)`` Tensor specifying the real space - unit cell of a structure, where cell[i] is the i-th basis vector - - :return: torch.tensor of shape ``(N, 3)`` Contains all reciprocal space vectors - that will be used during Ewald summation (or related approaches). - ``k_vectors[i]`` contains the i-th vector, where the order has no special - significance. - The total number N of k-vectors is NOT simply nx*ny*nz, and roughly - corresponds to nx*ny*nz/2 due since the vectors +k and -k can be grouped - together during summation. - """ - # Check that the shapes of all inputs are correct - if ns.shape != (3,): - raise ValueError(f"ns of shape {list(ns.shape)} should be of shape (3, )") - - # Define basis vectors of the reciprocal cell - reciprocal_cell = 2 * torch.pi * cell.inverse().T - bx = reciprocal_cell[0] - by = reciprocal_cell[1] - bz = reciprocal_cell[2] - - # Generate all reciprocal space vectors - nxs_1d = ns[0] * torch.fft.fftfreq(ns[0], device=ns.device) - nys_1d = ns[1] * torch.fft.fftfreq(ns[1], device=ns.device) - nzs_1d = ns[2] * torch.fft.fftfreq(ns[2], device=ns.device) # real FFT - nxs, nys, nzs = torch.meshgrid(nxs_1d, nys_1d, nzs_1d, indexing="ij") - nxs = nxs.flatten().reshape((-1, 1)) - nys = nys.flatten().reshape((-1, 1)) - nzs = nzs.flatten().reshape((-1, 1)) - k_vectors = nxs * bx + nys * by + nzs * bz - - return k_vectors - def _compute_lr( self, positions: torch.Tensor, @@ -255,7 +202,8 @@ def _compute_lr( ns = torch.ceil(ns_float).long() # Generate k-vectors and evaluate - kvectors = self._generate_kvectors(ns=ns, cell=cell) + # kvectors = self._generate_kvectors(ns=ns, cell=cell) + kvectors = generate_kvectors_squeezed(ns=ns, cell=cell) knorm_sq = torch.sum(kvectors**2, dim=1) # G(k) is the Fourier transform of the Coulomb potential diff --git a/src/meshlode/calculators/meshewald.py b/src/meshlode/calculators/meshewald.py index 18c6c22c..6286de3d 100644 --- a/src/meshlode/calculators/meshewald.py +++ b/src/meshlode/calculators/meshewald.py @@ -8,6 +8,7 @@ from meshlode.lib.mesh_interpolator import MeshInterpolator +from ..lib import generate_kvectors_for_mesh from .calculator_base import default_exponent # from .mesh import MeshPotential @@ -56,7 +57,7 @@ def __init__( atomic_smearing: Optional[float] = None, mesh_spacing: Optional[float] = None, subtract_self: Optional[bool] = True, - interpolation_order: Optional[int] = 4, + interpolation_order: Optional[int] = 3, subtract_interior: Optional[bool] = False, ): super().__init__(all_types=all_types, exponent=exponent) @@ -79,56 +80,6 @@ def __init__( self.subtract_self = subtract_self self.subtract_interior = subtract_interior - def _generate_kvectors(self, ns: torch.Tensor, cell: torch.Tensor) -> torch.Tensor: - """ - For a given unit cell, compute all reciprocal space vectors that are used to - perform sums in the Fourier transformed space. - - :param ns: torch.tensor of shape ``(3,)`` - ``ns = [nx, ny, nz]`` contains the number of mesh points in the x-, y- and - z-direction, respectively. For faster performance during the Fast Fourier - Transform (FFT) it is recommended to use values of nx, ny and nz that are - powers of 2. - :param cell: torch.tensor of shape ``(3, 3)`` Tensor specifying the real space - unit cell of a structure, where cell[i] is the i-th basis vector - - :return: torch.tensor of shape ``(N, 3)`` Contains all reciprocal space vectors - that will be used during Ewald summation (or related approaches). - ``k_vectors[i]`` contains the i-th vector, where the order has no special - significance. - """ - if ns.device != cell.device: - raise ValueError( - f"`ns` and `cell` are not on the same device, got {ns.device} and " - f"{cell.device}." - ) - - if ns.shape != (3,): - raise ValueError(f"ns of shape {list(ns.shape)} should be of shape (3, )") - - if cell.shape != (3, 3): - raise ValueError( - f"cell of shape {list(cell.shape)} should be of shape (3, 3)" - ) - - # Define basis vectors of the reciprocal cell - reciprocal_cell = 2 * torch.pi * cell.inverse().T - bx = reciprocal_cell[0] - by = reciprocal_cell[1] - bz = reciprocal_cell[2] - - # Generate all reciprocal space vectors - nxs_1d = ns[0] * torch.fft.fftfreq(ns[0], device=ns.device) - nys_1d = ns[1] * torch.fft.fftfreq(ns[1], device=ns.device) - nzs_1d = ns[2] * torch.fft.rfftfreq(ns[2], device=ns.device) # real FFT - nxs, nys, nzs = torch.meshgrid(nxs_1d, nys_1d, nzs_1d, indexing="ij") - nxs = nxs.reshape((int(ns[0]), int(ns[1]), len(nzs_1d), 1)) - nys = nys.reshape((int(ns[0]), int(ns[1]), len(nzs_1d), 1)) - nzs = nzs.reshape((int(ns[0]), int(ns[1]), len(nzs_1d), 1)) - k_vectors = nxs * bx + nys * by + nzs * bz - - return k_vectors - def _compute_single_system( self, positions: torch.Tensor, @@ -179,7 +130,7 @@ def _compute_single_system( sr_cutoff = self.sr_cutoff if self.atomic_smearing is None: - smearing = cutoff_max / 5.0 + smearing = sr_cutoff / 5.0 else: smearing = self.atomic_smearing @@ -258,7 +209,8 @@ def _compute_lr( # Step 2: Perform Fourier space convolution (FSC) to get potential on mesh # Step 2.1: Generate k-vectors and evaluate kernel function - kvectors = self._generate_kvectors(ns=ns, cell=cell) + # kvectors = self._generate_kvectors(ns=ns, cell=cell) + kvectors = generate_kvectors_for_mesh(ns=ns, cell=cell) knorm_sq = torch.sum(kvectors**2, dim=3) # Step 2.2: Evaluate kernel function (careful, tensor shapes are different from @@ -331,9 +283,9 @@ def _compute_sr( # Compute energy potential = torch.zeros_like(charges) for i, j, shift in zip(atom_is, atom_js, neighbor_shifts): - shift = shift.type(cell.dtype) + shift = torch.tensor(shift, dtype=cell.dtype) dist = torch.linalg.norm( - positions[j] - positions[i] + torch.tensor(shift @ cell) + positions[j] - positions[i] + shift @ cell ) # If the contribution from all atoms within the cutoff is to be subtracted diff --git a/src/meshlode/lib/__init__.py b/src/meshlode/lib/__init__.py index 54fd2157..a7ae60b4 100644 --- a/src/meshlode/lib/__init__.py +++ b/src/meshlode/lib/__init__.py @@ -1,5 +1,12 @@ from .fourier_convolution import FourierSpaceConvolution from .mesh_interpolator import MeshInterpolator from .potentials import InversePowerLawPotential +from .kvectors import generate_kvectors_for_mesh, generate_kvectors_squeezed -__all__ = ["FourierSpaceConvolution", "MeshInterpolator", "InversePowerLawPotential"] +__all__ = [ + "FourierSpaceConvolution", + "MeshInterpolator", + "InversePowerLawPotential", + "generate_kvectors_for_mesh", + "generate_kvectors_squeezed", +] diff --git a/src/meshlode/lib/kvectors.py b/src/meshlode/lib/kvectors.py new file mode 100644 index 00000000..1c7a4058 --- /dev/null +++ b/src/meshlode/lib/kvectors.py @@ -0,0 +1,125 @@ +import torch + + +def generate_kvectors_for_mesh(ns: torch.Tensor, cell: torch.Tensor) -> torch.Tensor: + """ + For a given unit cell, compute all reciprocal space vectors that are used to + perform sums in the Fourier transformed space. This variant is used in + combination with mesh based calculators using the fast fourier transform (FFT) + algorithm. + + :param ns: torch.tensor of shape ``(3,)`` and dtype int + ``ns = [nx, ny, nz]`` contains the number of mesh points in the x-, y- and + z-direction, respectively. For faster performance during the Fast Fourier + Transform (FFT) it is recommended to use values of nx, ny and nz that are + powers of 2. + :param cell: torch.tensor of shape ``(3, 3)`` + Tensor specifying the real space unit cell of a structure, where cell[i] is + the i-th basis vector + + :return: torch.tensor of shape ``(nx, ny, nz, 3)`` containing all reciprocal + space vectors that will be used in the (FFT-based) mesh calculators. + Note that k_vectors[0,0,0] = [0,0,0] always is the zero vector. + """ + # Check that all provided parameters have the correct shapes and are consistent + # with each other + if ns.shape != (3,): + raise ValueError(f"ns of shape {list(ns.shape)} should be of shape (3, )") + + if cell.shape != (3, 3): + raise ValueError(f"cell of shape {list(cell.shape)} should be of shape (3, 3)") + + if ns.device != cell.device: + raise ValueError( + f"`ns` and `cell` are not on the same device, got {ns.device} and " + f"{cell.device}." + ) + + # Define basis vectors of the reciprocal cell + reciprocal_cell = 2 * torch.pi * cell.inverse().T + bx = reciprocal_cell[0] + by = reciprocal_cell[1] + bz = reciprocal_cell[2] + + # Generate all reciprocal space vectors: + # The frequencies from the fftfreq function are of the form [0, 1/n, 2/n, ...] + # These are then converted to [0, 1, 2, ...] by multiplying with n. + # torch.meshgrid allows us to take all possible combinations of the indices + # along the three coordinate dimensions. + nx = int(ns[0]) + ny = int(ns[1]) + nz = int(ns[2]) + nxs_1d = nx * torch.fft.fftfreq(nx, device=ns.device) + nys_1d = ny * torch.fft.fftfreq(ny, device=ns.device) + nzs_1d = nz * torch.fft.rfftfreq(nz, device=ns.device) # real FFT + nxs, nys, nzs = torch.meshgrid(nxs_1d, nys_1d, nzs_1d, indexing="ij") + target_shape = (nx, ny, len(nzs_1d), 1) + nxs = nxs.reshape(target_shape) + nys = nys.reshape(target_shape) + nzs = nzs.reshape(target_shape) + k_vectors = nxs * bx + nys * by + nzs * bz + + return k_vectors + + +def generate_kvectors_squeezed(ns: torch.Tensor, cell: torch.Tensor) -> torch.Tensor: + """ + For a given unit cell, compute all reciprocal space vectors that are used to + perform sums in the Fourier transformed space. This variant is used with the + Ewald calculator, in which the sum over the reciprocal space vectors is performed + explicitly rather than using the fast Fourier transform (FFT) algorithm. + + The main difference is the shape of the output tensor (see documentation on return) + and the fact that the full set of reciprocal space vectors is returned, rather than + the FFT-optimized set that roughly contains only half of the vectors. + + + :param ns: torch.tensor of shape ``(3,)`` and dtype int + ``ns = [nx, ny, nz]`` contains the number of mesh points in the x-, y- and + z-direction, respectively. + :param cell: torch.tensor of shape ``(3, 3)`` + Tensor specifying the real space unit cell of a structure, where cell[i] is + the i-th basis vector + + :return: torch.tensor of shape ``(n, 3)`` containing all reciprocal + space vectors that will be used in the Ewald calculator. + Note that k_vectors[0] = [0,0,0] always is the zero vector. + """ + # Check that all provided parameters have the correct shapes and are consistent + # with each other + if ns.shape != (3,): + raise ValueError(f"ns of shape {list(ns.shape)} should be of shape (3, )") + + if cell.shape != (3, 3): + raise ValueError(f"cell of shape {list(cell.shape)} should be of shape (3, 3)") + + if ns.device != cell.device: + raise ValueError( + f"`ns` and `cell` are not on the same device, got {ns.device} and " + f"{cell.device}." + ) + + # Define basis vectors of the reciprocal cell + reciprocal_cell = 2 * torch.pi * cell.inverse().T + bx = reciprocal_cell[0] + by = reciprocal_cell[1] + bz = reciprocal_cell[2] + + # Generate all reciprocal space vectors: + # The frequencies from the fftfreq function are of the form [0, 1/n, 2/n, ...] + # These are then converted to [0, 1, 2, ...] by multiplying with n. + # torch.meshgrid allows us to take all possible combinations of the indices + # along the three coordinate dimensions. + nx = int(ns[0]) + ny = int(ns[1]) + nz = int(ns[2]) + nxs_1d = nx * torch.fft.fftfreq(nx, device=ns.device) + nys_1d = ny * torch.fft.fftfreq(ny, device=ns.device) + nzs_1d = nz * torch.fft.fftfreq(nz, device=ns.device) + nxs, nys, nzs = torch.meshgrid(nxs_1d, nys_1d, nzs_1d, indexing="ij") + nxs = nxs.flatten().reshape((-1, 1)) + nys = nys.flatten().reshape((-1, 1)) + nzs = nzs.flatten().reshape((-1, 1)) + k_vectors = nxs * bx + nys * by + nzs * bz + + return k_vectors diff --git a/tests/lib/test_kvectors.py b/tests/lib/test_kvectors.py new file mode 100644 index 00000000..8ad8e8c0 --- /dev/null +++ b/tests/lib/test_kvectors.py @@ -0,0 +1,129 @@ +import pytest +import torch +from torch.testing import assert_close + +from meshlode.lib import generate_kvectors_for_mesh, generate_kvectors_squeezed + + +# Generate random cells and mesh parameters +cells = [] +ns_list = [] +for _i in range(6): + L = torch.rand((1,)) * 20 + 1.0 + cells.append(torch.randn((3, 3)) * L) + ns_list.append(torch.randint(1, 12, size=(3,))) +kvec_generators = [generate_kvectors_for_mesh, generate_kvectors_squeezed] + + +@pytest.mark.parametrize("ns", ns_list) +@pytest.mark.parametrize("cell", cells) +def test_duality_of_kvectors_mesh(cell, ns): + """ + If a_j for j=1,2,3 are the three basis vectors of a unit cell and + b_j the corresponding basis vectors of the reciprocal cell, the inner product + between them needs to satisfy a_j*b_l=2pi*delta_jl. + """ + nx, ny, nz = ns + kvectors = generate_kvectors_for_mesh(ns=ns, cell=cell) + + # Define frequencies with the same convention as in FFT + # This is essentially a manual implementation of torch.fft.fftfreq + ix_refs = torch.arange(nx) + ix_refs[ix_refs >= (nx + 1) // 2] -= nx + iy_refs = torch.arange(ny) + iy_refs[iy_refs >= (ny + 1) // 2] -= ny + + for ix in range(nx): + for iy in range(ny): + for iz in range((nz + 1) // 2): + inner_prods = torch.matmul(cell, kvectors[ix, iy, iz]) / 2 / torch.pi + inner_prods = torch.round(inner_prods) + inner_prods_ref = torch.tensor([ix_refs[ix], iy_refs[iy], iz]) * 1.0 + assert_close(inner_prods, inner_prods_ref, atol=1e-15, rtol=0.0) + + +@pytest.mark.parametrize("ns", ns_list) +@pytest.mark.parametrize("cell", cells) +def test_duality_of_kvectors_squeezed(cell, ns): + """ + If a_j for j=1,2,3 are the three basis vectors of a unit cell and + b_j the corresponding basis vectors of the reciprocal cell, the inner product + between them needs to satisfy a_j*b_l=2pi*delta_jl. + """ + nx, ny, nz = ns + kvectors = generate_kvectors_squeezed(ns=ns, cell=cell) + + # Define frequencies with the same convention as in FFT + # This is essentially a manual implementation of torch.fft.fftfreq + ix_refs = torch.arange(nx) + ix_refs[ix_refs >= (nx + 1) // 2] -= nx + iy_refs = torch.arange(ny) + iy_refs[iy_refs >= (ny + 1) // 2] -= ny + iz_refs = torch.arange(nz) + iz_refs[iz_refs >= (nz + 1) // 2] -= nz + + i_tot = 0 + for ix in range(nx): + for iy in range(ny): + for iz in range(nz): + inner_prods = torch.matmul(cell, kvectors[i_tot]) / 2 / torch.pi + inner_prods = torch.round(inner_prods) + inner_prods_ref = ( + torch.tensor([ix_refs[ix], iy_refs[iy], iz_refs[iz]]) * 1.0 + ) + assert_close(inner_prods, inner_prods_ref, atol=1e-15, rtol=0.0) + i_tot += 1 + + +@pytest.mark.parametrize("ns", ns_list) +@pytest.mark.parametrize("cell", cells) +@pytest.mark.parametrize("kvec_type", ["fft", "ewald"]) +def test_lenghts_of_kvectors(cell, ns, kvec_type): + """ + Check that the lengths of the obtained kvectors satisfy the triangle + inequality. + """ + # Compute an upper bound for the norms of the kvectors + # that should be obtained + reciprocal_cell = 2 * torch.pi * cell.inverse().T + norms_basisvecs = torch.linalg.norm(reciprocal_cell, dim=1) + norm_bound = torch.sum(norms_basisvecs * ns) + + # Compute the norms of all kvectors and check that they satisfy the bound + if kvec_type == "fft": + kvectors = generate_kvectors_for_mesh(ns=ns, cell=cell) + norms_all = torch.linalg.norm(kvectors, dim=3).flatten() + elif kvec_type == "ewald": + kvectors = generate_kvectors_squeezed(ns=ns, cell=cell) + norms_all = torch.linalg.norm(kvectors, dim=1).flatten() + + assert torch.all(norms_all < norm_bound) + + +# Tests that errors are raised when the inputs are of the wrong shape or have +# inconsistent devices +@pytest.mark.parametrize("generate_kvectors", kvec_generators) +def test_ns_wrong_shape(generate_kvectors): + ns = torch.tensor([2, 2]) + cell = torch.tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) + match = "ns of shape \\[2\\] should be of shape \\(3, \\)" + with pytest.raises(ValueError, match=match): + generate_kvectors(ns, cell) + + +@pytest.mark.parametrize("generate_kvectors", kvec_generators) +def test_cell_wrong_shape(generate_kvectors): + ns = torch.tensor([2, 2, 2]) + cell = torch.tensor([[1, 0, 0], [0, 1, 0]]) + match = "cell of shape \\[2, 3\\] should be of shape \\(3, 3\\)" + with pytest.raises(ValueError, match=match): + generate_kvectors(ns, cell) + + +@pytest.mark.parametrize("generate_kvectors", kvec_generators) +def test_different_devices_mesh_values_cell(generate_kvectors): + ns = torch.tensor([2, 2, 2], device="cpu") + cell = torch.eye(3, device="meta") # different device + match = "`ns` and `cell` are not on the same device, got cpu and meta" + with pytest.raises(ValueError, match=match): + generate_kvectors(ns, cell) From 12182709e99f185585264b93b8dd522c7558f5c1 Mon Sep 17 00:00:00 2001 From: Philip Loche Date: Fri, 5 Jul 2024 15:47:17 +0200 Subject: [PATCH 19/35] remove wrong typecheck test --- src/meshlode/calculators/base.py | 7 ------- tests/calculators/test_calculator_base.py | 17 ----------------- 2 files changed, 24 deletions(-) diff --git a/src/meshlode/calculators/base.py b/src/meshlode/calculators/base.py index ba9ba066..5295007b 100644 --- a/src/meshlode/calculators/base.py +++ b/src/meshlode/calculators/base.py @@ -206,13 +206,6 @@ def _validate_compute_parameters( f"{list(neighbor_shifts_single.shape)}" ) - if neighbor_shifts_single.dtype != positions_single.dtype: - raise ValueError( - "`neighbor_shifts` must be have the same dtype as `positions`, " - f"got {neighbor_shifts_single.dtype} and " - f"{positions_single.dtype}" - ) - if types_single.device != neighbor_shifts_single.device: raise ValueError( f"Inconsistent devices of types ({types_single.device}) and " diff --git a/tests/calculators/test_calculator_base.py b/tests/calculators/test_calculator_base.py index c7b7417b..0de08d09 100644 --- a/tests/calculators/test_calculator_base.py +++ b/tests/calculators/test_calculator_base.py @@ -222,20 +222,3 @@ def test_invalid_shape_neighbor_shifts(): neighbor_indices=None, neighbor_shifts=torch.ones([3, 3]), ) - - -def test_inconsistent_dtypes_neighbor_shifts(): - calculator = TestCalculator(all_types=None, exponent=1.0) - match = ( - r"`neighbor_shifts` must be have the same dtype as `positions`, got " - r"torch.float32 and torch.float64" - ) - with pytest.raises(ValueError, match=match): - calculator.compute( - types=torch.arange(2), - positions=torch.ones([2, 3], dtype=torch.float64), - cell=None, - charges=None, - neighbor_indices=None, - neighbor_shifts=torch.ones([3, 2], dtype=torch.float32), - ) From fc17493f7a14049944616a61d0f208289b91e1d1 Mon Sep 17 00:00:00 2001 From: Kevin Kazuki Huguenin-Dumittan Date: Fri, 5 Jul 2024 15:40:09 +0200 Subject: [PATCH 20/35] Add non-binary + random strucs to tests --- src/meshlode/calculators/meshewald.py | 4 +- src/meshlode/metatensor/meshewald.py | 7 +- tests/calculators/test_values_periodic.py | 225 ++++++++++++++---- .../coulomb_test_frames.xyz | 30 +++ 4 files changed, 220 insertions(+), 46 deletions(-) create mode 100644 tests/reference_structures/coulomb_test_frames.xyz diff --git a/src/meshlode/calculators/meshewald.py b/src/meshlode/calculators/meshewald.py index 6286de3d..f6905bb3 100644 --- a/src/meshlode/calculators/meshewald.py +++ b/src/meshlode/calculators/meshewald.py @@ -284,9 +284,7 @@ def _compute_sr( potential = torch.zeros_like(charges) for i, j, shift in zip(atom_is, atom_js, neighbor_shifts): shift = torch.tensor(shift, dtype=cell.dtype) - dist = torch.linalg.norm( - positions[j] - positions[i] + shift @ cell - ) + dist = torch.linalg.norm(positions[j] - positions[i] + shift @ cell) # If the contribution from all atoms within the cutoff is to be subtracted # this short-range part will simply use -V_LR as the potential diff --git a/src/meshlode/metatensor/meshewald.py b/src/meshlode/metatensor/meshewald.py index 9351aee2..dd7a99fa 100644 --- a/src/meshlode/metatensor/meshewald.py +++ b/src/meshlode/metatensor/meshewald.py @@ -75,13 +75,16 @@ def compute( neighbor_shifts = [neighbor_shifts] # Check that the lengths of the provided lists agree + n_sys = len(systems) + n_shif = len(neighbor_shifts) + n_ind = len(neighbor_indices) if (neighbor_indices is not None) and len(neighbor_indices) != len(systems): raise ValueError( - f"Numbers of systems (= {len(systems)}) needs to match number of neighbor lists (= {len(neighbor_indices)})" + f"Need equal numbers of systems ({n_sys}) and neighbor lists ({n_ind})" ) if (neighbor_shifts is not None) and len(neighbor_shifts) != len(systems): raise ValueError( - f"Numbers of systems (= {len(systems)}) needs to match number of neighbor shifts (= {len(neighbor_shifts)})" + f"Need equal numbers of systems ({n_sys}) and neighbor lists ({n_shif})" ) if len(systems) > 1: diff --git a/tests/calculators/test_values_periodic.py b/tests/calculators/test_values_periodic.py index c97ac264..246016ce 100644 --- a/tests/calculators/test_values_periodic.py +++ b/tests/calculators/test_values_periodic.py @@ -1,8 +1,51 @@ +import math +import os + import numpy as np import pytest import torch -from meshlode import EwaldPotential +# Imports for random structure +from ase.io import read + +from meshlode import EwaldPotential, MeshEwaldPotential + + +def generate_orthogonal_transformations(): + dtype = torch.float64 + + # first rotation matrix: identity + rot_1 = torch.eye(3, dtype=dtype) + + # second rotation matrix: rotation by angle phi around z-axis + phi = 0.82321 + rot_2 = torch.zeros((3, 3), dtype=dtype) + rot_2[0, 0] = rot_2[1, 1] = math.cos(phi) + rot_2[0, 1] = -math.sin(phi) + rot_2[1, 0] = math.sin(phi) + rot_2[2, 2] = 1.0 + + # third rotation matrix: second matrix followed by rotation by angle theta around y + theta = 1.23456 + rot_3 = torch.zeros((3, 3), dtype=dtype) + rot_3[0, 0] = rot_3[2, 2] = math.cos(theta) + rot_3[0, 2] = math.sin(theta) + rot_3[2, 0] = -math.sin(theta) + rot_3[1, 1] = 1.0 + rot_3 = rot_3 @ rot_2 + + # add additional orthogonal transformations by combining inversion + transformations = [rot_2, rot_3] + + # make sure that the generated transformations are indeed orthogonal + for q in transformations: + id = torch.eye(3, dtype=dtype) + id_2 = q.T @ q + torch.testing.assert_close(id, id_2, atol=2e-15, rtol=1e-14) + return transformations + + +dtype = torch.float64 def define_crystal(crystal_name="CsCl"): @@ -11,7 +54,6 @@ def define_crystal(crystal_name="CsCl"): # are compared with reference values. # see https://www.sciencedirect.com/science/article/pii/B9780128143698000078#s0015 # More detailed values can be found in https://pubs.acs.org/doi/10.1021/ic2023852 - dtype = torch.float64 # Caesium-Chloride (CsCl) structure: # - Cubic unit cell @@ -22,7 +64,8 @@ def define_crystal(crystal_name="CsCl"): positions = torch.tensor([[0, 0, 0], [0.5, 0.5, 0.5]], dtype=dtype) charges = torch.tensor([-1.0, 1.0], dtype=dtype) cell = torch.eye(3, dtype=dtype) - madelung_reference = 2.035361 + madelung_ref = 2.035361 + num_formula_units = 1 # Sodium-Chloride (NaCl) structure using a primitive unit cell # - non-cubic unit cell (fcc) @@ -33,7 +76,8 @@ def define_crystal(crystal_name="CsCl"): positions = torch.tensor([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0]], dtype=dtype) charges = torch.tensor([1.0, -1.0], dtype=dtype) cell = torch.tensor([[0, 1.0, 1], [1, 0, 1], [1, 1, 0]], dtype=dtype) # fcc - madelung_reference = 1.74756 + madelung_ref = 1.74756 + num_formula_units = 1 # Sodium-Chloride (NaCl) structure using a cubic unit cell # - cubic unit cell @@ -56,7 +100,8 @@ def define_crystal(crystal_name="CsCl"): ) charges = torch.tensor([+1.0, -1, -1, -1, +1, +1, +1, -1], dtype=dtype) cell = 2 * torch.eye(3, dtype=dtype) - madelung_reference = 1.747565 + madelung_ref = 1.747565 + num_formula_units = 4 # ZnS (zincblende) structure # - non-cubic unit cell (fcc) @@ -69,7 +114,8 @@ def define_crystal(crystal_name="CsCl"): positions = torch.tensor([[0, 0, 0], [0.5, 0.5, 0.5]], dtype=dtype) charges = torch.tensor([1.0, -1], dtype=dtype) cell = torch.tensor([[0, 1, 1], [1, 0, 1], [1, 1, 0]], dtype=dtype) - madelung_reference = 2 * 1.63806 / np.sqrt(3) + madelung_ref = 2 * 1.63806 / np.sqrt(3) + num_formula_units = 1 # Wurtzite structure # - non-cubic unit cell (triclinic) @@ -93,12 +139,13 @@ def define_crystal(crystal_name="CsCl"): [[0.5, -0.5 * np.sqrt(3), 0], [0.5, 0.5 * np.sqrt(3), 0], [0, 0, c]], dtype=dtype, ) - madelung_reference = 1.64132 / (u * c) + madelung_ref = 1.64132 / (u * c) + num_formula_units = 2 - # Fluorite structure + # Fluorite structure (e.g. CaF2 with Ca2+ and F-) # - non-cubic (fcc) unit cell # - 1 neutral molecule per unit cell - # - Cation-Anion ratio of 2:1 + # - Cation-Anion ratio of 1:2 elif crystal_name == "fluorite": a = 5.463 a = 1.0 @@ -108,19 +155,31 @@ def define_crystal(crystal_name="CsCl"): ) charges = torch.tensor([-1, -1, 2], dtype=dtype) cell = torch.tensor([[a, a, 0], [a, 0, a], [0, a, a]], dtype=dtype) / 2.0 - madelung_reference = 11.636575 + madelung_ref = 11.636575 + num_formula_units = 1 - # Copper-Oxide Cu2O structure + # Copper(I)-Oxide structure (e.g. Cu2O with Cu+ and O2-) + # - cubic unit cell + # - 2 neutral molecules per unit cell + # - Cation-Anion ratio of 2:1 elif crystal_name == "cu2o": - a = 0.4627 a = 1.0 - types = torch.tensor([8, 29, 29]) + types = torch.tensor([8, 8, 29, 29, 29, 29]) positions = a * torch.tensor( - [[1 / 4, 1 / 4, 1 / 4], [0, 0, 0], [1 / 2, 1 / 2, 1 / 2]], dtype=dtype + [ + [0, 0, 0], + [1 / 2, 1 / 2, 1 / 2], + [1 / 4, 1 / 4, 1 / 4], + [1 / 4, 3 / 4, 3 / 4], + [3 / 4, 1 / 4, 3 / 4], + [3 / 4, 3 / 4, 1 / 4], + ], + dtype=dtype, ) - charges = torch.tensor([-2, 1, 1], dtype=dtype) - cell = torch.tensor([[a, 0, 0], [0, a, 0], [0, 0, a]], dtype=dtype) - madelung_reference = 10.2594570330750 + charges = torch.tensor([-2, -2, 1, 1, 1, 1], dtype=dtype) + cell = a * torch.eye(3, dtype=dtype) + madelung_ref = 10.2594570330750 + num_formula_units = 2 # Wigner crystal in simple cubic structure. # Wigner crystals are equivalent to the Jellium or uniform electron gas models. @@ -141,7 +200,8 @@ def define_crystal(crystal_name="CsCl"): # be rescaled to the case in which the lattice parameter = 1. madelung_wigner_seiz = 1.7601188 wigner_seiz_radius = (3 / (4 * np.pi)) ** (1 / 3) - madelung_reference = madelung_wigner_seiz / wigner_seiz_radius # 2.83730 + madelung_ref = madelung_wigner_seiz / wigner_seiz_radius # 2.83730 + num_formula_units = 1 # Wigner crystal in bcc structure (note: this is the most stable structure). # See description of "wigner_sc" for a general explanation on Wigner crystals. @@ -160,7 +220,8 @@ def define_crystal(crystal_name="CsCl"): wigner_seiz_radius = (3 / (4 * np.pi * 2)) ** ( 1 / 3 ) # 2 atoms per cubic unit cell - madelung_reference = madelung_wigner_seiz / wigner_seiz_radius # 3.63924 + madelung_ref = madelung_wigner_seiz / wigner_seiz_radius # 3.63924 + num_formula_units = 1 # Same as above, but now using a cubic unit cell rather than the primitive bcc cell elif crystal_name == "wigner_bcc_cubiccell": @@ -175,14 +236,15 @@ def define_crystal(crystal_name="CsCl"): wigner_seiz_radius = (3 / (4 * np.pi * 2)) ** ( 1 / 3 ) # 2 atoms per cubic unit cell - madelung_reference = madelung_wigner_seiz / wigner_seiz_radius # 3.63924 + madelung_ref = madelung_wigner_seiz / wigner_seiz_radius # 3.63924 + num_formula_units = 2 # Wigner crystal in fcc structure # See description of "wigner_sc" for a general explanation on Wigner crystals. # Used to test the code for cases in which the unit cell has a nonzero net charge. elif crystal_name == "wigner_fcc": types = torch.tensor([1]) - positions = torch.tensor([[0.0, 0, 0]], dtype=dtype) + positions = torch.tensor([[0, 0, 0]], dtype=dtype) charges = torch.tensor([1.0], dtype=dtype) cell = torch.tensor([[1, 0, 1], [0, 1, 1], [1, 1, 0]], dtype=dtype) / 2 @@ -192,7 +254,8 @@ def define_crystal(crystal_name="CsCl"): wigner_seiz_radius = (3 / (4 * np.pi * 4)) ** ( 1 / 3 ) # 4 atoms per cubic unit cell - madelung_reference = madelung_wigner_seiz / wigner_seiz_radius # 4.58488 + madelung_ref = madelung_wigner_seiz / wigner_seiz_radius # 4.58488 + num_formula_units = 1 # Same as above, but now using a cubic unit cell rather than the primitive fcc cell elif crystal_name == "wigner_fcc_cubiccell": @@ -209,44 +272,61 @@ def define_crystal(crystal_name="CsCl"): wigner_seiz_radius = (3 / (4 * np.pi * 4)) ** ( 1 / 3 ) # 4 atoms per cubic unit cell - madelung_reference = madelung_wigner_seiz / wigner_seiz_radius # 4.58488 + madelung_ref = madelung_wigner_seiz / wigner_seiz_radius # 4.58488 + num_formula_units = 4 else: raise ValueError(f"crystal_name = {crystal_name} is not supported!") - return types, positions, charges, cell, madelung_reference + madelung_ref = torch.tensor(madelung_ref, dtype=dtype) + return types, positions, charges, cell, madelung_ref, num_formula_units -neutral_crystals = ["CsCl", "NaCl_primitive", "NaCl_cubic", "zincblende", "wurtzite"] -# neutral_crystals = ['CsCl'] scaling_factors = torch.tensor([1 / 2.0353610, 1.0, 3.4951291], dtype=torch.float64) +neutral_crystals = ["CsCl", "NaCl_primitive", "NaCl_cubic", "zincblende", "wurtzite"] +neutral_crystals += ["cu2o", "fluorite"] +@pytest.mark.parametrize("calc_name", ["ewald", "pme"]) @pytest.mark.parametrize("crystal_name", neutral_crystals) @pytest.mark.parametrize("scaling_factor", scaling_factors) -def test_madelung(crystal_name, scaling_factor): +def test_madelung(crystal_name, scaling_factor, calc_name): """ Check that the Madelung constants obtained from the Ewald sum calculator matches the reference values. In this test, only the charge-neutral crystal systems are chosen for which the potential converges relatively quickly, while the systems with a net charge are treated separately below. + The structures cover a broad range of simple crystals, with cells ranging from cubic + to triclinic, as well as cation-anion ratios of 1:1, 1:2 and 2:1. """ - # Call Ewald potential class without specifying any of the convergence parameters - # so that they are chosen by default (in a structure-dependent way) - EP = EwaldPotential() - - # Compute potential at the position of the atoms for the specified structure - types, positions, charges, cell, madelung_reference = define_crystal(crystal_name) - positions *= scaling_factor + # Get input parameters and adjust to account for scaling + types, pos, charges, cell, madelung_ref, num_units = define_crystal(crystal_name) + pos *= scaling_factor cell *= scaling_factor - potentials = EP.compute(types, positions, cell, charges) + madelung_ref /= scaling_factor + charges = charges.reshape((-1, 1)) + + # Define calculator and tolerances + if calc_name == "ewald": + sr_cutoff = scaling_factor * torch.tensor(1.0, dtype=dtype) + calc = EwaldPotential(sr_cutoff=sr_cutoff) + rtol = 4e-6 + elif calc_name == "pme": + sr_cutoff = scaling_factor * torch.tensor(2.0, dtype=dtype) + calc = MeshEwaldPotential(sr_cutoff=sr_cutoff) + rtol = 9e-4 + + # Compute potential and compare against target value using default hypers + potentials = calc.compute(types, positions=pos, cell=cell, charges=charges) energies = potentials * charges - energies_ref = -torch.ones_like(energies) * madelung_reference / scaling_factor + madelung = -torch.sum(energies) / 2 / num_units - torch.testing.assert_close(energies, energies_ref, atol=0.0, rtol=3.2e-6) + torch.testing.assert_close(madelung, madelung_ref, atol=0.0, rtol=rtol) +# Since structures without charge neutrality show slower convergence, these +# structures are tested separately. wigner_crystals = [ "wigner_sc", "wigner_fcc", @@ -254,7 +334,7 @@ def test_madelung(crystal_name, scaling_factor): "wigner_bcc", "wigner_bcc_cubiccell", ] -wigner_crystal = ["wigner_sc"] + scaling_factors = torch.tensor([0.4325, 1.0, 2.0353610], dtype=torch.float64) @@ -273,10 +353,10 @@ def test_wigner(crystal_name, scaling_factor): to numerically slower convergence of the relevant sums. """ # Get parameters defining atomic positions, cell and charges - types, positions, charges, cell, madelung_reference = define_crystal(crystal_name) + types, positions, charges, cell, madelung_ref, num = define_crystal(crystal_name) positions *= scaling_factor cell *= scaling_factor - madelung_reference /= scaling_factor + madelung_ref /= scaling_factor # Due to the slow convergence, we do not use the default values of the smearing, # but provide a range instead. The first value of 0.1 corresponds to what would be @@ -297,5 +377,68 @@ def test_wigner(crystal_name, scaling_factor): EP = EwaldPotential(atomic_smearing=smeareff) potentials = EP.compute(types, positions, cell, charges) energies = potentials * charges - energies_ref = -torch.ones_like(energies) * madelung_reference + energies_ref = -torch.ones_like(energies) * madelung_ref torch.testing.assert_close(energies, energies_ref, atol=0.0, rtol=rtol) + + +orthogonal_transformations = generate_orthogonal_transformations() +scaling_factors = torch.tensor([0.4325, 2.0353610], dtype=dtype) + + +@pytest.mark.parametrize("sr_cutoff", [2.01, 5.5]) +@pytest.mark.parametrize("frame_index", [0, 1, 2]) +@pytest.mark.parametrize("scaling_factor", scaling_factors) +@pytest.mark.parametrize("ortho", orthogonal_transformations) +@pytest.mark.parametrize("calc_name", ["ewald", "pme"]) +def test_random_structure(sr_cutoff, frame_index, scaling_factor, ortho, calc_name): + """ + Check that the potentials obtained from the main code agree with the ones computed + using an external library (GROMACS) for more complicated structures consisting of + 8 atoms placed randomly in cubic cells of varying sizes. + """ + # Get the predefined frames with the + # Coulomb energy and forces computed by GROMACS using PME + # using parameters as defined in the GROMACS manual + # https://manual.gromacs.org/documentation/current/user-guide/mdp-options.html#ewald + # + # coulombtype = PME + # fourierspacing = 0.01 ; 1/nm + # pme_order = 8 + # rcoulomb = 0.3 ; nm + struc_path = "tests/reference_structures/" + frame = read(os.path.join(struc_path, "coulomb_test_frames.xyz"), frame_index) + + # Energies in Gaussian units (without e²/[4 π ɛ_0] prefactor) + energy_target = torch.tensor(frame.info["energy"], dtype=dtype) / scaling_factor + # Forces in Gaussian units per Å + forces_target = ( + torch.tensor(frame.arrays["forces"], dtype=dtype) / scaling_factor**2 + ) + + # Convert into input format suitable for MeshLODE + positions = scaling_factor * (torch.tensor(frame.positions, dtype=dtype) @ ortho) + positions.requires_grad = True + cell = scaling_factor * torch.tensor(np.array(frame.cell), dtype=dtype) @ ortho + charges = torch.tensor([1, 1, 1, 1, -1, -1, -1, -1], dtype=dtype).reshape((-1, 1)) + types = torch.tensor([1, 1, 1, 1, 2, 2, 2, 2]) + + # Compute potential using MeshLODE and compare against reference values + sr_cutoff = scaling_factor * torch.tensor(sr_cutoff, dtype=dtype) + if calc_name == "ewald": + calc = EwaldPotential(sr_cutoff=sr_cutoff) + rtol_e = 2e-5 + rtol_f = 3.6e-3 + elif calc_name == "pme": + calc = MeshEwaldPotential(sr_cutoff=sr_cutoff) + rtol_e = 4.5e-3 # 1.5e-3 + rtol_f = 2.5e-3 # 6e-3 + potentials = calc.compute(types, positions=positions, cell=cell, charges=charges) + + # Compute energy, taking into account the double counting of each pair + energy = torch.sum(potentials * charges) / 2 + torch.testing.assert_close(energy, energy_target, atol=0.0, rtol=rtol_e) + + # Compute forces + energy.backward() + forces = -positions.grad + torch.testing.assert_close(forces, forces_target @ ortho, atol=0.0, rtol=rtol_f) diff --git a/tests/reference_structures/coulomb_test_frames.xyz b/tests/reference_structures/coulomb_test_frames.xyz new file mode 100644 index 00000000..c7b09c0b --- /dev/null +++ b/tests/reference_structures/coulomb_test_frames.xyz @@ -0,0 +1,30 @@ +8 +Lattice="8.459999999999999 0.0 0.0 0.0 8.459999999999999 0.0 0.0 0.0 8.459999999999999" Properties=species:S:1:pos:R:3:forces:R:3 energy=-1.711935801217823 pbc="T T T" +Na 6.41395283 1.91990662 7.67332792 -0.14070982 -0.11474542 -0.04825055 +Na 4.18068695 4.95752287 8.03095055 0.02304201 -0.00337703 -0.05851965 +Na 4.31246758 7.90075493 3.69300699 0.01032462 -0.00711682 0.04352940 +Na 8.23982906 2.97728539 3.90395117 -0.00588548 0.01980903 -0.00262495 +Cl 4.58017778 0.87936962 7.42740345 0.15206258 0.07818233 0.04745008 +Cl 7.16421080 5.12644577 1.27403510 -0.03842140 -0.06048855 0.00404734 +Cl 8.27988529 8.03175545 5.62085533 -0.01563965 0.05616960 -0.02410075 +Cl 4.18750286 3.83302188 5.23034573 0.01522815 0.03156665 0.03846963 +8 +Lattice="8.0 0.0 0.0 0.0 8.0 0.0 0.0 0.0 8.0" Properties=species:S:1:pos:R:3:forces:R:3 energy=-1.3774848936680977 pbc="T T T" +Na 7.49958754 1.50817406 0.04903742 0.18842268 0.22339685 0.26501864 +Na 6.36871147 6.44243002 5.29353428 -0.00565849 -0.17990391 -0.18518641 +Na 1.15779471 2.55996442 3.89204693 0.02542196 0.05150340 -0.13545713 +Na 5.46410370 7.88015556 6.58486128 -0.22116421 0.07723974 0.20727965 +Cl 3.62183785 5.90626240 3.56432629 0.00277152 0.03026245 0.14060318 +Cl 6.16898727 4.03323746 2.27290463 0.06484856 -0.02595718 0.03754553 +Cl 4.37587976 0.11005501 0.82337153 0.11867952 0.02015366 -0.19643391 +Cl 0.70531571 2.49052644 0.94032115 -0.17332105 -0.19669629 -0.13336943 +8 +Lattice="10.0 0.0 0.0 0.0 10.0 0.0 0.0 0.0 10.0" Properties=species:S:1:pos:R:3:forces:R:3 energy=-1.1020011291278247 pbc="T T T" +Na 9.37448406 1.88521755 0.06129678 0.12058989 0.14297387 0.16961185 +Na 7.96088934 8.05303764 6.61691761 -0.00362286 -0.11513921 -0.11852075 +Na 1.44724345 3.19995522 4.86505842 0.01627035 0.03296255 -0.08669289 +Na 6.83012962 9.85019398 8.23107624 -0.14154177 0.04943282 0.13265023 +Cl 4.52729702 7.38282776 4.45540762 0.00177510 0.01936905 0.08998748 +Cl 7.71123362 5.04154682 2.84113073 0.04150165 -0.01661189 0.02402992 +Cl 5.46985054 0.13756876 1.02921438 0.07595183 0.01289930 -0.12570915 +Cl 0.88164455 3.11315799 1.17540145 -0.11092404 -0.12588664 -0.08535721 From 5d60dc5602468a8b2ccf15a2b1a88da865a8c247 Mon Sep 17 00:00:00 2001 From: Philip Loche Date: Fri, 5 Jul 2024 16:45:01 +0200 Subject: [PATCH 21/35] use systems neighborlist in metatensor branch --- examples/neighborlist_example.ipynb | 62 ++++++++----- src/meshlode/calculators/meshpotential.py | 2 +- src/meshlode/calculators/pmepotential.py | 14 ++- src/meshlode/metatensor/__init__.py | 4 +- .../{meshewald.py => pmepotential.py} | 86 +++++++------------ 5 files changed, 86 insertions(+), 82 deletions(-) rename src/meshlode/metatensor/{meshewald.py => pmepotential.py} (73%) diff --git a/examples/neighborlist_example.ipynb b/examples/neighborlist_example.ipynb index f5df529e..f4f0f240 100644 --- a/examples/neighborlist_example.ipynb +++ b/examples/neighborlist_example.ipynb @@ -10,7 +10,8 @@ "import torch\n", "import numpy as np\n", "import math\n", - "from metatensor.torch.atomistic import System\n", + "from metatensor.torch.atomistic import System, NeighborListOptions\n", + "from metatensor.torch import Labels, TensorBlock\n", "\n", "from ase import Atoms\n", "from ase.neighborlist import neighbor_list" @@ -61,14 +62,14 @@ "source": [ "sr_cutoff = np.sqrt(3) * 0.8\n", "struc = Atoms(positions=positions, cell=cell, pbc=True)\n", - "atom_is, atom_js, neighbor_shifts = neighbor_list(\"ijS\", struc, sr_cutoff, self_interaction=False)" + "nl_i, nl_j, nl_S, nl_D = neighbor_list(\"ijSD\", struc, sr_cutoff)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Convert neighbor list from ASE to desired format (torch tensor of dtype int)" + "Convert ASE neighbor list into suitable format for a Metatensor system " ] }, { @@ -77,37 +78,56 @@ "metadata": {}, "outputs": [], "source": [ - "atom_is = atom_is.reshape((-1,1))\n", - "atom_js = atom_js.reshape((-1,1))\n", - "neighbor_indices = torch.tensor(np.hstack([atom_is, atom_js]))\n", - "neighbor_shifts = torch.tensor(neighbor_shifts)" + "neighbors = TensorBlock(\n", + " values=torch.from_numpy(nl_D.astype(np.float32).reshape(-1, 3, 1)),\n", + " samples=Labels(\n", + " names=[\n", + " \"first_atom\",\n", + " \"second_atom\",\n", + " \"cell_shift_a\",\n", + " \"cell_shift_b\",\n", + " \"cell_shift_c\",\n", + " ],\n", + " values=torch.from_numpy(np.vstack([nl_i, nl_j, nl_S.T]).T),\n", + " ),\n", + " components=[Labels.range(\"xyz\", 3)],\n", + " properties=Labels.range(\"distance\", 1),\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Attach neighbor list to system object" ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 8, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/khugueni/code/MeshLODE/src/meshlode/calculators/meshewald.py:336: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", - " positions[j] - positions[i] + torch.tensor(shift @ cell)\n" - ] - } - ], + "outputs": [], "source": [ "system = System(types=types, positions=positions, cell=cell)\n", "\n", - "MP = meshlode.metatensor.MeshEwaldPotential(\n", + "nl_options = NeighborListOptions(cutoff=sr_cutoff, full_list=True)\n", + "system.add_neighbor_list(options=nl_options, neighbors=neighbors)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "MP = meshlode.metatensor.PMEPotential(\n", " atomic_smearing=atomic_smearing,\n", " mesh_spacing=mesh_spacing,\n", " interpolation_order=interpolation_order,\n", " subtract_self=True,\n", " sr_cutoff=sr_cutoff,\n", ")\n", - "potential_metatensor = MP.compute(system, neighbor_indices=neighbor_indices, neighbor_shifts=neighbor_shifts)" + "potential_metatensor = MP.compute(system)" ] }, { @@ -119,7 +139,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 10, "metadata": {}, "outputs": [ { diff --git a/src/meshlode/calculators/meshpotential.py b/src/meshlode/calculators/meshpotential.py index 9b997003..2d9bb7c4 100644 --- a/src/meshlode/calculators/meshpotential.py +++ b/src/meshlode/calculators/meshpotential.py @@ -15,7 +15,7 @@ class MeshPotential(CalculatorBase): and long range contribution but calculates the full contribution to the potential on a grid. - For a Particle Mesh Ewald (PME) use :py:class:`meshlode.MeshEwaldPotential`. + For a Particle Mesh Ewald (PME) use :py:class:`meshlode.PMEPotential`. :param atomic_smearing: Width of the atom-centered Gaussian used to create the atomic density. diff --git a/src/meshlode/calculators/pmepotential.py b/src/meshlode/calculators/pmepotential.py index 25735235..10195dc8 100644 --- a/src/meshlode/calculators/pmepotential.py +++ b/src/meshlode/calculators/pmepotential.py @@ -392,10 +392,18 @@ def _compute_sr( """ if neighbor_indices is None or neighbor_shifts is None: # Get list of neighbors - struc = Atoms(positions=positions.detach().numpy(), cell=cell, pbc=True) + struc = Atoms( + positions=positions.detach().numpy(), + cell=cell.detach().numpy(), + pbc=True, + ) atom_is, atom_js, neighbor_shifts = neighbor_list( "ijS", struc, sr_cutoff.item(), self_interaction=False ) + + atom_is = torch.from_numpy(atom_is) + atom_js = torch.from_numpy(atom_js) + neighbor_shifts = torch.from_numpy(neighbor_shifts) else: atom_is = neighbor_indices[0] atom_js = neighbor_indices[1] @@ -404,9 +412,7 @@ def _compute_sr( potential = torch.zeros_like(charges) for i, j, shift in zip(atom_is, atom_js, neighbor_shifts): shift = shift.type(cell.dtype) - dist = torch.linalg.norm( - positions[j] - positions[i] + torch.tensor(shift @ cell) - ) + dist = torch.linalg.norm(positions[j] - positions[i] + shift @ cell) # If the contribution from all atoms within the cutoff is to be subtracted # this short-range part will simply use -V_LR as the potential diff --git a/src/meshlode/metatensor/__init__.py b/src/meshlode/metatensor/__init__.py index 8afbae3a..b52ec3ef 100644 --- a/src/meshlode/metatensor/__init__.py +++ b/src/meshlode/metatensor/__init__.py @@ -1,4 +1,4 @@ from .meshpotential import MeshPotential -from .meshewald import MeshEwaldPotential +from .meshewald import PMEPotential -__all__ = ["MeshPotential", "EwaldPotential", "MeshEwaldPotential"] +__all__ = ["MeshPotential", "EwaldPotential", "PMEPotential"] diff --git a/src/meshlode/metatensor/meshewald.py b/src/meshlode/metatensor/pmepotential.py similarity index 73% rename from src/meshlode/metatensor/meshewald.py rename to src/meshlode/metatensor/pmepotential.py index 0d661665..72dd4f6c 100644 --- a/src/meshlode/metatensor/meshewald.py +++ b/src/meshlode/metatensor/pmepotential.py @@ -21,31 +21,17 @@ # mypy: disable-error-code="override" -class MeshEwaldPotential(calculators.PMEPotential): - """An (atomic) type wise long range potential. +class PMEPotential(calculators.PMEPotential): + """Specie-wise long-range potential using a particle mesh-based Ewald (PME). Refer to :class:`meshlode.MeshPotential` for full documentation. """ - def forward( - self, - systems: Union[List[System], System], - neighbor_indices: Union[List[torch.Tensor], torch.Tensor] = None, - neighbor_shifts: Union[List[torch.Tensor], torch.Tensor] = None, - ) -> TensorMap: - """forward just calls :py:meth:`CalculatorModule.compute`""" - return self.compute( - systems=systems, - neighbor_indices=neighbor_indices, - neighbor_shifts=neighbor_shifts, - ) + def forward(self, systems: Union[List[System], System]) -> TensorMap: + """forward just calls :py:meth:`compute()`""" + return self.compute(systems=systems) - def compute( - self, - systems: Union[List[System], System], - neighbor_indices: Union[List[torch.Tensor], torch.Tensor] = None, - neighbor_shifts: Union[List[torch.Tensor], torch.Tensor] = None, - ) -> TensorMap: + def compute(self, systems: Union[List[System], System]) -> TensorMap: """Compute potential for all provided ``systems``. All ``systems`` must have the same ``dtype`` and the same ``device``. If each @@ -69,36 +55,19 @@ def compute( # provided as input (for convenience of users testing out the code) if not isinstance(systems, list): systems = [systems] - if (neighbor_indices is not None) and not isinstance(neighbor_indices, list): - neighbor_indices = [neighbor_indices] - if (neighbor_shifts is not None) and not isinstance(neighbor_shifts, list): - neighbor_shifts = [neighbor_shifts] - - # Check that the lengths of the provided lists agree - if (neighbor_indices is not None) and len(neighbor_indices) != len(systems): - raise ValueError( - f"Numbers of systems (= {len(systems)}) needs to match number of " - f"neighbor lists (= {len(neighbor_indices)})" - ) - if (neighbor_shifts is not None) and len(neighbor_shifts) != len(systems): - raise ValueError( - f"Numbers of systems (= {len(systems)}) needs to match number of " - f"neighbor shifts (= {len(neighbor_shifts)})" - ) - if len(systems) > 1: - for system in systems[1:]: - if system.dtype != systems[0].dtype: - raise ValueError( - "`dtype` of all systems must be the same, got " - f"{system.dtype} and {systems[0].dtype}`" - ) + for system in systems: + if system.dtype != systems[0].dtype: + raise ValueError( + "`dtype` of all systems must be the same, got " + f"{system.dtype} and {systems[0].dtype}`" + ) - if system.device != systems[0].device: - raise ValueError( - "`device` of all systems must be the same, got " - f"{system.device} and {systems[0].device}`" - ) + if system.device != systems[0].device: + raise ValueError( + "`device` of all systems must be the same, got " + f"{system.device} and {systems[0].device}`" + ) dtype = systems[0].positions.dtype device = systems[0].positions.device @@ -146,7 +115,7 @@ def compute( n_blocks = n_types * n_charges_channels feat_dic: Dict[int, List[torch.Tensor]] = {a: [] for a in range(n_blocks)} - for i, system in enumerate(systems): + for system in systems: if use_explicit_charges: charges = system.get_data("charges").values else: @@ -155,9 +124,18 @@ def compute( system.types, requested_types, dtype, device ) - if neighbor_indices is None or neighbor_shifts is None: - # Compute the potentials - # TODO: use neighborlist from system if provided. + # try to extract neighbor list from system object + neighbor_indices = None + for neighbor_list_options in system.known_neighbor_lists(): + if neighbor_list_options.cutoff == self.sr_cutoff: + neighbor_list = system.get_neighbor_list(neighbor_list_options) + + neighbor_indices = neighbor_list.samples.values[:, :2] + neighbor_shifts = neighbor_list.samples.values[:, 2:] + + break + + if neighbor_indices is None: potential = self._compute_single_system( positions=system.positions, cell=system.cell, @@ -170,8 +148,8 @@ def compute( positions=system.positions, charges=charges, cell=system.cell, - neighbor_indices=neighbor_indices[i], - neighbor_shifts=neighbor_shifts[i], + neighbor_indices=neighbor_indices, + neighbor_shifts=neighbor_shifts, ) # Reorder data into metatensor format From e73031ed6c79f2ca30816b6ad95897def37488f4 Mon Sep 17 00:00:00 2001 From: Philip Loche Date: Fri, 5 Jul 2024 16:46:23 +0200 Subject: [PATCH 22/35] update notebook --- examples/neighborlist_example.ipynb | 31 ++++++++--------------------- 1 file changed, 8 insertions(+), 23 deletions(-) diff --git a/examples/neighborlist_example.ipynb b/examples/neighborlist_example.ipynb index f4f0f240..d41b6bdd 100644 --- a/examples/neighborlist_example.ipynb +++ b/examples/neighborlist_example.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -19,7 +19,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -56,7 +56,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -74,7 +74,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -104,7 +104,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -116,7 +116,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -139,24 +139,9 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "tensor(17) tensor(17) tensor(1.) tensor(-2.7745)\n", - "tensor(17) tensor(55) tensor(-1.) tensor(-0.7391)\n", - "tensor(55) tensor(17) tensor(-1.) tensor(-0.7391)\n", - "tensor(55) tensor(55) tensor(1.) tensor(-2.7745)\n", - "Using the metatensor version\n", - "Computed energies on each atom = [[-2.035360813140869], [-2.035360813140869]]\n", - "Reference Madelung constant = 2.035\n", - "Total energy = -4.071\n" - ] - } - ], + "outputs": [], "source": [ "atomic_energies_metatensor = torch.zeros((n_atoms, 1))\n", "for idx_c, c in enumerate(types):\n", From db6e06599da581baf589fd614cbfac5edcc3949f Mon Sep 17 00:00:00 2001 From: Philip Loche Date: Fri, 5 Jul 2024 17:23:08 +0200 Subject: [PATCH 23/35] convert notebook to propper example --- .readthedocs.yaml | 1 + docs/src/conf.py | 1 + examples/neighborlist_example.ipynb | 204 ---------------------------- examples/neighborlist_example.py | 128 +++++++++++++++++ src/meshlode/metatensor/__init__.py | 4 +- 5 files changed, 132 insertions(+), 206 deletions(-) delete mode 100644 examples/neighborlist_example.ipynb create mode 100644 examples/neighborlist_example.py diff --git a/.readthedocs.yaml b/.readthedocs.yaml index 7d876318..2625fe9f 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -25,4 +25,5 @@ python: - method: pip path: . extra_requirements: + - examples - metatensor diff --git a/docs/src/conf.py b/docs/src/conf.py index c957e320..2e7071cf 100644 --- a/docs/src/conf.py +++ b/docs/src/conf.py @@ -55,6 +55,7 @@ autodoc_typehints_format = "short" intersphinx_mapping = { + "ase": ("https://wiki.fysik.dtu.dk/ase/", None), "python": ("https://docs.python.org/3", None), "numpy": ("https://numpy.org/doc/stable/", None), "torch": ("https://pytorch.org/docs/stable/", None), diff --git a/examples/neighborlist_example.ipynb b/examples/neighborlist_example.ipynb deleted file mode 100644 index d41b6bdd..00000000 --- a/examples/neighborlist_example.ipynb +++ /dev/null @@ -1,204 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import meshlode\n", - "import torch\n", - "import numpy as np\n", - "import math\n", - "from metatensor.torch.atomistic import System, NeighborListOptions\n", - "from metatensor.torch import Labels, TensorBlock\n", - "\n", - "from ase import Atoms\n", - "from ase.neighborlist import neighbor_list" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Define simple example structure having the CsCl structure and compute the reference\n", - "# values. MeshPotential by default outputs the types sorted according to the atomic\n", - "# number. Thus, we input the compound \"CsCl\" and \"ClCs\" since Cl and Cs have atomic\n", - "# numbers 17 and 55, respectively.\n", - "types = torch.tensor([17, 55]) # Cl and Cs\n", - "positions = torch.tensor([[0, 0, 0], [0.5, 0.5, 0.5]])\n", - "charges = torch.tensor([-1.0, 1.0])\n", - "cell = torch.eye(3)\n", - "\n", - "# %%\n", - "# Define the expected values of the energy\n", - "n_atoms = len(types)\n", - "madelung = 2 * 1.7626 / math.sqrt(3)\n", - "energies_ref = -madelung * torch.ones((n_atoms, 1))\n", - "\n", - "# %%\n", - "# We first define general parameters for our calculation MeshLODE\n", - "\n", - "atomic_smearing = 0.1\n", - "cell = torch.eye(3)\n", - "mesh_spacing = atomic_smearing / 4\n", - "interpolation_order = 2" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Generate neighbor list using ASE" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "sr_cutoff = np.sqrt(3) * 0.8\n", - "struc = Atoms(positions=positions, cell=cell, pbc=True)\n", - "nl_i, nl_j, nl_S, nl_D = neighbor_list(\"ijSD\", struc, sr_cutoff)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Convert ASE neighbor list into suitable format for a Metatensor system " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "neighbors = TensorBlock(\n", - " values=torch.from_numpy(nl_D.astype(np.float32).reshape(-1, 3, 1)),\n", - " samples=Labels(\n", - " names=[\n", - " \"first_atom\",\n", - " \"second_atom\",\n", - " \"cell_shift_a\",\n", - " \"cell_shift_b\",\n", - " \"cell_shift_c\",\n", - " ],\n", - " values=torch.from_numpy(np.vstack([nl_i, nl_j, nl_S.T]).T),\n", - " ),\n", - " components=[Labels.range(\"xyz\", 3)],\n", - " properties=Labels.range(\"distance\", 1),\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Attach neighbor list to system object" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "system = System(types=types, positions=positions, cell=cell)\n", - "\n", - "nl_options = NeighborListOptions(cutoff=sr_cutoff, full_list=True)\n", - "system.add_neighbor_list(options=nl_options, neighbors=neighbors)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "MP = meshlode.metatensor.PMEPotential(\n", - " atomic_smearing=atomic_smearing,\n", - " mesh_spacing=mesh_spacing,\n", - " interpolation_order=interpolation_order,\n", - " subtract_self=True,\n", - " sr_cutoff=sr_cutoff,\n", - ")\n", - "potential_metatensor = MP.compute(system)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Convert to Madelung constant and check that the value is correct" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "atomic_energies_metatensor = torch.zeros((n_atoms, 1))\n", - "for idx_c, c in enumerate(types):\n", - " for idx_n, n in enumerate(types):\n", - " # Take the coefficients with the correct center atom and neighbor atom types\n", - " block = potential_metatensor.block(\n", - " {\"center_type\": int(c), \"neighbor_type\": int(n)}\n", - " )\n", - "\n", - " # The coulomb potential between atoms i and j is charge_i * charge_j / d_ij\n", - " # The features are simply computing a pure 1/r potential with no prefactors.\n", - " # Thus, to compute the energy between atoms of types i and j, we need to\n", - " # multiply by the charges of i and j.\n", - " print(c, n, charges[idx_c] * charges[idx_n], block.values[0, 0])\n", - " atomic_energies_metatensor[idx_c] += (\n", - " charges[idx_c] * charges[idx_n] * block.values[0, 0]\n", - " )\n", - "\n", - "# %%\n", - "# The total energy is just the sum of all atomic energies\n", - "total_energy_metatensor = torch.sum(atomic_energies_metatensor)\n", - "\n", - "# %%\n", - "# Compare against reference Madelung constant and reference energy:\n", - "print(\"Using the metatensor version\")\n", - "print(f\"Computed energies on each atom = {atomic_energies_metatensor.tolist()}\")\n", - "print(f\"Reference Madelung constant = {madelung:.3f}\")\n", - "print(f\"Total energy = {total_energy_metatensor:.3f}\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.12" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/examples/neighborlist_example.py b/examples/neighborlist_example.py new file mode 100644 index 00000000..99e67f55 --- /dev/null +++ b/examples/neighborlist_example.py @@ -0,0 +1,128 @@ +""" +Computations with explicit Neighbor Lists +========================================= + +This example will explain how to use the metatensor branch of Meshlode with an attached +neighborlist to a :py:class:`metatensor.torch.atomistic.System` object. +""" + +# %% + +import math + +import numpy as np +import torch +from ase import Atoms +from ase.neighborlist import neighbor_list +from metatensor.torch import Labels, TensorBlock +from metatensor.torch.atomistic import NeighborListOptions, System + +import meshlode + + +# %% +# Define simple example structure having the CsCl structure and compute the reference +# values. MeshPotential by default outputs the types sorted according to the atomic +# number. Thus, we input the compound "CsCl" and "ClCs" since Cl and Cs have atomic +# numbers 17 and 55, respectively. + +types = torch.tensor([17, 55]) # Cl and Cs +positions = torch.tensor([[0, 0, 0], [0.5, 0.5, 0.5]]) +charges = torch.tensor([-1.0, 1.0]) +cell = torch.eye(3) + +# %% +# Define the expected values of the energy + +n_atoms = len(types) +madelung = 2 * 1.7626 / math.sqrt(3) +energies_ref = -madelung * torch.ones((n_atoms, 1)) + +# %% +# We first define general parameters for our calculation MeshLODE. + +atomic_smearing = 0.1 +cell = torch.eye(3) +mesh_spacing = atomic_smearing / 4 +interpolation_order = 2 + + +# %% +# Generate neighbor list using ASE's :py:func:`neighbor_list() +# ` function. + +sr_cutoff = np.sqrt(3) * 0.8 +struc = Atoms(positions=positions, cell=cell, pbc=True) +nl_i, nl_j, nl_S, nl_D = neighbor_list("ijSD", struc, sr_cutoff) + + +# %% +# Convert ASE neighbor list into suitable format for a Metatensor system. + +neighbors = TensorBlock( + values=torch.from_numpy(nl_D.astype(np.float32).reshape(-1, 3, 1)), + samples=Labels( + names=[ + "first_atom", + "second_atom", + "cell_shift_a", + "cell_shift_b", + "cell_shift_c", + ], + values=torch.from_numpy(np.vstack([nl_i, nl_j, nl_S.T]).T), + ), + components=[Labels.range("xyz", 3)], + properties=Labels.range("distance", 1), +) + + +# %% +# Attach ``neighbors`` to ``system`` object. + +system = System(types=types, positions=positions, cell=cell) + +nl_options = NeighborListOptions(cutoff=sr_cutoff, full_list=True) +system.add_neighbor_list(options=nl_options, neighbors=neighbors) + +MP = meshlode.metatensor.PMEPotential( + atomic_smearing=atomic_smearing, + mesh_spacing=mesh_spacing, + interpolation_order=interpolation_order, + subtract_self=True, + sr_cutoff=sr_cutoff, +) +potential_metatensor = MP.compute(system) + + +# %% +# Convert to Madelung constant and check that the value is correct + +atomic_energies_metatensor = torch.zeros((n_atoms, 1)) +for idx_c, c in enumerate(types): + for idx_n, n in enumerate(types): + # Take the coefficients with the correct center atom and neighbor atom types + block = potential_metatensor.block( + {"center_type": int(c), "neighbor_type": int(n)} + ) + + # The coulomb potential between atoms i and j is charge_i * charge_j / d_ij + # The features are simply computing a pure 1/r potential with no prefactors. + # Thus, to compute the energy between atoms of types i and j, we need to + # multiply by the charges of i and j. + print(c, n, charges[idx_c] * charges[idx_n], block.values[0, 0]) + atomic_energies_metatensor[idx_c] += ( + charges[idx_c] * charges[idx_n] * block.values[0, 0] + ) + +# %% +# The total energy is just the sum of all atomic energies + +total_energy_metatensor = torch.sum(atomic_energies_metatensor) + +# %% +# Compare against reference Madelung constant and reference energy: + +print("Using the metatensor version") +print(f"Computed energies on each atom = {atomic_energies_metatensor.tolist()}") +print(f"Reference Madelung constant = {madelung:.3f}") +print(f"Total energy = {total_energy_metatensor:.3f}") diff --git a/src/meshlode/metatensor/__init__.py b/src/meshlode/metatensor/__init__.py index b52ec3ef..d1230883 100644 --- a/src/meshlode/metatensor/__init__.py +++ b/src/meshlode/metatensor/__init__.py @@ -1,4 +1,4 @@ from .meshpotential import MeshPotential -from .meshewald import PMEPotential +from .pmepotential import PMEPotential -__all__ = ["MeshPotential", "EwaldPotential", "PMEPotential"] +__all__ = ["MeshPotential", "PMEPotential"] From dee626872efef73ad8b9fdb40b0fa9820cb813b9 Mon Sep 17 00:00:00 2001 From: Kevin Kazuki Huguenin-Dumittan Date: Fri, 5 Jul 2024 17:21:12 +0200 Subject: [PATCH 24/35] Remove outdated meshpotential calculator --- src/meshlode/__init__.py | 1 - src/meshlode/calculators/__init__.py | 1 - src/meshlode/calculators/meshpotential.py | 208 ---------- src/meshlode/metatensor/__init__.py | 3 +- src/meshlode/metatensor/meshewald.py | 5 +- src/meshlode/metatensor/meshpotential.py | 238 ----------- .../calculators/test_calculators_workflow.py | 15 +- tests/metatensor/test_madelung.py | 8 +- .../test_metatensor_meshpotential.py | 376 ------------------ 9 files changed, 8 insertions(+), 847 deletions(-) delete mode 100644 src/meshlode/calculators/meshpotential.py delete mode 100644 src/meshlode/metatensor/meshpotential.py delete mode 100644 tests/metatensor/test_metatensor_meshpotential.py diff --git a/src/meshlode/__init__.py b/src/meshlode/__init__.py index f454810d..96cea434 100644 --- a/src/meshlode/__init__.py +++ b/src/meshlode/__init__.py @@ -1,4 +1,3 @@ -from .calculators.meshpotential import MeshPotential from .calculators.ewaldpotential import EwaldPotential from .calculators.directpotential import DirectPotential from .calculators.pmepotential import PMEPotential diff --git a/src/meshlode/calculators/__init__.py b/src/meshlode/calculators/__init__.py index 13a6d857..91dfbb23 100644 --- a/src/meshlode/calculators/__init__.py +++ b/src/meshlode/calculators/__init__.py @@ -1,4 +1,3 @@ -from .meshpotential import MeshPotential from .ewaldpotential import EwaldPotential from .directpotential import DirectPotential from .pmepotential import PMEPotential diff --git a/src/meshlode/calculators/meshpotential.py b/src/meshlode/calculators/meshpotential.py deleted file mode 100644 index 9b997003..00000000 --- a/src/meshlode/calculators/meshpotential.py +++ /dev/null @@ -1,208 +0,0 @@ -from typing import List, Optional, Union - -import torch - -from ..lib.fourier_convolution import FourierSpaceConvolution -from ..lib.mesh_interpolator import MeshInterpolator -from .base import CalculatorBase - - -class MeshPotential(CalculatorBase): - r"""Specie-wise long-range potential, computed on a grid. - - Method scaling as :math:`\mathcal{O}(NlogN)` with respect to the number of particles - :math:`N`. This class does not perform a usual Ewald style splitting into a short - and long range contribution but calculates the full contribution to the potential on - a grid. - - For a Particle Mesh Ewald (PME) use :py:class:`meshlode.MeshEwaldPotential`. - - :param atomic_smearing: Width of the atom-centered Gaussian used to create the - atomic density. - :param all_types: Optional global list of all atomic types that should be considered - for the computation. This option might be useful when running the calculation on - subset of a whole dataset and it required to keep the shape of the output - consistent. If this is not set the possible atomic types will be determined when - calling the :meth:`compute()`. - :param exponent: the exponent "p" in 1/r^p potentials - :param mesh_spacing: Value that determines the umber of Fourier-space grid points - that will be used along each axis. If set to None, it will automatically be set - to half of ``atomic_smearing``. - :param interpolation_order: Interpolation order for mapping onto the grid, where an - interpolation order of p corresponds to interpolation by a polynomial of degree - ``p - 1`` (e.g. ``p = 4`` for cubic interpolation). - :param subtract_self: If set to :py:obj:`True`, subtract from the features of an - atom the contributions to the potential arising from that atom itself (but not - the periodic images). - - Example - ------- - >>> import torch - >>> from meshlode import MeshPotential - - Define simple example structure having the CsCl (Cesium Chloride) structure - - >>> types = torch.tensor([55, 17]) # Cs and Cl - >>> positions = torch.tensor([[0, 0, 0], [0.5, 0.5, 0.5]]) - >>> cell = torch.eye(3) - - Compute features - - >>> MP = MeshPotential(atomic_smearing=0.2, mesh_spacing=0.1, interpolation_order=4) - >>> MP.compute(types=types, positions=positions, cell=cell) - tensor([[-0.5467, 1.3755], - [ 1.3755, -0.5467]]) - """ - - def __init__( - self, - atomic_smearing: float, - all_types: Optional[List[int]] = None, - exponent: float = 1.0, - mesh_spacing: Optional[float] = None, - interpolation_order: Optional[int] = 4, - subtract_self: Optional[bool] = False, - ): - super().__init__(all_types=all_types, exponent=exponent) - - # Check that all provided values are correct - if interpolation_order not in [1, 2, 3, 4, 5]: - raise ValueError("Only `interpolation_order` from 1 to 5 are allowed") - - # If no explicit mesh_spacing is given, set it such that it can resolve - # the smeared potentials. - if atomic_smearing <= 0: - raise ValueError(f"`atomic_smearing` {atomic_smearing} has to be positive") - - self.atomic_smearing = atomic_smearing - self.mesh_spacing = mesh_spacing - self.interpolation_order = interpolation_order - self.subtract_self = subtract_self - - # Initilize auxiliary objects - self.fourier_space_convolution = FourierSpaceConvolution() - - def compute( - self, - types: Union[List[torch.Tensor], torch.Tensor], - positions: Union[List[torch.Tensor], torch.Tensor], - cell: Union[List[torch.Tensor], torch.Tensor], - charges: Optional[Union[List[torch.Tensor], torch.Tensor]] = None, - ) -> Union[torch.Tensor, List[torch.Tensor]]: - """Compute potential for all provided "systems" stacked inside list. - - The computation is performed on the same ``device`` as ``systems`` is stored on. - The ``dtype`` of the output tensors will be the same as the input. - - :param types: single or list of 1D tensor of integer representing the - particles identity. For atoms, this is typically their atomic numbers. - :param positions: single or 2D tensor of shape (len(types), 3) containing the - Cartesian positions of all particles in the system. - :param cell: single or 2D tensor of shape (3, 3), describing the bounding - box/unit cell of the system. Each row should be one of the bounding box - vector; and columns should contain the x, y, and z components of these - vectors (i.e. the cell should be given in row-major order). - :param charges: Optional single or list of 2D tensor of shape (len(types), n), - :param neighbor_indices: Optional single or list of 2D tensors of shape (2, n), - where n is the number of atoms. The 2 rows correspond to the indices of - the two atoms which are considered neighbors (e.g. within a cutoff distance) - :param neighbor_shifts: Optional single or list of 2D tensors of shape (3, n), - where n is the number of atoms. The 3 rows correspond to the shift indices - for periodic images. - - :return: List of torch Tensors containing the potentials for all frames and all - atoms. Each tensor in the list is of shape (n_atoms, n_types), where - n_types is the number of types in all systems combined. If the input was - a single system only a single torch tensor with the potentials is returned. - - IMPORTANT: If multiple types are present, the different "types-channels" - are ordered according to atomic number. For example, if a structure contains - a water molecule with atoms 0, 1, 2 being of types O, H, H, then for this - system, the feature tensor will be of shape (3, 2) = (``n_atoms``, - ``n_types``), where ``features[0, 0]`` is the potential at the position of - the Oxygen atom (atom 0, first index) generated by the HYDROGEN(!) atoms, - while ``features[0,1]`` is the potential at the position of the Oxygen atom - generated by the Oxygen atom(s). - """ - - return self._compute_impl( - types=types, - positions=positions, - cell=cell, - charges=charges, - neighbor_indices=None, - neighbor_shifts=None, - ) - - # This function is kept to keep MeshLODE compatible with the broader pytorch - # infrastructure, which require a "forward" function. We name this function - # "compute" instead, for compatibility with other COSMO software. - def forward( - self, - types: Union[List[torch.Tensor], torch.Tensor], - positions: Union[List[torch.Tensor], torch.Tensor], - cell: Union[List[torch.Tensor], torch.Tensor], - charges: Optional[Union[List[torch.Tensor], torch.Tensor]] = None, - ) -> Union[torch.Tensor, List[torch.Tensor]]: - """Forward just calls :py:meth:`compute`.""" - return self.compute( - types=types, - positions=positions, - cell=cell, - charges=charges, - ) - - def _compute_single_system( - self, - positions: torch.Tensor, - cell: Union[None, torch.Tensor], - charges: torch.Tensor, - neighbor_indices: Union[None, torch.Tensor], - neighbor_shifts: Union[None, torch.Tensor], - ) -> torch.Tensor: - - if self.mesh_spacing is None: - mesh_spacing = self.atomic_smearing / 2 - else: - mesh_spacing = self.mesh_spacing - - # Initializations - k_cutoff = 2 * torch.pi / mesh_spacing - - # Compute number of times each basis vector of the - # reciprocal space can be scaled until the cutoff - # is reached - basis_norms = torch.linalg.norm(cell, dim=1) - ns_approx = k_cutoff * basis_norms / 2 / torch.pi - ns_actual_approx = 2 * ns_approx + 1 # actual number of mesh points - ns = 2 ** torch.ceil(torch.log2(ns_actual_approx)).long() # [nx, ny, nz] - - # Step 1: Smear particles onto mesh - MI = MeshInterpolator(cell, ns, interpolation_order=self.interpolation_order) - MI.compute_interpolation_weights(positions) - rho_mesh = MI.points_to_mesh(particle_weights=charges) - - # Step 2: Perform Fourier space convolution (FSC) - potential_mesh = self.fourier_space_convolution.compute( - mesh_values=rho_mesh, - cell=cell, - potential_exponent=1, - atomic_smearing=self.atomic_smearing, - ) - - # Step 3: Back interpolation - interpolated_potential = MI.mesh_to_points(potential_mesh) - - # Remove self contribution - if self.subtract_self: - self_contrib = ( - torch.sqrt( - torch.tensor( - 2.0 / torch.pi, dtype=positions.dtype, device=positions.device - ), - ) - / self.atomic_smearing - ) - interpolated_potential -= charges * self_contrib - - return interpolated_potential diff --git a/src/meshlode/metatensor/__init__.py b/src/meshlode/metatensor/__init__.py index 8afbae3a..b6fc9713 100644 --- a/src/meshlode/metatensor/__init__.py +++ b/src/meshlode/metatensor/__init__.py @@ -1,4 +1,3 @@ -from .meshpotential import MeshPotential -from .meshewald import MeshEwaldPotential +from .meshewald import PMEPotential __all__ = ["MeshPotential", "EwaldPotential", "MeshEwaldPotential"] diff --git a/src/meshlode/metatensor/meshewald.py b/src/meshlode/metatensor/meshewald.py index adc055ea..b6ed8b6e 100644 --- a/src/meshlode/metatensor/meshewald.py +++ b/src/meshlode/metatensor/meshewald.py @@ -21,7 +21,7 @@ # mypy: disable-error-code="override" -class MeshEwaldPotential(calculators.PMEPotential): +class PMEPotential(calculators.PMEPotential): """An (atomic) type wise long range potential. Refer to :class:`meshlode.MeshPotential` for full documentation. @@ -75,9 +75,6 @@ def compute( neighbor_shifts = [neighbor_shifts] # Check that the lengths of the provided lists agree - n_sys = len(systems) - n_shif = len(neighbor_shifts) - n_ind = len(neighbor_indices) if (neighbor_indices is not None) and len(neighbor_indices) != len(systems): raise ValueError( f"Numbers of systems (= {len(systems)}) needs to match number of " diff --git a/src/meshlode/metatensor/meshpotential.py b/src/meshlode/metatensor/meshpotential.py deleted file mode 100644 index 990df338..00000000 --- a/src/meshlode/metatensor/meshpotential.py +++ /dev/null @@ -1,238 +0,0 @@ -from typing import Dict, List, Union - -import torch - - -try: - from metatensor.torch import Labels, TensorBlock, TensorMap - from metatensor.torch.atomistic import System -except ImportError: - raise ImportError( - "metatensor.torch is required for meshlode.metatensor but is not installed. " - "Try installing it with:\npip install metatensor[torch]" - ) - - -from .. import calculators - - -# We are breaking the Liskov substitution principle here by changing the signature of -# "compute" compated to the supertype of "MeshPotential". -# mypy: disable-error-code="override" - - -class MeshPotential(calculators.MeshPotential): - """An (atomic) type wise long range potential. - - Refer to :class:`meshlode.MeshPotential` for full documentation. - - Example - ------- - >>> import torch - >>> from metatensor.torch.atomistic import System - >>> from meshlode.metatensor import MeshPotential - - Define simple example structure having the CsCl (Cesium Chloride) structure - - >>> types = torch.tensor([55, 17]) # Cs and Cl - >>> positions = torch.tensor([[0, 0, 0], [0.5, 0.5, 0.5]]) - >>> cell = torch.eye(3) - >>> system = System(types=types, positions=positions, cell=cell) - - Compute features - - >>> MP = MeshPotential(atomic_smearing=0.2, mesh_spacing=0.1, interpolation_order=4) - >>> features = MP.compute(system) - - All (atomic) type combinations - - >>> features.keys - Labels( - center_type neighbor_type - 17 17 - 17 55 - 55 17 - 55 55 - ) - - The Cl-potential at the position of the Cl atom - - >>> block_ClCl = features.block({"center_type": 17, "neighbor_type": 17}) - >>> block_ClCl.values - tensor([[1.3755]]) - """ - - def forward( - self, - systems: Union[List[System], System], - ) -> TensorMap: - """forward just calls :py:meth:`CalculatorModule.compute`""" - return self.compute(systems=systems) - - def compute( - self, - systems: Union[List[System], System], - ) -> TensorMap: - """Compute potential for all provided ``systems``. - - All ``systems`` must have the same ``dtype`` and the same ``device``. If each - system contains a custom data field `charges` the potential will be calculated - for each "charges-channel". The number of `charges-channels` must be same in all - ``systems``. If no "explicit" charges are set the potential will be calculated - for each "types-channels". - - Refer to :meth:`meshlode.MeshPotential.compute()` for additional details on how - "charges-channel" and "types-channels" are computed. - - :param systems: single System or list of - :py:class:`metatensor.torch.atomisic.System` on which to run the - calculations. - - :return: TensorMap containing the potential of all types. The keys of the - TensorMap are "center_type" and "neighbor_type" if no charges are asociated - with - """ - # Make sure that the compute function also works if only a single frame is - # provided as input (for convenience of users testing out the code) - if not isinstance(systems, list): - systems = [systems] - - if len(systems) > 1: - for system in systems[1:]: - if system.dtype != systems[0].dtype: - raise ValueError( - "`dtype` of all systems must be the same, got " - f"{system.dtype} and {systems[0].dtype}`" - ) - - if system.device != systems[0].device: - raise ValueError( - "`device` of all systems must be the same, got " - f"{system.device} and {systems[0].device}`" - ) - - dtype = systems[0].positions.dtype - device = systems[0].positions.device - - requested_types = self._get_requested_types( - [system.types for system in systems] - ) - n_types = len(requested_types) - - has_charges = torch.tensor(["charges" in s.known_data() for s in systems]) - all_charges = torch.all(has_charges) - any_charges = torch.any(has_charges) - - if any_charges and not all_charges: - raise ValueError("`systems` do not consistently contain `charges` data") - if all_charges: - use_explicit_charges = True - n_charges_channels = systems[0].get_data("charges").values.shape[1] - spec_channels = list(range(n_charges_channels)) - key_names = ["center_type", "charges_channel"] - - for i_system, system in enumerate(systems): - n_channels = system.get_data("charges").values.shape[1] - if n_channels != n_charges_channels: - raise ValueError( - f"number of charges-channels in system index {i_system} " - f"({n_channels}) is inconsistent with first system " - f"({n_charges_channels})" - ) - else: - # Use one hot encoded type channel per species for charges channel - use_explicit_charges = False - n_charges_channels = n_types - spec_channels = requested_types - key_names = ["center_type", "neighbor_type"] - - # Initialize dictionary for TensorBlock storage. - # - # If `use_explicit_charges=False`, the blocks are sorted according to the - # (integer) center_type and neighbor_type. Blocks are assigned the array indices - # 0, 1, 2,... Example: for H2O: `H` is mapped to `0` and `O` is mapped to `1`. - # - # For `use_explicit_charges=True` the blocks are stored according to the - # center_type and charge_channel - n_blocks = n_types * n_charges_channels - feat_dic: Dict[int, List[torch.Tensor]] = {a: [] for a in range(n_blocks)} - - for system in systems: - if use_explicit_charges: - charges = system.get_data("charges").values - else: - # One-hot encoding of charge information - charges = self._one_hot_charges( - system.types, requested_types, dtype, device - ) - - # Compute the potentials - # TODO: use neighborlist from system if provided. - potential = self._compute_single_system( - positions=system.positions, - cell=system.cell, - charges=charges, - neighbor_indices=None, - neighbor_shifts=None, - ) - - # Reorder data into metatensor format - for spec_center, at_num_center in enumerate(requested_types): - for spec_channel in range(len(spec_channels)): - a_pair = spec_center * n_charges_channels + spec_channel - feat_dic[a_pair] += [ - potential[system.types == at_num_center, spec_channel] - ] - - # Assemble all computed potential values into TensorBlocks for each combination - # of center_type and neighbor_type/charge_channel - blocks: List[TensorBlock] = [] - for keys, values in feat_dic.items(): - spec_center = requested_types[keys // n_charges_channels] - - # Generate the Labels objects for the samples and properties of the - # TensorBlock. - values_samples: List[List[int]] = [] - for i_frame, system in enumerate(systems): - for i_atom in range(len(system)): - if system.types[i_atom] == spec_center: - values_samples.append([i_frame, i_atom]) - - samples_vals_tensor = torch.tensor( - values_samples, dtype=torch.int32, device=device - ) - - # If no atoms are found that match the types pair `samples_vals_tensor` - # will be empty. We have to reshape the empty tensor to be a valid input for - # `Labels`. - if len(samples_vals_tensor) == 0: - samples_vals_tensor = samples_vals_tensor.reshape(-1, 2) - - labels_samples = Labels(["system", "atom"], samples_vals_tensor) - labels_properties = Labels( - ["potential"], torch.tensor([[0]], device=device) - ) - - block = TensorBlock( - samples=labels_samples, - components=[], - properties=labels_properties, - values=torch.hstack(values).reshape((-1, 1)), - ) - - blocks.append(block) - - assert len(blocks) == n_blocks - - # Generate TensorMap from TensorBlocks by defining suitable keys - key_values: List[torch.Tensor] = [] - for spec_center in requested_types: - for spec_channel in spec_channels: - key_values.append( - torch.tensor([spec_center, spec_channel], device=device) - ) - key_values = torch.vstack(key_values) - - labels_keys = Labels(key_names, key_values) - - return TensorMap(keys=labels_keys, blocks=blocks) diff --git a/tests/calculators/test_calculators_workflow.py b/tests/calculators/test_calculators_workflow.py index 50562914..770d0901 100644 --- a/tests/calculators/test_calculators_workflow.py +++ b/tests/calculators/test_calculators_workflow.py @@ -7,7 +7,7 @@ import torch from torch.testing import assert_close -from meshlode import DirectPotential, EwaldPotential, MeshPotential, PMEPotential +from meshlode import DirectPotential, EwaldPotential, PMEPotential MADELUNG_CSCL = torch.tensor(2 * 1.7626 / math.sqrt(3)) @@ -44,20 +44,9 @@ }, True, ), - ( - MeshPotential, - { - "atomic_smearing": ATOMIC_SMEARING, - "mesh_spacing": MESH_SPACING, - "interpolation_order": INTERPOLATION_ORDER, - "subtract_self": SUBTRACT_SELF, - }, - True, - ), ], ) class TestWorkflow: - def cscl_system(self, periodic): """CsCl crystal. Same as in the madelung test""" types = torch.tensor([55, 17]) @@ -95,7 +84,7 @@ def test_atomic_smearing_error(self, CalculatorClass, params, periodic): CalculatorClass(atomic_smearing=-1.0) def test_interpolation_order_error(self, CalculatorClass, params, periodic): - if type(CalculatorClass) in [PMEPotential, MeshPotential]: + if type(CalculatorClass) in [PMEPotential]: match = "Only `interpolation_order` from 1 to 5" with pytest.raises(ValueError, match=match): CalculatorClass(atomic_smearing=1, interpolation_order=10) diff --git a/tests/metatensor/test_madelung.py b/tests/metatensor/test_madelung.py index ef8b27bc..cb855f68 100644 --- a/tests/metatensor/test_madelung.py +++ b/tests/metatensor/test_madelung.py @@ -13,7 +13,7 @@ class TestMadelung: """ - Test features computed in MeshPotential correspond to the "electrostatic" potential + Test features computed in PMEPotential correspond to the "electrostatic" potential of the structures. We thus compare the computed potential against the known exact values for some simple crystal structures. """ @@ -121,7 +121,7 @@ def test_madelung_low_order( madelung = dic["madelung"] / scaling_factor mesh_spacing = atomic_smearing / 2 * scaling_factor smearing_eff = atomic_smearing * scaling_factor - MP = meshlode_metatensor.MeshPotential( + MP = meshlode_metatensor.PMEPotential( atomic_smearing=smearing_eff, mesh_spacing=mesh_spacing, interpolation_order=interpolation_order, @@ -164,7 +164,7 @@ def test_madelung_high_order( madelung = dic["madelung"] / scaling_factor mesh_spacing = atomic_smearing / 10 * scaling_factor smearing_eff = atomic_smearing * scaling_factor - MP = meshlode_metatensor.MeshPotential( + MP = meshlode_metatensor.PMEPotential( atomic_smearing=smearing_eff, mesh_spacing=mesh_spacing, interpolation_order=interpolation_order, @@ -207,7 +207,7 @@ def test_madelung_low_order_metatensor( smearing_eff = atomic_smearing * scaling_factor n_atoms = len(positions) system = mts_atomistic.System(types=types, positions=positions, cell=cell) - MP = meshlode_metatensor.MeshPotential( + MP = meshlode_metatensor.PMEPotential( atomic_smearing=smearing_eff, mesh_spacing=mesh_spacing, interpolation_order=interpolation_order, diff --git a/tests/metatensor/test_metatensor_meshpotential.py b/tests/metatensor/test_metatensor_meshpotential.py deleted file mode 100644 index 98a53e54..00000000 --- a/tests/metatensor/test_metatensor_meshpotential.py +++ /dev/null @@ -1,376 +0,0 @@ -from typing import List - -import pytest -import torch -from packaging import version - - -metatensor_torch = pytest.importorskip("metatensor.torch") -meshlode_metatensor = pytest.importorskip("meshlode.metatensor") - - -# Define toy system consisting of a single structure for testing -def toy_system_single_frame( - dtype=None, device=None -) -> metatensor_torch.atomistic.System: - return metatensor_torch.atomistic.System( - types=torch.tensor([1, 1, 8, 8], device=device), - positions=torch.tensor( - [[0.0, 0, 0], [0, 0, 1], [0, 0, 2], [0, 0, 3]], - dtype=dtype, - device=device, - ), - cell=torch.tensor( - [[10.0, 0, 0], [0, 10, 0], [0, 0, 10]], - dtype=dtype, - device=device, - ), - ) - - -def toy_system_single_frame_charges(): - system = toy_system_single_frame() - - # Create system with "hand" written one hot charges - charges = torch.tensor([[1.0, 0.0], [1.0, 0.0], [0.0, 1.0], [0.0, 1.0]]) - - # create a metatensor.TensorBlock wich and to add it to the system - samples = metatensor_torch.Labels("atom", torch.arange(len(system)).reshape(-1, 1)) - properties = metatensor_torch.Labels( - "charge", torch.arange(charges.shape[1]).reshape(-1, 1) - ) - - charges_block = metatensor_torch.TensorBlock( - samples=samples, - components=[], - properties=properties, - values=charges, - ) - - system.add_data("charges", charges_block) - - return system - - -def toy_system_single_frame_charges_arbitrary_charges(): - system = toy_system_single_frame() - - # Create system with "hand" written random charges with 4 samples and 5 channels - charges = torch.rand(4, 5) - - # create a metatensor.TensorBlock wich and to add it to the system - samples = metatensor_torch.Labels("atom", torch.arange(len(system)).reshape(-1, 1)) - properties = metatensor_torch.Labels( - "charge", torch.arange(charges.shape[1]).reshape(-1, 1) - ) - - charges_block = metatensor_torch.TensorBlock( - samples=samples, - components=[], - properties=properties, - values=charges, - ) - - system.add_data("charges", charges_block) - - return system - - -# Initialize the calculators. For now, only the meshlode_metatensor.MeshPotential is -# implemented. -def descriptor() -> meshlode_metatensor.MeshPotential: - return meshlode_metatensor.MeshPotential( - atomic_smearing=1.0, - ) - - -def test_forward(): - mp = descriptor() - descriptor_compute = mp.compute(toy_system_single_frame()) - descriptor_forward = mp.forward(toy_system_single_frame()) - - metatensor_torch.equal_raise(descriptor_forward, descriptor_compute) - - -# Test correct filling of zero and empty blocks when setting global atomic numbers -def test_all_types(): - all_types = [9, 1, 8] - descriptor = meshlode_metatensor.MeshPotential( - atomic_smearing=1, all_types=all_types - ) - values = descriptor.compute(toy_system_single_frame()) - - for n in all_types: - assert len(values.block({"center_type": 9, "neighbor_type": n}).values) == 0 - - for n in [1, 8]: - assert torch.equal( - values.block({"center_type": n, "neighbor_type": 9}).values, - torch.tensor([[0], [0]]), - ) - - -def test_dtype_device(): - """Test that the output dtype and device are the same as the input.""" - device = "cpu" - dtype = torch.float64 - - mp = descriptor() - potential = mp.compute(toy_system_single_frame(dtype=torch.float64, device=device)) - - assert potential[0].values.dtype == dtype - assert potential[0].values.device.type == device - - -def test_wrong_dtype_between_systems(): - match = "`dtype` of all systems must be the same, got 7 and 6" - with pytest.raises(ValueError, match=match): - descriptor().compute( - [ - toy_system_single_frame(dtype=torch.float32), - toy_system_single_frame(dtype=torch.float64), - ] - ) - - -def test_wrong_device_between_systems(): - match = "`device` of all systems must be the same, got meta and cpu" - with pytest.raises(ValueError, match=match): - descriptor().compute( - [ - toy_system_single_frame(device="cpu"), - toy_system_single_frame(device="meta"), - ] - ) - - -def test_explicit_charges(): - mp = descriptor() - potential = mp.compute(toy_system_single_frame()) - potential_charges = mp.compute(toy_system_single_frame_charges()) - - # Test metatdata - assert potential_charges.keys.names == ["center_type", "charges_channel"] - assert torch.all( - potential_charges.keys.values == torch.tensor([[1, 0], [1, 1], [8, 0], [8, 1]]) - ) - - # Test values - for block, block_charges in zip(potential, potential_charges): - assert block_charges.samples == block.samples - assert block_charges.components == block.components - assert block_charges.properties == block.properties - assert torch.all(block_charges.values == block.values) - - -def test_explicit_arbitrarycharges(): - mp = descriptor() - potential_charges = mp.compute(toy_system_single_frame_charges_arbitrary_charges()) - - # Test metatdata - assert potential_charges.keys.names == ["center_type", "charges_channel"] - assert torch.all( - potential_charges.keys.values - == torch.tensor( - [ - [1, 0], - [1, 1], - [1, 2], - [1, 3], - [1, 4], - [8, 0], - [8, 1], - [8, 2], - [8, 3], - [8, 4], - ] - ) - ) - - -def test_error_raise_charges_no_charges(): - systems = [toy_system_single_frame(), toy_system_single_frame_charges()] - match = "`systems` do not consistently contain `charges` data" - - with pytest.raises(ValueError, match=match): - descriptor().compute(systems) - - -def test_error_raise_charge_shape(): - system = toy_system_single_frame() - - # Create system with "hand" written one hot charges - charges = torch.tensor( - [[1.0, 0.0, 2.0], [1.0, 0.0, 2.0], [0.0, 1.0, 2.0], [0.0, 1.0, 2.0]] - ) - - # create a metatensor.TensorBlock wich and to add it to the system - samples = metatensor_torch.Labels( - "atom", torch.arange(charges.shape[0]).reshape(-1, 1) - ) - properties = metatensor_torch.Labels( - "charge", torch.arange(charges.shape[1]).reshape(-1, 1) - ) - - charges_block = metatensor_torch.TensorBlock( - samples=samples, - components=[], - properties=properties, - values=charges, - ) - - system.add_data("charges", charges_block) - - systems = [system, toy_system_single_frame_charges()] - - match = ( - r"number of charges-channels in system index 1 \(2\) is inconsistent with " - r"first system \(3\)" - ) - - with pytest.raises(ValueError, match=match): - descriptor().compute(systems) - - -# Make sure that the calculators are computing the features without raising errors, -# and returns the correct output format (TensorMap) -def check_operation(calculator): - descriptor = calculator.compute(toy_system_single_frame()) - assert isinstance(descriptor, torch.ScriptObject) - if version.parse(torch.__version__) >= version.parse("2.1"): - assert descriptor._type().name() == "TensorMap" - - -# Run the above test as a normal python script -def test_operation_as_python(): - check_operation(descriptor()) - - -# Similar to the above, but also testing that the code can be compiled as a torch script -# def test_operation_as_torch_script(): -# scripted = torch.jit.script(descriptor()) -# check_operation(scripted) - - -# Define a more complex toy system consisting of multiple frames, mixing three types. -def toy_system_2() -> List[metatensor_torch.atomistic.System]: - # First few frames containing Nitrogen - L = 2.0 - frames = [] - frames.append( - metatensor_torch.atomistic.System( - types=torch.tensor([7]), - positions=torch.zeros((1, 3)), - cell=L * 2 * torch.eye(3), - ) - ) - frames.append( - metatensor_torch.atomistic.System( - types=torch.tensor([7, 7]), - positions=torch.zeros((2, 3)), - cell=L * 2 * torch.eye(3), - ) - ) - frames.append( - metatensor_torch.atomistic.System( - types=torch.tensor([7, 7, 7]), - positions=torch.zeros((3, 3)), - cell=L * 2 * torch.eye(3), - ) - ) - - # One more frame containing Na and Cl - positions = torch.tensor([[0, 0, 0], [1.0, 0, 0]]) - cell = torch.tensor([[0, 1.0, 1], [1, 0, 1], [1, 1, 0]]) - frames.append( - metatensor_torch.atomistic.System( - types=torch.tensor([11, 17]), positions=positions, cell=cell - ) - ) - - return frames - - -class TestMultiFrameToySystem: - # Compute TensorMap containing features for various hyperparameters, including more - # extreme values. - tensormaps_list = [] - frames = toy_system_2() - for atomic_smearing in [0.01, 0.3, 3.7]: - for mesh_spacing in [15.3, 0.19]: - for interpolation_order in [1, 2, 3, 4, 5]: - MP = meshlode_metatensor.MeshPotential( - atomic_smearing=atomic_smearing, - mesh_spacing=mesh_spacing, - interpolation_order=interpolation_order, - subtract_self=False, - ) - tensormaps_list.append(MP.compute(frames)) - - @pytest.mark.parametrize("features", tensormaps_list) - def test_tensormap_labels(self, features): - # Test that the keys of the TensorMap for the toy system are correct - label_values = torch.tensor( - [ - [7, 7], - [7, 11], - [7, 17], - [11, 7], - [11, 11], - [11, 17], - [17, 7], - [17, 11], - [17, 17], - ] - ) - label_names = ["center_type", "neighbor_type"] - labels_ref = metatensor_torch.Labels(names=label_names, values=label_values) - - assert labels_ref == features.keys - - @pytest.mark.parametrize("features", tensormaps_list) - def test_zero_blocks(self, features): - # Since the first 3 frames contain Nitrogen only, while the last frame - # only contains Na and Cl, the features should be zero - for i in [11, 17]: - # For structures in which Nitrogen is present, there will be no Na or Cl - # neighbors. There are six such center atoms in total. - block = features.block({"center_type": 7, "neighbor_type": i}) - assert torch.equal(block.values, torch.zeros((6, 1))) - - # For structures in which Na or Cl are present, there will be no Nitrogen - # neighbors. - block = features.block({"center_type": i, "neighbor_type": 7}) - assert torch.equal(block.values, torch.zeros((1, 1))) - - @pytest.mark.parametrize("features", tensormaps_list) - def test_nitrogen_blocks(self, features): - # For this toy data set: - # - the first frame contains a single atom at the origin - # - the second frame contains two atoms at the origin - # - the third frame contains three atoms at the origin - # Thus, the features should almost be identical, up to a global factor - # that is the number of atoms (that are exactly on the same position). - block = features.block({"center_type": 7, "neighbor_type": 7}) - values = block.values[:, 0] # flatten to 1d - values_ref = torch.tensor([1.0, 2, 2, 3, 3, 3]) - - # We use a slightly higher relative tolerance due to numerical errors - torch.testing.assert_close(values / values[0], values_ref, rtol=1e-6, atol=0.0) - - @pytest.mark.parametrize("features", tensormaps_list) - def test_nacl_blocks(self, features): - # In the NaCl structure, swapping the positions of all Na and Cl atoms leads to - # an equivalent structure (up to global translation). This leads to symmetry - # in the features: the Na-density around Cl is the same as the Cl-density around - # Na and so on. - block_nana = features.block({"center_type": 11, "neighbor_type": 11}) - block_nacl = features.block({"center_type": 11, "neighbor_type": 17}) - block_clna = features.block({"center_type": 17, "neighbor_type": 11}) - block_clcl = features.block({"center_type": 17, "neighbor_type": 17}) - torch.testing.assert_close( - block_nacl.values, block_clna.values, rtol=1e-15, atol=0.0 - ) - torch.testing.assert_close( - block_nana.values, block_clcl.values, rtol=1e-15, atol=0.0 - ) From 12eb31824fb2f0af200b9829dec8d5e619f7e266 Mon Sep 17 00:00:00 2001 From: Kevin Kazuki Huguenin-Dumittan Date: Fri, 5 Jul 2024 17:47:03 +0200 Subject: [PATCH 25/35] Vectorize + move SR calc to base class --- src/meshlode/calculators/base.py | 89 ++++++++++- src/meshlode/calculators/ewaldpotential.py | 162 ++++++--------------- src/meshlode/calculators/pmepotential.py | 139 +++--------------- src/meshlode/lib/potentials.py | 4 +- tests/calculators/test_values_periodic.py | 16 +- tests/lib/test_potentials.py | 4 +- 6 files changed, 166 insertions(+), 248 deletions(-) diff --git a/src/meshlode/calculators/base.py b/src/meshlode/calculators/base.py index 5295007b..a71b4c84 100644 --- a/src/meshlode/calculators/base.py +++ b/src/meshlode/calculators/base.py @@ -1,6 +1,8 @@ from typing import List, Optional, Tuple, Union import torch +from ase import Atoms +from ase.neighborlist import neighbor_list from meshlode.lib import InversePowerLawPotential @@ -264,10 +266,15 @@ def _compute_impl( neighbor_indices: Union[None, List[torch.Tensor], torch.Tensor], neighbor_shifts: Union[None, List[torch.Tensor], torch.Tensor], ) -> Union[torch.Tensor, List[torch.Tensor]]: - types, positions, cell, charges, neighbor_indices, neighbor_shifts = ( - self._validate_compute_parameters( - types, positions, cell, charges, neighbor_indices, neighbor_shifts - ) + ( + types, + positions, + cell, + charges, + neighbor_indices, + neighbor_shifts, + ) = self._validate_compute_parameters( + types, positions, cell, charges, neighbor_indices, neighbor_shifts ) potentials = [] @@ -303,3 +310,77 @@ def _compute_single_system( neighbor_shifts: Union[None, torch.Tensor], ) -> torch.Tensor: raise NotImplementedError("only implemented in child classes") + + def _compute_sr( + self, + positions: torch.Tensor, + charges: torch.Tensor, + cell: torch.Tensor, + smearing: torch.Tensor, + sr_cutoff: torch.Tensor, + neighbor_indices: Optional[torch.Tensor] = None, + neighbor_shifts: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Compute the short-range part of the Ewald sum in realspace + + :param positions: torch.tensor of shape (n_atoms, 3). Contains the Cartesian + coordinates of the atoms. The implementation also works if the positions + are not contained within the unit cell. + :param charges: torch.tensor of shape `(n_atoms, n_channels)`. In the simplest + case, this would be a tensor of shape (n_atoms, 1) where charges[i,0] is the + charge of atom i. More generally, the potential for the same atom positions + is computed for n_channels independent meshes, and one can specify the + "charge" of each atom on each of the meshes independently. + :param cell: torch.tensor of shape `(3, 3)`. Describes the unit cell of the + structure, where cell[i] is the i-th basis vector. + :param smearing: torch.Tensor smearing paramter determining the splitting + between the SR and LR parts. + :param sr_cutoff: Cutoff radius used for the short-range part of the Ewald sum. + :param neighbor_indices: Optional single or list of 2D tensors of shape (2, n), + where n is the number of atoms. The 2 rows correspond to the indices of + the two atoms which are considered neighbors (e.g. within a cutoff distance) + :param neighbor_shifts: Optional single or list of 2D tensors of shape (3, n), + where n is the number of atoms. The 3 rows correspond to the shift indices + for periodic images. + + :returns: torch.tensor of shape `(n_atoms, n_channels)` containing the potential + at the position of each atom for the `n_channels` independent meshes separately. + """ + if neighbor_indices is None or neighbor_shifts is None: + # Get list of neighbors + struc = Atoms(positions=positions.detach().numpy(), cell=cell, pbc=True) + atom_is, atom_js, neighbor_shifts = neighbor_list( + "ijS", struc, sr_cutoff.item(), self_interaction=False + ) + atom_is = torch.tensor(atom_is) + atom_js = torch.tensor(atom_js) + shifts = torch.tensor(neighbor_shifts, dtype=cell.dtype) # N x 3 + + else: + atom_is = neighbor_indices[0] + atom_js = neighbor_indices[1] + shifts = neighbor_shifts.T + shifts.dtype = cell.dtype + + # Compute energy + potential = torch.zeros_like(charges) + + pos_is = positions[atom_is] + pos_js = positions[atom_js] + dists = torch.linalg.norm(pos_js - pos_is + shifts @ cell, dim=1) + # If the contribution from all atoms within the cutoff is to be subtracted + # this short-range part will simply use -V_LR as the potential + if self.subtract_interior: + potentials_bare = -self.potential.potential_lr_from_dist(dists, smearing) + # In the remaining cases, we simply use the usual V_SR to get the full + # 1/r^p potential when combined with the long-range part implemented in + # reciprocal space + else: + potentials_bare = self.potential.potential_sr_from_dist(dists, smearing) + # potential.index_add_(0, atom_is, charges[atom_js] * potentials_bare) + + for i, j, potential_bare in zip(atom_is, atom_js, potentials_bare): + potential[i.item()] += charges[j.item()] * potential_bare + + return potential diff --git a/src/meshlode/calculators/ewaldpotential.py b/src/meshlode/calculators/ewaldpotential.py index c9156fcf..730e0a79 100644 --- a/src/meshlode/calculators/ewaldpotential.py +++ b/src/meshlode/calculators/ewaldpotential.py @@ -6,6 +6,7 @@ from ase import Atoms from ase.neighborlist import neighbor_list +from ..lib import generate_kvectors_squeezed from .base import CalculatorBase @@ -71,10 +72,11 @@ def __init__( subtract_self: Optional[bool] = True, subtract_interior: Optional[bool] = False, ): - super().__init__(all_types=all_types, exponent=exponent) - + if exponent < 0.0 or exponent > 3.0: + raise ValueError(f"`exponent` p={exponent} has to satisfy 0 < p < 3") if atomic_smearing is not None and atomic_smearing <= 0: raise ValueError(f"`atomic_smearing` {atomic_smearing} has to be positive") + super().__init__(all_types=all_types, exponent=exponent) self.atomic_smearing = atomic_smearing self.sr_cutoff = sr_cutoff @@ -165,20 +167,42 @@ def forward( def _compute_single_system( self, positions: torch.Tensor, - cell: Union[None, torch.Tensor], charges: torch.Tensor, - neighbor_indices: Union[None, torch.Tensor], - neighbor_shifts: Union[None, torch.Tensor], + cell: torch.Tensor, + neighbor_indices: Optional[torch.Tensor] = None, + neighbor_shifts: Optional[torch.Tensor] = None, ) -> torch.Tensor: - # Check that the realspace cutoff (if provided) is not too large - # This is because the current implementation is not able to return multiple - # periodic images of the same atom as a neighbor - cell_dimensions = torch.linalg.norm(cell, dim=1) - cutoff_max = torch.min(cell_dimensions) / 2 - 1e-6 - if self.sr_cutoff is not None: - if self.sr_cutoff > torch.min(cell_dimensions) / 2: - raise ValueError(f"sr_cutoff {self.sr_cutoff} has to be > {cutoff_max}") + """ + Compute the "electrostatic" potential at the position of all atoms in a + structure. + :param positions: torch.tensor of shape (n_atoms, 3). Contains the Cartesian + coordinates of the atoms. The implementation also works if the positions + are not contained within the unit cell. + :param charges: torch.tensor of shape `(n_atoms, n_channels)`. In the simplest + case, this would be a tensor of shape (n_atoms, 1) where charges[i,0] is the + charge of atom i. More generally, the potential for the same atom positions + is computed for n_channels independent meshes, and one can specify the + "charge" of each atom on each of the meshes independently. For standard LODE + that treats all (atomic) types separately, one example could be: If n_atoms + = 4 and the types are [Na, Cl, Cl, Na], one could set n_channels=2 and use + the one-hot encoding charges = torch.tensor([[1,0],[0,1],[0,1],[1,0]]) for + the charges. This would then separately compute the "Na" potential and "Cl" + potential. Subtracting these from each other, one could recover the more + standard electrostatic potential in which Na and Cl have charges of +1 and + -1, respectively. + :param cell: torch.tensor of shape `(3, 3)`. Describes the unit cell of the + structure, where cell[i] is the i-th basis vector. + :param neighbor_indices: Optional single or list of 2D tensors of shape (2, n), + where n is the number of atoms. The 2 rows correspond to the indices of + the two atoms which are considered neighbors (e.g. within a cutoff distance) + :param neighbor_shifts: Optional single or list of 2D tensors of shape (3, n), + where n is the number of atoms. The 3 rows correspond to the shift indices + for periodic images. + + :returns: torch.tensor of shape `(n_atoms, n_channels)` containing the potential + at the position of each atom for the `n_channels` independent meshes separately. + """ # Set the defaut values of convergence parameters # The total computational cost = cost of SR part + cost of LR part # Bigger smearing increases the cost of the SR part while decreasing the cost @@ -189,12 +213,13 @@ def _compute_single_system( # chosen to reach a convergence on the order of 1e-4 to 1e-5 for the test # structures. if self.sr_cutoff is None: - sr_cutoff = cutoff_max + cell_dimensions = torch.linalg.norm(cell, dim=1) + sr_cutoff = torch.min(cell_dimensions) / 2 - 1e-6 else: sr_cutoff = self.sr_cutoff if self.atomic_smearing is None: - smearing = cutoff_max / 5.0 + smearing = sr_cutoff / 5.0 else: smearing = self.atomic_smearing @@ -203,14 +228,18 @@ def _compute_single_system( else: lr_wavelength = self.lr_wavelength + # Compute short-range (SR) part using a real space sum potential_sr = self._compute_sr( positions=positions, charges=charges, cell=cell, smearing=smearing, sr_cutoff=sr_cutoff, + neighbor_indices=neighbor_indices, + neighbor_shifts=neighbor_shifts, ) + # Compute long-range (LR) part using a Fourier / reciprocal space sum potential_lr = self._compute_lr( positions=positions, charges=charges, @@ -222,52 +251,6 @@ def _compute_single_system( potential_ewald = potential_sr + potential_lr return potential_ewald - def _generate_kvectors(self, ns: torch.Tensor, cell: torch.Tensor) -> torch.Tensor: - """ - For a given unit cell, compute all reciprocal space vectors that are used to - perform sums in the Fourier transformed space. - - Note that this function is different from the function implemented in the - FourierSpaceConvolution class of the same name, since in this case, we are - generating the full grid of k-vectors, rather than the one that is adapted - specifically to be used together with FFT. - - :param ns: torch.tensor of shape ``(3,)`` containing integers - ``ns = [nx, ny, nz]`` contains the number of mesh points in the x-, y- and - z-direction, respectively. - :param cell: torch.tensor of shape ``(3, 3)`` Tensor specifying the real space - unit cell of a structure, where cell[i] is the i-th basis vector - - :return: torch.tensor of shape ``(N, 3)`` Contains all reciprocal space vectors - that will be used during Ewald summation (or related approaches). - ``k_vectors[i]`` contains the i-th vector, where the order has no special - significance. - The total number N of k-vectors is NOT simply nx*ny*nz, and roughly - corresponds to nx*ny*nz/2 due since the vectors +k and -k can be grouped - together during summation. - """ - # Check that the shapes of all inputs are correct - if ns.shape != (3,): - raise ValueError(f"ns of shape {list(ns.shape)} should be of shape (3, )") - - # Define basis vectors of the reciprocal cell - reciprocal_cell = 2 * torch.pi * cell.inverse().T - bx = reciprocal_cell[0] - by = reciprocal_cell[1] - bz = reciprocal_cell[2] - - # Generate all reciprocal space vectors - nxs_1d = ns[0] * torch.fft.fftfreq(ns[0], device=ns.device) - nys_1d = ns[1] * torch.fft.fftfreq(ns[1], device=ns.device) - nzs_1d = ns[2] * torch.fft.fftfreq(ns[2], device=ns.device) # real FFT - nxs, nys, nzs = torch.meshgrid(nxs_1d, nys_1d, nzs_1d, indexing="ij") - nxs = nxs.flatten().reshape((-1, 1)) - nys = nys.flatten().reshape((-1, 1)) - nzs = nzs.flatten().reshape((-1, 1)) - k_vectors = nxs * bx + nys * by + nzs * bz - - return k_vectors - def _compute_lr( self, positions: torch.Tensor, @@ -309,7 +292,8 @@ def _compute_lr( ns = torch.ceil(ns_float).long() # Generate k-vectors and evaluate - kvectors = self._generate_kvectors(ns=ns, cell=cell) + # kvectors = self._generate_kvectors(ns=ns, cell=cell) + kvectors = generate_kvectors_squeezed(ns=ns, cell=cell) knorm_sq = torch.sum(kvectors**2, dim=1) # G(k) is the Fourier transform of the Coulomb potential @@ -361,57 +345,3 @@ def _compute_lr( energy -= charges * self_contrib return energy - - def _compute_sr( - self, - positions: torch.Tensor, - charges: torch.Tensor, - cell: torch.Tensor, - smearing: torch.Tensor, - sr_cutoff: torch.Tensor, - ) -> torch.Tensor: - """ - Compute the short-range part of the Ewald sum in realspace - - :param positions: torch.tensor of shape (n_atoms, 3). Contains the Cartesian - coordinates of the atoms. The implementation also works if the positions - are not contained within the unit cell. - :param charges: torch.tensor of shape `(n_atoms, n_channels)`. In the simplest - case, this would be a tensor of shape (n_atoms, 1) where charges[i,0] is the - charge of atom i. More generally, the potential for the same atom positions - is computed for n_channels independent meshes, and one can specify the - "charge" of each atom on each of the meshes independently. - :param cell: torch.tensor of shape `(3, 3)`. Describes the unit cell of the - structure, where cell[i] is the i-th basis vector. - :param smearing: torch.Tensor smearing paramter determining the splitting - between the SR and LR parts. - :param sr_cutoff: Cutoff radius used for the short-range part of the Ewald sum. - - :returns: torch.tensor of shape `(n_atoms, n_channels)` containing the potential - at the position of each atom for the `n_channels` independent meshes separately. - """ - # Get list of neighbors - struc = Atoms(positions=positions.detach().numpy(), cell=cell, pbc=True) - atom_is, atom_js, shifts = neighbor_list( - "ijS", struc, sr_cutoff.item(), self_interaction=False - ) - - # Compute energy - potential = torch.zeros_like(charges) - for i, j, shift in zip(atom_is, atom_js, shifts): - dist = torch.linalg.norm( - positions[j] - positions[i] + torch.tensor(shift.dot(struc.cell)) - ) - - # If the contribution from all atoms within the cutoff is to be subtracted - # this short-range part will simply use -V_LR as the potential - if self.subtract_interior: - potential_bare = -self.potential.potential_lr_from_dist(dist, smearing) - # In the remaining cases, we simply use the usual V_SR to get the full - # 1/r^p potential when combined with the long-range part implemented in - # reciprocal space - else: - potential_bare = self.potential.potential_sr_from_dist(dist, smearing) - potential[i] += charges[j] * potential_bare - - return potential diff --git a/src/meshlode/calculators/pmepotential.py b/src/meshlode/calculators/pmepotential.py index 25735235..7c6996aa 100644 --- a/src/meshlode/calculators/pmepotential.py +++ b/src/meshlode/calculators/pmepotential.py @@ -6,7 +6,9 @@ from ase import Atoms from ase.neighborlist import neighbor_list -from ..lib.mesh_interpolator import MeshInterpolator +from meshlode.lib.mesh_interpolator import MeshInterpolator + +from ..lib import generate_kvectors_for_mesh from .base import CalculatorBase @@ -52,17 +54,18 @@ def __init__( sr_cutoff: Optional[torch.Tensor] = None, atomic_smearing: Optional[float] = None, mesh_spacing: Optional[float] = None, - interpolation_order: Optional[int] = 4, + interpolation_order: Optional[int] = 3, subtract_self: Optional[bool] = True, subtract_interior: Optional[bool] = False, ): - super().__init__(all_types=all_types, exponent=exponent) - # Check that all provided values are correct + if exponent < 0.0 or exponent > 3.0: + raise ValueError(f"`exponent` p={exponent} has to satisfy 0 < p < 3") if interpolation_order not in [1, 2, 3, 4, 5]: raise ValueError("Only `interpolation_order` from 1 to 5 are allowed") if atomic_smearing is not None and atomic_smearing <= 0: raise ValueError(f"`atomic_smearing` {atomic_smearing} has to be positive") + super().__init__(all_types=all_types, exponent=exponent) self.atomic_smearing = atomic_smearing self.mesh_spacing = mesh_spacing @@ -151,63 +154,13 @@ def forward( neighbor_shifts=neighbor_shifts, ) - def _generate_kvectors(self, ns: torch.Tensor, cell: torch.Tensor) -> torch.Tensor: - """ - For a given unit cell, compute all reciprocal space vectors that are used to - perform sums in the Fourier transformed space. - - :param ns: torch.tensor of shape ``(3,)`` - ``ns = [nx, ny, nz]`` contains the number of mesh points in the x-, y- and - z-direction, respectively. For faster performance during the Fast Fourier - Transform (FFT) it is recommended to use values of nx, ny and nz that are - powers of 2. - :param cell: torch.tensor of shape ``(3, 3)`` Tensor specifying the real space - unit cell of a structure, where cell[i] is the i-th basis vector - - :return: torch.tensor of shape ``(N, 3)`` Contains all reciprocal space vectors - that will be used during Ewald summation (or related approaches). - ``k_vectors[i]`` contains the i-th vector, where the order has no special - significance. - """ - if ns.device != cell.device: - raise ValueError( - f"`ns` and `cell` are not on the same device, got {ns.device} and " - f"{cell.device}." - ) - - if ns.shape != (3,): - raise ValueError(f"ns of shape {list(ns.shape)} should be of shape (3, )") - - if cell.shape != (3, 3): - raise ValueError( - f"cell of shape {list(cell.shape)} should be of shape (3, 3)" - ) - - # Define basis vectors of the reciprocal cell - reciprocal_cell = 2 * torch.pi * cell.inverse().T - bx = reciprocal_cell[0] - by = reciprocal_cell[1] - bz = reciprocal_cell[2] - - # Generate all reciprocal space vectors - nxs_1d = ns[0] * torch.fft.fftfreq(ns[0], device=ns.device) - nys_1d = ns[1] * torch.fft.fftfreq(ns[1], device=ns.device) - nzs_1d = ns[2] * torch.fft.rfftfreq(ns[2], device=ns.device) # real FFT - nxs, nys, nzs = torch.meshgrid(nxs_1d, nys_1d, nzs_1d, indexing="ij") - nxs = nxs.reshape((int(ns[0]), int(ns[1]), len(nzs_1d), 1)) - nys = nys.reshape((int(ns[0]), int(ns[1]), len(nzs_1d), 1)) - nzs = nzs.reshape((int(ns[0]), int(ns[1]), len(nzs_1d), 1)) - k_vectors = nxs * bx + nys * by + nzs * bz - - return k_vectors - def _compute_single_system( self, positions: torch.Tensor, - cell: Union[None, torch.Tensor], charges: torch.Tensor, - neighbor_indices: Union[None, torch.Tensor], - neighbor_shifts: Union[None, torch.Tensor], + cell: torch.Tensor, + neighbor_indices: Optional[torch.Tensor] = None, + neighbor_shifts: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Compute the "electrostatic" potential at the position of all atoms in a @@ -230,6 +183,12 @@ def _compute_single_system( -1, respectively. :param cell: torch.tensor of shape `(3, 3)`. Describes the unit cell of the structure, where cell[i] is the i-th basis vector. + :param neighbor_indices: Optional single or list of 2D tensors of shape (2, n), + where n is the number of atoms. The 2 rows correspond to the indices of + the two atoms which are considered neighbors (e.g. within a cutoff distance) + :param neighbor_shifts: Optional single or list of 2D tensors of shape (3, n), + where n is the number of atoms. The 3 rows correspond to the shift indices + for periodic images. :returns: torch.tensor of shape `(n_atoms, n_channels)` containing the potential at the position of each atom for the `n_channels` independent meshes separately. @@ -251,7 +210,7 @@ def _compute_single_system( sr_cutoff = self.sr_cutoff if self.atomic_smearing is None: - smearing = cutoff_max / 5.0 + smearing = sr_cutoff / 5.0 else: smearing = self.atomic_smearing @@ -330,7 +289,8 @@ def _compute_lr( # Step 2: Perform Fourier space convolution (FSC) to get potential on mesh # Step 2.1: Generate k-vectors and evaluate kernel function - kvectors = self._generate_kvectors(ns=ns, cell=cell) + # kvectors = self._generate_kvectors(ns=ns, cell=cell) + kvectors = generate_kvectors_for_mesh(ns=ns, cell=cell) knorm_sq = torch.sum(kvectors**2, dim=3) # Step 2.2: Evaluate kernel function (careful, tensor shapes are different from @@ -359,64 +319,3 @@ def _compute_lr( interpolated_potential -= charges * self_contrib return interpolated_potential - - def _compute_sr( - self, - positions: torch.Tensor, - charges: torch.Tensor, - cell: torch.Tensor, - smearing: torch.Tensor, - sr_cutoff: torch.Tensor, - neighbor_indices: Optional[torch.Tensor] = None, - neighbor_shifts: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - """ - Compute the short-range part of the Ewald sum in realspace - - :param positions: torch.tensor of shape (n_atoms, 3). Contains the Cartesian - coordinates of the atoms. The implementation also works if the positions - are not contained within the unit cell. - :param charges: torch.tensor of shape `(n_atoms, n_channels)`. In the simplest - case, this would be a tensor of shape (n_atoms, 1) where charges[i,0] is the - charge of atom i. More generally, the potential for the same atom positions - is computed for n_channels independent meshes, and one can specify the - "charge" of each atom on each of the meshes independently. - :param cell: torch.tensor of shape `(3, 3)`. Describes the unit cell of the - structure, where cell[i] is the i-th basis vector. - :param smearing: torch.Tensor smearing paramter determining the splitting - between the SR and LR parts. - :param sr_cutoff: Cutoff radius used for the short-range part of the Ewald sum. - - :returns: torch.tensor of shape `(n_atoms, n_channels)` containing the potential - at the position of each atom for the `n_channels` independent meshes separately. - """ - if neighbor_indices is None or neighbor_shifts is None: - # Get list of neighbors - struc = Atoms(positions=positions.detach().numpy(), cell=cell, pbc=True) - atom_is, atom_js, neighbor_shifts = neighbor_list( - "ijS", struc, sr_cutoff.item(), self_interaction=False - ) - else: - atom_is = neighbor_indices[0] - atom_js = neighbor_indices[1] - - # Compute energy - potential = torch.zeros_like(charges) - for i, j, shift in zip(atom_is, atom_js, neighbor_shifts): - shift = shift.type(cell.dtype) - dist = torch.linalg.norm( - positions[j] - positions[i] + torch.tensor(shift @ cell) - ) - - # If the contribution from all atoms within the cutoff is to be subtracted - # this short-range part will simply use -V_LR as the potential - if self.subtract_interior: - potential_bare = -self.potential.potential_lr_from_dist(dist, smearing) - # In the remaining cases, we simply use the usual V_SR to get the full - # 1/r^p potential when combined with the long-range part implemented in - # reciprocal space - else: - potential_bare = self.potential.potential_sr_from_dist(dist, smearing) - potential[i] += charges[j] * potential_bare - - return potential diff --git a/src/meshlode/lib/potentials.py b/src/meshlode/lib/potentials.py index 83789261..be0dcf95 100644 --- a/src/meshlode/lib/potentials.py +++ b/src/meshlode/lib/potentials.py @@ -121,7 +121,9 @@ def potential_fourier_from_k_sq( smearing parameter corresponds to the "width" of the Gaussian. """ peff = (3 - self.exponent) / 2 - prefac = (math.pi) ** 1.5 / gamma(self.exponent / 2) * (2 * smearing**2) ** peff + prefac = ( + (math.pi) ** 1.5 / gamma(self.exponent / 2) * (2 * smearing**2) ** peff + ) x = 0.5 * smearing**2 * k_sq fourier = prefac * gammaincc(peff, x) / x**peff * gamma(peff) diff --git a/tests/calculators/test_values_periodic.py b/tests/calculators/test_values_periodic.py index 246016ce..d5b96862 100644 --- a/tests/calculators/test_values_periodic.py +++ b/tests/calculators/test_values_periodic.py @@ -8,7 +8,7 @@ # Imports for random structure from ase.io import read -from meshlode import EwaldPotential, MeshEwaldPotential +from meshlode import EwaldPotential, PMEPotential def generate_orthogonal_transformations(): @@ -314,7 +314,7 @@ def test_madelung(crystal_name, scaling_factor, calc_name): rtol = 4e-6 elif calc_name == "pme": sr_cutoff = scaling_factor * torch.tensor(2.0, dtype=dtype) - calc = MeshEwaldPotential(sr_cutoff=sr_cutoff) + calc = PMEPotential(sr_cutoff=sr_cutoff) rtol = 9e-4 # Compute potential and compare against target value using default hypers @@ -374,8 +374,10 @@ def test_wigner(crystal_name, scaling_factor): smeareff *= scaling_factor # Compute potential and compare against reference - EP = EwaldPotential(atomic_smearing=smeareff) - potentials = EP.compute(types, positions, cell, charges) + calc = EwaldPotential(atomic_smearing=smeareff) + potentials = calc.compute( + types, positions=positions, cell=cell, charges=charges + ) energies = potentials * charges energies_ref = -torch.ones_like(energies) * madelung_ref torch.testing.assert_close(energies, energies_ref, atol=0.0, rtol=rtol) @@ -429,10 +431,12 @@ def test_random_structure(sr_cutoff, frame_index, scaling_factor, ortho, calc_na rtol_e = 2e-5 rtol_f = 3.6e-3 elif calc_name == "pme": - calc = MeshEwaldPotential(sr_cutoff=sr_cutoff) + calc = PMEPotential(sr_cutoff=sr_cutoff) rtol_e = 4.5e-3 # 1.5e-3 rtol_f = 2.5e-3 # 6e-3 - potentials = calc.compute(types, positions=positions, cell=cell, charges=charges) + potentials = calc.compute( + types=types, positions=positions, cell=cell, charges=charges + ) # Compute energy, taking into account the double counting of each pair energy = torch.sum(potentials * charges) / 2 diff --git a/tests/lib/test_potentials.py b/tests/lib/test_potentials.py index 99b2da4a..46527b9a 100644 --- a/tests/lib/test_potentials.py +++ b/tests/lib/test_potentials.py @@ -226,6 +226,8 @@ def test_lr_value_at_zero(exponent, smearing): potential_close_to_zero = ipl.potential_lr_from_dist(dist_small, smearing=smearing) # Compare to - exact_value = 1.0 / (2 * smearing**2) ** (exponent / 2) / gamma(exponent / 2 + 1.0) + exact_value = ( + 1.0 / (2 * smearing**2) ** (exponent / 2) / gamma(exponent / 2 + 1.0) + ) relerr = torch.abs(potential_close_to_zero - exact_value) / exact_value assert relerr.item() < 3e-14 From 36e08d9b212ff8e188a199acbc0daab80899f15d Mon Sep 17 00:00:00 2001 From: Philip Loche Date: Sat, 6 Jul 2024 07:15:45 +0200 Subject: [PATCH 26/35] linting --- src/meshlode/calculators/ewaldpotential.py | 4 ---- src/meshlode/calculators/meshpotential.py | 1 - src/meshlode/calculators/pmepotential.py | 7 +------ src/meshlode/lib/potentials.py | 4 +--- src/meshlode/metatensor/pmepotential.py | 1 - tests/calculators/test_values_periodic.py | 3 --- tests/lib/test_potentials.py | 4 +--- 7 files changed, 3 insertions(+), 21 deletions(-) diff --git a/src/meshlode/calculators/ewaldpotential.py b/src/meshlode/calculators/ewaldpotential.py index 730e0a79..8d8acbda 100644 --- a/src/meshlode/calculators/ewaldpotential.py +++ b/src/meshlode/calculators/ewaldpotential.py @@ -2,10 +2,6 @@ import torch -# extra imports for neighbor list -from ase import Atoms -from ase.neighborlist import neighbor_list - from ..lib import generate_kvectors_squeezed from .base import CalculatorBase diff --git a/src/meshlode/calculators/meshpotential.py b/src/meshlode/calculators/meshpotential.py index c028e1e3..2d9bb7c4 100644 --- a/src/meshlode/calculators/meshpotential.py +++ b/src/meshlode/calculators/meshpotential.py @@ -206,4 +206,3 @@ def _compute_single_system( interpolated_potential -= charges * self_contrib return interpolated_potential - diff --git a/src/meshlode/calculators/pmepotential.py b/src/meshlode/calculators/pmepotential.py index 7c6996aa..9d7a8eba 100644 --- a/src/meshlode/calculators/pmepotential.py +++ b/src/meshlode/calculators/pmepotential.py @@ -2,13 +2,8 @@ import torch -# extra imports for neighbor list -from ase import Atoms -from ase.neighborlist import neighbor_list - -from meshlode.lib.mesh_interpolator import MeshInterpolator - from ..lib import generate_kvectors_for_mesh +from ..lib.mesh_interpolator import MeshInterpolator from .base import CalculatorBase diff --git a/src/meshlode/lib/potentials.py b/src/meshlode/lib/potentials.py index be0dcf95..83789261 100644 --- a/src/meshlode/lib/potentials.py +++ b/src/meshlode/lib/potentials.py @@ -121,9 +121,7 @@ def potential_fourier_from_k_sq( smearing parameter corresponds to the "width" of the Gaussian. """ peff = (3 - self.exponent) / 2 - prefac = ( - (math.pi) ** 1.5 / gamma(self.exponent / 2) * (2 * smearing**2) ** peff - ) + prefac = (math.pi) ** 1.5 / gamma(self.exponent / 2) * (2 * smearing**2) ** peff x = 0.5 * smearing**2 * k_sq fourier = prefac * gammaincc(peff, x) / x**peff * gamma(peff) diff --git a/src/meshlode/metatensor/pmepotential.py b/src/meshlode/metatensor/pmepotential.py index 4df274b3..72dd4f6c 100644 --- a/src/meshlode/metatensor/pmepotential.py +++ b/src/meshlode/metatensor/pmepotential.py @@ -212,4 +212,3 @@ def compute(self, systems: Union[List[System], System]) -> TensorMap: labels_keys = Labels(key_names, key_values) return TensorMap(keys=labels_keys, blocks=blocks) - diff --git a/tests/calculators/test_values_periodic.py b/tests/calculators/test_values_periodic.py index d5b96862..2b941201 100644 --- a/tests/calculators/test_values_periodic.py +++ b/tests/calculators/test_values_periodic.py @@ -14,9 +14,6 @@ def generate_orthogonal_transformations(): dtype = torch.float64 - # first rotation matrix: identity - rot_1 = torch.eye(3, dtype=dtype) - # second rotation matrix: rotation by angle phi around z-axis phi = 0.82321 rot_2 = torch.zeros((3, 3), dtype=dtype) diff --git a/tests/lib/test_potentials.py b/tests/lib/test_potentials.py index 46527b9a..99b2da4a 100644 --- a/tests/lib/test_potentials.py +++ b/tests/lib/test_potentials.py @@ -226,8 +226,6 @@ def test_lr_value_at_zero(exponent, smearing): potential_close_to_zero = ipl.potential_lr_from_dist(dist_small, smearing=smearing) # Compare to - exact_value = ( - 1.0 / (2 * smearing**2) ** (exponent / 2) / gamma(exponent / 2 + 1.0) - ) + exact_value = 1.0 / (2 * smearing**2) ** (exponent / 2) / gamma(exponent / 2 + 1.0) relerr = torch.abs(potential_close_to_zero - exact_value) / exact_value assert relerr.item() < 3e-14 From 083a43a06cffbf710e99958896a07df9fb71ef48 Mon Sep 17 00:00:00 2001 From: Philip Loche Date: Mon, 8 Jul 2024 15:28:37 +0200 Subject: [PATCH 27/35] Add metatensor API and remove mehspotential --- .../references/calculators/meshpotential.rst | 6 - docs/src/references/metatensor/index.rst | 3 +- .../references/metatensor/meshpotential.rst | 6 +- examples/madelung.py | 10 +- examples/neighborlist_example.py | 2 +- pyproject.toml | 3 +- src/meshlode/__init__.py | 2 +- src/meshlode/calculators/__init__.py | 2 +- src/meshlode/calculators/base.py | 175 +++++------ src/meshlode/calculators/directpotential.py | 78 ++--- src/meshlode/calculators/ewaldpotential.py | 295 ++++++++++-------- src/meshlode/calculators/meshpotential.py | 208 ------------ src/meshlode/calculators/pmepotential.py | 269 +++++++++------- src/meshlode/metatensor/__init__.py | 4 +- .../{pmepotential.py => calculators.py} | 162 ++++++---- tests/calculators/test_calculator_base.py | 4 +- tests/calculators/test_values_periodic.py | 8 +- 17 files changed, 563 insertions(+), 674 deletions(-) delete mode 100644 docs/src/references/calculators/meshpotential.rst delete mode 100644 src/meshlode/calculators/meshpotential.py rename src/meshlode/metatensor/{pmepotential.py => calculators.py} (59%) diff --git a/docs/src/references/calculators/meshpotential.rst b/docs/src/references/calculators/meshpotential.rst deleted file mode 100644 index 3a8e52ba..00000000 --- a/docs/src/references/calculators/meshpotential.rst +++ /dev/null @@ -1,6 +0,0 @@ -MeshPotential -############# - -.. autoclass:: meshlode.MeshPotential - :members: - :undoc-members: diff --git a/docs/src/references/metatensor/index.rst b/docs/src/references/metatensor/index.rst index ee6d7f83..be0032e3 100644 --- a/docs/src/references/metatensor/index.rst +++ b/docs/src/references/metatensor/index.rst @@ -15,5 +15,6 @@ For a plain :py:class:`torch.Tensor` refer to :ref:`calculators`. .. toctree:: :maxdepth: 1 + :glob: - meshpotential + ./* diff --git a/docs/src/references/metatensor/meshpotential.rst b/docs/src/references/metatensor/meshpotential.rst index c96a3343..b10a00cf 100644 --- a/docs/src/references/metatensor/meshpotential.rst +++ b/docs/src/references/metatensor/meshpotential.rst @@ -1,6 +1,6 @@ -MeshPotential -############# +PMEPotential +############ -.. autoclass:: meshlode.metatensor.MeshPotential +.. autoclass:: meshlode.metatensor.PMEPotential :members: :undoc-members: diff --git a/examples/madelung.py b/examples/madelung.py index fbfc6fa4..c9b92a0a 100644 --- a/examples/madelung.py +++ b/examples/madelung.py @@ -2,8 +2,8 @@ Compute Madelung Constants ========================== In this tutorial we show how to calculate the Madelung constants and total electrostatic -energy of atomic structures using the :py:class:`meshlode.MeshPotential` and -:py:class:`meshlode.metatensor.MeshPotential` calculator. +energy of atomic structures using the :py:class:`meshlode.PMEPotential` and +:py:class:`meshlode.metatensor.PMEPotential` calculator. """ # %% @@ -17,7 +17,7 @@ # %% # Define simple example structure having the CsCl structure and compute the reference -# values. MeshPotential by default outputs the types sorted according to the atomic +# values. PMEPotential by default outputs the types sorted according to the atomic # number. Thus, we input the compound "CsCl" and "ClCs" since Cl and Cs have atomic # numbers 17 and 55, respectively. types = torch.tensor([17, 55]) # Cl and Cs @@ -44,7 +44,7 @@ # ------------------------------ # Compute features using -MP = meshlode.MeshPotential( +MP = meshlode.PMEPotential( atomic_smearing=atomic_smearing, mesh_spacing=mesh_spacing, interpolation_order=interpolation_order, @@ -92,7 +92,7 @@ system = System(types=types, positions=positions, cell=cell) -MP = meshlode.metatensor.MeshPotential( +MP = meshlode.metatensor.PMEPotential( atomic_smearing=atomic_smearing, mesh_spacing=mesh_spacing, interpolation_order=interpolation_order, diff --git a/examples/neighborlist_example.py b/examples/neighborlist_example.py index 99e67f55..6a6d4c7c 100644 --- a/examples/neighborlist_example.py +++ b/examples/neighborlist_example.py @@ -22,7 +22,7 @@ # %% # Define simple example structure having the CsCl structure and compute the reference -# values. MeshPotential by default outputs the types sorted according to the atomic +# values. PMEPotential by default outputs the types sorted according to the atomic # number. Thus, we input the compound "CsCl" and "ClCs" since Cl and Cs have atomic # numbers 17 and 55, respectively. diff --git a/pyproject.toml b/pyproject.toml index 0072a646..9a41ebd7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,13 +35,12 @@ keywords = [ ] dependencies = [ "torch >=1.11", - "ase", + "ase >= 3.23.0", ] dynamic = ["version"] [project.optional-dependencies] examples = [ - "ase", "matplotlib", ] metatensor = [ diff --git a/src/meshlode/__init__.py b/src/meshlode/__init__.py index 96cea434..6d8f0988 100644 --- a/src/meshlode/__init__.py +++ b/src/meshlode/__init__.py @@ -8,5 +8,5 @@ pass -__all__ = ["MeshPotential", "EwaldPotential", "DirectPotential", "PMEPotential"] +__all__ = ["EwaldPotential", "DirectPotential", "PMEPotential"] __version__ = "0.0.0-dev" diff --git a/src/meshlode/calculators/__init__.py b/src/meshlode/calculators/__init__.py index 91dfbb23..c83447c8 100644 --- a/src/meshlode/calculators/__init__.py +++ b/src/meshlode/calculators/__init__.py @@ -2,4 +2,4 @@ from .directpotential import DirectPotential from .pmepotential import PMEPotential -__all__ = ["MeshPotential", "EwaldPotential", "DirectPotential", "PMEPotential"] +__all__ = ["EwaldPotential", "DirectPotential", "PMEPotential"] diff --git a/src/meshlode/calculators/base.py b/src/meshlode/calculators/base.py index a71b4c84..03098e52 100644 --- a/src/meshlode/calculators/base.py +++ b/src/meshlode/calculators/base.py @@ -26,8 +26,92 @@ def _is_subset(subset_candidate: List[int], superset: List[int]) -> bool: class CalculatorBase(torch.nn.Module): + """Base class providing general funtionality.""" + + def __init__(self, exponent): + self.exponent = exponent + self.potential = InversePowerLawPotential(exponent=exponent) + + super().__init__() + + def _compute_sr( + self, + positions: torch.Tensor, + charges: torch.Tensor, + cell: torch.Tensor, + smearing: torch.Tensor, + sr_cutoff: torch.Tensor, + neighbor_indices: Optional[torch.Tensor] = None, + neighbor_shifts: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Compute the short-range part of the Ewald sum in realspace + + :param positions: torch.tensor of shape (n_atoms, 3). Contains the Cartesian + coordinates of the atoms. The implementation also works if the positions + are not contained within the unit cell. + :param charges: torch.tensor of shape `(n_atoms, n_channels)`. In the simplest + case, this would be a tensor of shape (n_atoms, 1) where charges[i,0] is the + charge of atom i. More generally, the potential for the same atom positions + is computed for n_channels independent meshes, and one can specify the + "charge" of each atom on each of the meshes independently. + :param cell: torch.tensor of shape `(3, 3)`. Describes the unit cell of the + structure, where cell[i] is the i-th basis vector. + :param smearing: torch.Tensor smearing paramter determining the splitting + between the SR and LR parts. + :param sr_cutoff: Cutoff radius used for the short-range part of the Ewald sum. + :param neighbor_indices: Optional single or list of 2D tensors of shape (2, n), + where n is the number of atoms. The 2 rows correspond to the indices of + the two atoms which are considered neighbors (e.g. within a cutoff distance) + :param neighbor_shifts: Optional single or list of 2D tensors of shape (3, n), + where n is the number of atoms. The 3 rows correspond to the shift indices + for periodic images. + + :returns: torch.tensor of shape `(n_atoms, n_channels)` containing the potential + at the position of each atom for the `n_channels` independent meshes separately. + """ + if neighbor_indices is None or neighbor_shifts is None: + # Get list of neighbors + struc = Atoms(positions=positions.detach().numpy(), cell=cell, pbc=True) + atom_is, atom_js, neighbor_shifts = neighbor_list( + "ijS", struc, sr_cutoff.item(), self_interaction=False + ) + atom_is = torch.tensor(atom_is) + atom_js = torch.tensor(atom_js) + shifts = torch.tensor(neighbor_shifts, dtype=cell.dtype) # N x 3 + + else: + atom_is = neighbor_indices[0] + atom_js = neighbor_indices[1] + shifts = neighbor_shifts.T + shifts.dtype = cell.dtype + + # Compute energy + potential = torch.zeros_like(charges) + + pos_is = positions[atom_is] + pos_js = positions[atom_js] + dists = torch.linalg.norm(pos_js - pos_is + shifts @ cell, dim=1) + # If the contribution from all atoms within the cutoff is to be subtracted + # this short-range part will simply use -V_LR as the potential + if self.subtract_interior: + potentials_bare = -self.potential.potential_lr_from_dist(dists, smearing) + # In the remaining cases, we simply use the usual V_SR to get the full + # 1/r^p potential when combined with the long-range part implemented in + # reciprocal space + else: + potentials_bare = self.potential.potential_sr_from_dist(dists, smearing) + # potential.index_add_(0, atom_is, charges[atom_js] * potentials_bare) + + for i, j, potential_bare in zip(atom_is, atom_js, potentials_bare): + potential[i.item()] += charges[j.item()] * potential_bare + + return potential + + +class CalculatorBaseTorch(CalculatorBase): """ - Base calculator + Base calculator for the torch interface to MeshLODE. :param all_types: Optional global list of all atomic types that should be considered for the computation. This option might be useful when running the calculation on @@ -38,16 +122,13 @@ class CalculatorBase(torch.nn.Module): """ def __init__(self, all_types: Union[None, List[int]], exponent: float): - super().__init__() + super().__init__(exponent) if all_types is None: self.all_types = None else: self.all_types = _1d_tolist(torch.unique(torch.tensor(all_types))) - self.exponent = exponent - self.potential = InversePowerLawPotential(exponent=exponent) - def _get_requested_types(self, types: List[torch.Tensor]) -> List[int]: """Extract a list of all unique and present types from the list of types.""" all_types = torch.hstack(types) @@ -300,87 +381,3 @@ def _compute_impl( return potentials[0] else: return potentials - - def _compute_single_system( - self, - positions: torch.Tensor, - cell: Union[None, torch.Tensor], - charges: torch.Tensor, - neighbor_indices: Union[None, torch.Tensor], - neighbor_shifts: Union[None, torch.Tensor], - ) -> torch.Tensor: - raise NotImplementedError("only implemented in child classes") - - def _compute_sr( - self, - positions: torch.Tensor, - charges: torch.Tensor, - cell: torch.Tensor, - smearing: torch.Tensor, - sr_cutoff: torch.Tensor, - neighbor_indices: Optional[torch.Tensor] = None, - neighbor_shifts: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - """ - Compute the short-range part of the Ewald sum in realspace - - :param positions: torch.tensor of shape (n_atoms, 3). Contains the Cartesian - coordinates of the atoms. The implementation also works if the positions - are not contained within the unit cell. - :param charges: torch.tensor of shape `(n_atoms, n_channels)`. In the simplest - case, this would be a tensor of shape (n_atoms, 1) where charges[i,0] is the - charge of atom i. More generally, the potential for the same atom positions - is computed for n_channels independent meshes, and one can specify the - "charge" of each atom on each of the meshes independently. - :param cell: torch.tensor of shape `(3, 3)`. Describes the unit cell of the - structure, where cell[i] is the i-th basis vector. - :param smearing: torch.Tensor smearing paramter determining the splitting - between the SR and LR parts. - :param sr_cutoff: Cutoff radius used for the short-range part of the Ewald sum. - :param neighbor_indices: Optional single or list of 2D tensors of shape (2, n), - where n is the number of atoms. The 2 rows correspond to the indices of - the two atoms which are considered neighbors (e.g. within a cutoff distance) - :param neighbor_shifts: Optional single or list of 2D tensors of shape (3, n), - where n is the number of atoms. The 3 rows correspond to the shift indices - for periodic images. - - :returns: torch.tensor of shape `(n_atoms, n_channels)` containing the potential - at the position of each atom for the `n_channels` independent meshes separately. - """ - if neighbor_indices is None or neighbor_shifts is None: - # Get list of neighbors - struc = Atoms(positions=positions.detach().numpy(), cell=cell, pbc=True) - atom_is, atom_js, neighbor_shifts = neighbor_list( - "ijS", struc, sr_cutoff.item(), self_interaction=False - ) - atom_is = torch.tensor(atom_is) - atom_js = torch.tensor(atom_js) - shifts = torch.tensor(neighbor_shifts, dtype=cell.dtype) # N x 3 - - else: - atom_is = neighbor_indices[0] - atom_js = neighbor_indices[1] - shifts = neighbor_shifts.T - shifts.dtype = cell.dtype - - # Compute energy - potential = torch.zeros_like(charges) - - pos_is = positions[atom_is] - pos_js = positions[atom_js] - dists = torch.linalg.norm(pos_js - pos_is + shifts @ cell, dim=1) - # If the contribution from all atoms within the cutoff is to be subtracted - # this short-range part will simply use -V_LR as the potential - if self.subtract_interior: - potentials_bare = -self.potential.potential_lr_from_dist(dists, smearing) - # In the remaining cases, we simply use the usual V_SR to get the full - # 1/r^p potential when combined with the long-range part implemented in - # reciprocal space - else: - potentials_bare = self.potential.potential_sr_from_dist(dists, smearing) - # potential.index_add_(0, atom_is, charges[atom_js] * potentials_bare) - - for i, j, potential_bare in zip(atom_is, atom_js, potentials_bare): - potential[i.item()] += charges[j.item()] * potential_bare - - return potential diff --git a/src/meshlode/calculators/directpotential.py b/src/meshlode/calculators/directpotential.py index 8194cf6c..a0c79555 100644 --- a/src/meshlode/calculators/directpotential.py +++ b/src/meshlode/calculators/directpotential.py @@ -2,10 +2,48 @@ import torch -from .base import CalculatorBase +from .base import CalculatorBaseTorch -class DirectPotential(CalculatorBase): +class _DirectPotentialImpl: + def __init__(self, exponent): + self.exponent = exponent + + def _compute_single_system( + self, + positions: torch.Tensor, + cell: Union[None, torch.Tensor], + charges: torch.Tensor, + neighbor_indices: Union[None, torch.Tensor], + neighbor_shifts: Union[None, torch.Tensor], + ) -> torch.Tensor: + # Compute matrix containing the squared distances from the Gram matrix + # The squared distance and the inner product between two vectors r_i and r_j are + # related by: d_ij^2 = |r_i - r_j|^2 = r_i^2 + r_j^2 - 2*r_i*r_j + num_atoms = len(positions) + diagonal_indices = torch.arange(num_atoms) + gram_matrix = positions @ positions.T + squared_norms = gram_matrix[diagonal_indices, diagonal_indices].reshape(-1, 1) + ones = torch.ones((1, len(positions)), dtype=positions.dtype) + squared_norms_matrix = torch.matmul(squared_norms, ones) + distances_sq = squared_norms_matrix + squared_norms_matrix.T - 2 * gram_matrix + + # Add terms to diagonal in order to avoid division by zero + # Since these components in the target tensor need to be set to zero, we add + # a huge number such that after taking the inverse (since we evaluate 1/r^p), + # the components will effectively be set to zero. + # This is not the most elegant solution, but I am doing this since the more + # obvious alternative of setting the same components to zero after the division + # had issues with autograd. I would appreciate any better alternatives. + distances_sq[diagonal_indices, diagonal_indices] += 1e50 + + # Compute potential + potentials_by_pair = distances_sq.pow(-self.exponent / 2.0) + + return torch.matmul(potentials_by_pair, charges) + + +class DirectPotential(CalculatorBaseTorch, _DirectPotentialImpl): r"""Specie-wise long-range potential using a direct summation over all atoms. Scaling as :math:`\mathcal{O}(N^2)` with respect to the number of particles @@ -23,7 +61,8 @@ class DirectPotential(CalculatorBase): """ def __init__(self, all_types: Optional[List[int]] = None, exponent: float = 1.0): - super().__init__(all_types=all_types, exponent=exponent) + _DirectPotentialImpl.__init__(self, exponent=exponent) + CalculatorBaseTorch.__init__(self, all_types=all_types, exponent=exponent) def compute( self, @@ -81,36 +120,3 @@ def forward( positions=positions, charges=charges, ) - - def _compute_single_system( - self, - positions: torch.Tensor, - cell: Union[None, torch.Tensor], - charges: torch.Tensor, - neighbor_indices: Union[None, torch.Tensor], - neighbor_shifts: Union[None, torch.Tensor], - ) -> torch.Tensor: - # Compute matrix containing the squared distances from the Gram matrix - # The squared distance and the inner product between two vectors r_i and r_j are - # related by: d_ij^2 = |r_i - r_j|^2 = r_i^2 + r_j^2 - 2*r_i*r_j - num_atoms = len(positions) - diagonal_indices = torch.arange(num_atoms) - gram_matrix = positions @ positions.T - squared_norms = gram_matrix[diagonal_indices, diagonal_indices].reshape(-1, 1) - ones = torch.ones((1, len(positions)), dtype=positions.dtype) - squared_norms_matrix = torch.matmul(squared_norms, ones) - distances_sq = squared_norms_matrix + squared_norms_matrix.T - 2 * gram_matrix - - # Add terms to diagonal in order to avoid division by zero - # Since these components in the target tensor need to be set to zero, we add - # a huge number such that after taking the inverse (since we evaluate 1/r^p), - # the components will effectively be set to zero. - # This is not the most elegant solution, but I am doing this since the more - # obvious alternative of setting the same components to zero after the division - # had issues with autograd. I would appreciate any better alternatives. - distances_sq[diagonal_indices, diagonal_indices] += 1e50 - - # Compute potential - potentials_by_pair = distances_sq.pow(-self.exponent / 2.0) - - return torch.matmul(potentials_by_pair, charges) diff --git a/src/meshlode/calculators/ewaldpotential.py b/src/meshlode/calculators/ewaldpotential.py index 8d8acbda..a52d3315 100644 --- a/src/meshlode/calculators/ewaldpotential.py +++ b/src/meshlode/calculators/ewaldpotential.py @@ -3,76 +3,23 @@ import torch from ..lib import generate_kvectors_squeezed -from .base import CalculatorBase +from .base import CalculatorBaseTorch -class EwaldPotential(CalculatorBase): - r"""Specie-wise long-range potential computed using the Ewald sum. - - Scaling as :math:`\mathcal{O}(N^2)` with respect to the number of particles - :math:`N`. - - :param all_types: Optional global list of all atomic types that should be considered - for the computation. This option might be useful when running the calculation on - subset of a whole dataset and it required to keep the shape of the output - consistent. If this is not set the possible atomic types will be determined when - calling the :meth:`compute()`. - :param exponent: the exponent "p" in 1/r^p potentials - :param sr_cutoff: Cutoff radius used for the short-range part of the Ewald sum. If - not set to a global value, it will be set to be half of the shortest lattice - vector defining the cell (separately for each structure). - :param atomic_smearing: Width of the atom-centered Gaussian used to split the - Coulomb potential into the short- and long-range parts. If not set to a global - value, it will be set to 1/5 times the sr_cutoff value (separately for each - structure) to ensure convergence of the short-range part to a relative precision - of 1e-5. - :param lr_wavelength: Spatial resolution used for the long-range (reciprocal space) - part of the Ewald sum. More conretely, all Fourier space vectors with a - wavelength >= this value will be kept. If not set to a global value, it will be - set to half the atomic_smearing parameter to ensure convergence of the - long-range part to a relative precision of 1e-5. - :param subtract_self: If set to :py:obj:`True`, subtract from the features of an - atom the contributions to the potential arising from that atom itself (but not - the periodic images). - :param subtract_interior: If set to :py:obj:`True`, subtract from the features of an - atom the contributions to the potential arising from all atoms within the cutoff - Note that if set to true, the self contribution (see previous) is also - subtracted by default. - - Example - ------- - >>> import torch - >>> from meshlode import EwaldPotential - - Define simple example structure having the CsCl (Cesium Chloride) structure - - >>> types = torch.tensor([55, 17]) # Cs and Cl - >>> positions = torch.tensor([[0, 0, 0], [0.5, 0.5, 0.5]]) - >>> cell = torch.eye(3) - - Compute features - - >>> EP = EwaldPotential() - >>> EP.compute(types=types, positions=positions, cell=cell) - tensor([[-0.7391, -2.7745], - [-2.7745, -0.7391]]) - """ - +class _EwaldPotentialImpl: def __init__( self, - all_types: Optional[List[int]] = None, - exponent: float = 1.0, - sr_cutoff: Optional[torch.Tensor] = None, - atomic_smearing: Optional[float] = None, - lr_wavelength: Optional[float] = None, - subtract_self: Optional[bool] = True, - subtract_interior: Optional[bool] = False, + exponent: float, + sr_cutoff: Union[None, torch.Tensor], + atomic_smearing: Union[None, float], + lr_wavelength: Union[None, float], + subtract_self: Union[None, bool], + subtract_interior: Union[None, bool], ): if exponent < 0.0 or exponent > 3.0: raise ValueError(f"`exponent` p={exponent} has to satisfy 0 < p < 3") if atomic_smearing is not None and atomic_smearing <= 0: raise ValueError(f"`atomic_smearing` {atomic_smearing} has to be positive") - super().__init__(all_types=all_types, exponent=exponent) self.atomic_smearing = atomic_smearing self.sr_cutoff = sr_cutoff @@ -84,82 +31,6 @@ def __init__( self.subtract_self = subtract_self self.subtract_interior = subtract_interior - def compute( - self, - types: Union[List[torch.Tensor], torch.Tensor], - positions: Union[List[torch.Tensor], torch.Tensor], - cell: Union[List[torch.Tensor], torch.Tensor], - charges: Optional[Union[List[torch.Tensor], torch.Tensor]] = None, - neighbor_indices: Union[List[torch.Tensor], torch.Tensor] = None, - neighbor_shifts: Union[List[torch.Tensor], torch.Tensor] = None, - ) -> Union[torch.Tensor, List[torch.Tensor]]: - """Compute potential for all provided "systems" stacked inside list. - - The computation is performed on the same ``device`` as ``systems`` is stored on. - The ``dtype`` of the output tensors will be the same as the input. - - :param types: single or list of 1D tensor of integer representing the - particles identity. For atoms, this is typically their atomic numbers. - :param positions: single or 2D tensor of shape (len(types), 3) containing the - Cartesian positions of all particles in the system. - :param cell: single or 2D tensor of shape (3, 3), describing the bounding - box/unit cell of the system. Each row should be one of the bounding box - vector; and columns should contain the x, y, and z components of these - vectors (i.e. the cell should be given in row-major order). - :param charges: Optional single or list of 2D tensor of shape (len(types), n), - :param neighbor_indices: Optional single or list of 2D tensors of shape (2, n), - where n is the number of atoms. The 2 rows correspond to the indices of - the two atoms which are considered neighbors (e.g. within a cutoff distance) - :param neighbor_shifts: Optional single or list of 2D tensors of shape (3, n), - where n is the number of atoms. The 3 rows correspond to the shift indices - for periodic images. - - :return: List of torch Tensors containing the potentials for all frames and all - atoms. Each tensor in the list is of shape (n_atoms, n_types), where - n_types is the number of types in all systems combined. If the input was - a single system only a single torch tensor with the potentials is returned. - - IMPORTANT: If multiple types are present, the different "types-channels" - are ordered according to atomic number. For example, if a structure contains - a water molecule with atoms 0, 1, 2 being of types O, H, H, then for this - system, the feature tensor will be of shape (3, 2) = (``n_atoms``, - ``n_types``), where ``features[0, 0]`` is the potential at the position of - the Oxygen atom (atom 0, first index) generated by the HYDROGEN(!) atoms, - while ``features[0,1]`` is the potential at the position of the Oxygen atom - generated by the Oxygen atom(s). - """ - - return self._compute_impl( - types=types, - positions=positions, - cell=cell, - charges=charges, - neighbor_indices=neighbor_indices, - neighbor_shifts=neighbor_shifts, - ) - - # This function is kept to keep MeshLODE compatible with the broader pytorch - # infrastructure, which require a "forward" function. We name this function - # "compute" instead, for compatibility with other COSMO software. - def forward( - self, - types: Union[List[torch.Tensor], torch.Tensor], - positions: Union[List[torch.Tensor], torch.Tensor], - cell: Union[List[torch.Tensor], torch.Tensor], - charges: Optional[Union[List[torch.Tensor], torch.Tensor]] = None, - neighbor_indices: Union[List[torch.Tensor], torch.Tensor] = None, - neighbor_shifts: Union[List[torch.Tensor], torch.Tensor] = None, - ) -> Union[torch.Tensor, List[torch.Tensor]]: - """Forward just calls :py:meth:`compute`.""" - return self.compute( - types=types, - positions=positions, - cell=cell, - charges=charges, - neighbor_indices=neighbor_indices, - neighbor_shifts=neighbor_shifts, - ) - def _compute_single_system( self, positions: torch.Tensor, @@ -341,3 +212,153 @@ def _compute_lr( energy -= charges * self_contrib return energy + + +class EwaldPotential(CalculatorBaseTorch, _EwaldPotentialImpl): + r"""Specie-wise long-range potential computed using the Ewald sum. + + Scaling as :math:`\mathcal{O}(N^2)` with respect to the number of particles + :math:`N`. + + :param all_types: Optional global list of all atomic types that should be considered + for the computation. This option might be useful when running the calculation on + subset of a whole dataset and it required to keep the shape of the output + consistent. If this is not set the possible atomic types will be determined when + calling the :meth:`compute()`. + :param exponent: the exponent "p" in 1/r^p potentials + :param sr_cutoff: Cutoff radius used for the short-range part of the Ewald sum. If + not set to a global value, it will be set to be half of the shortest lattice + vector defining the cell (separately for each structure). + :param atomic_smearing: Width of the atom-centered Gaussian used to split the + Coulomb potential into the short- and long-range parts. If not set to a global + value, it will be set to 1/5 times the sr_cutoff value (separately for each + structure) to ensure convergence of the short-range part to a relative precision + of 1e-5. + :param lr_wavelength: Spatial resolution used for the long-range (reciprocal space) + part of the Ewald sum. More conretely, all Fourier space vectors with a + wavelength >= this value will be kept. If not set to a global value, it will be + set to half the atomic_smearing parameter to ensure convergence of the + long-range part to a relative precision of 1e-5. + :param subtract_self: If set to :py:obj:`True`, subtract from the features of an + atom the contributions to the potential arising from that atom itself (but not + the periodic images). + :param subtract_interior: If set to :py:obj:`True`, subtract from the features of an + atom the contributions to the potential arising from all atoms within the cutoff + Note that if set to true, the self contribution (see previous) is also + subtracted by default. + + Example + ------- + >>> import torch + >>> from meshlode import EwaldPotential + + Define simple example structure having the CsCl (Cesium Chloride) structure + + >>> types = torch.tensor([55, 17]) # Cs and Cl + >>> positions = torch.tensor([[0, 0, 0], [0.5, 0.5, 0.5]]) + >>> cell = torch.eye(3) + + Compute features + + >>> EP = EwaldPotential() + >>> EP.compute(types=types, positions=positions, cell=cell) + tensor([[-0.7391, -2.7745], + [-2.7745, -0.7391]]) + """ + + def __init__( + self, + all_types: Optional[List[int]] = None, + exponent: float = 1.0, + sr_cutoff: Optional[torch.Tensor] = None, + atomic_smearing: Optional[float] = None, + lr_wavelength: Optional[float] = None, + subtract_self: Optional[bool] = True, + subtract_interior: Optional[bool] = False, + ): + _EwaldPotentialImpl.__init__( + self, + exponent=exponent, + sr_cutoff=sr_cutoff, + atomic_smearing=atomic_smearing, + lr_wavelength=lr_wavelength, + subtract_self=subtract_self, + subtract_interior=subtract_interior, + ) + CalculatorBaseTorch.__init__(self, all_types=all_types, exponent=exponent) + + def compute( + self, + types: Union[List[torch.Tensor], torch.Tensor], + positions: Union[List[torch.Tensor], torch.Tensor], + cell: Union[List[torch.Tensor], torch.Tensor], + charges: Optional[Union[List[torch.Tensor], torch.Tensor]] = None, + neighbor_indices: Union[List[torch.Tensor], torch.Tensor] = None, + neighbor_shifts: Union[List[torch.Tensor], torch.Tensor] = None, + ) -> Union[torch.Tensor, List[torch.Tensor]]: + """Compute potential for all provided "systems" stacked inside list. + + The computation is performed on the same ``device`` as ``systems`` is stored on. + The ``dtype`` of the output tensors will be the same as the input. + + :param types: single or list of 1D tensor of integer representing the + particles identity. For atoms, this is typically their atomic numbers. + :param positions: single or 2D tensor of shape (len(types), 3) containing the + Cartesian positions of all particles in the system. + :param cell: single or 2D tensor of shape (3, 3), describing the bounding + box/unit cell of the system. Each row should be one of the bounding box + vector; and columns should contain the x, y, and z components of these + vectors (i.e. the cell should be given in row-major order). + :param charges: Optional single or list of 2D tensor of shape (len(types), n), + :param neighbor_indices: Optional single or list of 2D tensors of shape (2, n), + where n is the number of atoms. The 2 rows correspond to the indices of + the two atoms which are considered neighbors (e.g. within a cutoff distance) + :param neighbor_shifts: Optional single or list of 2D tensors of shape (3, n), + where n is the number of atoms. The 3 rows correspond to the shift indices + for periodic images. + + :return: List of torch Tensors containing the potentials for all frames and all + atoms. Each tensor in the list is of shape (n_atoms, n_types), where + n_types is the number of types in all systems combined. If the input was + a single system only a single torch tensor with the potentials is returned. + + IMPORTANT: If multiple types are present, the different "types-channels" + are ordered according to atomic number. For example, if a structure contains + a water molecule with atoms 0, 1, 2 being of types O, H, H, then for this + system, the feature tensor will be of shape (3, 2) = (``n_atoms``, + ``n_types``), where ``features[0, 0]`` is the potential at the position of + the Oxygen atom (atom 0, first index) generated by the HYDROGEN(!) atoms, + while ``features[0,1]`` is the potential at the position of the Oxygen atom + generated by the Oxygen atom(s). + """ + + return self._compute_impl( + types=types, + positions=positions, + cell=cell, + charges=charges, + neighbor_indices=neighbor_indices, + neighbor_shifts=neighbor_shifts, + ) + + # This function is kept to keep MeshLODE compatible with the broader pytorch + # infrastructure, which require a "forward" function. We name this function + # "compute" instead, for compatibility with other COSMO software. + def forward( + self, + types: Union[List[torch.Tensor], torch.Tensor], + positions: Union[List[torch.Tensor], torch.Tensor], + cell: Union[List[torch.Tensor], torch.Tensor], + charges: Optional[Union[List[torch.Tensor], torch.Tensor]] = None, + neighbor_indices: Union[List[torch.Tensor], torch.Tensor] = None, + neighbor_shifts: Union[List[torch.Tensor], torch.Tensor] = None, + ) -> Union[torch.Tensor, List[torch.Tensor]]: + """Forward just calls :py:meth:`compute`.""" + return self.compute( + types=types, + positions=positions, + cell=cell, + charges=charges, + neighbor_indices=neighbor_indices, + neighbor_shifts=neighbor_shifts, + ) diff --git a/src/meshlode/calculators/meshpotential.py b/src/meshlode/calculators/meshpotential.py deleted file mode 100644 index 2d9bb7c4..00000000 --- a/src/meshlode/calculators/meshpotential.py +++ /dev/null @@ -1,208 +0,0 @@ -from typing import List, Optional, Union - -import torch - -from ..lib.fourier_convolution import FourierSpaceConvolution -from ..lib.mesh_interpolator import MeshInterpolator -from .base import CalculatorBase - - -class MeshPotential(CalculatorBase): - r"""Specie-wise long-range potential, computed on a grid. - - Method scaling as :math:`\mathcal{O}(NlogN)` with respect to the number of particles - :math:`N`. This class does not perform a usual Ewald style splitting into a short - and long range contribution but calculates the full contribution to the potential on - a grid. - - For a Particle Mesh Ewald (PME) use :py:class:`meshlode.PMEPotential`. - - :param atomic_smearing: Width of the atom-centered Gaussian used to create the - atomic density. - :param all_types: Optional global list of all atomic types that should be considered - for the computation. This option might be useful when running the calculation on - subset of a whole dataset and it required to keep the shape of the output - consistent. If this is not set the possible atomic types will be determined when - calling the :meth:`compute()`. - :param exponent: the exponent "p" in 1/r^p potentials - :param mesh_spacing: Value that determines the umber of Fourier-space grid points - that will be used along each axis. If set to None, it will automatically be set - to half of ``atomic_smearing``. - :param interpolation_order: Interpolation order for mapping onto the grid, where an - interpolation order of p corresponds to interpolation by a polynomial of degree - ``p - 1`` (e.g. ``p = 4`` for cubic interpolation). - :param subtract_self: If set to :py:obj:`True`, subtract from the features of an - atom the contributions to the potential arising from that atom itself (but not - the periodic images). - - Example - ------- - >>> import torch - >>> from meshlode import MeshPotential - - Define simple example structure having the CsCl (Cesium Chloride) structure - - >>> types = torch.tensor([55, 17]) # Cs and Cl - >>> positions = torch.tensor([[0, 0, 0], [0.5, 0.5, 0.5]]) - >>> cell = torch.eye(3) - - Compute features - - >>> MP = MeshPotential(atomic_smearing=0.2, mesh_spacing=0.1, interpolation_order=4) - >>> MP.compute(types=types, positions=positions, cell=cell) - tensor([[-0.5467, 1.3755], - [ 1.3755, -0.5467]]) - """ - - def __init__( - self, - atomic_smearing: float, - all_types: Optional[List[int]] = None, - exponent: float = 1.0, - mesh_spacing: Optional[float] = None, - interpolation_order: Optional[int] = 4, - subtract_self: Optional[bool] = False, - ): - super().__init__(all_types=all_types, exponent=exponent) - - # Check that all provided values are correct - if interpolation_order not in [1, 2, 3, 4, 5]: - raise ValueError("Only `interpolation_order` from 1 to 5 are allowed") - - # If no explicit mesh_spacing is given, set it such that it can resolve - # the smeared potentials. - if atomic_smearing <= 0: - raise ValueError(f"`atomic_smearing` {atomic_smearing} has to be positive") - - self.atomic_smearing = atomic_smearing - self.mesh_spacing = mesh_spacing - self.interpolation_order = interpolation_order - self.subtract_self = subtract_self - - # Initilize auxiliary objects - self.fourier_space_convolution = FourierSpaceConvolution() - - def compute( - self, - types: Union[List[torch.Tensor], torch.Tensor], - positions: Union[List[torch.Tensor], torch.Tensor], - cell: Union[List[torch.Tensor], torch.Tensor], - charges: Optional[Union[List[torch.Tensor], torch.Tensor]] = None, - ) -> Union[torch.Tensor, List[torch.Tensor]]: - """Compute potential for all provided "systems" stacked inside list. - - The computation is performed on the same ``device`` as ``systems`` is stored on. - The ``dtype`` of the output tensors will be the same as the input. - - :param types: single or list of 1D tensor of integer representing the - particles identity. For atoms, this is typically their atomic numbers. - :param positions: single or 2D tensor of shape (len(types), 3) containing the - Cartesian positions of all particles in the system. - :param cell: single or 2D tensor of shape (3, 3), describing the bounding - box/unit cell of the system. Each row should be one of the bounding box - vector; and columns should contain the x, y, and z components of these - vectors (i.e. the cell should be given in row-major order). - :param charges: Optional single or list of 2D tensor of shape (len(types), n), - :param neighbor_indices: Optional single or list of 2D tensors of shape (2, n), - where n is the number of atoms. The 2 rows correspond to the indices of - the two atoms which are considered neighbors (e.g. within a cutoff distance) - :param neighbor_shifts: Optional single or list of 2D tensors of shape (3, n), - where n is the number of atoms. The 3 rows correspond to the shift indices - for periodic images. - - :return: List of torch Tensors containing the potentials for all frames and all - atoms. Each tensor in the list is of shape (n_atoms, n_types), where - n_types is the number of types in all systems combined. If the input was - a single system only a single torch tensor with the potentials is returned. - - IMPORTANT: If multiple types are present, the different "types-channels" - are ordered according to atomic number. For example, if a structure contains - a water molecule with atoms 0, 1, 2 being of types O, H, H, then for this - system, the feature tensor will be of shape (3, 2) = (``n_atoms``, - ``n_types``), where ``features[0, 0]`` is the potential at the position of - the Oxygen atom (atom 0, first index) generated by the HYDROGEN(!) atoms, - while ``features[0,1]`` is the potential at the position of the Oxygen atom - generated by the Oxygen atom(s). - """ - - return self._compute_impl( - types=types, - positions=positions, - cell=cell, - charges=charges, - neighbor_indices=None, - neighbor_shifts=None, - ) - - # This function is kept to keep MeshLODE compatible with the broader pytorch - # infrastructure, which require a "forward" function. We name this function - # "compute" instead, for compatibility with other COSMO software. - def forward( - self, - types: Union[List[torch.Tensor], torch.Tensor], - positions: Union[List[torch.Tensor], torch.Tensor], - cell: Union[List[torch.Tensor], torch.Tensor], - charges: Optional[Union[List[torch.Tensor], torch.Tensor]] = None, - ) -> Union[torch.Tensor, List[torch.Tensor]]: - """Forward just calls :py:meth:`compute`.""" - return self.compute( - types=types, - positions=positions, - cell=cell, - charges=charges, - ) - - def _compute_single_system( - self, - positions: torch.Tensor, - cell: Union[None, torch.Tensor], - charges: torch.Tensor, - neighbor_indices: Union[None, torch.Tensor], - neighbor_shifts: Union[None, torch.Tensor], - ) -> torch.Tensor: - - if self.mesh_spacing is None: - mesh_spacing = self.atomic_smearing / 2 - else: - mesh_spacing = self.mesh_spacing - - # Initializations - k_cutoff = 2 * torch.pi / mesh_spacing - - # Compute number of times each basis vector of the - # reciprocal space can be scaled until the cutoff - # is reached - basis_norms = torch.linalg.norm(cell, dim=1) - ns_approx = k_cutoff * basis_norms / 2 / torch.pi - ns_actual_approx = 2 * ns_approx + 1 # actual number of mesh points - ns = 2 ** torch.ceil(torch.log2(ns_actual_approx)).long() # [nx, ny, nz] - - # Step 1: Smear particles onto mesh - MI = MeshInterpolator(cell, ns, interpolation_order=self.interpolation_order) - MI.compute_interpolation_weights(positions) - rho_mesh = MI.points_to_mesh(particle_weights=charges) - - # Step 2: Perform Fourier space convolution (FSC) - potential_mesh = self.fourier_space_convolution.compute( - mesh_values=rho_mesh, - cell=cell, - potential_exponent=1, - atomic_smearing=self.atomic_smearing, - ) - - # Step 3: Back interpolation - interpolated_potential = MI.mesh_to_points(potential_mesh) - - # Remove self contribution - if self.subtract_self: - self_contrib = ( - torch.sqrt( - torch.tensor( - 2.0 / torch.pi, dtype=positions.dtype, device=positions.device - ), - ) - / self.atomic_smearing - ) - interpolated_potential -= charges * self_contrib - - return interpolated_potential diff --git a/src/meshlode/calculators/pmepotential.py b/src/meshlode/calculators/pmepotential.py index 9d7a8eba..1c3d11ea 100644 --- a/src/meshlode/calculators/pmepotential.py +++ b/src/meshlode/calculators/pmepotential.py @@ -4,54 +4,19 @@ from ..lib import generate_kvectors_for_mesh from ..lib.mesh_interpolator import MeshInterpolator -from .base import CalculatorBase +from .base import CalculatorBaseTorch -class PMEPotential(CalculatorBase): - r"""Specie-wise long-range potential using a particle mesh-based Ewald (PME). - - Scaling as :math:`\mathcal{O}(NlogN)` with respect to the number of particles - :math:`N` used as a reference to test faster implementations. - - :param all_types: Optional global list of all atomic types that should be considered - for the computation. This option might be useful when running the calculation on - subset of a whole dataset and it required to keep the shape of the output - consistent. If this is not set the possible atomic types will be determined when - calling the :meth:`compute()`. - :param exponent: the exponent "p" in 1/r^p potentials - :param sr_cutoff: Cutoff radius used for the short-range part of the Ewald sum. If - not set to a global value, it will be set to be half of the shortest lattice - vector defining the cell (separately for each structure). - :param atomic_smearing: Width of the atom-centered Gaussian used to split the - Coulomb potential into the short- and long-range parts. If not set to a global - value, it will be set to 1/5 times the sr_cutoff value (separately for each - structure) to ensure convergence of the short-range part to a relative precision - of 1e-5. - :param mesh_spacing: Value that determines the umber of Fourier-space grid points - that will be used along each axis. If set to None, it will automatically be set - to half of ``atomic_smearing``. - :param interpolation_order: Interpolation order for mapping onto the grid, where an - interpolation order of p corresponds to interpolation by a polynomial of degree - ``p - 1`` (e.g. ``p = 4`` for cubic interpolation). - :param subtract_self: If set to :py:obj:`True`, subtract from the features of an - atom the contributions to the potential arising from that atom itself (but not - the periodic images). - :param subtract_interior: If set to :py:obj:`True`, subtract from the features of an - atom the contributions to the potential arising from all atoms within the cutoff - Note that if set to true, the self contribution (see previous) is also - subtracted by default. - """ - +class _PMEPotentialImpl: def __init__( self, - all_types: Optional[List[int]] = None, - exponent: float = 1.0, - sr_cutoff: Optional[torch.Tensor] = None, - atomic_smearing: Optional[float] = None, - mesh_spacing: Optional[float] = None, - interpolation_order: Optional[int] = 3, - subtract_self: Optional[bool] = True, - subtract_interior: Optional[bool] = False, + exponent: float, + sr_cutoff: Union[None, torch.Tensor], + atomic_smearing: Union[None, float], + mesh_spacing: Union[None, float], + interpolation_order: Union[None, int], + subtract_self: Union[None, bool], + subtract_interior: Union[None, bool], ): # Check that all provided values are correct if exponent < 0.0 or exponent > 3.0: @@ -60,7 +25,6 @@ def __init__( raise ValueError("Only `interpolation_order` from 1 to 5 are allowed") if atomic_smearing is not None and atomic_smearing <= 0: raise ValueError(f"`atomic_smearing` {atomic_smearing} has to be positive") - super().__init__(all_types=all_types, exponent=exponent) self.atomic_smearing = atomic_smearing self.mesh_spacing = mesh_spacing @@ -73,81 +37,16 @@ def __init__( self.subtract_self = subtract_self self.subtract_interior = subtract_interior - def compute( - self, - types: Union[List[torch.Tensor], torch.Tensor], - positions: Union[List[torch.Tensor], torch.Tensor], - cell: Union[List[torch.Tensor], torch.Tensor], - charges: Optional[Union[List[torch.Tensor], torch.Tensor]] = None, - neighbor_indices: Union[List[torch.Tensor], torch.Tensor] = None, - neighbor_shifts: Union[List[torch.Tensor], torch.Tensor] = None, - ) -> Union[torch.Tensor, List[torch.Tensor]]: - """Compute potential for all provided "systems" stacked inside list. - - The computation is performed on the same ``device`` as ``systems`` is stored on. - The ``dtype`` of the output tensors will be the same as the input. - - :param types: single or list of 1D tensor of integer representing the - particles identity. For atoms, this is typically their atomic numbers. - :param positions: single or 2D tensor of shape (len(types), 3) containing the - Cartesian positions of all particles in the system. - :param cell: single or 2D tensor of shape (3, 3), describing the bounding - box/unit cell of the system. Each row should be one of the bounding box - vector; and columns should contain the x, y, and z components of these - vectors (i.e. the cell should be given in row-major order). - :param charges: Optional single or list of 2D tensor of shape (len(types), n), - :param neighbor_indices: Optional single or list of 2D tensors of shape (2, n), - where n is the number of atoms. The 2 rows correspond to the indices of - the two atoms which are considered neighbors (e.g. within a cutoff distance) - :param neighbor_shifts: Optional single or list of 2D tensors of shape (3, n), - where n is the number of atoms. The 3 rows correspond to the shift indices - for periodic images. - - :return: List of torch Tensors containing the potentials for all frames and all - atoms. Each tensor in the list is of shape (n_atoms, n_types), where - n_types is the number of types in all systems combined. If the input was - a single system only a single torch tensor with the potentials is returned. - - IMPORTANT: If multiple types are present, the different "types-channels" - are ordered according to atomic number. For example, if a structure contains - a water molecule with atoms 0, 1, 2 being of types O, H, H, then for this - system, the feature tensor will be of shape (3, 2) = (``n_atoms``, - ``n_types``), where ``features[0, 0]`` is the potential at the position of - the Oxygen atom (atom 0, first index) generated by the HYDROGEN(!) atoms, - while ``features[0,1]`` is the potential at the position of the Oxygen atom - generated by the Oxygen atom(s). - """ - - return self._compute_impl( - types=types, - positions=positions, - cell=cell, - charges=charges, - neighbor_indices=neighbor_indices, - neighbor_shifts=neighbor_shifts, - ) + self.atomic_smearing = atomic_smearing + self.mesh_spacing = mesh_spacing + self.interpolation_order = interpolation_order + self.sr_cutoff = sr_cutoff - # This function is kept to keep MeshLODE compatible with the broader pytorch - # infrastructure, which require a "forward" function. We name this function - # "compute" instead, for compatibility with other COSMO software. - def forward( - self, - types: Union[List[torch.Tensor], torch.Tensor], - positions: Union[List[torch.Tensor], torch.Tensor], - cell: Union[List[torch.Tensor], torch.Tensor], - charges: Optional[Union[List[torch.Tensor], torch.Tensor]] = None, - neighbor_indices: Union[List[torch.Tensor], torch.Tensor] = None, - neighbor_shifts: Union[List[torch.Tensor], torch.Tensor] = None, - ) -> Union[torch.Tensor, List[torch.Tensor]]: - """Forward just calls :py:meth:`compute`.""" - return self.compute( - types=types, - positions=positions, - cell=cell, - charges=charges, - neighbor_indices=neighbor_indices, - neighbor_shifts=neighbor_shifts, - ) + # If interior contributions are to be subtracted, also do so for self term + if subtract_interior: + subtract_self = True + self.subtract_self = subtract_self + self.subtract_interior = subtract_interior def _compute_single_system( self, @@ -314,3 +213,135 @@ def _compute_lr( interpolated_potential -= charges * self_contrib return interpolated_potential + + +class PMEPotential(CalculatorBaseTorch, _PMEPotentialImpl): + r"""Specie-wise long-range potential using a particle mesh-based Ewald (PME). + + Scaling as :math:`\mathcal{O}(NlogN)` with respect to the number of particles + :math:`N` used as a reference to test faster implementations. + + :param all_types: Optional global list of all atomic types that should be considered + for the computation. This option might be useful when running the calculation on + subset of a whole dataset and it required to keep the shape of the output + consistent. If this is not set the possible atomic types will be determined when + calling the :meth:`compute()`. + :param exponent: the exponent "p" in 1/r^p potentials + :param sr_cutoff: Cutoff radius used for the short-range part of the Ewald sum. If + not set to a global value, it will be set to be half of the shortest lattice + vector defining the cell (separately for each structure). + :param atomic_smearing: Width of the atom-centered Gaussian used to split the + Coulomb potential into the short- and long-range parts. If not set to a global + value, it will be set to 1/5 times the sr_cutoff value (separately for each + structure) to ensure convergence of the short-range part to a relative precision + of 1e-5. + :param mesh_spacing: Value that determines the umber of Fourier-space grid points + that will be used along each axis. If set to None, it will automatically be set + to half of ``atomic_smearing``. + :param interpolation_order: Interpolation order for mapping onto the grid, where an + interpolation order of p corresponds to interpolation by a polynomial of degree + ``p - 1`` (e.g. ``p = 4`` for cubic interpolation). + :param subtract_self: If set to :py:obj:`True`, subtract from the features of an + atom the contributions to the potential arising from that atom itself (but not + the periodic images). + :param subtract_interior: If set to :py:obj:`True`, subtract from the features of an + atom the contributions to the potential arising from all atoms within the cutoff + Note that if set to true, the self contribution (see previous) is also + subtracted by default. + """ + + def __init__( + self, + all_types: Optional[List[int]] = None, + exponent: float = 1.0, + sr_cutoff: Optional[torch.Tensor] = None, + atomic_smearing: Optional[float] = None, + mesh_spacing: Optional[float] = None, + interpolation_order: Optional[int] = 3, + subtract_self: Optional[bool] = True, + subtract_interior: Optional[bool] = False, + ): + _PMEPotentialImpl.__init__( + self, + exponent=exponent, + sr_cutoff=sr_cutoff, + atomic_smearing=atomic_smearing, + mesh_spacing=mesh_spacing, + interpolation_order=interpolation_order, + subtract_self=subtract_self, + subtract_interior=subtract_interior, + ) + CalculatorBaseTorch.__init__(self, all_types=all_types, exponent=exponent) + + def compute( + self, + types: Union[List[torch.Tensor], torch.Tensor], + positions: Union[List[torch.Tensor], torch.Tensor], + cell: Union[List[torch.Tensor], torch.Tensor], + charges: Optional[Union[List[torch.Tensor], torch.Tensor]] = None, + neighbor_indices: Union[List[torch.Tensor], torch.Tensor] = None, + neighbor_shifts: Union[List[torch.Tensor], torch.Tensor] = None, + ) -> Union[torch.Tensor, List[torch.Tensor]]: + """Compute potential for all provided "systems" stacked inside list. + + The computation is performed on the same ``device`` as ``systems`` is stored on. + The ``dtype`` of the output tensors will be the same as the input. + + :param types: single or list of 1D tensor of integer representing the + particles identity. For atoms, this is typically their atomic numbers. + :param positions: single or 2D tensor of shape (len(types), 3) containing the + Cartesian positions of all particles in the system. + :param cell: single or 2D tensor of shape (3, 3), describing the bounding + box/unit cell of the system. Each row should be one of the bounding box + vector; and columns should contain the x, y, and z components of these + vectors (i.e. the cell should be given in row-major order). + :param charges: Optional single or list of 2D tensor of shape (len(types), n), + :param neighbor_indices: Optional single or list of 2D tensors of shape (2, n), + where n is the number of atoms. The 2 rows correspond to the indices of + the two atoms which are considered neighbors (e.g. within a cutoff distance) + :param neighbor_shifts: Optional single or list of 2D tensors of shape (3, n), + where n is the number of atoms. The 3 rows correspond to the shift indices + for periodic images. + + :return: List of torch Tensors containing the potentials for all frames and all + atoms. Each tensor in the list is of shape (n_atoms, n_types), where + n_types is the number of types in all systems combined. If the input was + a single system only a single torch tensor with the potentials is returned. + + IMPORTANT: If multiple types are present, the different "types-channels" + are ordered according to atomic number. For example, if a structure contains + a water molecule with atoms 0, 1, 2 being of types O, H, H, then for this + system, the feature tensor will be of shape (3, 2) = (``n_atoms``, + ``n_types``), where ``features[0, 0]`` is the potential at the position of + the Oxygen atom (atom 0, first index) generated by the HYDROGEN(!) atoms, + while ``features[0,1]`` is the potential at the position of the Oxygen atom + generated by the Oxygen atom(s). + """ + + return self._compute_impl( + types=types, + positions=positions, + cell=cell, + charges=charges, + neighbor_indices=neighbor_indices, + neighbor_shifts=neighbor_shifts, + ) + + def forward( + self, + types: Union[List[torch.Tensor], torch.Tensor], + positions: Union[List[torch.Tensor], torch.Tensor], + cell: Union[List[torch.Tensor], torch.Tensor], + charges: Optional[Union[List[torch.Tensor], torch.Tensor]] = None, + neighbor_indices: Union[List[torch.Tensor], torch.Tensor] = None, + neighbor_shifts: Union[List[torch.Tensor], torch.Tensor] = None, + ) -> Union[torch.Tensor, List[torch.Tensor]]: + """Forward just calls :py:meth:`compute`.""" + return self.compute( + types=types, + positions=positions, + cell=cell, + charges=charges, + neighbor_indices=neighbor_indices, + neighbor_shifts=neighbor_shifts, + ) diff --git a/src/meshlode/metatensor/__init__.py b/src/meshlode/metatensor/__init__.py index dec269fb..f52bf854 100644 --- a/src/meshlode/metatensor/__init__.py +++ b/src/meshlode/metatensor/__init__.py @@ -1,3 +1,3 @@ -from .pmepotential import PMEPotential +from .calculators import PMEPotential, EwaldPotential, DirectPotential -__all__ = ["MeshPotential", "PMEPotential"] +__all__ = ["DirectPotential", "EwaldPotential", "PMEPotential"] diff --git a/src/meshlode/metatensor/pmepotential.py b/src/meshlode/metatensor/calculators.py similarity index 59% rename from src/meshlode/metatensor/pmepotential.py rename to src/meshlode/metatensor/calculators.py index 72dd4f6c..353f0531 100644 --- a/src/meshlode/metatensor/pmepotential.py +++ b/src/meshlode/metatensor/calculators.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Union +from typing import Dict, List, Optional, Union import torch @@ -12,24 +12,24 @@ "Try installing it with:\npip install metatensor[torch]" ) - -from .. import calculators +from ..calculators.base import CalculatorBase, _1d_tolist +from ..calculators.directpotential import _DirectPotentialImpl +from ..calculators.ewaldpotential import _EwaldPotentialImpl +from ..calculators.pmepotential import _PMEPotentialImpl # We are breaking the Liskov substitution principle here by changing the signature of -# "compute" compated to the supertype of "MeshPotential". +# "compute" method to the supertype of metatansor class. # mypy: disable-error-code="override" -class PMEPotential(calculators.PMEPotential): - """Specie-wise long-range potential using a particle mesh-based Ewald (PME). - - Refer to :class:`meshlode.MeshPotential` for full documentation. - """ +class CalculatorBaseMetatensor(CalculatorBase): + def __init__(self, exponent: float): + super().__init__(exponent) def forward(self, systems: Union[List[System], System]) -> TensorMap: - """forward just calls :py:meth:`compute()`""" - return self.compute(systems=systems) + """Forward just calls :py:meth:`compute`.""" + return self.compute(systems) def compute(self, systems: Union[List[System], System]) -> TensorMap: """Compute potential for all provided ``systems``. @@ -40,7 +40,7 @@ def compute(self, systems: Union[List[System], System]) -> TensorMap: ``systems``. If no "explicit" charges are set the potential will be calculated for each "types-channels". - Refer to :meth:`meshlode.MeshPotential.compute()` for additional details on how + Refer to :meth:`meshlode.PMEPotential.compute()` for additional details on how "charges-channel" and "types-channels" are computed. :param systems: single System or list of @@ -69,65 +69,46 @@ def compute(self, systems: Union[List[System], System]) -> TensorMap: f"{system.device} and {systems[0].device}`" ) - dtype = systems[0].positions.dtype device = systems[0].positions.device - requested_types = self._get_requested_types( - [system.types for system in systems] - ) - n_types = len(requested_types) + all_atomic_types = torch.hstack([system.types for system in systems]) + atomic_types = _1d_tolist(torch.unique(all_atomic_types)) + n_types = len(atomic_types) has_charges = torch.tensor(["charges" in s.known_data() for s in systems]) - all_charges = torch.all(has_charges) - any_charges = torch.any(has_charges) - if any_charges and not all_charges: + if not torch.all(has_charges): raise ValueError("`systems` do not consistently contain `charges` data") - if all_charges: - use_explicit_charges = True - n_charges_channels = systems[0].get_data("charges").values.shape[1] - spec_channels = list(range(n_charges_channels)) - key_names = ["center_type", "charges_channel"] - - for i_system, system in enumerate(systems): - n_channels = system.get_data("charges").values.shape[1] - if n_channels != n_charges_channels: - raise ValueError( - f"number of charges-channels in system index {i_system} " - f"({n_channels}) is inconsistent with first system " - f"({n_charges_channels})" - ) - else: - # Use one hot encoded type channel per species for charges channel - use_explicit_charges = False - n_charges_channels = n_types - spec_channels = requested_types - key_names = ["center_type", "neighbor_type"] + + n_charges_channels = systems[0].get_data("charges").values.shape[1] + spec_channels = list(range(n_charges_channels)) + key_names = ["center_type", "charges_channel"] + + for i_system, system in enumerate(systems): + n_channels = system.get_data("charges").values.shape[1] + if n_channels != n_charges_channels: + raise ValueError( + f"number of charges-channels in system index {i_system} " + f"({n_channels}) is inconsistent with first system " + f"({n_charges_channels})" + ) # Initialize dictionary for TensorBlock storage. # - # If `use_explicit_charges=False`, the blocks are sorted according to the - # (integer) center_type and neighbor_type. Blocks are assigned the array indices - # 0, 1, 2,... Example: for H2O: `H` is mapped to `0` and `O` is mapped to `1`. - # - # For `use_explicit_charges=True` the blocks are stored according to the - # center_type and charge_channel + # blocks are stored according to the `center_type` and `charge_channel` n_blocks = n_types * n_charges_channels feat_dic: Dict[int, List[torch.Tensor]] = {a: [] for a in range(n_blocks)} for system in systems: - if use_explicit_charges: - charges = system.get_data("charges").values - else: - # One-hot encoding of charge information - charges = self._one_hot_charges( - system.types, requested_types, dtype, device - ) + charges = system.get_data("charges").values # try to extract neighbor list from system object neighbor_indices = None for neighbor_list_options in system.known_neighbor_lists(): - if neighbor_list_options.cutoff == self.sr_cutoff: + if ( + hasattr(self, "sr_cutoff") + and neighbor_list_options.cutoff == self.sr_cutoff + ): neighbor_list = system.get_neighbor_list(neighbor_list_options) neighbor_indices = neighbor_list.samples.values[:, :2] @@ -153,7 +134,7 @@ def compute(self, systems: Union[List[System], System]) -> TensorMap: ) # Reorder data into metatensor format - for spec_center, at_num_center in enumerate(requested_types): + for spec_center, at_num_center in enumerate(atomic_types): for spec_channel in range(len(spec_channels)): a_pair = spec_center * n_charges_channels + spec_channel feat_dic[a_pair] += [ @@ -164,7 +145,7 @@ def compute(self, systems: Union[List[System], System]) -> TensorMap: # of center_type and neighbor_type/charge_channel blocks: List[TensorBlock] = [] for keys, values in feat_dic.items(): - spec_center = requested_types[keys // n_charges_channels] + spec_center = atomic_types[keys // n_charges_channels] # Generate the Labels objects for the samples and properties of the # TensorBlock. @@ -202,7 +183,7 @@ def compute(self, systems: Union[List[System], System]) -> TensorMap: # Generate TensorMap from TensorBlocks by defining suitable keys key_values: List[torch.Tensor] = [] - for spec_center in requested_types: + for spec_center in atomic_types: for spec_channel in spec_channels: key_values.append( torch.tensor([spec_center, spec_channel], device=device) @@ -212,3 +193,70 @@ def compute(self, systems: Union[List[System], System]) -> TensorMap: labels_keys = Labels(key_names, key_values) return TensorMap(keys=labels_keys, blocks=blocks) + + +class DirectPotential(CalculatorBaseMetatensor, _DirectPotentialImpl): + """Specie-wise long-range potential using a direct summation over all atoms. + + Refer to :class:`meshlode.DirectPotential` for full documentation. + """ + + def __init__(self, exponent: float = 1.0): + self._DirectPotentialImpl.__init__(self, exponent=exponent) + CalculatorBaseMetatensor.__init__(self, exponent=exponent) + + +class EwaldPotential(CalculatorBaseMetatensor, _EwaldPotentialImpl): + """Specie-wise long-range potential computed using the Ewald sum. + + Refer to :class:`meshlode.EwaldPotential` for full documentation. + """ + + def __init__( + self, + exponent: float = 1.0, + sr_cutoff: Optional[torch.Tensor] = None, + atomic_smearing: Optional[float] = None, + lr_wavelength: Optional[float] = None, + subtract_self: Optional[bool] = True, + subtract_interior: Optional[bool] = False, + ): + _EwaldPotentialImpl.__init__( + self, + exponent=exponent, + sr_cutoff=sr_cutoff, + atomic_smearing=atomic_smearing, + lr_wavelength=lr_wavelength, + subtract_self=subtract_self, + subtract_interior=subtract_interior, + ) + CalculatorBaseMetatensor.__init__(self, exponent=exponent) + + +class PMEPotential(CalculatorBaseMetatensor, _PMEPotentialImpl): + """Specie-wise long-range potential using a particle mesh-based Ewald (PME). + + Refer to :class:`meshlode.PMEPotential` for full documentation. + """ + + def __init__( + self, + exponent: float = 1.0, + sr_cutoff: Optional[torch.Tensor] = None, + atomic_smearing: Optional[float] = None, + mesh_spacing: Optional[float] = None, + interpolation_order: Optional[int] = 3, + subtract_self: Optional[bool] = True, + subtract_interior: Optional[bool] = False, + ): + _PMEPotentialImpl.__init__( + self, + exponent=exponent, + sr_cutoff=sr_cutoff, + atomic_smearing=atomic_smearing, + mesh_spacing=mesh_spacing, + interpolation_order=interpolation_order, + subtract_self=subtract_self, + subtract_interior=subtract_interior, + ) + CalculatorBaseMetatensor.__init__(self, exponent=exponent) diff --git a/tests/calculators/test_calculator_base.py b/tests/calculators/test_calculator_base.py index 0de08d09..574d1551 100644 --- a/tests/calculators/test_calculator_base.py +++ b/tests/calculators/test_calculator_base.py @@ -1,10 +1,10 @@ import pytest import torch -from meshlode.calculators.base import CalculatorBase +from meshlode.calculators.base import CalculatorBaseTorch -class TestCalculator(CalculatorBase): +class TestCalculator(CalculatorBaseTorch): def compute( self, types, positions, cell, charges, neighbor_indices, neighbor_shifts ): diff --git a/tests/calculators/test_values_periodic.py b/tests/calculators/test_values_periodic.py index 2b941201..5ac355c4 100644 --- a/tests/calculators/test_values_periodic.py +++ b/tests/calculators/test_values_periodic.py @@ -408,11 +408,11 @@ def test_random_structure(sr_cutoff, frame_index, scaling_factor, ortho, calc_na frame = read(os.path.join(struc_path, "coulomb_test_frames.xyz"), frame_index) # Energies in Gaussian units (without e²/[4 π ɛ_0] prefactor) - energy_target = torch.tensor(frame.info["energy"], dtype=dtype) / scaling_factor - # Forces in Gaussian units per Å - forces_target = ( - torch.tensor(frame.arrays["forces"], dtype=dtype) / scaling_factor**2 + energy_target = ( + torch.tensor(frame.get_potential_energy(), dtype=dtype) / scaling_factor ) + # Forces in Gaussian units per Å + forces_target = torch.tensor(frame.get_forces(), dtype=dtype) / scaling_factor**2 # Convert into input format suitable for MeshLODE positions = scaling_factor * (torch.tensor(frame.positions, dtype=dtype) @ ortho) From 0d63a0e8fedcffa920f52a748b4db1244899386b Mon Sep 17 00:00:00 2001 From: Kevin Kazuki Huguenin-Dumittan Date: Tue, 9 Jul 2024 20:12:01 +0200 Subject: [PATCH 28/35] Remove types from input parameters --- src/meshlode/calculators/base.py | 290 +++++++-------- src/meshlode/calculators/directpotential.py | 10 +- src/meshlode/calculators/ewaldpotential.py | 16 +- src/meshlode/calculators/pmepotential.py | 9 +- src/meshlode/lib/potentials.py | 4 +- src/meshlode/metatensor/calculators.py | 11 +- tests/calculators/test_calculator_base.py | 337 +++++++++++++----- .../calculators/test_calculators_workflow.py | 75 ++-- tests/calculators/test_values_aperiodic.py | 4 +- tests/calculators/test_values_periodic.py | 11 +- tests/lib/test_potentials.py | 4 +- 11 files changed, 427 insertions(+), 344 deletions(-) diff --git a/src/meshlode/calculators/base.py b/src/meshlode/calculators/base.py index 03098e52..e3eee578 100644 --- a/src/meshlode/calculators/base.py +++ b/src/meshlode/calculators/base.py @@ -7,28 +7,15 @@ from meshlode.lib import InversePowerLawPotential -@torch.jit.script -def _1d_tolist(x: torch.Tensor) -> List[int]: - """Auxilary function to convert 1d torch tensor to list of integers.""" - result: List[int] = [] - for i in x: - result.append(i.item()) - return result - - -@torch.jit.script -def _is_subset(subset_candidate: List[int], superset: List[int]) -> bool: - """Checks whether all elements of `subset_candidate` are part of `superset`.""" - for element in subset_candidate: - if element not in superset: - return False - return True - - class CalculatorBase(torch.nn.Module): """Base class providing general funtionality.""" - def __init__(self, exponent): + def __init__( + self, + exponent: float, + ): + # Attach the function handling all computations related to the + # power-law potential for later convenience self.exponent = exponent self.potential = InversePowerLawPotential(exponent=exponent) @@ -113,252 +100,227 @@ class CalculatorBaseTorch(CalculatorBase): """ Base calculator for the torch interface to MeshLODE. - :param all_types: Optional global list of all atomic types that should be considered - for the computation. This option might be useful when running the calculation on - subset of a whole dataset and it required to keep the shape of the output - consistent. If this is not set the possible atomic types will be determined when - calling the :meth:`compute()`. :param exponent: the exponent "p" in 1/r^p potentials """ - def __init__(self, all_types: Union[None, List[int]], exponent: float): - super().__init__(exponent) - - if all_types is None: - self.all_types = None - else: - self.all_types = _1d_tolist(torch.unique(torch.tensor(all_types))) - - def _get_requested_types(self, types: List[torch.Tensor]) -> List[int]: - """Extract a list of all unique and present types from the list of types.""" - all_types = torch.hstack(types) - types_requested = _1d_tolist(torch.unique(all_types)) - - if self.all_types is not None: - if not _is_subset(types_requested, self.all_types): - raise ValueError( - f"Global list of types {self.all_types} does not contain all " - f"types for the provided systems {types_requested}." - ) - return self.all_types - else: - return types_requested - - def _one_hot_charges( + def __init__( self, - types: torch.Tensor, - requested_types: List[int], - dtype: Optional[torch.dtype] = None, - device: Optional[torch.device] = None, - ) -> torch.Tensor: - n_types = len(requested_types) - one_hot_charges = torch.zeros((len(types), n_types), dtype=dtype, device=device) - - for i_type, atomic_type in enumerate(requested_types): - one_hot_charges[types == atomic_type, i_type] = 1.0 - - return one_hot_charges + exponent: float, + ): + super().__init__(exponent=exponent) def _validate_compute_parameters( self, - types: Union[List[torch.Tensor], torch.Tensor], positions: Union[List[torch.Tensor], torch.Tensor], cell: Union[None, List[torch.Tensor], torch.Tensor], - charges: Union[None, List[torch.Tensor], torch.Tensor], + charges: Union[List[torch.Tensor], torch.Tensor], neighbor_indices: Union[None, List[torch.Tensor], torch.Tensor], neighbor_shifts: Union[None, List[torch.Tensor], torch.Tensor], ) -> Tuple[ - List[torch.Tensor], List[torch.Tensor], Union[List[None], List[torch.Tensor]], List[torch.Tensor], Union[List[None], List[torch.Tensor]], Union[List[None], List[torch.Tensor]], ]: - # validate types and positions - if not isinstance(types, list): - types = [types] + # make sure that the provided positions are a list if not isinstance(positions, list): positions = [positions] - if len(types) != len(positions): - raise ValueError( - f"Got inconsistent lengths of types ({len(types)}) " - f"positions ({len(positions)})" - ) + # In actual computations, the data type (dtype) and device (e.g. CPU, GPU) of + # all remaining variables need to be consistent + self.device = positions[0].device + self.dtype = positions[0].dtype + # make sure that provided cells are a list of same length as positions if cell is None: - cell = len(types) * [None] + cell = len(positions) * [None] elif not isinstance(cell, list): cell = [cell] - if len(types) != len(cell): + if len(positions) != len(cell): raise ValueError( - f"Got inconsistent lengths of types ({len(types)}) and " + f"Got inconsistent numbers of positions ({len(positions)}) and " f"cell ({len(cell)})" ) + # make sure that provided charges are a list of same length as positions + if not isinstance(charges, list): + charges = [charges] + + if len(positions) != len(charges): + raise ValueError( + f"Got inconsistent numbers of positions ({len(positions)}) and " + f"charges ({len(charges)})" + ) + + # check neighbor_indices if neighbor_indices is None: - neighbor_indices = len(types) * [None] + neighbor_indices = len(positions) * [None] elif not isinstance(neighbor_indices, list): neighbor_indices = [neighbor_indices] - if len(types) != len(neighbor_indices): + if len(positions) != len(neighbor_indices): raise ValueError( - f"Got inconsistent lengths of types ({len(types)}) and " + f"Got inconsistent numbers of positions ({len(positions)}) and " f"neighbor_indices ({len(neighbor_indices)})" ) + # check neighbor_shifts if neighbor_shifts is None: - neighbor_shifts = len(types) * [None] + neighbor_shifts = len(positions) * [None] elif not isinstance(neighbor_shifts, list): neighbor_shifts = [neighbor_shifts] - if len(types) != len(neighbor_shifts): + if len(positions) != len(neighbor_shifts): raise ValueError( - f"Got inconsistent lengths of types ({len(types)}) and " - f"neighbor_indices ({len(neighbor_shifts)})" + f"Got inconsistent numbers of positions ({len(positions)}) and " + f"neighbor_shifts ({len(neighbor_shifts)})" ) - # Check that all inputs are consistent. We don't require and test that all - # dtypes and devices are consistent if a list of inputs. Each single "frame" is - # processed independently. + # check that all devices and data types (dtypes) are consistent for ( - types_single, positions_single, cell_single, + charges_single, neighbor_indices_single, neighbor_shifts_single, - ) in zip(types, positions, cell, neighbor_indices, neighbor_shifts): - if len(types_single.shape) != 1: + ) in zip(positions, cell, charges, neighbor_indices, neighbor_shifts): + # check shape, dtype and device of positions + num_atoms = len(positions_single) + if list(positions_single.shape) != [num_atoms, 3]: raise ValueError( - "each `types` must be a 1 dimensional tensor, got at least " - f"one tensor with {len(types_single.shape)} dimensions" + "each `positions` must be a (n_atoms x 3) tensor, got at least " + f"one tensor with shape {tuple(positions_single.shape)}" ) - if positions_single.shape != (len(types_single), 3): + if positions_single.dtype != self.dtype: raise ValueError( - "each `positions` must be a (n_types x 3) tensor, got at least " - f"one tensor with shape {list(positions_single.shape)}" + f"each `positions` must have the same type {self.dtype} as the " + "first provided one. Got at least one tensor of type " + f"{positions_single.dtype}" ) - if types_single.device != positions_single.device: + if positions_single.device != self.device: raise ValueError( - f"Inconsistent devices of types ({types_single.device}) and " - f"positions ({positions_single.device})" + f"each `positions` must be on the same device {self.device} as the " + "first provided one. Got at least one tensor on device " + f"{positions_single.device}" ) + # check shape, dtype and device of cell if cell_single is not None: - if cell_single.shape != (3, 3): + if list(cell_single.shape) != [3, 3]: raise ValueError( - "each `cell` must be a (3 x 3) tensor, got at least " - f"one tensor with shape {list(cell_single.shape)}" + f"each `cell` must be a (3 x 3) tensor, got at least one tensor " + f"with shape {tuple(cell_single.shape)}" ) - if cell_single.dtype != positions_single.dtype: + if cell_single.dtype != self.dtype: raise ValueError( - "`cell` must be have the same dtype as `positions`, got " - f"{cell_single.dtype} and {positions_single.dtype}" + f"each `cell` must have the same type {self.dtype} as positions, " + f"got at least one tensor of type {cell_single.dtype}" ) - if types_single.device != cell_single.device: + if cell_single.device != self.device: raise ValueError( - f"Inconsistent devices of types ({types_single.device}) and " - f"cell ({cell_single.device})" + f"each `cell` must be on the same device {self.device} as positions, " + f"got at least one tensor with device {cell_single.device}" ) + # check shape, dtype and device of charges + if charges_single.dim() != 2: + raise ValueError( + f"each `charges` needs to be a 2-dimensional tensor, got at least " + f"one tensor with {charges_single.dim()} dimension(s) and shape " + f"{tuple(charges_single.shape)}" + ) + + if list(charges_single.shape) != [num_atoms, charges_single.shape[1]]: + raise ValueError( + f"each `charges` must be a (n_atoms x n_channels) tensor, with" + f"`n_atoms` being the same as the variable `positions`. Got at " + f"least one tensor with shape {tuple(charges_single.shape)} where " + f"positions contains {len(positions_single)} atoms" + ) + + if charges_single.dtype != self.dtype: + raise ValueError( + f"each `charges` must have the same type {self.dtype} as positions, " + f"got at least one tensor of type {charges_single.dtype}" + ) + + if charges_single.device != self.device: + raise ValueError( + f"each `charges` must be on the same device {self.device} as positions, " + f"got at least one tensor with device {charges_single.device}" + ) + + # check shape, dtype and device of neighbor_indices and neighbor_shifts if neighbor_indices_single is not None: - if neighbor_indices_single.shape != (2, len(types_single)): + if neighbor_shifts_single is None: raise ValueError( - "Expected shape of neighbor_indices is " - f"{2, len(types_single)}, but got " - f"{list(neighbor_indices_single.shape)}" + "Need to provide both neighbor_indices and neighbor_shifts together." ) - if types_single.device != neighbor_indices_single.device: + if neighbor_indices_single.shape[0] != 2: raise ValueError( - f"Inconsistent devices of types ({types_single.device}) and " - f"neighbor_indices ({neighbor_indices_single.device})" + "neighbor_indices is expected to have shape (2, num_neighbors)" + f", but got {tuple(neighbor_indices_single.shape)} for one structure" ) - if neighbor_shifts_single is not None: - if neighbor_shifts_single.shape != (3, len(types_single)): + if neighbor_shifts_single.shape[1] != 3: raise ValueError( - "Expected shape of neighbor_shifts is " - f"{3, len(types_single)}, but got " - f"{list(neighbor_shifts_single.shape)}" + "neighbor_shifts is expected to have shape (num_neighbors, 3)" + f", but got {tuple(neighbor_shifts_single.shape)} for one structure" ) - if types_single.device != neighbor_shifts_single.device: + if neighbor_shifts_single.shape[0] != neighbor_indices_single.shape[1]: raise ValueError( - f"Inconsistent devices of types ({types_single.device}) and " - f"neighbor_shifts_single ({neighbor_shifts_single.device})" + f"`neighbor_indices` and `neighbor_shifts` need to have shapes " + f"(2, num_neighbors) and (num_neighbors, 3). For at least one" + f"structure, got {tuple(neighbor_indices_single.shape)} and " + f"{tuple(neighbor_shifts_single.shape)}, which is inconsistent" ) - # If charges are not provided, we assume that all types are treated separately - if charges is None: - charges = [] - for types_single, positions_single in zip(types, positions): - # One-hot encoding of charge information - charges_single = self._one_hot_charges( - types=types_single, - requested_types=self._get_requested_types(types), - dtype=positions_single.dtype, - device=positions_single.device, - ) - charges.append(charges_single) + if neighbor_indices_single.device != self.device: + raise ValueError( + f"each `neighbor_indices` must be on the same device {self.device} as positions, " + f"got at least one tensor with device {neighbor_indices_single.device}" + ) - # If charges are provided, we need to make sure that they are consistent with - # the provided types - else: - if not isinstance(charges, list): - charges = [charges] - if len(charges) != len(types): - raise ValueError( - "The number of `types` and `charges` tensors must be the same, " - f"got {len(types)} and {len(charges)}." - ) - for charges_single, types_single in zip(charges, types): - if charges_single.shape[0] != len(types_single): + if neighbor_shifts_single.device != self.device: raise ValueError( - "The first dimension of `charges` must be the same as the " - f"length of `types`, got {charges_single.shape[0]} and " - f"{len(types_single)}." + f"each `neighbor_shifts` must be on the same device {self.device} as positions, " + f"got at least one tensor with device {neighbor_shifts_single.device}" ) - if charges[0].dtype != positions[0].dtype: - raise ValueError( - "`charges` must be have the same dtype as `positions`, got " - f"{charges[0].dtype} and {positions[0].dtype}." - ) - if charges[0].device != positions[0].device: - raise ValueError( - "`charges` must be on the same device as `positions`, got " - f"{charges[0].device} and {positions[0].device}." - ) - return types, positions, cell, charges, neighbor_indices, neighbor_shifts + return positions, cell, charges, neighbor_indices, neighbor_shifts def _compute_impl( self, - types: Union[List[torch.Tensor], torch.Tensor], positions: Union[List[torch.Tensor], torch.Tensor], cell: Union[None, List[torch.Tensor], torch.Tensor], - charges: Union[None, Union[List[torch.Tensor], torch.Tensor]], + charges: Union[Union[List[torch.Tensor], torch.Tensor]], neighbor_indices: Union[None, List[torch.Tensor], torch.Tensor], neighbor_shifts: Union[None, List[torch.Tensor], torch.Tensor], ) -> Union[torch.Tensor, List[torch.Tensor]]: + # Check that all shapes, data types and devices are consistent + # Furthermore, to handle the special case in which only the inputs for a single + # structure are provided, turn inputs into a list to be consistent with the + # more general case ( - types, positions, cell, charges, neighbor_indices, neighbor_shifts, ) = self._validate_compute_parameters( - types, positions, cell, charges, neighbor_indices, neighbor_shifts + positions, cell, charges, neighbor_indices, neighbor_shifts ) - potentials = [] + # compute and append into a list the features of each structure + potentials = [] for ( positions_single, cell_single, @@ -377,7 +339,9 @@ def _compute_impl( ) ) - if len(types) == 1: + # if only a single structure if provided as input, we directly return a single + # tensor containing its features rather than a list of tensors + if len(positions) == 1: return potentials[0] else: return potentials diff --git a/src/meshlode/calculators/directpotential.py b/src/meshlode/calculators/directpotential.py index a0c79555..a7309241 100644 --- a/src/meshlode/calculators/directpotential.py +++ b/src/meshlode/calculators/directpotential.py @@ -60,13 +60,12 @@ class DirectPotential(CalculatorBaseTorch, _DirectPotentialImpl): :param exponent: the exponent "p" in 1/r^p potentials """ - def __init__(self, all_types: Optional[List[int]] = None, exponent: float = 1.0): + def __init__(self, exponent: float = 1.0): _DirectPotentialImpl.__init__(self, exponent=exponent) - CalculatorBaseTorch.__init__(self, all_types=all_types, exponent=exponent) + CalculatorBaseTorch.__init__(self, exponent=exponent) def compute( self, - types: Union[List[torch.Tensor], torch.Tensor], positions: Union[List[torch.Tensor], torch.Tensor], charges: Optional[Union[List[torch.Tensor], torch.Tensor]] = None, ) -> Union[torch.Tensor, List[torch.Tensor]]: @@ -75,8 +74,6 @@ def compute( The computation is performed on the same ``device`` as ``systems`` is stored on. The ``dtype`` of the output tensors will be the same as the input. - :param types: single or list of 1D tensor of integer representing the - particles identity. For atoms, this is typically their atomic numbers. :param positions: single or 2D tensor of shape (len(types), 3) containing the Cartesian positions of all particles in the system. :param charges: Optional single or list of 2D tensor of shape (len(types), n), @@ -97,7 +94,6 @@ def compute( """ return self._compute_impl( - types=types, positions=positions, cell=None, charges=charges, @@ -110,13 +106,11 @@ def compute( # "compute" instead, for compatibility with other COSMO software. def forward( self, - types: Union[List[torch.Tensor], torch.Tensor], positions: Union[List[torch.Tensor], torch.Tensor], charges: Optional[Union[List[torch.Tensor], torch.Tensor]] = None, ) -> Union[torch.Tensor, List[torch.Tensor]]: """Forward just calls :py:meth:`compute`.""" return self.compute( - types=types, positions=positions, charges=charges, ) diff --git a/src/meshlode/calculators/ewaldpotential.py b/src/meshlode/calculators/ewaldpotential.py index a52d3315..7ca01b39 100644 --- a/src/meshlode/calculators/ewaldpotential.py +++ b/src/meshlode/calculators/ewaldpotential.py @@ -207,7 +207,7 @@ def _compute_lr( # TODO: modify to expression for general p if subtract_self: self_contrib = ( - torch.sqrt(torch.tensor(2.0 / torch.pi, device=cell.device)) / smearing + torch.sqrt(torch.tensor(2.0 / torch.pi, device=self.device)) / smearing ) energy -= charges * self_contrib @@ -220,11 +220,6 @@ class EwaldPotential(CalculatorBaseTorch, _EwaldPotentialImpl): Scaling as :math:`\mathcal{O}(N^2)` with respect to the number of particles :math:`N`. - :param all_types: Optional global list of all atomic types that should be considered - for the computation. This option might be useful when running the calculation on - subset of a whole dataset and it required to keep the shape of the output - consistent. If this is not set the possible atomic types will be determined when - calling the :meth:`compute()`. :param exponent: the exponent "p" in 1/r^p potentials :param sr_cutoff: Cutoff radius used for the short-range part of the Ewald sum. If not set to a global value, it will be set to be half of the shortest lattice @@ -268,7 +263,6 @@ class EwaldPotential(CalculatorBaseTorch, _EwaldPotentialImpl): def __init__( self, - all_types: Optional[List[int]] = None, exponent: float = 1.0, sr_cutoff: Optional[torch.Tensor] = None, atomic_smearing: Optional[float] = None, @@ -285,11 +279,10 @@ def __init__( subtract_self=subtract_self, subtract_interior=subtract_interior, ) - CalculatorBaseTorch.__init__(self, all_types=all_types, exponent=exponent) + CalculatorBaseTorch.__init__(self, exponent=exponent) def compute( self, - types: Union[List[torch.Tensor], torch.Tensor], positions: Union[List[torch.Tensor], torch.Tensor], cell: Union[List[torch.Tensor], torch.Tensor], charges: Optional[Union[List[torch.Tensor], torch.Tensor]] = None, @@ -301,8 +294,6 @@ def compute( The computation is performed on the same ``device`` as ``systems`` is stored on. The ``dtype`` of the output tensors will be the same as the input. - :param types: single or list of 1D tensor of integer representing the - particles identity. For atoms, this is typically their atomic numbers. :param positions: single or 2D tensor of shape (len(types), 3) containing the Cartesian positions of all particles in the system. :param cell: single or 2D tensor of shape (3, 3), describing the bounding @@ -333,7 +324,6 @@ def compute( """ return self._compute_impl( - types=types, positions=positions, cell=cell, charges=charges, @@ -346,7 +336,6 @@ def compute( # "compute" instead, for compatibility with other COSMO software. def forward( self, - types: Union[List[torch.Tensor], torch.Tensor], positions: Union[List[torch.Tensor], torch.Tensor], cell: Union[List[torch.Tensor], torch.Tensor], charges: Optional[Union[List[torch.Tensor], torch.Tensor]] = None, @@ -355,7 +344,6 @@ def forward( ) -> Union[torch.Tensor, List[torch.Tensor]]: """Forward just calls :py:meth:`compute`.""" return self.compute( - types=types, positions=positions, cell=cell, charges=charges, diff --git a/src/meshlode/calculators/pmepotential.py b/src/meshlode/calculators/pmepotential.py index 1c3d11ea..c19539ab 100644 --- a/src/meshlode/calculators/pmepotential.py +++ b/src/meshlode/calculators/pmepotential.py @@ -252,7 +252,6 @@ class PMEPotential(CalculatorBaseTorch, _PMEPotentialImpl): def __init__( self, - all_types: Optional[List[int]] = None, exponent: float = 1.0, sr_cutoff: Optional[torch.Tensor] = None, atomic_smearing: Optional[float] = None, @@ -271,11 +270,10 @@ def __init__( subtract_self=subtract_self, subtract_interior=subtract_interior, ) - CalculatorBaseTorch.__init__(self, all_types=all_types, exponent=exponent) + CalculatorBaseTorch.__init__(self, exponent=exponent) def compute( self, - types: Union[List[torch.Tensor], torch.Tensor], positions: Union[List[torch.Tensor], torch.Tensor], cell: Union[List[torch.Tensor], torch.Tensor], charges: Optional[Union[List[torch.Tensor], torch.Tensor]] = None, @@ -287,8 +285,6 @@ def compute( The computation is performed on the same ``device`` as ``systems`` is stored on. The ``dtype`` of the output tensors will be the same as the input. - :param types: single or list of 1D tensor of integer representing the - particles identity. For atoms, this is typically their atomic numbers. :param positions: single or 2D tensor of shape (len(types), 3) containing the Cartesian positions of all particles in the system. :param cell: single or 2D tensor of shape (3, 3), describing the bounding @@ -319,7 +315,6 @@ def compute( """ return self._compute_impl( - types=types, positions=positions, cell=cell, charges=charges, @@ -329,7 +324,6 @@ def compute( def forward( self, - types: Union[List[torch.Tensor], torch.Tensor], positions: Union[List[torch.Tensor], torch.Tensor], cell: Union[List[torch.Tensor], torch.Tensor], charges: Optional[Union[List[torch.Tensor], torch.Tensor]] = None, @@ -338,7 +332,6 @@ def forward( ) -> Union[torch.Tensor, List[torch.Tensor]]: """Forward just calls :py:meth:`compute`.""" return self.compute( - types=types, positions=positions, cell=cell, charges=charges, diff --git a/src/meshlode/lib/potentials.py b/src/meshlode/lib/potentials.py index 83789261..be0dcf95 100644 --- a/src/meshlode/lib/potentials.py +++ b/src/meshlode/lib/potentials.py @@ -121,7 +121,9 @@ def potential_fourier_from_k_sq( smearing parameter corresponds to the "width" of the Gaussian. """ peff = (3 - self.exponent) / 2 - prefac = (math.pi) ** 1.5 / gamma(self.exponent / 2) * (2 * smearing**2) ** peff + prefac = ( + (math.pi) ** 1.5 / gamma(self.exponent / 2) * (2 * smearing**2) ** peff + ) x = 0.5 * smearing**2 * k_sq fourier = prefac * gammaincc(peff, x) / x**peff * gamma(peff) diff --git a/src/meshlode/metatensor/calculators.py b/src/meshlode/metatensor/calculators.py index 353f0531..a35cd282 100644 --- a/src/meshlode/metatensor/calculators.py +++ b/src/meshlode/metatensor/calculators.py @@ -12,11 +12,20 @@ "Try installing it with:\npip install metatensor[torch]" ) -from ..calculators.base import CalculatorBase, _1d_tolist + +from ..calculators.base import CalculatorBase from ..calculators.directpotential import _DirectPotentialImpl from ..calculators.ewaldpotential import _EwaldPotentialImpl from ..calculators.pmepotential import _PMEPotentialImpl +@torch.jit.script +def _1d_tolist(x: torch.Tensor) -> List[int]: + """Auxilary function to convert 1d torch tensor to list of integers.""" + result: List[int] = [] + for i in x: + result.append(i.item()) + return result + # We are breaking the Liskov substitution principle here by changing the signature of # "compute" method to the supertype of metatansor class. diff --git a/tests/calculators/test_calculator_base.py b/tests/calculators/test_calculator_base.py index 574d1551..644bd03b 100644 --- a/tests/calculators/test_calculator_base.py +++ b/tests/calculators/test_calculator_base.py @@ -4,12 +4,20 @@ from meshlode.calculators.base import CalculatorBaseTorch +# Define some example parameters +dtype = torch.float32 +device = "cpu" +charges_1 = torch.ones((4, 1), dtype=dtype, device=device) +positions_1 = 0.3 * torch.arange(12, dtype=dtype, device=device).reshape((4, 3)) +charges_2 = torch.ones((5, 3), dtype=dtype, device=device) +positions_2 = 0.7 * torch.arange(15, dtype=dtype, device=device).reshape((5, 3)) +cell_1 = torch.eye(3, dtype=dtype, device=device) +cell_2 = torch.arange(9, dtype=dtype, device=device).reshape((3, 3)) + + class TestCalculator(CalculatorBaseTorch): - def compute( - self, types, positions, cell, charges, neighbor_indices, neighbor_shifts - ): + def compute(self, positions, cell, charges, neighbor_indices, neighbor_shifts): return self._compute_impl( - types=types, positions=positions, cell=cell, charges=charges, @@ -17,11 +25,8 @@ def compute( neighbor_shifts=neighbor_shifts, ) - def forward( - self, types, positions, cell, charges, neighbor_indices, neighbor_shifts - ): + def forward(self, positions, cell, charges, neighbor_indices, neighbor_shifts): return self._compute_impl( - types=types, positions=positions, cell=cell, charges=charges, @@ -37,23 +42,21 @@ def _compute_single_system( @pytest.mark.parametrize("method_name", ["compute", "forward"]) @pytest.mark.parametrize( - "types, positions, charges", + "positions, charges", [ - (torch.arange(2), torch.ones([2, 3]), torch.ones(2)), - ([torch.arange(2)], [torch.ones([2, 3])], [torch.ones(2)]), + (torch.ones([2, 3]), torch.ones(2).reshape((-1, 1))), + ([torch.ones([2, 3])], [torch.ones(2).reshape((-1, 1))]), ( - [torch.arange(2), torch.arange(4)], [torch.ones([2, 3]), torch.ones([4, 3])], - [torch.ones(2), torch.ones(4)], + [torch.ones(2).reshape((-1, 1)), torch.ones(4).reshape((-1, 1))], ), ], ) -def test_compute(method_name, types, positions, charges): - calculator = TestCalculator(all_types=None, exponent=1.0) +def test_compute_output_shapes(method_name, positions, charges): + calculator = TestCalculator(exponent=1.0) method = getattr(calculator, method_name) result = method( - types=types, positions=positions, cell=None, charges=charges, @@ -69,156 +72,320 @@ def test_compute(method_name, types, positions, charges): assert result.shape == charges.shape -def test_mismatched_lengths_types_positions(): - calculator = TestCalculator(all_types=None, exponent=1.0) - match = r"inconsistent lengths of types \(\d+\) positions \(\d+\)" +# Tests for a mismatch in the number of provided inputs for different variables +def test_mismatched_numbers_cell(): + calculator = TestCalculator(exponent=1.0) + match = r"Got inconsistent numbers of positions \(2\) and cell \(3\)" with pytest.raises(ValueError, match=match): calculator.compute( - types=torch.arange(2), - positions=[torch.ones([2, 3]), torch.ones([3, 3])], + positions=[positions_1, positions_2], + cell=[cell_1, cell_2, torch.eye(3)], + charges=[charges_1, charges_2], + neighbor_indices=None, + neighbor_shifts=None, + ) + + +def test_mismatched_numbers_charges(): + calculator = TestCalculator(exponent=1.0) + match = r"Got inconsistent numbers of positions \(2\) and charges \(3\)" + with pytest.raises(ValueError, match=match): + calculator.compute( + positions=[positions_1, positions_2], cell=None, - charges=None, + charges=[charges_1, charges_2, charges_2], neighbor_indices=None, neighbor_shifts=None, ) +def test_mismatched_numbers_neighbor_indices(): + calculator = TestCalculator(exponent=1.0) + match = r"Got inconsistent numbers of positions \(2\) and neighbor_indices \(3\)" + with pytest.raises(ValueError, match=match): + calculator.compute( + positions=[positions_1, positions_2], + cell=None, + charges=[charges_1, charges_2], + neighbor_indices=[charges_1, charges_2, positions_1], + neighbor_shifts=None, + ) + + +def test_mismatched_numbers_neighbor_shiftss(): + calculator = TestCalculator(exponent=1.0) + match = r"Got inconsistent numbers of positions \(2\) and neighbor_shifts \(3\)" + with pytest.raises(ValueError, match=match): + calculator.compute( + positions=[positions_1, positions_2], + cell=None, + charges=[charges_1, charges_2], + neighbor_indices=None, + neighbor_shifts=[charges_1, charges_2, positions_1], + ) + + +# Tests for invalid shape, dtype and device of positions def test_invalid_shape_positions(): - calculator = TestCalculator(all_types=None, exponent=1.0) + calculator = TestCalculator(exponent=1.0) match = ( - r"each `positions` must be a \(n_types x 3\) tensor, got at least one tensor " - r"with shape \[3, 3\]" + r"each `positions` must be a \(n_atoms x 3\) tensor, got at least " + r"one tensor with shape \(4, 5\)" ) with pytest.raises(ValueError, match=match): calculator.compute( - types=torch.arange(2), - positions=torch.ones([3, 3]), + positions=torch.ones((4, 5), dtype=dtype, device=device), cell=None, - charges=None, + charges=charges_1, neighbor_indices=None, neighbor_shifts=None, ) -def test_mismatched_lengths_types_cell(): - calculator = TestCalculator(all_types=None, exponent=1.0) - match = r"inconsistent lengths of types \(\d+\) and cell \(\d+\)" +def test_invalid_dtype_positions(): + calculator = TestCalculator(exponent=1.0) + match = ( + r"each `positions` must have the same type torch.float32 as the " + r"first provided one. Got at least one tensor of type " + r"torch.float64" + ) + positions_2_wrong_dtype = torch.ones((5, 3), dtype=torch.float64, device=device) with pytest.raises(ValueError, match=match): calculator.compute( - types=torch.arange(2), - positions=torch.ones([2, 3]), - cell=[torch.ones([3, 3]), torch.ones([3, 3])], - charges=None, + positions=[positions_1, positions_2_wrong_dtype], + cell=None, + charges=[charges_1, charges_2], neighbor_indices=None, neighbor_shifts=None, ) -def test_inconsistent_devices(): - calculator = TestCalculator(all_types=None, exponent=1.0) - match = r"Inconsistent devices of types \([a-zA-Z:]+\) and positions \([a-zA-Z:]+\)" +def test_invalid_device_positions(): + calculator = TestCalculator(exponent=1.0) + match = ( + r"each `positions` must be on the same device cpu as the " + r"first provided one. Got at least one tensor on device " + r"meta" + ) + positions_2_wrong_device = torch.ones((5, 3), dtype=dtype, device="meta") with pytest.raises(ValueError, match=match): calculator.compute( - types=torch.arange(2, device="meta"), - positions=torch.ones([2, 3], device="cpu"), + positions=[positions_1, positions_2_wrong_device], cell=None, - charges=None, + charges=[charges_1, charges_2], neighbor_indices=None, neighbor_shifts=None, ) -def test_inconsistent_dtypes_cell(): - calculator = TestCalculator(all_types=None, exponent=1.0) +# Tests for invalid shape, dtype and device of cell +def test_invalid_shape_cell(): + calculator = TestCalculator(exponent=1.0) match = ( - r"`cell` must be have the same dtype as `positions`, got " - r"torch.float32 and torch.float64" + r"each `cell` must be a \(3 x 3\) tensor, got at least one tensor with " + r"shape \(2, 2\)" ) with pytest.raises(ValueError, match=match): calculator.compute( - types=torch.arange(2), - positions=torch.ones([2, 3], dtype=torch.float64), - cell=torch.ones([3, 3], dtype=torch.float32), - charges=None, + positions=positions_1, + cell=torch.ones([2, 2], dtype=dtype, device=device), + charges=charges_1, neighbor_indices=None, neighbor_shifts=None, ) -def test_inconsistent_dtypes_charges(): - calculator = TestCalculator(all_types=None, exponent=1.0) +def test_invalid_dtype_cell(): + calculator = TestCalculator(exponent=1.0) match = ( - r"`charges` must be have the same dtype as `positions`, got " - r"torch.float32 and torch.float64" + r"each `cell` must have the same type torch.float32 as positions, " + r"got at least one tensor of type torch.float64" ) with pytest.raises(ValueError, match=match): calculator.compute( - types=torch.arange(2), - positions=torch.ones([2, 3], dtype=torch.float64), + positions=positions_1, + cell=torch.ones([3, 3], dtype=torch.float64, device=device), + charges=charges_1, + neighbor_indices=None, + neighbor_shifts=None, + ) + + +def test_invalid_device_cell(): + calculator = TestCalculator(exponent=1.0) + match = ( + r"each `cell` must be on the same device cpu as positions, " + r"got at least one tensor with device meta" + ) + with pytest.raises(ValueError, match=match): + calculator.compute( + positions=positions_1, + cell=torch.ones([3, 3], dtype=dtype, device="meta"), + charges=charges_1, + neighbor_indices=None, + neighbor_shifts=None, + ) + + +# Tests for invalid shape, dtype and device of charges +def test_invalid_dim_charges(): + calculator = TestCalculator(exponent=1.0) + match = ( + r"each `charges` needs to be a 2-dimensional tensor, got at least " + r"one tensor with 1 dimension\(s\) and shape " + r"\(4,\)" + ) + with pytest.raises(ValueError, match=match): + calculator.compute( + positions=positions_1, cell=None, - charges=torch.ones([2], dtype=torch.float32), + charges=torch.ones(len(positions_1), dtype=dtype, device=device), neighbor_indices=None, neighbor_shifts=None, ) -def test_mismatched_lengths_types_charges(): - calculator = TestCalculator(all_types=None, exponent=1.0) +def test_invalid_shape_charges(): + calculator = TestCalculator(exponent=1.0) match = ( - r"The first dimension of `charges` must be the same as the length of `types`, " - r"got \d+ and \d+" + r"each `charges` must be a \(n_atoms x n_channels\) tensor, with" + r"`n_atoms` being the same as the variable `positions`. Got at " + r"least one tensor with shape \(6, 2\) where " + r"positions contains 4 atoms" ) with pytest.raises(ValueError, match=match): calculator.compute( - types=torch.arange(2), - positions=torch.ones([2, 3]), + positions=positions_1, cell=None, - charges=torch.ones([3]), + charges=torch.ones((6, 2), dtype=dtype, device=device), neighbor_indices=None, neighbor_shifts=None, ) -def test_invalid_shape_cell(): - calculator = TestCalculator(all_types=None, exponent=1.0) +def test_invalid_dtype_charges(): + calculator = TestCalculator(exponent=1.0) match = ( - r"each `cell` must be a \(3 x 3\) tensor, got at least one tensor with " - r"shape \[2, 2\]" + r"each `charges` must have the same type torch.float32 as positions, " + r"got at least one tensor of type torch.float64" ) with pytest.raises(ValueError, match=match): calculator.compute( - types=torch.arange(2), - positions=torch.ones([2, 3]), - cell=torch.ones([2, 2]), - charges=None, + positions=positions_1, + cell=None, + charges=torch.ones((4, 2), dtype=torch.float64, device=device), neighbor_indices=None, neighbor_shifts=None, ) -def test_invalid_shape_neighbor_indices(): - calculator = TestCalculator(all_types=None, exponent=1.0) - match = r"Expected shape of neighbor_indices is \(2, \d+\), but got \[\d+, \d+\]" +def test_invalid_dtype_charges(): + calculator = TestCalculator(exponent=1.0) + match = ( + r"each `charges` must be on the same device cpu as positions, " + r"got at least one tensor with device meta" + ) + with pytest.raises(ValueError, match=match): + calculator.compute( + positions=positions_1, + cell=None, + charges=torch.ones((4, 2), dtype=dtype, device="meta"), + neighbor_indices=None, + neighbor_shifts=None, + ) + + +# Tests for invalid shape, dtype and device of neighbor_indices and neighbor_shifts +def test_need_both_neighbor_indices_and_shifts(): + calculator = TestCalculator(exponent=1.0) + match = r"Need to provide both neighbor_indices and neighbor_shifts together." with pytest.raises(ValueError, match=match): calculator.compute( - types=torch.arange(2), - positions=torch.ones([2, 3]), + positions=positions_1, cell=None, - charges=None, - neighbor_indices=torch.ones([3, 2]), + charges=charges_1, + neighbor_indices=torch.ones((2, 10), dtype=dtype, device=device), neighbor_shifts=None, ) +def test_invalid_shape_neighbor_indices(): + calculator = TestCalculator(exponent=1.0) + match = ( + r"neighbor_indices is expected to have shape \(2, num_neighbors\)" + r", but got \(4, 10\) for one structure" + ) + with pytest.raises(ValueError, match=match): + calculator.compute( + positions=positions_1, + cell=None, + charges=charges_1, + neighbor_indices=torch.ones((4, 10), dtype=dtype, device=device), + neighbor_shifts=torch.ones((10, 3), dtype=dtype, device=device), + ) + + def test_invalid_shape_neighbor_shifts(): - calculator = TestCalculator(all_types=None, exponent=1.0) - match = r"Expected shape of neighbor_shifts is \(3, \d+\), but got \[\d+, \d+\]" + calculator = TestCalculator(exponent=1.0) + match = ( + r"neighbor_shifts is expected to have shape \(num_neighbors, 3\)" + r", but got \(10, 2\) for one structure" + ) with pytest.raises(ValueError, match=match): calculator.compute( - types=torch.arange(2), - positions=torch.ones([2, 3]), + positions=positions_1, cell=None, - charges=None, - neighbor_indices=None, - neighbor_shifts=torch.ones([3, 3]), + charges=charges_1, + neighbor_indices=torch.ones((2, 10), dtype=dtype, device=device), + neighbor_shifts=torch.ones((10, 2), dtype=dtype, device=device), + ) + + +def test_invalid_shape_neighbor_shifts(): + calculator = TestCalculator(exponent=1.0) + match = ( + r"`neighbor_indices` and `neighbor_shifts` need to have shapes " + r"\(2, num_neighbors\) and \(num_neighbors, 3\). For at least one" + r"structure, got \(2, 10\) and " + r"\(11, 3\), which is inconsistent" + ) + with pytest.raises(ValueError, match=match): + calculator.compute( + positions=positions_1, + cell=None, + charges=charges_1, + neighbor_indices=torch.ones((2, 10), dtype=dtype, device=device), + neighbor_shifts=torch.ones((11, 3), dtype=dtype, device=device), + ) + + +def test_invalid_device_neighbor_indices(): + calculator = TestCalculator(exponent=1.0) + match = ( + r"each `neighbor_indices` must be on the same device cpu as positions, " + r"got at least one tensor with device meta" + ) + with pytest.raises(ValueError, match=match): + calculator.compute( + positions=positions_1, + cell=None, + charges=charges_1, + neighbor_indices=torch.ones((2, 10), dtype=dtype, device="meta"), + neighbor_shifts=torch.ones((10, 3), dtype=dtype, device=device), + ) + + +def test_invalid_device_neighbor_shifts(): + calculator = TestCalculator(exponent=1.0) + match = ( + r"each `neighbor_shifts` must be on the same device cpu as positions, " + r"got at least one tensor with device meta" + ) + with pytest.raises(ValueError, match=match): + calculator.compute( + positions=positions_1, + cell=None, + charges=charges_1, + neighbor_indices=torch.ones((2, 10), dtype=dtype, device=device), + neighbor_shifts=torch.ones((10, 3), dtype=dtype, device="meta"), ) diff --git a/tests/calculators/test_calculators_workflow.py b/tests/calculators/test_calculators_workflow.py index 770d0901..1b531051 100644 --- a/tests/calculators/test_calculators_workflow.py +++ b/tests/calculators/test_calculators_workflow.py @@ -49,19 +49,13 @@ class TestWorkflow: def cscl_system(self, periodic): """CsCl crystal. Same as in the madelung test""" - types = torch.tensor([55, 17]) positions = torch.tensor([[0, 0, 0], [0.5, 0.5, 0.5]]) cell = torch.eye(3) - + charges = torch.tensor([1.0, -1.0]).reshape((-1, 1)) if periodic: - return types, positions, cell + return positions, cell, charges else: - return types, positions - - def cscl_system_with_charges(self, periodic): - """CsCl crystal with charges.""" - charges = torch.tensor([[0.0, 1.0], [1.0, 0]]) - return self.cscl_system(periodic) + (charges,) + return positions, charges def calculator(self, CalculatorClass, periodic, params): if periodic: @@ -89,59 +83,26 @@ def test_interpolation_order_error(self, CalculatorClass, params, periodic): with pytest.raises(ValueError, match=match): CalculatorClass(atomic_smearing=1, interpolation_order=10) - def test_all_types(self, CalculatorClass, params, periodic): - if periodic: - descriptor = CalculatorClass(atomic_smearing=0.1, all_types=[8, 55, 17]) - values = descriptor.compute(*self.cscl_system(periodic)) - assert values.shape == (2, 3) - assert torch.equal(values[:, 0], torch.zeros(2)) - - def test_all_types_error(self, CalculatorClass, params, periodic): - if periodic: - descriptor = CalculatorClass(atomic_smearing=0.1, all_types=[17]) - with pytest.raises(ValueError, match="Global list of types"): - descriptor.compute(*self.cscl_system(periodic)) - - def test_single_frame(self, CalculatorClass, periodic, params): - calculator = self.calculator(CalculatorClass, periodic, params) - values = calculator.compute(*self.cscl_system(periodic)) - assert_close( - MADELUNG_CSCL, - CHARGES_CSCL[0] * values[0, 0] + CHARGES_CSCL[1] * values[0, 1], - atol=1e4, - rtol=1e-5, - ) - - def test_single_frame_with_charges(self, CalculatorClass, periodic, params): - calculator = self.calculator(CalculatorClass, periodic, params) - values = calculator.compute(*self.cscl_system_with_charges(periodic)) - assert_close( - MADELUNG_CSCL, - CHARGES_CSCL[0] * values[0, 0] + CHARGES_CSCL[1] * values[0, 1], - atol=1e4, - rtol=1e-5, - ) - def test_multi_frame(self, CalculatorClass, periodic, params): calculator = self.calculator(CalculatorClass, periodic, params) if periodic: - types, positions, cell = self.cscl_system(periodic) + positions, cell, charges = self.cscl_system(periodic) l_values = calculator.compute( - types=[types, types], positions=[positions, positions], cell=[cell, cell], + charges=[charges, charges], ) else: - types, positions = self.cscl_system(periodic) + positions, charges = self.cscl_system(periodic) l_values = calculator.compute( - types=[types, types], positions=[positions, positions] + positions=[positions, positions], charges=[charges, charges] ) for values in l_values: assert_close( MADELUNG_CSCL, - CHARGES_CSCL[0] * values[0, 0] + CHARGES_CSCL[1] * values[0, 1], - atol=1e4, + -torch.sum(charges * values) / 2, + atol=1, rtol=1e-5, ) @@ -150,15 +111,17 @@ def test_dtype_device(self, CalculatorClass, periodic, params): device = "cpu" dtype = torch.float64 - types = torch.tensor([1], dtype=dtype, device=device) positions = torch.tensor([[0.0, 0.0, 0.0]], dtype=dtype, device=device) + charges = torch.ones((1, 2), dtype=dtype, device=device) calculator = self.calculator(CalculatorClass, periodic, params) if periodic: cell = torch.eye(3, dtype=dtype, device=device) - potential = calculator.compute(types=types, positions=positions, cell=cell) + potential = calculator.compute( + positions=positions, cell=cell, charges=charges + ) else: - potential = calculator.compute(types=types, positions=positions) + potential = calculator.compute(positions=positions, charges=charges) assert potential.dtype == dtype assert potential.device.type == device @@ -169,11 +132,13 @@ def check_operation(self, CalculatorClass, periodic, params): calculator = self.calculator(CalculatorClass, periodic, params) if periodic: - types, positions, cell = self.cscl_system(periodic) - descriptor = calculator.compute(types=types, positions=positions, cell=cell) + positions, cell, charges = self.cscl_system(periodic) + descriptor = calculator.compute( + positions=positions, cell=cell, charges=charges + ) else: - types, positions = self.cscl_system(periodic) - descriptor = calculator.compute(types=types, positions=positions) + positions, charges = self.cscl_system(periodic) + descriptor = calculator.compute(positions=positions, charges=charges) assert type(descriptor) is torch.Tensor diff --git a/tests/calculators/test_values_aperiodic.py b/tests/calculators/test_values_aperiodic.py index e4612b2b..1b899c24 100644 --- a/tests/calculators/test_values_aperiodic.py +++ b/tests/calculators/test_values_aperiodic.py @@ -99,6 +99,8 @@ def define_molecule(molecule_name="dimer"): charges *= -1.0 potentials *= -1.0 + charges = charges.reshape((-1, 1)) + potentials = potentials.reshape((-1, 1)) return types, positions, charges, potentials @@ -163,7 +165,7 @@ def test_coulomb_exact( molecule_name = molecule + molecule_charge types, positions, charges, ref_potentials = define_molecule(molecule_name) positions = scaling_factor * (positions @ orthogonal_transformation) - potentials = DP.compute(types, positions, charges=charges) + potentials = DP.compute(positions, charges=charges) ref_potentials /= scaling_factor torch.testing.assert_close(potentials, ref_potentials, atol=2e-15, rtol=1e-14) diff --git a/tests/calculators/test_values_periodic.py b/tests/calculators/test_values_periodic.py index 5ac355c4..b86a1609 100644 --- a/tests/calculators/test_values_periodic.py +++ b/tests/calculators/test_values_periodic.py @@ -276,6 +276,7 @@ def define_crystal(crystal_name="CsCl"): raise ValueError(f"crystal_name = {crystal_name} is not supported!") madelung_ref = torch.tensor(madelung_ref, dtype=dtype) + charges = charges.reshape((-1, 1)) return types, positions, charges, cell, madelung_ref, num_formula_units @@ -315,7 +316,7 @@ def test_madelung(crystal_name, scaling_factor, calc_name): rtol = 9e-4 # Compute potential and compare against target value using default hypers - potentials = calc.compute(types, positions=pos, cell=cell, charges=charges) + potentials = calc.compute(positions=pos, cell=cell, charges=charges) energies = potentials * charges madelung = -torch.sum(energies) / 2 / num_units @@ -372,9 +373,7 @@ def test_wigner(crystal_name, scaling_factor): # Compute potential and compare against reference calc = EwaldPotential(atomic_smearing=smeareff) - potentials = calc.compute( - types, positions=positions, cell=cell, charges=charges - ) + potentials = calc.compute(positions=positions, cell=cell, charges=charges) energies = potentials * charges energies_ref = -torch.ones_like(energies) * madelung_ref torch.testing.assert_close(energies, energies_ref, atol=0.0, rtol=rtol) @@ -431,9 +430,7 @@ def test_random_structure(sr_cutoff, frame_index, scaling_factor, ortho, calc_na calc = PMEPotential(sr_cutoff=sr_cutoff) rtol_e = 4.5e-3 # 1.5e-3 rtol_f = 2.5e-3 # 6e-3 - potentials = calc.compute( - types=types, positions=positions, cell=cell, charges=charges - ) + potentials = calc.compute(positions=positions, cell=cell, charges=charges) # Compute energy, taking into account the double counting of each pair energy = torch.sum(potentials * charges) / 2 diff --git a/tests/lib/test_potentials.py b/tests/lib/test_potentials.py index 99b2da4a..46527b9a 100644 --- a/tests/lib/test_potentials.py +++ b/tests/lib/test_potentials.py @@ -226,6 +226,8 @@ def test_lr_value_at_zero(exponent, smearing): potential_close_to_zero = ipl.potential_lr_from_dist(dist_small, smearing=smearing) # Compare to - exact_value = 1.0 / (2 * smearing**2) ** (exponent / 2) / gamma(exponent / 2 + 1.0) + exact_value = ( + 1.0 / (2 * smearing**2) ** (exponent / 2) / gamma(exponent / 2 + 1.0) + ) relerr = torch.abs(potential_close_to_zero - exact_value) / exact_value assert relerr.item() < 3e-14 From d8364a3bcfb45946c90b5758911b06b8ff9fae30 Mon Sep 17 00:00:00 2001 From: Philip Loche Date: Tue, 9 Jul 2024 17:20:24 +0200 Subject: [PATCH 29/35] change metatensor output shape --- src/meshlode/metatensor/calculators.py | 178 +++++++++---------------- 1 file changed, 62 insertions(+), 116 deletions(-) diff --git a/src/meshlode/metatensor/calculators.py b/src/meshlode/metatensor/calculators.py index a35cd282..fe569dd8 100644 --- a/src/meshlode/metatensor/calculators.py +++ b/src/meshlode/metatensor/calculators.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Optional, Union +from typing import List, Optional, Union import torch @@ -12,7 +12,6 @@ "Try installing it with:\npip install metatensor[torch]" ) - from ..calculators.base import CalculatorBase from ..calculators.directpotential import _DirectPotentialImpl from ..calculators.ewaldpotential import _EwaldPotentialImpl @@ -27,11 +26,6 @@ def _1d_tolist(x: torch.Tensor) -> List[int]: return result -# We are breaking the Liskov substitution principle here by changing the signature of -# "compute" method to the supertype of metatansor class. -# mypy: disable-error-code="override" - - class CalculatorBaseMetatensor(CalculatorBase): def __init__(self, exponent: float): super().__init__(exponent) @@ -40,31 +34,15 @@ def forward(self, systems: Union[List[System], System]) -> TensorMap: """Forward just calls :py:meth:`compute`.""" return self.compute(systems) - def compute(self, systems: Union[List[System], System]) -> TensorMap: - """Compute potential for all provided ``systems``. - - All ``systems`` must have the same ``dtype`` and the same ``device``. If each - system contains a custom data field `charges` the potential will be calculated - for each "charges-channel". The number of `charges-channels` must be same in all - ``systems``. If no "explicit" charges are set the potential will be calculated - for each "types-channels". - - Refer to :meth:`meshlode.PMEPotential.compute()` for additional details on how - "charges-channel" and "types-channels" are computed. - - :param systems: single System or list of - :py:class:`metatensor.torch.atomisic.System` on which to run the - calculations. - - :return: TensorMap containing the potential of all types. The keys of the - TensorMap are "center_type" and "neighbor_type" if no charges are asociated - with - """ + def _validate_compute_parameters( + self, systems: Union[List[System], System] + ) -> List[System]: # Make sure that the compute function also works if only a single frame is # provided as input (for convenience of users testing out the code) if not isinstance(systems, list): systems = [systems] + self._device = systems[0].positions.device for system in systems: if system.dtype != systems[0].dtype: raise ValueError( @@ -72,41 +50,50 @@ def compute(self, systems: Union[List[System], System]) -> TensorMap: f"{system.dtype} and {systems[0].dtype}`" ) - if system.device != systems[0].device: + if system.device != self._device: raise ValueError( "`device` of all systems must be the same, got " f"{system.device} and {systems[0].device}`" ) - device = systems[0].positions.device - - all_atomic_types = torch.hstack([system.types for system in systems]) - atomic_types = _1d_tolist(torch.unique(all_atomic_types)) - n_types = len(atomic_types) - has_charges = torch.tensor(["charges" in s.known_data() for s in systems]) - if not torch.all(has_charges): raise ValueError("`systems` do not consistently contain `charges` data") - n_charges_channels = systems[0].get_data("charges").values.shape[1] - spec_channels = list(range(n_charges_channels)) - key_names = ["center_type", "charges_channel"] - + self._n_charges_channels = systems[0].get_data("charges").values.shape[1] for i_system, system in enumerate(systems): n_channels = system.get_data("charges").values.shape[1] - if n_channels != n_charges_channels: + if n_channels != self._n_charges_channels: raise ValueError( f"number of charges-channels in system index {i_system} " f"({n_channels}) is inconsistent with first system " - f"({n_charges_channels})" + f"({self._n_charges_channels})" ) - # Initialize dictionary for TensorBlock storage. - # - # blocks are stored according to the `center_type` and `charge_channel` - n_blocks = n_types * n_charges_channels - feat_dic: Dict[int, List[torch.Tensor]] = {a: [] for a in range(n_blocks)} + return systems + + def compute(self, systems: Union[List[System], System]) -> TensorMap: + """Compute potential for all provided ``systems``. + + All ``systems`` must have the same ``dtype`` and the same ``device``. If each + system contains a custom data field `charges` the potential will be calculated + for each "charges-channel". The number of `charges-channels` must be same in all + ``systems``. If no "explicit" charges are set the potential will be calculated + for each "types-channels". + + Refer to :meth:`meshlode.PMEPotential.compute()` for additional details on how + "charges-channel" and "types-channels" are computed. + + :param systems: single System or list of + :py:class:`metatensor.torch.atomisic.System` on which to run the + calculations. + + :return: TensorMap containing the potential of all types. The keys of the + TensorMap are "center_type" and "neighbor_type" if no charges are asociated + with + """ + systems = self._validate_compute_parameters(systems) + potentials: List[torch.Tensor] = [] for system in systems: charges = system.get_data("charges").values @@ -126,82 +113,41 @@ def compute(self, systems: Union[List[System], System]) -> TensorMap: break if neighbor_indices is None: - potential = self._compute_single_system( - positions=system.positions, - cell=system.cell, - charges=charges, - neighbor_indices=None, - neighbor_shifts=None, + potentials.append( + self._compute_single_system( + positions=system.positions, + cell=system.cell, + charges=charges, + neighbor_indices=None, + neighbor_shifts=None, + ) ) else: - potential = self._compute_single_system( - positions=system.positions, - charges=charges, - cell=system.cell, - neighbor_indices=neighbor_indices, - neighbor_shifts=neighbor_shifts, + potentials.append( + self._compute_single_system( + positions=system.positions, + charges=charges, + cell=system.cell, + neighbor_indices=neighbor_indices, + neighbor_shifts=neighbor_shifts, + ) ) - # Reorder data into metatensor format - for spec_center, at_num_center in enumerate(atomic_types): - for spec_channel in range(len(spec_channels)): - a_pair = spec_center * n_charges_channels + spec_channel - feat_dic[a_pair] += [ - potential[system.types == at_num_center, spec_channel] - ] - - # Assemble all computed potential values into TensorBlocks for each combination - # of center_type and neighbor_type/charge_channel - blocks: List[TensorBlock] = [] - for keys, values in feat_dic.items(): - spec_center = atomic_types[keys // n_charges_channels] - - # Generate the Labels objects for the samples and properties of the - # TensorBlock. - values_samples: List[List[int]] = [] - for i_frame, system in enumerate(systems): - for i_atom in range(len(system)): - if system.types[i_atom] == spec_center: - values_samples.append([i_frame, i_atom]) - - samples_vals_tensor = torch.tensor( - values_samples, dtype=torch.int32, device=device - ) - - # If no atoms are found that match the types pair `samples_vals_tensor` - # will be empty. We have to reshape the empty tensor to be a valid input for - # `Labels`. - if len(samples_vals_tensor) == 0: - samples_vals_tensor = samples_vals_tensor.reshape(-1, 2) - - labels_samples = Labels(["system", "atom"], samples_vals_tensor) - labels_properties = Labels( - ["potential"], torch.tensor([[0]], device=device) - ) - - block = TensorBlock( - samples=labels_samples, - components=[], - properties=labels_properties, - values=torch.hstack(values).reshape((-1, 1)), - ) - - blocks.append(block) - - assert len(blocks) == n_blocks - - # Generate TensorMap from TensorBlocks by defining suitable keys - key_values: List[torch.Tensor] = [] - for spec_center in atomic_types: - for spec_channel in spec_channels: - key_values.append( - torch.tensor([spec_center, spec_channel], device=device) - ) - key_values = torch.vstack(key_values) + values_samples: List[List[int]] = [] + for i_system in range(len(systems)): + for i_atom in range(len(system)): + values_samples.append([i_system, i_atom]) + + samples_vals_tensor = torch.tensor(values_samples, device=self._device) - labels_keys = Labels(key_names, key_values) + block = TensorBlock( + values=torch.vstack(potentials), + samples=Labels(["system", "atom"], samples_vals_tensor), + components=[], + properties=Labels.range("charges_channel", self._n_charges_channels), + ) - return TensorMap(keys=labels_keys, blocks=blocks) + return TensorMap(keys=Labels.single(), blocks=[block]) class DirectPotential(CalculatorBaseMetatensor, _DirectPotentialImpl): From 342188e59a406d6d7948619a89d635ef87c88be9 Mon Sep 17 00:00:00 2001 From: Philip Loche Date: Wed, 10 Jul 2024 10:01:44 +0200 Subject: [PATCH 30/35] Continue types removal --- examples/madelung.py | 30 ++- pyproject.toml | 7 + src/meshlode/calculators/base.py | 78 +++--- src/meshlode/calculators/directpotential.py | 57 +++-- src/meshlode/calculators/ewaldpotential.py | 110 ++------- src/meshlode/calculators/pmepotential.py | 108 +++------ src/meshlode/lib/potentials.py | 4 +- src/meshlode/metatensor/calculators.py | 50 ++-- .../{test_calculator_base.py => test_base.py} | 200 +++++++-------- tests/calculators/test_values_aperiodic.py | 28 +-- tests/calculators/test_values_periodic.py | 8 +- ...lculators_workflow.py => test_workflow.py} | 14 +- tests/lib/test_potentials.py | 4 +- tests/metatensor/test_calculators.py | 99 ++++++++ tests/metatensor/test_madelung.py | 228 ------------------ tox.ini | 44 ++-- 16 files changed, 434 insertions(+), 635 deletions(-) rename tests/calculators/{test_calculator_base.py => test_base.py} (65%) rename tests/calculators/{test_calculators_workflow.py => test_workflow.py} (92%) create mode 100644 tests/metatensor/test_calculators.py delete mode 100644 tests/metatensor/test_madelung.py diff --git a/examples/madelung.py b/examples/madelung.py index c9b92a0a..b9785f4c 100644 --- a/examples/madelung.py +++ b/examples/madelung.py @@ -10,6 +10,7 @@ import math import torch +from metatensor.torch import Labels, TensorBlock from metatensor.torch.atomistic import System import meshlode @@ -22,12 +23,12 @@ # numbers 17 and 55, respectively. types = torch.tensor([17, 55]) # Cl and Cs positions = torch.tensor([[0, 0, 0], [0.5, 0.5, 0.5]]) -charges = torch.tensor([-1.0, 1.0]) +charges = torch.tensor([-1.0, 1.0]).reshape(-1, 1) cell = torch.eye(3) # %% # Define the expected values of the energy -n_atoms = len(types) +n_atoms = len(positions) madelung = 2 * 1.7626 / math.sqrt(3) energies_ref = -madelung * torch.ones((n_atoms, 1)) @@ -44,13 +45,15 @@ # ------------------------------ # Compute features using -MP = meshlode.PMEPotential( +pme = meshlode.PMEPotential( atomic_smearing=atomic_smearing, mesh_spacing=mesh_spacing, interpolation_order=interpolation_order, subtract_self=True, ) -potentials_torch = MP.compute(types=types, positions=positions, cell=cell) +potentials_torch: torch.Tensor = pme.compute( + positions=positions, charges=charges, cell=cell +) # %% # The "potentials" that have been computed so far are not the actual electrostatic @@ -92,13 +95,28 @@ system = System(types=types, positions=positions, cell=cell) -MP = meshlode.metatensor.PMEPotential( +# %% +# Attach charges to the system. + +data = TensorBlock( + values=charges, + samples=Labels.range("atom", len(system)), + components=[], + properties=Labels("charge", torch.tensor([[0]])), +) +system.add_data(name="charges", data=data) + + +# %% +# Perform the calculation. + +pme = meshlode.metatensor.PMEPotential( atomic_smearing=atomic_smearing, mesh_spacing=mesh_spacing, interpolation_order=interpolation_order, subtract_self=True, ) -potential_metatensor = MP.compute(system) +potential_metatensor = pme.compute(system) # %% diff --git a/pyproject.toml b/pyproject.toml index 9a41ebd7..3f73d1b9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -66,6 +66,9 @@ data_file = 'tests/.coverage' [tool.coverage.xml] output = 'tests/coverage.xml' +[tool.black] +exclude = 'docs/src/examples' + [tool.isort] skip = "__init__.py" profile = "black" @@ -76,6 +79,10 @@ lines_after_imports = 2 known_first_party = "meshlode" [tool.mypy] +exclude = [ + "docs/src/examples" +] +follow_imports = 'skip' ignore_missing_imports = true [tool.pytest.ini_options] diff --git a/src/meshlode/calculators/base.py b/src/meshlode/calculators/base.py index e3eee578..0201d458 100644 --- a/src/meshlode/calculators/base.py +++ b/src/meshlode/calculators/base.py @@ -112,8 +112,8 @@ def __init__( def _validate_compute_parameters( self, positions: Union[List[torch.Tensor], torch.Tensor], - cell: Union[None, List[torch.Tensor], torch.Tensor], charges: Union[List[torch.Tensor], torch.Tensor], + cell: Union[None, List[torch.Tensor], torch.Tensor], neighbor_indices: Union[None, List[torch.Tensor], torch.Tensor], neighbor_shifts: Union[None, List[torch.Tensor], torch.Tensor], ) -> Tuple[ @@ -129,8 +129,8 @@ def _validate_compute_parameters( # In actual computations, the data type (dtype) and device (e.g. CPU, GPU) of # all remaining variables need to be consistent - self.device = positions[0].device - self.dtype = positions[0].dtype + self._device = positions[0].device + self._dtype = positions[0].dtype # make sure that provided cells are a list of same length as positions if cell is None: @@ -194,17 +194,17 @@ def _validate_compute_parameters( f"one tensor with shape {tuple(positions_single.shape)}" ) - if positions_single.dtype != self.dtype: + if positions_single.dtype != self._dtype: raise ValueError( - f"each `positions` must have the same type {self.dtype} as the " + f"each `positions` must have the same type {self._dtype} as the " "first provided one. Got at least one tensor of type " f"{positions_single.dtype}" ) - if positions_single.device != self.device: + if positions_single.device != self._device: raise ValueError( - f"each `positions` must be on the same device {self.device} as the " - "first provided one. Got at least one tensor on device " + f"each `positions` must be on the same device {self._device} as " + "the first provided one. Got at least one tensor on device " f"{positions_single.device}" ) @@ -212,20 +212,22 @@ def _validate_compute_parameters( if cell_single is not None: if list(cell_single.shape) != [3, 3]: raise ValueError( - f"each `cell` must be a (3 x 3) tensor, got at least one tensor " - f"with shape {tuple(cell_single.shape)}" + f"each `cell` must be a (3 x 3) tensor, got at least one " + f"tensor with shape {tuple(cell_single.shape)}" ) - if cell_single.dtype != self.dtype: + if cell_single.dtype != self._dtype: raise ValueError( - f"each `cell` must have the same type {self.dtype} as positions, " - f"got at least one tensor of type {cell_single.dtype}" + f"each `cell` must have the same type {self._dtype} as " + "positions, got at least one tensor of type " + f"{cell_single.dtype}" ) - if cell_single.device != self.device: + if cell_single.device != self._device: raise ValueError( - f"each `cell` must be on the same device {self.device} as positions, " - f"got at least one tensor with device {cell_single.device}" + f"each `cell` must be on the same device {self._device} as " + "positions, got at least one tensor with device " + f"{cell_single.device}" ) # check shape, dtype and device of charges @@ -244,35 +246,39 @@ def _validate_compute_parameters( f"positions contains {len(positions_single)} atoms" ) - if charges_single.dtype != self.dtype: + if charges_single.dtype != self._dtype: raise ValueError( - f"each `charges` must have the same type {self.dtype} as positions, " - f"got at least one tensor of type {charges_single.dtype}" + f"each `charges` must have the same type {self._dtype} as " + f"positions, got at least one tensor of type {charges_single.dtype}" ) - if charges_single.device != self.device: + if charges_single.device != self._device: raise ValueError( - f"each `charges` must be on the same device {self.device} as positions, " - f"got at least one tensor with device {charges_single.device}" + f"each `charges` must be on the same device {self._device} as " + f"positions, got at least one tensor with device " + f"{charges_single.device}" ) # check shape, dtype and device of neighbor_indices and neighbor_shifts if neighbor_indices_single is not None: if neighbor_shifts_single is None: raise ValueError( - "Need to provide both neighbor_indices and neighbor_shifts together." + "Need to provide both `neighbor_indices` and `neighbor_shifts` " + "together." ) if neighbor_indices_single.shape[0] != 2: raise ValueError( "neighbor_indices is expected to have shape (2, num_neighbors)" - f", but got {tuple(neighbor_indices_single.shape)} for one structure" + f", but got {tuple(neighbor_indices_single.shape)} for one " + "structure" ) if neighbor_shifts_single.shape[1] != 3: raise ValueError( "neighbor_shifts is expected to have shape (num_neighbors, 3)" - f", but got {tuple(neighbor_shifts_single.shape)} for one structure" + f", but got {tuple(neighbor_shifts_single.shape)} for one " + "structure" ) if neighbor_shifts_single.shape[0] != neighbor_indices_single.shape[1]: @@ -283,16 +289,18 @@ def _validate_compute_parameters( f"{tuple(neighbor_shifts_single.shape)}, which is inconsistent" ) - if neighbor_indices_single.device != self.device: + if neighbor_indices_single.device != self._device: raise ValueError( - f"each `neighbor_indices` must be on the same device {self.device} as positions, " - f"got at least one tensor with device {neighbor_indices_single.device}" + f"each `neighbor_indices` must be on the same device " + f"{self._device} as positions, got at least one tensor with " + f"device {neighbor_indices_single.device}" ) - if neighbor_shifts_single.device != self.device: + if neighbor_shifts_single.device != self._device: raise ValueError( - f"each `neighbor_shifts` must be on the same device {self.device} as positions, " - f"got at least one tensor with device {neighbor_shifts_single.device}" + f"each `neighbor_shifts` must be on the same device " + f"{self._device} as positions, got at least one tensor with " + f"device {neighbor_shifts_single.device}" ) return positions, cell, charges, neighbor_indices, neighbor_shifts @@ -300,8 +308,8 @@ def _validate_compute_parameters( def _compute_impl( self, positions: Union[List[torch.Tensor], torch.Tensor], - cell: Union[None, List[torch.Tensor], torch.Tensor], charges: Union[Union[List[torch.Tensor], torch.Tensor]], + cell: Union[None, List[torch.Tensor], torch.Tensor], neighbor_indices: Union[None, List[torch.Tensor], torch.Tensor], neighbor_shifts: Union[None, List[torch.Tensor], torch.Tensor], ) -> Union[torch.Tensor, List[torch.Tensor]]: @@ -316,7 +324,11 @@ def _compute_impl( neighbor_indices, neighbor_shifts, ) = self._validate_compute_parameters( - positions, cell, charges, neighbor_indices, neighbor_shifts + positions=positions, + charges=charges, + cell=cell, + neighbor_indices=neighbor_indices, + neighbor_shifts=neighbor_shifts, ) # compute and append into a list the features of each structure diff --git a/src/meshlode/calculators/directpotential.py b/src/meshlode/calculators/directpotential.py index a7309241..d242f1da 100644 --- a/src/meshlode/calculators/directpotential.py +++ b/src/meshlode/calculators/directpotential.py @@ -52,12 +52,23 @@ class DirectPotential(CalculatorBaseTorch, _DirectPotentialImpl): infinitely extended three-dimensional Euclidean space. While slow, this implementation used as a reference to test faster algorithms. - :param all_types: Optional global list of all atomic types that should be considered - for the computation. This option might be useful when running the calculation on - subset of a whole dataset and it required to keep the shape of the output - consistent. If this is not set the possible atomic types will be determined when - calling the :meth:`compute()`. :param exponent: the exponent "p" in 1/r^p potentials + + Example + ------- + >>> import torch + + Define simple example structure + + >>> positions = torch.tensor([[0.0, 0.0, 0.0], [0.5, 0.5, 0.5]]) + >>> charges = torch.tensor([1.0, -1.0]).reshape(-1, 1) + + Compute features + + >>> direct = DirectPotential() + >>> direct.compute(positions=positions, charges=charges) + tensor([[-1.1547], + [ 1.1547]]) """ def __init__(self, exponent: float = 1.0): @@ -67,30 +78,24 @@ def __init__(self, exponent: float = 1.0): def compute( self, positions: Union[List[torch.Tensor], torch.Tensor], - charges: Optional[Union[List[torch.Tensor], torch.Tensor]] = None, + charges: Union[List[torch.Tensor], torch.Tensor], ) -> Union[torch.Tensor, List[torch.Tensor]]: """Compute potential for all provided "systems" stacked inside list. - The computation is performed on the same ``device`` as ``systems`` is stored on. - The ``dtype`` of the output tensors will be the same as the input. - - :param positions: single or 2D tensor of shape (len(types), 3) containing the - Cartesian positions of all particles in the system. - :param charges: Optional single or list of 2D tensor of shape (len(types), n), - - :return: List of torch Tensors containing the potentials for all frames and all - atoms. Each tensor in the list is of shape (n_atoms, n_types), where - n_types is the number of types in all systems combined. If the input was - a single system only a single torch tensor with the potentials is returned. - - IMPORTANT: If multiple types are present, the different "types-channels" - are ordered according to atomic number. For example, if a structure contains - a water molecule with atoms 0, 1, 2 being of types O, H, H, then for this - system, the feature tensor will be of shape (3, 2) = (``n_atoms``, - ``n_types``), where ``features[0, 0]`` is the potential at the position of - the Oxygen atom (atom 0, first index) generated by the HYDROGEN(!) atoms, - while ``features[0,1]`` is the potential at the position of the Oxygen atom - generated by the Oxygen atom(s). + The computation is performed on the same ``device`` as ``dtype`` is the input is + stored on. The ``dtype`` of the output tensors will be the same as the input. + + :param positions: Single or 2D tensor of shape (``len(charges), 3``) containing + the Cartesian positions of all point charges in the system. + :param charges: Single 2D tensor or list of 2D tensor of shape (``n_channels, + len(positions))``. ``n_channels`` is the number of charge channels the + potential should be calculated for a standard potential ``n_channels=1``. If + more than one "channel" is provided multiple potentials for the same + position but different are computed. + :return: Single or List of torch Tensors containing the potential(s) for all + positions. Each tensor in the list is of shape ``(len(positions), + len(charges))``, where If the inputs are only single tensors only a single + torch tensor with the potentials is returned. """ return self._compute_impl( diff --git a/src/meshlode/calculators/ewaldpotential.py b/src/meshlode/calculators/ewaldpotential.py index 7ca01b39..be2aaf2d 100644 --- a/src/meshlode/calculators/ewaldpotential.py +++ b/src/meshlode/calculators/ewaldpotential.py @@ -39,37 +39,6 @@ def _compute_single_system( neighbor_indices: Optional[torch.Tensor] = None, neighbor_shifts: Optional[torch.Tensor] = None, ) -> torch.Tensor: - """ - Compute the "electrostatic" potential at the position of all atoms in a - structure. - - :param positions: torch.tensor of shape (n_atoms, 3). Contains the Cartesian - coordinates of the atoms. The implementation also works if the positions - are not contained within the unit cell. - :param charges: torch.tensor of shape `(n_atoms, n_channels)`. In the simplest - case, this would be a tensor of shape (n_atoms, 1) where charges[i,0] is the - charge of atom i. More generally, the potential for the same atom positions - is computed for n_channels independent meshes, and one can specify the - "charge" of each atom on each of the meshes independently. For standard LODE - that treats all (atomic) types separately, one example could be: If n_atoms - = 4 and the types are [Na, Cl, Cl, Na], one could set n_channels=2 and use - the one-hot encoding charges = torch.tensor([[1,0],[0,1],[0,1],[1,0]]) for - the charges. This would then separately compute the "Na" potential and "Cl" - potential. Subtracting these from each other, one could recover the more - standard electrostatic potential in which Na and Cl have charges of +1 and - -1, respectively. - :param cell: torch.tensor of shape `(3, 3)`. Describes the unit cell of the - structure, where cell[i] is the i-th basis vector. - :param neighbor_indices: Optional single or list of 2D tensors of shape (2, n), - where n is the number of atoms. The 2 rows correspond to the indices of - the two atoms which are considered neighbors (e.g. within a cutoff distance) - :param neighbor_shifts: Optional single or list of 2D tensors of shape (3, n), - where n is the number of atoms. The 3 rows correspond to the shift indices - for periodic images. - - :returns: torch.tensor of shape `(n_atoms, n_channels)` containing the potential - at the position of each atom for the `n_channels` independent meshes separately. - """ # Set the defaut values of convergence parameters # The total computational cost = cost of SR part + cost of LR part # Bigger smearing increases the cost of the SR part while decreasing the cost @@ -127,28 +96,6 @@ def _compute_lr( lr_wavelength: torch.Tensor, subtract_self=True, ) -> torch.Tensor: - """ - Compute the long-range part of the Ewald sum in realspace - - :param positions: torch.tensor of shape (n_atoms, 3). Contains the Cartesian - coordinates of the atoms. The implementation also works if the positions - are not contained within the unit cell. - :param charges: torch.tensor of shape `(n_atoms, n_channels)`. In the simplest - case, this would be a tensor of shape (n_atoms, 1) where charges[i,0] is the - charge of atom i. More generally, the potential for the same atom positions - is computed for n_channels independent meshes, and one can specify the - "charge" of each atom on each of the meshes independently. - :param cell: torch.tensor of shape `(3, 3)`. Describes the unit cell of the - structure, where cell[i] is the i-th basis vector. - :param smearing: torch.Tensor smearing paramter determining the splitting - between the SR and LR parts. - :param lr_wavelength: Spatial resolution used for the long-range (reciprocal - space) part of the Ewald sum. More conretely, all Fourier space vectors with - a wavelength >= this value will be kept. - - :returns: torch.tensor of shape `(n_atoms, n_channels)` containing the potential - at the position of each atom for the `n_channels` independent meshes separately. - """ # Define k-space cutoff from required real-space resolution k_cutoff = 2 * torch.pi / lr_wavelength @@ -207,7 +154,7 @@ def _compute_lr( # TODO: modify to expression for general p if subtract_self: self_contrib = ( - torch.sqrt(torch.tensor(2.0 / torch.pi, device=self.device)) / smearing + torch.sqrt(torch.tensor(2.0 / torch.pi, device=self._device)) / smearing ) energy -= charges * self_contrib @@ -245,20 +192,19 @@ class EwaldPotential(CalculatorBaseTorch, _EwaldPotentialImpl): Example ------- >>> import torch - >>> from meshlode import EwaldPotential - Define simple example structure having the CsCl (Cesium Chloride) structure + Define simple example structure having the CsCl (Cesium-Chloride) structure - >>> types = torch.tensor([55, 17]) # Cs and Cl - >>> positions = torch.tensor([[0, 0, 0], [0.5, 0.5, 0.5]]) + >>> positions = torch.tensor([[0.0, 0.0, 0.0], [0.5, 0.5, 0.5]]) + >>> charges = torch.tensor([1.0, -1.0]).reshape(-1, 1) >>> cell = torch.eye(3) Compute features - >>> EP = EwaldPotential() - >>> EP.compute(types=types, positions=positions, cell=cell) - tensor([[-0.7391, -2.7745], - [-2.7745, -0.7391]]) + >>> ewald = EwaldPotential() + >>> ewald.compute(positions=positions, charges=charges, cell=cell) + tensor([[-2.0354], + [ 2.0354]]) """ def __init__( @@ -284,49 +230,43 @@ def __init__( def compute( self, positions: Union[List[torch.Tensor], torch.Tensor], + charges: Union[List[torch.Tensor], torch.Tensor], cell: Union[List[torch.Tensor], torch.Tensor], - charges: Optional[Union[List[torch.Tensor], torch.Tensor]] = None, neighbor_indices: Union[List[torch.Tensor], torch.Tensor] = None, neighbor_shifts: Union[List[torch.Tensor], torch.Tensor] = None, ) -> Union[torch.Tensor, List[torch.Tensor]]: """Compute potential for all provided "systems" stacked inside list. - The computation is performed on the same ``device`` as ``systems`` is stored on. - The ``dtype`` of the output tensors will be the same as the input. + The computation is performed on the same ``device`` as ``dtype`` is the input is + stored on. The ``dtype`` of the output tensors will be the same as the input. - :param positions: single or 2D tensor of shape (len(types), 3) containing the - Cartesian positions of all particles in the system. + :param positions: Single or 2D tensor of shape (``len(charges), 3``) containing + the Cartesian positions of all point charges in the system. + :param charges: Single 2D tensor or list of 2D tensor of shape (``n_channels, + len(positions))``. ``n_channels`` is the number of charge channels the + potential should be calculated for a standard potential ``n_channels=1``. If + more than one "channel" is provided multiple potentials for the same + position but different are computed. :param cell: single or 2D tensor of shape (3, 3), describing the bounding box/unit cell of the system. Each row should be one of the bounding box vector; and columns should contain the x, y, and z components of these vectors (i.e. the cell should be given in row-major order). - :param charges: Optional single or list of 2D tensor of shape (len(types), n), :param neighbor_indices: Optional single or list of 2D tensors of shape (2, n), where n is the number of atoms. The 2 rows correspond to the indices of the two atoms which are considered neighbors (e.g. within a cutoff distance) :param neighbor_shifts: Optional single or list of 2D tensors of shape (3, n), where n is the number of atoms. The 3 rows correspond to the shift indices for periodic images. - - :return: List of torch Tensors containing the potentials for all frames and all - atoms. Each tensor in the list is of shape (n_atoms, n_types), where - n_types is the number of types in all systems combined. If the input was - a single system only a single torch tensor with the potentials is returned. - - IMPORTANT: If multiple types are present, the different "types-channels" - are ordered according to atomic number. For example, if a structure contains - a water molecule with atoms 0, 1, 2 being of types O, H, H, then for this - system, the feature tensor will be of shape (3, 2) = (``n_atoms``, - ``n_types``), where ``features[0, 0]`` is the potential at the position of - the Oxygen atom (atom 0, first index) generated by the HYDROGEN(!) atoms, - while ``features[0,1]`` is the potential at the position of the Oxygen atom - generated by the Oxygen atom(s). + :return: Single or List of torch Tensors containing the potential(s) for all + positions. Each tensor in the list is of shape ``(len(positions), + len(charges))``, where If the inputs are only single tensors only a single + torch tensor with the potentials is returned. """ return self._compute_impl( positions=positions, - cell=cell, charges=charges, + cell=cell, neighbor_indices=neighbor_indices, neighbor_shifts=neighbor_shifts, ) @@ -337,16 +277,16 @@ def compute( def forward( self, positions: Union[List[torch.Tensor], torch.Tensor], + charges: Union[List[torch.Tensor], torch.Tensor], cell: Union[List[torch.Tensor], torch.Tensor], - charges: Optional[Union[List[torch.Tensor], torch.Tensor]] = None, neighbor_indices: Union[List[torch.Tensor], torch.Tensor] = None, neighbor_shifts: Union[List[torch.Tensor], torch.Tensor] = None, ) -> Union[torch.Tensor, List[torch.Tensor]]: """Forward just calls :py:meth:`compute`.""" return self.compute( positions=positions, - cell=cell, charges=charges, + cell=cell, neighbor_indices=neighbor_indices, neighbor_shifts=neighbor_shifts, ) diff --git a/src/meshlode/calculators/pmepotential.py b/src/meshlode/calculators/pmepotential.py index c19539ab..637dfe8f 100644 --- a/src/meshlode/calculators/pmepotential.py +++ b/src/meshlode/calculators/pmepotential.py @@ -56,37 +56,6 @@ def _compute_single_system( neighbor_indices: Optional[torch.Tensor] = None, neighbor_shifts: Optional[torch.Tensor] = None, ) -> torch.Tensor: - """ - Compute the "electrostatic" potential at the position of all atoms in a - structure. - - :param positions: torch.tensor of shape (n_atoms, 3). Contains the Cartesian - coordinates of the atoms. The implementation also works if the positions - are not contained within the unit cell. - :param charges: torch.tensor of shape `(n_atoms, n_channels)`. In the simplest - case, this would be a tensor of shape (n_atoms, 1) where charges[i,0] is the - charge of atom i. More generally, the potential for the same atom positions - is computed for n_channels independent meshes, and one can specify the - "charge" of each atom on each of the meshes independently. For standard LODE - that treats all (atomic) types separately, one example could be: If n_atoms - = 4 and the types are [Na, Cl, Cl, Na], one could set n_channels=2 and use - the one-hot encoding charges = torch.tensor([[1,0],[0,1],[0,1],[1,0]]) for - the charges. This would then separately compute the "Na" potential and "Cl" - potential. Subtracting these from each other, one could recover the more - standard electrostatic potential in which Na and Cl have charges of +1 and - -1, respectively. - :param cell: torch.tensor of shape `(3, 3)`. Describes the unit cell of the - structure, where cell[i] is the i-th basis vector. - :param neighbor_indices: Optional single or list of 2D tensors of shape (2, n), - where n is the number of atoms. The 2 rows correspond to the indices of - the two atoms which are considered neighbors (e.g. within a cutoff distance) - :param neighbor_shifts: Optional single or list of 2D tensors of shape (3, n), - where n is the number of atoms. The 3 rows correspond to the shift indices - for periodic images. - - :returns: torch.tensor of shape `(n_atoms, n_channels)` containing the potential - at the position of each atom for the `n_channels` independent meshes separately. - """ # Set the defaut values of convergence parameters # The total computational cost = cost of SR part + cost of LR part # Bigger smearing increases the cost of the SR part while decreasing the cost @@ -146,28 +115,6 @@ def _compute_lr( lr_wavelength: torch.Tensor, subtract_self=True, ) -> torch.Tensor: - """ - Compute the long-range part of the Ewald sum in realspace - - :param positions: torch.tensor of shape (n_atoms, 3). Contains the Cartesian - coordinates of the atoms. The implementation also works if the positions - are not contained within the unit cell. - :param charges: torch.tensor of shape `(n_atoms, n_channels)`. In the simplest - case, this would be a tensor of shape (n_atoms, 1) where charges[i,0] is the - charge of atom i. More generally, the potential for the same atom positions - is computed for n_channels independent meshes, and one can specify the - "charge" of each atom on each of the meshes independently. - :param cell: torch.tensor of shape `(3, 3)`. Describes the unit cell of the - structure, where cell[i] is the i-th basis vector. - :param smearing: torch.Tensor smearing paramter determining the splitting - between the SR and LR parts. - :param lr_wavelength: Spatial resolution used for the long-range (reciprocal - space) part of the Ewald sum. More conretely, all Fourier space vectors with - a wavelength >= this value will be kept. - - :returns: torch.tensor of shape `(n_atoms, n_channels)` containing the potential - at the position of each atom for the `n_channels` independent meshes separately. - """ # Step 0 (Preparation): Compute number of times each basis vector of the # reciprocal space can be scaled until the cutoff is reached k_cutoff = 2 * torch.pi / lr_wavelength @@ -248,6 +195,23 @@ class PMEPotential(CalculatorBaseTorch, _PMEPotentialImpl): atom the contributions to the potential arising from all atoms within the cutoff Note that if set to true, the self contribution (see previous) is also subtracted by default. + + Example + ------- + >>> import torch + + Define simple example structure having the CsCl (Cesium-Chloride) structure + + >>> positions = torch.tensor([[0.0, 0.0, 0.0], [0.5, 0.5, 0.5]]) + >>> charges = torch.tensor([1.0, -1.0]).reshape(-1, 1) + >>> cell = torch.eye(3) + + Compute features + + >>> pme = PMEPotential() + >>> pme.compute(positions=positions, charges=charges, cell=cell) + tensor([[-2.0384], + [ 2.0384]]) """ def __init__( @@ -275,43 +239,37 @@ def __init__( def compute( self, positions: Union[List[torch.Tensor], torch.Tensor], + charges: Union[List[torch.Tensor], torch.Tensor], cell: Union[List[torch.Tensor], torch.Tensor], - charges: Optional[Union[List[torch.Tensor], torch.Tensor]] = None, neighbor_indices: Union[List[torch.Tensor], torch.Tensor] = None, neighbor_shifts: Union[List[torch.Tensor], torch.Tensor] = None, ) -> Union[torch.Tensor, List[torch.Tensor]]: """Compute potential for all provided "systems" stacked inside list. - The computation is performed on the same ``device`` as ``systems`` is stored on. - The ``dtype`` of the output tensors will be the same as the input. + The computation is performed on the same ``device`` as ``dtype`` is the input is + stored on. The ``dtype`` of the output tensors will be the same as the input. - :param positions: single or 2D tensor of shape (len(types), 3) containing the - Cartesian positions of all particles in the system. + :param positions: Single or 2D tensor of shape (``len(charges), 3``) containing + the Cartesian positions of all point charges in the system. + :param charges: Single 2D tensor or list of 2D tensor of shape (``n_channels, + len(positions))``. ``n_channels`` is the number of charge channels the + potential should be calculated for a standard potential ``n_channels=1``. If + more than one "channel" is provided multiple potentials for the same + position but different are computed. :param cell: single or 2D tensor of shape (3, 3), describing the bounding box/unit cell of the system. Each row should be one of the bounding box vector; and columns should contain the x, y, and z components of these vectors (i.e. the cell should be given in row-major order). - :param charges: Optional single or list of 2D tensor of shape (len(types), n), :param neighbor_indices: Optional single or list of 2D tensors of shape (2, n), where n is the number of atoms. The 2 rows correspond to the indices of the two atoms which are considered neighbors (e.g. within a cutoff distance) :param neighbor_shifts: Optional single or list of 2D tensors of shape (3, n), where n is the number of atoms. The 3 rows correspond to the shift indices for periodic images. - - :return: List of torch Tensors containing the potentials for all frames and all - atoms. Each tensor in the list is of shape (n_atoms, n_types), where - n_types is the number of types in all systems combined. If the input was - a single system only a single torch tensor with the potentials is returned. - - IMPORTANT: If multiple types are present, the different "types-channels" - are ordered according to atomic number. For example, if a structure contains - a water molecule with atoms 0, 1, 2 being of types O, H, H, then for this - system, the feature tensor will be of shape (3, 2) = (``n_atoms``, - ``n_types``), where ``features[0, 0]`` is the potential at the position of - the Oxygen atom (atom 0, first index) generated by the HYDROGEN(!) atoms, - while ``features[0,1]`` is the potential at the position of the Oxygen atom - generated by the Oxygen atom(s). + :return: Single or List of torch Tensors containing the potential(s) for all + positions. Each tensor in the list is of shape ``(len(positions), + len(charges))``, where If the inputs are only single tensors only a single + torch tensor with the potentials is returned. """ return self._compute_impl( @@ -325,16 +283,16 @@ def compute( def forward( self, positions: Union[List[torch.Tensor], torch.Tensor], + charges: Union[List[torch.Tensor], torch.Tensor], cell: Union[List[torch.Tensor], torch.Tensor], - charges: Optional[Union[List[torch.Tensor], torch.Tensor]] = None, neighbor_indices: Union[List[torch.Tensor], torch.Tensor] = None, neighbor_shifts: Union[List[torch.Tensor], torch.Tensor] = None, ) -> Union[torch.Tensor, List[torch.Tensor]]: """Forward just calls :py:meth:`compute`.""" return self.compute( positions=positions, - cell=cell, charges=charges, + cell=cell, neighbor_indices=neighbor_indices, neighbor_shifts=neighbor_shifts, ) diff --git a/src/meshlode/lib/potentials.py b/src/meshlode/lib/potentials.py index be0dcf95..83789261 100644 --- a/src/meshlode/lib/potentials.py +++ b/src/meshlode/lib/potentials.py @@ -121,9 +121,7 @@ def potential_fourier_from_k_sq( smearing parameter corresponds to the "width" of the Gaussian. """ peff = (3 - self.exponent) / 2 - prefac = ( - (math.pi) ** 1.5 / gamma(self.exponent / 2) * (2 * smearing**2) ** peff - ) + prefac = (math.pi) ** 1.5 / gamma(self.exponent / 2) * (2 * smearing**2) ** peff x = 0.5 * smearing**2 * k_sq fourier = prefac * gammaincc(peff, x) / x**peff * gamma(peff) diff --git a/src/meshlode/metatensor/calculators.py b/src/meshlode/metatensor/calculators.py index fe569dd8..35d70767 100644 --- a/src/meshlode/metatensor/calculators.py +++ b/src/meshlode/metatensor/calculators.py @@ -17,14 +17,6 @@ from ..calculators.ewaldpotential import _EwaldPotentialImpl from ..calculators.pmepotential import _PMEPotentialImpl -@torch.jit.script -def _1d_tolist(x: torch.Tensor) -> List[int]: - """Auxilary function to convert 1d torch tensor to list of integers.""" - result: List[int] = [] - for i in x: - result.append(i.item()) - return result - class CalculatorBaseMetatensor(CalculatorBase): def __init__(self, exponent: float): @@ -76,7 +68,7 @@ def compute(self, systems: Union[List[System], System]) -> TensorMap: """Compute potential for all provided ``systems``. All ``systems`` must have the same ``dtype`` and the same ``device``. If each - system contains a custom data field `charges` the potential will be calculated + system contains a custom data field ``charges`` the potential will be calculated for each "charges-channel". The number of `charges-channels` must be same in all ``systems``. If no "explicit" charges are set the potential will be calculated for each "types-channels". @@ -88,9 +80,7 @@ def compute(self, systems: Union[List[System], System]) -> TensorMap: :py:class:`metatensor.torch.atomisic.System` on which to run the calculations. - :return: TensorMap containing the potential of all types. The keys of the - TensorMap are "center_type" and "neighbor_type" if no charges are asociated - with + :return: TensorMap containing the potential of all types. """ systems = self._validate_compute_parameters(systems) potentials: List[torch.Tensor] = [] @@ -100,6 +90,7 @@ def compute(self, systems: Union[List[System], System]) -> TensorMap: # try to extract neighbor list from system object neighbor_indices = None + neighbor_shifts = None for neighbor_list_options in system.known_neighbor_lists(): if ( hasattr(self, "sr_cutoff") @@ -112,26 +103,15 @@ def compute(self, systems: Union[List[System], System]) -> TensorMap: break - if neighbor_indices is None: - potentials.append( - self._compute_single_system( - positions=system.positions, - cell=system.cell, - charges=charges, - neighbor_indices=None, - neighbor_shifts=None, - ) - ) - else: - potentials.append( - self._compute_single_system( - positions=system.positions, - charges=charges, - cell=system.cell, - neighbor_indices=neighbor_indices, - neighbor_shifts=neighbor_shifts, - ) + potentials.append( + self._compute_single_system( + positions=system.positions, + charges=charges, + cell=system.cell, + neighbor_indices=neighbor_indices, + neighbor_shifts=neighbor_shifts, ) + ) values_samples: List[List[int]] = [] for i_system in range(len(systems)): @@ -153,18 +133,18 @@ def compute(self, systems: Union[List[System], System]) -> TensorMap: class DirectPotential(CalculatorBaseMetatensor, _DirectPotentialImpl): """Specie-wise long-range potential using a direct summation over all atoms. - Refer to :class:`meshlode.DirectPotential` for full documentation. + Refer to :class:`meshlode.DirectPotential` for parameter documentation. """ def __init__(self, exponent: float = 1.0): - self._DirectPotentialImpl.__init__(self, exponent=exponent) + _DirectPotentialImpl.__init__(self, exponent=exponent) CalculatorBaseMetatensor.__init__(self, exponent=exponent) class EwaldPotential(CalculatorBaseMetatensor, _EwaldPotentialImpl): """Specie-wise long-range potential computed using the Ewald sum. - Refer to :class:`meshlode.EwaldPotential` for full documentation. + Refer to :class:`meshlode.EwaldPotential` for parameter documentation. """ def __init__( @@ -191,7 +171,7 @@ def __init__( class PMEPotential(CalculatorBaseMetatensor, _PMEPotentialImpl): """Specie-wise long-range potential using a particle mesh-based Ewald (PME). - Refer to :class:`meshlode.PMEPotential` for full documentation. + Refer to :class:`meshlode.PMEPotential` for parameter documentation. """ def __init__( diff --git a/tests/calculators/test_calculator_base.py b/tests/calculators/test_base.py similarity index 65% rename from tests/calculators/test_calculator_base.py rename to tests/calculators/test_base.py index 644bd03b..09d07fd1 100644 --- a/tests/calculators/test_calculator_base.py +++ b/tests/calculators/test_base.py @@ -5,37 +5,37 @@ # Define some example parameters -dtype = torch.float32 -device = "cpu" -charges_1 = torch.ones((4, 1), dtype=dtype, device=device) -positions_1 = 0.3 * torch.arange(12, dtype=dtype, device=device).reshape((4, 3)) -charges_2 = torch.ones((5, 3), dtype=dtype, device=device) -positions_2 = 0.7 * torch.arange(15, dtype=dtype, device=device).reshape((5, 3)) -cell_1 = torch.eye(3, dtype=dtype, device=device) -cell_2 = torch.arange(9, dtype=dtype, device=device).reshape((3, 3)) - - -class TestCalculator(CalculatorBaseTorch): - def compute(self, positions, cell, charges, neighbor_indices, neighbor_shifts): +DTYPE = torch.float32 +DEVICE = "cpu" +CHARGES_1 = torch.ones((4, 1), dtype=DTYPE, device=DEVICE) +POSITIONS_1 = 0.3 * torch.arange(12, dtype=DTYPE, device=DEVICE).reshape((4, 3)) +CHARGES_2 = torch.ones((5, 3), dtype=DTYPE, device=DEVICE) +POSITIONS_2 = 0.7 * torch.arange(15, dtype=DTYPE, device=DEVICE).reshape((5, 3)) +CELL_1 = torch.eye(3, dtype=DTYPE, device=DEVICE) +CELL_2 = torch.arange(9, dtype=DTYPE, device=DEVICE).reshape((3, 3)) + + +class CalculatorTest(CalculatorBaseTorch): + def compute(self, positions, charges, cell, neighbor_indices, neighbor_shifts): return self._compute_impl( positions=positions, - cell=cell, charges=charges, + cell=cell, neighbor_indices=neighbor_indices, neighbor_shifts=neighbor_shifts, ) - def forward(self, positions, cell, charges, neighbor_indices, neighbor_shifts): + def forward(self, positions, charges, cell, neighbor_indices, neighbor_shifts): return self._compute_impl( positions=positions, - cell=cell, charges=charges, + cell=cell, neighbor_indices=neighbor_indices, neighbor_shifts=neighbor_shifts, ) def _compute_single_system( - self, positions, cell, charges, neighbor_indices, neighbor_shifts + self, positions, charges, cell, neighbor_indices, neighbor_shifts ): return charges @@ -53,13 +53,13 @@ def _compute_single_system( ], ) def test_compute_output_shapes(method_name, positions, charges): - calculator = TestCalculator(exponent=1.0) + calculator = CalculatorTest(exponent=1.0) method = getattr(calculator, method_name) result = method( positions=positions, - cell=None, charges=charges, + cell=None, neighbor_indices=None, neighbor_shifts=None, ) @@ -74,105 +74,105 @@ def test_compute_output_shapes(method_name, positions, charges): # Tests for a mismatch in the number of provided inputs for different variables def test_mismatched_numbers_cell(): - calculator = TestCalculator(exponent=1.0) + calculator = CalculatorTest(exponent=1.0) match = r"Got inconsistent numbers of positions \(2\) and cell \(3\)" with pytest.raises(ValueError, match=match): calculator.compute( - positions=[positions_1, positions_2], - cell=[cell_1, cell_2, torch.eye(3)], - charges=[charges_1, charges_2], + positions=[POSITIONS_1, POSITIONS_2], + charges=[CHARGES_1, CHARGES_2], + cell=[CELL_1, CELL_2, torch.eye(3)], neighbor_indices=None, neighbor_shifts=None, ) def test_mismatched_numbers_charges(): - calculator = TestCalculator(exponent=1.0) + calculator = CalculatorTest(exponent=1.0) match = r"Got inconsistent numbers of positions \(2\) and charges \(3\)" with pytest.raises(ValueError, match=match): calculator.compute( - positions=[positions_1, positions_2], + positions=[POSITIONS_1, POSITIONS_2], + charges=[CHARGES_1, CHARGES_2, CHARGES_2], cell=None, - charges=[charges_1, charges_2, charges_2], neighbor_indices=None, neighbor_shifts=None, ) def test_mismatched_numbers_neighbor_indices(): - calculator = TestCalculator(exponent=1.0) + calculator = CalculatorTest(exponent=1.0) match = r"Got inconsistent numbers of positions \(2\) and neighbor_indices \(3\)" with pytest.raises(ValueError, match=match): calculator.compute( - positions=[positions_1, positions_2], + positions=[POSITIONS_1, POSITIONS_2], + charges=[CHARGES_1, CHARGES_2], cell=None, - charges=[charges_1, charges_2], - neighbor_indices=[charges_1, charges_2, positions_1], + neighbor_indices=[CHARGES_1, CHARGES_2, POSITIONS_1], neighbor_shifts=None, ) def test_mismatched_numbers_neighbor_shiftss(): - calculator = TestCalculator(exponent=1.0) + calculator = CalculatorTest(exponent=1.0) match = r"Got inconsistent numbers of positions \(2\) and neighbor_shifts \(3\)" with pytest.raises(ValueError, match=match): calculator.compute( - positions=[positions_1, positions_2], + positions=[POSITIONS_1, POSITIONS_2], + charges=[CHARGES_1, CHARGES_2], cell=None, - charges=[charges_1, charges_2], neighbor_indices=None, - neighbor_shifts=[charges_1, charges_2, positions_1], + neighbor_shifts=[CHARGES_1, CHARGES_2, POSITIONS_1], ) # Tests for invalid shape, dtype and device of positions def test_invalid_shape_positions(): - calculator = TestCalculator(exponent=1.0) + calculator = CalculatorTest(exponent=1.0) match = ( r"each `positions` must be a \(n_atoms x 3\) tensor, got at least " r"one tensor with shape \(4, 5\)" ) with pytest.raises(ValueError, match=match): calculator.compute( - positions=torch.ones((4, 5), dtype=dtype, device=device), + positions=torch.ones((4, 5), dtype=DTYPE, device=DEVICE), + charges=CHARGES_1, cell=None, - charges=charges_1, neighbor_indices=None, neighbor_shifts=None, ) def test_invalid_dtype_positions(): - calculator = TestCalculator(exponent=1.0) + calculator = CalculatorTest(exponent=1.0) match = ( r"each `positions` must have the same type torch.float32 as the " r"first provided one. Got at least one tensor of type " r"torch.float64" ) - positions_2_wrong_dtype = torch.ones((5, 3), dtype=torch.float64, device=device) + positions_2_wrong_dtype = torch.ones((5, 3), dtype=torch.float64, device=DEVICE) with pytest.raises(ValueError, match=match): calculator.compute( - positions=[positions_1, positions_2_wrong_dtype], + positions=[POSITIONS_1, positions_2_wrong_dtype], + charges=[CHARGES_1, CHARGES_2], cell=None, - charges=[charges_1, charges_2], neighbor_indices=None, neighbor_shifts=None, ) def test_invalid_device_positions(): - calculator = TestCalculator(exponent=1.0) + calculator = CalculatorTest(exponent=1.0) match = ( r"each `positions` must be on the same device cpu as the " r"first provided one. Got at least one tensor on device " r"meta" ) - positions_2_wrong_device = torch.ones((5, 3), dtype=dtype, device="meta") + positions_2_wrong_device = torch.ones((5, 3), dtype=DTYPE, device="meta") with pytest.raises(ValueError, match=match): calculator.compute( - positions=[positions_1, positions_2_wrong_device], + positions=[POSITIONS_1, positions_2_wrong_device], + charges=[CHARGES_1, CHARGES_2], cell=None, - charges=[charges_1, charges_2], neighbor_indices=None, neighbor_shifts=None, ) @@ -180,48 +180,48 @@ def test_invalid_device_positions(): # Tests for invalid shape, dtype and device of cell def test_invalid_shape_cell(): - calculator = TestCalculator(exponent=1.0) + calculator = CalculatorTest(exponent=1.0) match = ( r"each `cell` must be a \(3 x 3\) tensor, got at least one tensor with " r"shape \(2, 2\)" ) with pytest.raises(ValueError, match=match): calculator.compute( - positions=positions_1, - cell=torch.ones([2, 2], dtype=dtype, device=device), - charges=charges_1, + positions=POSITIONS_1, + charges=CHARGES_1, + cell=torch.ones([2, 2], dtype=DTYPE, device=DEVICE), neighbor_indices=None, neighbor_shifts=None, ) def test_invalid_dtype_cell(): - calculator = TestCalculator(exponent=1.0) + calculator = CalculatorTest(exponent=1.0) match = ( r"each `cell` must have the same type torch.float32 as positions, " r"got at least one tensor of type torch.float64" ) with pytest.raises(ValueError, match=match): calculator.compute( - positions=positions_1, - cell=torch.ones([3, 3], dtype=torch.float64, device=device), - charges=charges_1, + positions=POSITIONS_1, + charges=CHARGES_1, + cell=torch.ones([3, 3], dtype=torch.float64, device=DEVICE), neighbor_indices=None, neighbor_shifts=None, ) def test_invalid_device_cell(): - calculator = TestCalculator(exponent=1.0) + calculator = CalculatorTest(exponent=1.0) match = ( r"each `cell` must be on the same device cpu as positions, " r"got at least one tensor with device meta" ) with pytest.raises(ValueError, match=match): calculator.compute( - positions=positions_1, - cell=torch.ones([3, 3], dtype=dtype, device="meta"), - charges=charges_1, + positions=POSITIONS_1, + charges=CHARGES_1, + cell=torch.ones([3, 3], dtype=DTYPE, device="meta"), neighbor_indices=None, neighbor_shifts=None, ) @@ -229,7 +229,7 @@ def test_invalid_device_cell(): # Tests for invalid shape, dtype and device of charges def test_invalid_dim_charges(): - calculator = TestCalculator(exponent=1.0) + calculator = CalculatorTest(exponent=1.0) match = ( r"each `charges` needs to be a 2-dimensional tensor, got at least " r"one tensor with 1 dimension\(s\) and shape " @@ -237,16 +237,16 @@ def test_invalid_dim_charges(): ) with pytest.raises(ValueError, match=match): calculator.compute( - positions=positions_1, + positions=POSITIONS_1, + charges=torch.ones(len(POSITIONS_1), dtype=DTYPE, device=DEVICE), cell=None, - charges=torch.ones(len(positions_1), dtype=dtype, device=device), neighbor_indices=None, neighbor_shifts=None, ) def test_invalid_shape_charges(): - calculator = TestCalculator(exponent=1.0) + calculator = CalculatorTest(exponent=1.0) match = ( r"each `charges` must be a \(n_atoms x n_channels\) tensor, with" r"`n_atoms` being the same as the variable `positions`. Got at " @@ -255,41 +255,41 @@ def test_invalid_shape_charges(): ) with pytest.raises(ValueError, match=match): calculator.compute( - positions=positions_1, + positions=POSITIONS_1, + charges=torch.ones((6, 2), dtype=DTYPE, device=DEVICE), cell=None, - charges=torch.ones((6, 2), dtype=dtype, device=device), neighbor_indices=None, neighbor_shifts=None, ) def test_invalid_dtype_charges(): - calculator = TestCalculator(exponent=1.0) + calculator = CalculatorTest(exponent=1.0) match = ( r"each `charges` must have the same type torch.float32 as positions, " r"got at least one tensor of type torch.float64" ) with pytest.raises(ValueError, match=match): calculator.compute( - positions=positions_1, + positions=POSITIONS_1, + charges=torch.ones((4, 2), dtype=torch.float64, device=DEVICE), cell=None, - charges=torch.ones((4, 2), dtype=torch.float64, device=device), neighbor_indices=None, neighbor_shifts=None, ) -def test_invalid_dtype_charges(): - calculator = TestCalculator(exponent=1.0) +def test_invalid_device_charges(): + calculator = CalculatorTest(exponent=1.0) match = ( r"each `charges` must be on the same device cpu as positions, " r"got at least one tensor with device meta" ) with pytest.raises(ValueError, match=match): calculator.compute( - positions=positions_1, + positions=POSITIONS_1, + charges=torch.ones((4, 2), dtype=DTYPE, device="meta"), cell=None, - charges=torch.ones((4, 2), dtype=dtype, device="meta"), neighbor_indices=None, neighbor_shifts=None, ) @@ -297,52 +297,52 @@ def test_invalid_dtype_charges(): # Tests for invalid shape, dtype and device of neighbor_indices and neighbor_shifts def test_need_both_neighbor_indices_and_shifts(): - calculator = TestCalculator(exponent=1.0) - match = r"Need to provide both neighbor_indices and neighbor_shifts together." + calculator = CalculatorTest(exponent=1.0) + match = r"Need to provide both `neighbor_indices` and `neighbor_shifts` together." with pytest.raises(ValueError, match=match): calculator.compute( - positions=positions_1, + positions=POSITIONS_1, + charges=CHARGES_1, cell=None, - charges=charges_1, - neighbor_indices=torch.ones((2, 10), dtype=dtype, device=device), + neighbor_indices=torch.ones((2, 10), dtype=DTYPE, device=DEVICE), neighbor_shifts=None, ) def test_invalid_shape_neighbor_indices(): - calculator = TestCalculator(exponent=1.0) + calculator = CalculatorTest(exponent=1.0) match = ( r"neighbor_indices is expected to have shape \(2, num_neighbors\)" r", but got \(4, 10\) for one structure" ) with pytest.raises(ValueError, match=match): calculator.compute( - positions=positions_1, + positions=POSITIONS_1, + charges=CHARGES_1, cell=None, - charges=charges_1, - neighbor_indices=torch.ones((4, 10), dtype=dtype, device=device), - neighbor_shifts=torch.ones((10, 3), dtype=dtype, device=device), + neighbor_indices=torch.ones((4, 10), dtype=DTYPE, device=DEVICE), + neighbor_shifts=torch.ones((10, 3), dtype=DTYPE, device=DEVICE), ) def test_invalid_shape_neighbor_shifts(): - calculator = TestCalculator(exponent=1.0) + calculator = CalculatorTest(exponent=1.0) match = ( r"neighbor_shifts is expected to have shape \(num_neighbors, 3\)" r", but got \(10, 2\) for one structure" ) with pytest.raises(ValueError, match=match): calculator.compute( - positions=positions_1, + positions=POSITIONS_1, + charges=CHARGES_1, cell=None, - charges=charges_1, - neighbor_indices=torch.ones((2, 10), dtype=dtype, device=device), - neighbor_shifts=torch.ones((10, 2), dtype=dtype, device=device), + neighbor_indices=torch.ones((2, 10), dtype=DTYPE, device=DEVICE), + neighbor_shifts=torch.ones((10, 2), dtype=DTYPE, device=DEVICE), ) -def test_invalid_shape_neighbor_shifts(): - calculator = TestCalculator(exponent=1.0) +def test_invalid_shape_neighbor_indices_neighbor_shifts(): + calculator = CalculatorTest(exponent=1.0) match = ( r"`neighbor_indices` and `neighbor_shifts` need to have shapes " r"\(2, num_neighbors\) and \(num_neighbors, 3\). For at least one" @@ -351,41 +351,41 @@ def test_invalid_shape_neighbor_shifts(): ) with pytest.raises(ValueError, match=match): calculator.compute( - positions=positions_1, + positions=POSITIONS_1, + charges=CHARGES_1, cell=None, - charges=charges_1, - neighbor_indices=torch.ones((2, 10), dtype=dtype, device=device), - neighbor_shifts=torch.ones((11, 3), dtype=dtype, device=device), + neighbor_indices=torch.ones((2, 10), dtype=DTYPE, device=DEVICE), + neighbor_shifts=torch.ones((11, 3), dtype=DTYPE, device=DEVICE), ) def test_invalid_device_neighbor_indices(): - calculator = TestCalculator(exponent=1.0) + calculator = CalculatorTest(exponent=1.0) match = ( r"each `neighbor_indices` must be on the same device cpu as positions, " r"got at least one tensor with device meta" ) with pytest.raises(ValueError, match=match): calculator.compute( - positions=positions_1, + positions=POSITIONS_1, + charges=CHARGES_1, cell=None, - charges=charges_1, - neighbor_indices=torch.ones((2, 10), dtype=dtype, device="meta"), - neighbor_shifts=torch.ones((10, 3), dtype=dtype, device=device), + neighbor_indices=torch.ones((2, 10), dtype=DTYPE, device="meta"), + neighbor_shifts=torch.ones((10, 3), dtype=DTYPE, device=DEVICE), ) def test_invalid_device_neighbor_shifts(): - calculator = TestCalculator(exponent=1.0) + calculator = CalculatorTest(exponent=1.0) match = ( r"each `neighbor_shifts` must be on the same device cpu as positions, " r"got at least one tensor with device meta" ) with pytest.raises(ValueError, match=match): calculator.compute( - positions=positions_1, + positions=POSITIONS_1, + charges=CHARGES_1, cell=None, - charges=charges_1, - neighbor_indices=torch.ones((2, 10), dtype=dtype, device=device), - neighbor_shifts=torch.ones((10, 3), dtype=dtype, device="meta"), + neighbor_indices=torch.ones((2, 10), dtype=DTYPE, device=DEVICE), + neighbor_shifts=torch.ones((10, 3), dtype=DTYPE, device="meta"), ) diff --git a/tests/calculators/test_values_aperiodic.py b/tests/calculators/test_values_aperiodic.py index 1b899c24..c2332fc4 100644 --- a/tests/calculators/test_values_aperiodic.py +++ b/tests/calculators/test_values_aperiodic.py @@ -20,24 +20,22 @@ def define_molecule(molecule_name="dimer"): # Start defining molecules # Dimer if molecule_name == "dimer": - types = torch.tensor([1, 1]) positions = torch.tensor([[0.0, 0, 0], [0, 0, 1.0]], dtype=dtype) charges = torch.tensor([1.0, -1.0], dtype=dtype) potentials = torch.tensor([-1.0, 1], dtype=dtype) elif molecule_name == "dimer_positive": - types, positions, charges, potentials = define_molecule("dimer") + positions, charges, potentials = define_molecule("dimer") charges = torch.tensor([1.0, 1], dtype=dtype) potentials = torch.tensor([1.0, 1], dtype=dtype) elif molecule_name == "dimer_negative": - types, positions, charges, potentials = define_molecule("dimer_positive") + positions, charges, potentials = define_molecule("dimer_positive") charges *= -1.0 potentials *= -1.0 # Equilateral triangle elif molecule_name == "triangle": - types = torch.tensor([1, 1, 1]) positions = torch.tensor( [[0.0, 0, 0], [1, 0, 0], [1 / 2, SQRT3 / 2, 0]], dtype=dtype ) @@ -45,18 +43,17 @@ def define_molecule(molecule_name="dimer"): potentials = torch.tensor([-1.0, 1, 0], dtype=dtype) elif molecule_name == "triangle_positive": - types, positions, charges, potentials = define_molecule("triangle") + positions, charges, potentials = define_molecule("triangle") charges = torch.tensor([1.0, 1, 1], dtype=dtype) potentials = torch.tensor([2.0, 2, 2], dtype=dtype) elif molecule_name == "triangle_negative": - types, positions, charges, potentials = define_molecule("triangle_positive") + positions, charges, potentials = define_molecule("triangle_positive") charges *= -1.0 potentials *= -1.0 # Squares (planar) elif molecule_name == "square": - types = torch.tensor([1, 1, 1, 1]) positions = torch.tensor( [[1, 1, 0], [1, -1, 0], [-1, 1, 0], [-1, -1, 0]], dtype=dtype ) @@ -65,18 +62,17 @@ def define_molecule(molecule_name="dimer"): potentials = charges * (1.0 / SQRT2 - 2.0) elif molecule_name == "square_positive": - types, positions, charges, potentials = define_molecule("square") + positions, charges, potentials = define_molecule("square") charges = torch.tensor([1.0, 1, 1, 1], dtype=dtype) potentials = (2.0 + 1.0 / SQRT2) * torch.ones(4, dtype=dtype) elif molecule_name == "square_negative": - types, positions, charges, potentials = define_molecule("square_positive") + positions, charges, potentials = define_molecule("square_positive") charges *= -1.0 potentials *= -1.0 # Tetrahedra elif molecule_name == "tetrahedron": - types = torch.tensor([1, 1, 1, 1]) positions = torch.tensor( [ [0.0, 0, 0], @@ -90,18 +86,18 @@ def define_molecule(molecule_name="dimer"): potentials = -charges elif molecule_name == "tetrahedron_positive": - types, positions, charges, potentials = define_molecule("tetrahedron") + positions, charges, potentials = define_molecule("tetrahedron") charges = torch.ones(4, dtype=dtype) potentials = 3 * torch.ones(4, dtype=dtype) elif molecule_name == "tetrahedron_negative": - types, positions, charges, potentials = define_molecule("tetrahedron_positive") + positions, charges, potentials = define_molecule("tetrahedron_positive") charges *= -1.0 potentials *= -1.0 charges = charges.reshape((-1, 1)) potentials = potentials.reshape((-1, 1)) - return types, positions, charges, potentials + return positions, charges, potentials def generate_orthogonal_transformations(): @@ -159,13 +155,13 @@ def test_coulomb_exact( """ # Call Ewald potential class without specifying any of the convergence parameters # so that they are chosen by default (in a structure-dependent way) - DP = DirectPotential() + direct = DirectPotential() # Compute potential at the position of the atoms for the specified structure molecule_name = molecule + molecule_charge - types, positions, charges, ref_potentials = define_molecule(molecule_name) + positions, charges, ref_potentials = define_molecule(molecule_name) positions = scaling_factor * (positions @ orthogonal_transformation) - potentials = DP.compute(positions, charges=charges) + potentials = direct.compute(positions, charges=charges) ref_potentials /= scaling_factor torch.testing.assert_close(potentials, ref_potentials, atol=2e-15, rtol=1e-14) diff --git a/tests/calculators/test_values_periodic.py b/tests/calculators/test_values_periodic.py index b86a1609..7f9245d2 100644 --- a/tests/calculators/test_values_periodic.py +++ b/tests/calculators/test_values_periodic.py @@ -316,7 +316,7 @@ def test_madelung(crystal_name, scaling_factor, calc_name): rtol = 9e-4 # Compute potential and compare against target value using default hypers - potentials = calc.compute(positions=pos, cell=cell, charges=charges) + potentials = calc.compute(positions=pos, charges=charges, cell=cell) energies = potentials * charges madelung = -torch.sum(energies) / 2 / num_units @@ -373,7 +373,7 @@ def test_wigner(crystal_name, scaling_factor): # Compute potential and compare against reference calc = EwaldPotential(atomic_smearing=smeareff) - potentials = calc.compute(positions=positions, cell=cell, charges=charges) + potentials = calc.compute(positions=positions, charges=charges, cell=cell) energies = potentials * charges energies_ref = -torch.ones_like(energies) * madelung_ref torch.testing.assert_close(energies, energies_ref, atol=0.0, rtol=rtol) @@ -418,7 +418,6 @@ def test_random_structure(sr_cutoff, frame_index, scaling_factor, ortho, calc_na positions.requires_grad = True cell = scaling_factor * torch.tensor(np.array(frame.cell), dtype=dtype) @ ortho charges = torch.tensor([1, 1, 1, 1, -1, -1, -1, -1], dtype=dtype).reshape((-1, 1)) - types = torch.tensor([1, 1, 1, 1, 2, 2, 2, 2]) # Compute potential using MeshLODE and compare against reference values sr_cutoff = scaling_factor * torch.tensor(sr_cutoff, dtype=dtype) @@ -430,7 +429,8 @@ def test_random_structure(sr_cutoff, frame_index, scaling_factor, ortho, calc_na calc = PMEPotential(sr_cutoff=sr_cutoff) rtol_e = 4.5e-3 # 1.5e-3 rtol_f = 2.5e-3 # 6e-3 - potentials = calc.compute(positions=positions, cell=cell, charges=charges) + + potentials = calc.compute(positions=positions, charges=charges, cell=cell) # Compute energy, taking into account the double counting of each pair energy = torch.sum(potentials * charges) / 2 diff --git a/tests/calculators/test_calculators_workflow.py b/tests/calculators/test_workflow.py similarity index 92% rename from tests/calculators/test_calculators_workflow.py rename to tests/calculators/test_workflow.py index 1b531051..26985820 100644 --- a/tests/calculators/test_calculators_workflow.py +++ b/tests/calculators/test_workflow.py @@ -50,10 +50,10 @@ class TestWorkflow: def cscl_system(self, periodic): """CsCl crystal. Same as in the madelung test""" positions = torch.tensor([[0, 0, 0], [0.5, 0.5, 0.5]]) - cell = torch.eye(3) charges = torch.tensor([1.0, -1.0]).reshape((-1, 1)) if periodic: - return positions, cell, charges + cell = torch.eye(3) + return positions, charges, cell else: return positions, charges @@ -86,7 +86,7 @@ def test_interpolation_order_error(self, CalculatorClass, params, periodic): def test_multi_frame(self, CalculatorClass, periodic, params): calculator = self.calculator(CalculatorClass, periodic, params) if periodic: - positions, cell, charges = self.cscl_system(periodic) + positions, charges, cell = self.cscl_system(periodic) l_values = calculator.compute( positions=[positions, positions], cell=[cell, cell], @@ -118,7 +118,7 @@ def test_dtype_device(self, CalculatorClass, periodic, params): if periodic: cell = torch.eye(3, dtype=dtype, device=device) potential = calculator.compute( - positions=positions, cell=cell, charges=charges + positions=positions, charges=charges, cell=cell ) else: potential = calculator.compute(positions=positions, charges=charges) @@ -127,14 +127,14 @@ def test_dtype_device(self, CalculatorClass, periodic, params): assert potential.device.type == device # Make sure that the calculators are computing the features without raising errors, - # and returns the correct output format (TensorMap) + # and returns the correct output format (torch.Tensor) def check_operation(self, CalculatorClass, periodic, params): calculator = self.calculator(CalculatorClass, periodic, params) if periodic: - positions, cell, charges = self.cscl_system(periodic) + positions, charges, cell = self.cscl_system(periodic) descriptor = calculator.compute( - positions=positions, cell=cell, charges=charges + positions=positions, charges=charges, cell=cell ) else: positions, charges = self.cscl_system(periodic) diff --git a/tests/lib/test_potentials.py b/tests/lib/test_potentials.py index 46527b9a..99b2da4a 100644 --- a/tests/lib/test_potentials.py +++ b/tests/lib/test_potentials.py @@ -226,8 +226,6 @@ def test_lr_value_at_zero(exponent, smearing): potential_close_to_zero = ipl.potential_lr_from_dist(dist_small, smearing=smearing) # Compare to - exact_value = ( - 1.0 / (2 * smearing**2) ** (exponent / 2) / gamma(exponent / 2 + 1.0) - ) + exact_value = 1.0 / (2 * smearing**2) ** (exponent / 2) / gamma(exponent / 2 + 1.0) relerr = torch.abs(potential_close_to_zero - exact_value) / exact_value assert relerr.item() < 3e-14 diff --git a/tests/metatensor/test_calculators.py b/tests/metatensor/test_calculators.py new file mode 100644 index 00000000..09d95fa2 --- /dev/null +++ b/tests/metatensor/test_calculators.py @@ -0,0 +1,99 @@ +""" +Madelung tests +""" + +import pytest +import torch +from packaging import version + + +meshlode_metatensor = pytest.importorskip("meshlode.metatensor") +mts_torch = pytest.importorskip("metatensor.torch") +mts_atomistic = pytest.importorskip("metatensor.torch.atomistic") + + +ATOMIC_SMEARING = 0.1 +LR_WAVELENGTH = ATOMIC_SMEARING / 4 +MESH_SPACING = ATOMIC_SMEARING / 4 +INTERPOLATION_ORDER = 2 +SUBTRACT_SELF = True + + +@pytest.mark.parametrize( + "CalculatorClass, params", + [ + (meshlode_metatensor.DirectPotential, {}), + ( + meshlode_metatensor.EwaldPotential, + { + "atomic_smearing": ATOMIC_SMEARING, + "lr_wavelength": LR_WAVELENGTH, + "subtract_self": SUBTRACT_SELF, + }, + ), + ( + meshlode_metatensor.PMEPotential, + { + "atomic_smearing": ATOMIC_SMEARING, + "mesh_spacing": MESH_SPACING, + "interpolation_order": INTERPOLATION_ORDER, + "subtract_self": SUBTRACT_SELF, + }, + ), + ], +) +class TestWorkflow: + def cscl_system(self): + """CsCl crystal. Same as in the madelung test""" + + system = mts_atomistic.System( + types=torch.tensor([17, 55]), + positions=torch.tensor([[0, 0, 0], [0.5, 0.5, 0.5]]), + cell=torch.eye(3), + ) + + data = mts_torch.TensorBlock( + values=torch.tensor([-1.0, 1.0]).reshape(-1, 1), + samples=mts_torch.Labels.range("atom", len(system)), + components=[], + properties=mts_torch.Labels("charge", torch.tensor([[0]])), + ) + system.add_data(name="charges", data=data) + + return system + + def calculator(self, CalculatorClass, params): + return CalculatorClass(**params) + + def test_forward(self, CalculatorClass, params): + calculator = self.calculator(CalculatorClass, params) + descriptor_compute = calculator.compute(self.cscl_system()) + descriptor_forward = calculator.forward(self.cscl_system()) + + assert isinstance(descriptor_compute, torch.ScriptObject) + assert isinstance(descriptor_forward, torch.ScriptObject) + if version.parse(torch.__version__) >= version.parse("2.1"): + assert descriptor_compute._type().name() == "TensorMap" + assert descriptor_forward._type().name() == "TensorMap" + + assert mts_torch.equal(descriptor_forward, descriptor_compute) + + # Make sure that the calculators are computing the features without raising errors, + # and returns the correct output format (TensorMap) + def check_operation(self, CalculatorClass, params): + calculator = self.calculator(CalculatorClass, params) + descriptor = calculator.compute(self.cscl_system()) + + assert isinstance(descriptor, torch.ScriptObject) + if version.parse(torch.__version__) >= version.parse("2.1"): + assert descriptor._type().name() == "TensorMap" + + # Run the above test as a normal python script + def test_operation_as_python(self, CalculatorClass, params): + self.check_operation(CalculatorClass, params) + + # Similar to the above, but also testing that the code can be compiled as a torch + # script + # def test_operation_as_torch_script(self, CalculatorClass, params): + # scripted = torch.jit.script(CalculatorClass, params) + # self.check_operation(scripted) diff --git a/tests/metatensor/test_madelung.py b/tests/metatensor/test_madelung.py deleted file mode 100644 index cb855f68..00000000 --- a/tests/metatensor/test_madelung.py +++ /dev/null @@ -1,228 +0,0 @@ -""" -Madelung tests -""" - -import pytest -import torch -from torch.testing import assert_close - - -meshlode_metatensor = pytest.importorskip("meshlode.metatensor") -mts_atomistic = pytest.importorskip("metatensor.torch.atomistic") - - -class TestMadelung: - """ - Test features computed in PMEPotential correspond to the "electrostatic" potential - of the structures. We thus compare the computed potential against the known exact - values for some simple crystal structures. - """ - - scaling_factors = torch.tensor([0.5, 1.2, 3.3]) - crystal_list = ["NaCl", "CsCl", "ZnS", "ZnSO4"] - crystal_list_powers_of_2 = ["NaCl", "CsCl", "ZnS"] - - @pytest.fixture - def crystal_dictionary(self): - """ - Define the parameters of the three binary crystal structures: - NaCl, CsCl and ZnCl. The reference values of the Madelung - constants is taken from the book "Solid State Physics" - by Ashcroft and Mermin. - - Note: Symbols and charges keys have to be sorted according to their - atomic number in ascending alternating order! For an example see - ZnS04 in the wurtzite structure. - """ - # Initialize dictionary for crystal paramaters - d = {k: {} for k in self.crystal_list} - SQRT3 = torch.sqrt(torch.tensor(3)) - - # NaCl structure - # Using a primitive unit cell, the distance between the - # closest Na-Cl pair is exactly 1. The cubic unit cell - # in these units would have a length of 2. - d["NaCl"]["symbols"] = ["Na", "Cl"] - d["NaCl"]["types"] = torch.tensor([11, 17]) - d["NaCl"]["charges"] = torch.tensor([[1.0, -1]]).T - d["NaCl"]["positions"] = torch.tensor([[0, 0, 0], [1.0, 0, 0]]) - d["NaCl"]["cell"] = torch.tensor([[0, 1.0, 1], [1, 0, 1], [1, 1, 0]]) - d["NaCl"]["madelung"] = 1.7476 - - # CsCl structure - # This structure is simple since the primitive unit cell - # is just the usual cubic cell with side length set to one. - # The closest Cs-Cl distance is sqrt(3)/2. We thus divide - # the Madelung constant by this value to match the reference. - d["CsCl"]["symbols"] = ["Cs", "Cl"] - d["CsCl"]["types"] = torch.tensor([55, 17]) - d["CsCl"]["charges"] = torch.tensor([[1.0, -1]]).T - d["CsCl"]["positions"] = torch.tensor([[0, 0, 0], [0.5, 0.5, 0.5]]) - d["CsCl"]["cell"] = torch.eye(3) - d["CsCl"]["madelung"] = 2 * 1.7626 / SQRT3 - - # ZnS (zincblende) structure - # As for NaCl, a primitive unit cell is used which makes - # the lattice parameter of the cubic cell equal to 2. - # In these units, the closest Zn-S distance is sqrt(3)/2. - # We thus divide the Madelung constant by this value. - # If, on the other han_pylode_without_centerd, we set the lattice constant of - # the cubic cell equal to 1, the Zn-S distance is sqrt(3)/4. - d["ZnS"]["symbols"] = ["S", "Zn"] - d["ZnS"]["types"] = torch.tensor([16, 30]) - d["ZnS"]["charges"] = torch.tensor([[1.0, -1]]).T - d["ZnS"]["positions"] = torch.tensor([[0, 0, 0], [0.5, 0.5, 0.5]]) - d["ZnS"]["cell"] = torch.tensor([[0, 1.0, 1], [1, 0, 1], [1, 1, 0]]) - d["ZnS"]["madelung"] = 2 * 1.6381 / SQRT3 - - # ZnS (O4) in wurtzite structure (triclinic cell) - u = torch.tensor([3 / 8]) - c = torch.sqrt(1 / u) - d["ZnSO4"]["symbols"] = ["S", "Zn", "S", "Zn"] - d["ZnSO4"]["types"] = torch.tensor([16, 30, 16, 30]) - d["ZnSO4"]["charges"] = torch.tensor([[1.0, -1, 1, -1]]).T - d["ZnSO4"]["positions"] = torch.tensor( - [ - [0.5, 0.5 / SQRT3, 0.0], - [0.5, 0.5 / SQRT3, u * c], - [0.5, -0.5 / SQRT3, 0.5 * c], - [0.5, -0.5 / SQRT3, (0.5 + u) * c], - ] - ) - d["ZnSO4"]["cell"] = torch.tensor( - [[0.5, -0.5 * SQRT3, 0], [0.5, 0.5 * SQRT3, 0], [0, 0, c]] - ) - - d["ZnSO4"]["madelung"] = 1.6413 / (u * c) - - return d - - @pytest.mark.parametrize("crystal_name", crystal_list_powers_of_2) - @pytest.mark.parametrize("atomic_smearing", [0.1, 0.05]) - @pytest.mark.parametrize("interpolation_order", [1, 2]) - @pytest.mark.parametrize("scaling_factor", scaling_factors) - def test_madelung_low_order( - self, - crystal_dictionary, - crystal_name, - atomic_smearing, - scaling_factor, - interpolation_order, - ): - """ - For low interpolation orders, if the atoms already lie exactly on a mesh point, - there are no additional errors due to atomic_smearing the charges. Thus, we can - reach a relatively high accuracy. - """ - dic = crystal_dictionary[crystal_name] - positions = dic["positions"] * scaling_factor - cell = dic["cell"] * scaling_factor - charges = dic["charges"] - madelung = dic["madelung"] / scaling_factor - mesh_spacing = atomic_smearing / 2 * scaling_factor - smearing_eff = atomic_smearing * scaling_factor - MP = meshlode_metatensor.PMEPotential( - atomic_smearing=smearing_eff, - mesh_spacing=mesh_spacing, - interpolation_order=interpolation_order, - subtract_self=True, - ) - potentials_mesh = MP._compute_single_system( - positions=positions, - charges=charges, - cell=cell, - neighbor_indices=None, - neighbor_shifts=None, - ) - energies = potentials_mesh * charges - energies_target = -torch.ones_like(energies) * madelung - assert_close(energies, energies_target, rtol=1e-4, atol=1e-6) - - @pytest.mark.parametrize("crystal_name", crystal_list) - @pytest.mark.parametrize("atomic_smearing", [0.2, 0.12]) - @pytest.mark.parametrize("interpolation_order", [3, 4, 5]) - @pytest.mark.parametrize("scaling_factor", scaling_factors) - def test_madelung_high_order( - self, - crystal_dictionary, - crystal_name, - atomic_smearing, - scaling_factor, - interpolation_order, - ): - """ - For high interpolation order, the current naive implementation used to subtract - the center contribution introduces additional errors since an atom is smeared - onto multiple mesh points, turning the short-range correction into a more - complicated expression that has not yet been implemented. Thus, we use a much - larger tolerance of 1e-2 for the precision needed in the calculation. - """ - dic = crystal_dictionary[crystal_name] - positions = dic["positions"] * scaling_factor - cell = dic["cell"] * scaling_factor - charges = dic["charges"] - madelung = dic["madelung"] / scaling_factor - mesh_spacing = atomic_smearing / 10 * scaling_factor - smearing_eff = atomic_smearing * scaling_factor - MP = meshlode_metatensor.PMEPotential( - atomic_smearing=smearing_eff, - mesh_spacing=mesh_spacing, - interpolation_order=interpolation_order, - subtract_self=True, - ) - potentials_mesh = MP._compute_single_system( - positions=positions, - charges=charges, - cell=cell, - neighbor_indices=None, - neighbor_shifts=None, - ) - energies = potentials_mesh * charges - energies_target = -torch.ones_like(energies) * madelung - assert_close(energies, energies_target, rtol=1e-2, atol=1e-3) - - @pytest.mark.parametrize("crystal_name", crystal_list_powers_of_2) - @pytest.mark.parametrize("atomic_smearing", [0.1, 0.05]) - @pytest.mark.parametrize("interpolation_order", [1, 2]) - @pytest.mark.parametrize("scaling_factor", scaling_factors) - def test_madelung_low_order_metatensor( - self, - crystal_dictionary, - crystal_name, - atomic_smearing, - scaling_factor, - interpolation_order, - ): - """ - Same test as above but now using the main compute function of the class that is - actually facing the user and outputting in metatensor format. - """ - dic = crystal_dictionary[crystal_name] - positions = dic["positions"] * scaling_factor - cell = dic["cell"] * scaling_factor - types = dic["types"] - charges = dic["charges"] - madelung = dic["madelung"] / scaling_factor - mesh_spacing = atomic_smearing / 2 * scaling_factor - smearing_eff = atomic_smearing * scaling_factor - n_atoms = len(positions) - system = mts_atomistic.System(types=types, positions=positions, cell=cell) - MP = meshlode_metatensor.PMEPotential( - atomic_smearing=smearing_eff, - mesh_spacing=mesh_spacing, - interpolation_order=interpolation_order, - subtract_self=True, - ) - potentials_mesh = MP.compute(system) - - # Compute the actual potential from the features - energies = torch.zeros((n_atoms, 1)) - for idx_c, c in enumerate(types): - for idx_n, n in enumerate(types): - block = potentials_mesh.block( - {"center_type": int(c), "neighbor_type": int(n)} - ) - energies[idx_c] += charges[idx_c] * charges[idx_n] * block.values[0, 0] - - energies_ref = -madelung * torch.ones((n_atoms, 1)) - assert_close(energies, energies_ref, rtol=1e-4, atol=1e-6) diff --git a/tox.ini b/tox.ini index b47f295d..6f87f465 100644 --- a/tox.ini +++ b/tox.ini @@ -4,15 +4,28 @@ envlist = build tests + + + +[testenv] +passenv = * lint_folders = "{toxinidir}/src" \ "{toxinidir}/tests" \ "{toxinidir}/docs/src/" \ "{toxinidir}/examples" - -[testenv] -passenv = * -test_options = --cov --cov-append --cov-report= --import-mode=append +warning_options = \ + -W "ignore:ast.Str is deprecated and will be removed in Python 3.14:DeprecationWarning" \ + -W "ignore:Attribute s is deprecated and will be removed in Python 3.14:DeprecationWarning" \ + -W "ignore:ast.NameConstant is deprecated and will be removed in Python 3.14:DeprecationWarning" +# the "-W ignore" flags above are for PyTorch, which triggers a bunch of +# internal warnings with Python 3.12 +test_options = \ + --cov \ + --cov-append \ + --cov-report= \ + --import-mode=append \ + {[testenv]warning_options} [testenv:build] description = Asserts package build integrity. @@ -72,13 +85,16 @@ deps = isort sphinx-lint commands = - flake8 {[tox]lint_folders} - black --check --diff {[tox]lint_folders} - blackdoc --check --diff {[tox]lint_folders} "{toxinidir}/README.rst" - isort --check-only --diff {[tox]lint_folders} - mypy src/meshlode - sphinx-lint --enable line-too-long --max-line-length 88 \ - -i {[tox]lint_folders} "{toxinidir}/README.rst" + flake8 {[testenv]lint_folders} + black --check --diff {[testenv]lint_folders} + blackdoc --check --diff {[testenv]lint_folders} "{toxinidir}/README.rst" + isort --check-only --diff {[testenv]lint_folders} + mypy {[testenv]lint_folders} + sphinx-lint \ + --enable all \ + --max-line-length 88 \ + -i "{toxinidir}/docs/src/examples" \ + {[testenv]lint_folders} "{toxinidir}/README.rst" [testenv:format] description = Abuse tox to do actual formatting on all files. @@ -88,9 +104,9 @@ deps = blackdoc isort commands = - black {[tox]lint_folders} - blackdoc {[tox]lint_folders} "{toxinidir}/README.rst" - isort {[tox]lint_folders} + black {[testenv]lint_folders} + blackdoc {[testenv]lint_folders} "{toxinidir}/README.rst" + isort {[testenv]lint_folders} [testenv:docs] description = Building the package documentation. From ef5be851ef0d26b1df7b493be051d87945ebc7f2 Mon Sep 17 00:00:00 2001 From: Philip Loche Date: Wed, 10 Jul 2024 11:56:08 +0200 Subject: [PATCH 31/35] add docs --- docs/src/references/calculators/index.rst | 2 - docs/src/references/lib/index.rst | 4 +- docs/src/references/lib/kvectors.rst | 6 + .../references/metatensor/directpotential.rst | 6 + .../references/metatensor/ewaldpotential.rst | 6 + .../{meshpotential.rst => pmepotential.rst} | 0 examples/madelung.py | 152 ------------------ examples/neighborlist_example.py | 61 ++++--- src/meshlode/calculators/base.py | 8 +- src/meshlode/calculators/directpotential.py | 13 +- src/meshlode/calculators/ewaldpotential.py | 11 +- src/meshlode/calculators/pmepotential.py | 11 +- src/meshlode/lib/potentials.py | 11 +- src/meshlode/metatensor/__init__.py | 6 +- .../metatensor/{calculators.py => base.py} | 88 ++-------- src/meshlode/metatensor/directpotential.py | 55 +++++++ src/meshlode/metatensor/ewaldpotential.py | 75 +++++++++ src/meshlode/metatensor/pmepotential.py | 77 +++++++++ tests/calculators/test_base.py | 7 +- tests/{init.py => test_init.py} | 0 tox.ini | 7 +- 21 files changed, 316 insertions(+), 290 deletions(-) create mode 100644 docs/src/references/lib/kvectors.rst create mode 100644 docs/src/references/metatensor/directpotential.rst create mode 100644 docs/src/references/metatensor/ewaldpotential.rst rename docs/src/references/metatensor/{meshpotential.rst => pmepotential.rst} (100%) delete mode 100644 examples/madelung.py rename src/meshlode/metatensor/{calculators.py => base.py} (63%) create mode 100644 src/meshlode/metatensor/directpotential.py create mode 100644 src/meshlode/metatensor/ewaldpotential.py create mode 100644 src/meshlode/metatensor/pmepotential.py rename tests/{init.py => test_init.py} (100%) diff --git a/docs/src/references/calculators/index.rst b/docs/src/references/calculators/index.rst index fc7113ea..9a40532d 100644 --- a/docs/src/references/calculators/index.rst +++ b/docs/src/references/calculators/index.rst @@ -14,8 +14,6 @@ Calculators return the representations as a :py:obj:`List` of :py:class:`torch.T We also provide a return values as a :py:class:`metatensor.TensorMap` in :ref:`metatensor`. -.. automodule:: meshlode.calculators - .. toctree:: :maxdepth: 1 :glob: diff --git a/docs/src/references/lib/index.rst b/docs/src/references/lib/index.rst index fb2dc6c3..84d99650 100644 --- a/docs/src/references/lib/index.rst +++ b/docs/src/references/lib/index.rst @@ -6,6 +6,6 @@ are used for the meshLODE calculators. .. toctree:: :maxdepth: 1 + :glob: - fourier_convolution - mesh_interpolator + ./* diff --git a/docs/src/references/lib/kvectors.rst b/docs/src/references/lib/kvectors.rst new file mode 100644 index 00000000..8cd041b2 --- /dev/null +++ b/docs/src/references/lib/kvectors.rst @@ -0,0 +1,6 @@ +Kvectors +======== + +.. automodule:: meshlode.lib.kvectors + :members: + :undoc-members: diff --git a/docs/src/references/metatensor/directpotential.rst b/docs/src/references/metatensor/directpotential.rst new file mode 100644 index 00000000..f0d52897 --- /dev/null +++ b/docs/src/references/metatensor/directpotential.rst @@ -0,0 +1,6 @@ +DirectPotential +############### + +.. autoclass:: meshlode.metatensor.DirectPotential + :members: + :undoc-members: diff --git a/docs/src/references/metatensor/ewaldpotential.rst b/docs/src/references/metatensor/ewaldpotential.rst new file mode 100644 index 00000000..12926e9a --- /dev/null +++ b/docs/src/references/metatensor/ewaldpotential.rst @@ -0,0 +1,6 @@ +EwaldPotential +############## + +.. autoclass:: meshlode.metatensor.EwaldPotential + :members: + :undoc-members: diff --git a/docs/src/references/metatensor/meshpotential.rst b/docs/src/references/metatensor/pmepotential.rst similarity index 100% rename from docs/src/references/metatensor/meshpotential.rst rename to docs/src/references/metatensor/pmepotential.rst diff --git a/examples/madelung.py b/examples/madelung.py deleted file mode 100644 index b9785f4c..00000000 --- a/examples/madelung.py +++ /dev/null @@ -1,152 +0,0 @@ -""" -Compute Madelung Constants -========================== -In this tutorial we show how to calculate the Madelung constants and total electrostatic -energy of atomic structures using the :py:class:`meshlode.PMEPotential` and -:py:class:`meshlode.metatensor.PMEPotential` calculator. -""" - -# %% -import math - -import torch -from metatensor.torch import Labels, TensorBlock -from metatensor.torch.atomistic import System - -import meshlode - - -# %% -# Define simple example structure having the CsCl structure and compute the reference -# values. PMEPotential by default outputs the types sorted according to the atomic -# number. Thus, we input the compound "CsCl" and "ClCs" since Cl and Cs have atomic -# numbers 17 and 55, respectively. -types = torch.tensor([17, 55]) # Cl and Cs -positions = torch.tensor([[0, 0, 0], [0.5, 0.5, 0.5]]) -charges = torch.tensor([-1.0, 1.0]).reshape(-1, 1) -cell = torch.eye(3) - -# %% -# Define the expected values of the energy -n_atoms = len(positions) -madelung = 2 * 1.7626 / math.sqrt(3) -energies_ref = -madelung * torch.ones((n_atoms, 1)) - -# %% -# We first define general parameters for our calculation MeshLODE - -atomic_smearing = 0.1 -cell = torch.eye(3) -mesh_spacing = atomic_smearing / 4 -interpolation_order = 2 - -# %% -# Computation using ``meshlode`` -# ------------------------------ -# Compute features using - -pme = meshlode.PMEPotential( - atomic_smearing=atomic_smearing, - mesh_spacing=mesh_spacing, - interpolation_order=interpolation_order, - subtract_self=True, -) -potentials_torch: torch.Tensor = pme.compute( - positions=positions, charges=charges, cell=cell -) - -# %% -# The "potentials" that have been computed so far are not the actual electrostatic -# potentials. For instance, for the Cs atom, we are separately storing the contributions -# to the potential (at the location of the Cs atom) from the Cs atoms and Cl atoms -# separately. Thus, to get the Madelung constant, we need to take a linear combination -# of these "potentials" weighted by the charges of the atoms. - -atomic_energies_torch = torch.zeros((n_atoms, 1)) -for idx_c in range(n_atoms): - for idx_n in range(n_atoms): - # The coulomb potential between atoms i and j is charge_i * charge_j / d_ij - # The features are simply computing a pure 1/r potential with no prefactors. - # Thus, to compute the energy between atoms of types i and j, we need to - # multiply by the charges of i and j. - print(charges[idx_c] * charges[idx_n], potentials_torch[idx_n, idx_c]) - atomic_energies_torch[idx_c] += ( - charges[idx_c] * charges[idx_n] * potentials_torch[idx_c, idx_n] - ) - -# %% -# The total energy is just the sum of all atomic energies -total_energy_torch = torch.sum(atomic_energies_torch) - -# %% -# Compare against reference Madelung constant and reference energy: -print("Using the torch version") -print(f"Computed energies on each atom = {atomic_energies_torch.tolist()}") -print(f"Reference Madelung constant = {madelung:.3f}") -print(f"Total energy = {total_energy_torch:.3f}\n") - - -# %% -# Computation using ``meshlode.metatensor`` -# ----------------------------------------- -# We now compute the same constants using the metatensor based calculator. To achieve -# this we first store our system parameters like the ``types``, ``positions`` and the -# ``cell`` defined above into a :py:class:`metatensor.torch.atomistic.System` class. - -system = System(types=types, positions=positions, cell=cell) - -# %% -# Attach charges to the system. - -data = TensorBlock( - values=charges, - samples=Labels.range("atom", len(system)), - components=[], - properties=Labels("charge", torch.tensor([[0]])), -) -system.add_data(name="charges", data=data) - - -# %% -# Perform the calculation. - -pme = meshlode.metatensor.PMEPotential( - atomic_smearing=atomic_smearing, - mesh_spacing=mesh_spacing, - interpolation_order=interpolation_order, - subtract_self=True, -) -potential_metatensor = pme.compute(system) - - -# %% -# To get the Madelung constant, we again need to take a linear combination -# of the "potentials" weighted by the charges of the atoms. - -atomic_energies_metatensor = torch.zeros((n_atoms, 1)) -for idx_c, c in enumerate(types): - for idx_n, n in enumerate(types): - # Take the coefficients with the correct center atom and neighbor atom types - block = potential_metatensor.block( - {"center_type": int(c), "neighbor_type": int(n)} - ) - - # The coulomb potential between atoms i and j is charge_i * charge_j / d_ij - # The features are simply computing a pure 1/r potential with no prefactors. - # Thus, to compute the energy between atoms of types i and j, we need to - # multiply by the charges of i and j. - print(c, n, charges[idx_c] * charges[idx_n], block.values[0, 0]) - atomic_energies_metatensor[idx_c] += ( - charges[idx_c] * charges[idx_n] * block.values[0, 0] - ) - -# %% -# The total energy is just the sum of all atomic energies -total_energy_metatensor = torch.sum(atomic_energies_metatensor) - -# %% -# Compare against reference Madelung constant and reference energy: -print("Using the metatensor version") -print(f"Computed energies on each atom = {atomic_energies_metatensor.tolist()}") -print(f"Reference Madelung constant = {madelung:.3f}") -print(f"Total energy = {total_energy_metatensor:.3f}") diff --git a/examples/neighborlist_example.py b/examples/neighborlist_example.py index 6a6d4c7c..e4992f01 100644 --- a/examples/neighborlist_example.py +++ b/examples/neighborlist_example.py @@ -3,7 +3,7 @@ ========================================= This example will explain how to use the metatensor branch of Meshlode with an attached -neighborlist to a :py:class:`metatensor.torch.atomistic.System` object. +neighborlist to an :py:class:`metatensor.torch.atomistic.System` object. """ # %% @@ -28,7 +28,7 @@ types = torch.tensor([17, 55]) # Cl and Cs positions = torch.tensor([[0, 0, 0], [0.5, 0.5, 0.5]]) -charges = torch.tensor([-1.0, 1.0]) +charges = torch.tensor([-1.0, 1.0]).reshape(-1, 1) cell = torch.eye(3) # %% @@ -77,52 +77,47 @@ # %% -# Attach ``neighbors`` to ``system`` object. +# Define the system. system = System(types=types, positions=positions, cell=cell) +# %% +# Attach charges to the system. + +data = TensorBlock( + values=charges, + samples=Labels.range("atom", len(system)), + components=[], + properties=Labels("charge", torch.tensor([[0]])), +) +system.add_data(name="charges", data=data) + +# %% +# Attach ``neighbors`` to ``system`` object. + nl_options = NeighborListOptions(cutoff=sr_cutoff, full_list=True) system.add_neighbor_list(options=nl_options, neighbors=neighbors) -MP = meshlode.metatensor.PMEPotential( +pme = meshlode.metatensor.PMEPotential( atomic_smearing=atomic_smearing, mesh_spacing=mesh_spacing, interpolation_order=interpolation_order, subtract_self=True, sr_cutoff=sr_cutoff, ) -potential_metatensor = MP.compute(system) - - -# %% -# Convert to Madelung constant and check that the value is correct - -atomic_energies_metatensor = torch.zeros((n_atoms, 1)) -for idx_c, c in enumerate(types): - for idx_n, n in enumerate(types): - # Take the coefficients with the correct center atom and neighbor atom types - block = potential_metatensor.block( - {"center_type": int(c), "neighbor_type": int(n)} - ) - - # The coulomb potential between atoms i and j is charge_i * charge_j / d_ij - # The features are simply computing a pure 1/r potential with no prefactors. - # Thus, to compute the energy between atoms of types i and j, we need to - # multiply by the charges of i and j. - print(c, n, charges[idx_c] * charges[idx_n], block.values[0, 0]) - atomic_energies_metatensor[idx_c] += ( - charges[idx_c] * charges[idx_n] * block.values[0, 0] - ) +potential = pme.compute(system) # %% # The total energy is just the sum of all atomic energies -total_energy_metatensor = torch.sum(atomic_energies_metatensor) +print(potential) -# %% -# Compare against reference Madelung constant and reference energy: +# total_energy_metatensor = torch.sum(potential[0].values) + +# # %% +# # Compare against reference Madelung constant and reference energy: -print("Using the metatensor version") -print(f"Computed energies on each atom = {atomic_energies_metatensor.tolist()}") -print(f"Reference Madelung constant = {madelung:.3f}") -print(f"Total energy = {total_energy_metatensor:.3f}") +# print("Using the metatensor version") +# print(f"Computed energies on each atom = {potential[0].values.tolist()}") +# print(f"Reference Madelung constant = {madelung:.3f}") +# print(f"Total energy = {total_energy_metatensor[0].values}") diff --git a/src/meshlode/calculators/base.py b/src/meshlode/calculators/base.py index 0201d458..830015d1 100644 --- a/src/meshlode/calculators/base.py +++ b/src/meshlode/calculators/base.py @@ -26,7 +26,7 @@ def _compute_sr( positions: torch.Tensor, charges: torch.Tensor, cell: torch.Tensor, - smearing: torch.Tensor, + smearing: float, sr_cutoff: torch.Tensor, neighbor_indices: Optional[torch.Tensor] = None, neighbor_shifts: Optional[torch.Tensor] = None, @@ -66,12 +66,10 @@ def _compute_sr( atom_is = torch.tensor(atom_is) atom_js = torch.tensor(atom_js) shifts = torch.tensor(neighbor_shifts, dtype=cell.dtype) # N x 3 - else: atom_is = neighbor_indices[0] atom_js = neighbor_indices[1] - shifts = neighbor_shifts.T - shifts.dtype = cell.dtype + shifts = neighbor_shifts.type(cell.dtype).T # Compute energy potential = torch.zeros_like(charges) @@ -100,7 +98,7 @@ class CalculatorBaseTorch(CalculatorBase): """ Base calculator for the torch interface to MeshLODE. - :param exponent: the exponent "p" in 1/r^p potentials + :param exponent: the exponent :math:`p` in :math:`1/r^p` potentials """ def __init__( diff --git a/src/meshlode/calculators/directpotential.py b/src/meshlode/calculators/directpotential.py index d242f1da..26207910 100644 --- a/src/meshlode/calculators/directpotential.py +++ b/src/meshlode/calculators/directpotential.py @@ -52,23 +52,28 @@ class DirectPotential(CalculatorBaseTorch, _DirectPotentialImpl): infinitely extended three-dimensional Euclidean space. While slow, this implementation used as a reference to test faster algorithms. - :param exponent: the exponent "p" in 1/r^p potentials + :param exponent: the exponent :math:`p` in :math:`1/r^p` potentials Example ------- + We compute the energy of two charges which are sepearated by 2 along the z-axis. + >>> import torch Define simple example structure - >>> positions = torch.tensor([[0.0, 0.0, 0.0], [0.5, 0.5, 0.5]]) + >>> positions = torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 2.0]]) >>> charges = torch.tensor([1.0, -1.0]).reshape(-1, 1) Compute features >>> direct = DirectPotential() >>> direct.compute(positions=positions, charges=charges) - tensor([[-1.1547], - [ 1.1547]]) + tensor([[-0.5000], + [ 0.5000]]) + + Which is the expected potential since :math:`V \propto 1/r` where :math:`r` is the + distance between the particles. """ def __init__(self, exponent: float = 1.0): diff --git a/src/meshlode/calculators/ewaldpotential.py b/src/meshlode/calculators/ewaldpotential.py index be2aaf2d..20bf877e 100644 --- a/src/meshlode/calculators/ewaldpotential.py +++ b/src/meshlode/calculators/ewaldpotential.py @@ -167,7 +167,7 @@ class EwaldPotential(CalculatorBaseTorch, _EwaldPotentialImpl): Scaling as :math:`\mathcal{O}(N^2)` with respect to the number of particles :math:`N`. - :param exponent: the exponent "p" in 1/r^p potentials + :param exponent: the exponent :math:`p` in :math:`1/r^p` potentials :param sr_cutoff: Cutoff radius used for the short-range part of the Ewald sum. If not set to a global value, it will be set to be half of the shortest lattice vector defining the cell (separately for each structure). @@ -191,20 +191,25 @@ class EwaldPotential(CalculatorBaseTorch, _EwaldPotentialImpl): Example ------- + We calculate the Madelung constant of a CsCl (Cesium-Chloride) crystal. The + reference value is :math:`2 \cdot 1.7626 / \sqrt{3} \approx 2.0354`. + >>> import torch - Define simple example structure having the CsCl (Cesium-Chloride) structure + Define crystal structure >>> positions = torch.tensor([[0.0, 0.0, 0.0], [0.5, 0.5, 0.5]]) >>> charges = torch.tensor([1.0, -1.0]).reshape(-1, 1) >>> cell = torch.eye(3) - Compute features + Compute the potential >>> ewald = EwaldPotential() >>> ewald.compute(positions=positions, charges=charges, cell=cell) tensor([[-2.0354], [ 2.0354]]) + + Which is the same as the reference value given above. """ def __init__( diff --git a/src/meshlode/calculators/pmepotential.py b/src/meshlode/calculators/pmepotential.py index 637dfe8f..22d7a06b 100644 --- a/src/meshlode/calculators/pmepotential.py +++ b/src/meshlode/calculators/pmepotential.py @@ -173,7 +173,7 @@ class PMEPotential(CalculatorBaseTorch, _PMEPotentialImpl): subset of a whole dataset and it required to keep the shape of the output consistent. If this is not set the possible atomic types will be determined when calling the :meth:`compute()`. - :param exponent: the exponent "p" in 1/r^p potentials + :param exponent: the exponent :math:`p` in :math:`1/r^p` potentials :param sr_cutoff: Cutoff radius used for the short-range part of the Ewald sum. If not set to a global value, it will be set to be half of the shortest lattice vector defining the cell (separately for each structure). @@ -198,20 +198,25 @@ class PMEPotential(CalculatorBaseTorch, _PMEPotentialImpl): Example ------- + We calculate the Madelung constant of a CsCl (Cesium-Chloride) crystal. The + reference value is :math:`2 \cdot 1.7626 / \sqrt{3} \approx 2.0354`. + >>> import torch - Define simple example structure having the CsCl (Cesium-Chloride) structure + Define crystal structure >>> positions = torch.tensor([[0.0, 0.0, 0.0], [0.5, 0.5, 0.5]]) >>> charges = torch.tensor([1.0, -1.0]).reshape(-1, 1) >>> cell = torch.eye(3) - Compute features + Compute the potential >>> pme = PMEPotential() >>> pme.compute(positions=positions, charges=charges, cell=cell) tensor([[-2.0384], [ 2.0384]]) + + Which is the close the reference value given above. """ def __init__( diff --git a/src/meshlode/lib/potentials.py b/src/meshlode/lib/potentials.py index 83789261..63be3e05 100644 --- a/src/meshlode/lib/potentials.py +++ b/src/meshlode/lib/potentials.py @@ -1,3 +1,5 @@ +from typing import Union + import math import torch @@ -22,11 +24,14 @@ class InversePowerLawPotential: length-scale parameter (called "smearing" in the code) 3. the Fourier transform of the LR part - :param exponent: the exponent "p" in 1/r^p potentials + :param exponent: the exponent :math:`p` in :math:`1/r^p` potentials """ - def __init__(self, exponent: float): - self.exponent = torch.tensor(exponent) + def __init__(self, exponent: Union[float, torch.Tensor]): + if type(exponent) is float: + self.exponent = torch.tensor(exponent) + else: + self.exponent = exponent def potential_from_dist(self, dist: torch.Tensor) -> torch.Tensor: """ diff --git a/src/meshlode/metatensor/__init__.py b/src/meshlode/metatensor/__init__.py index f52bf854..c83447c8 100644 --- a/src/meshlode/metatensor/__init__.py +++ b/src/meshlode/metatensor/__init__.py @@ -1,3 +1,5 @@ -from .calculators import PMEPotential, EwaldPotential, DirectPotential +from .ewaldpotential import EwaldPotential +from .directpotential import DirectPotential +from .pmepotential import PMEPotential -__all__ = ["DirectPotential", "EwaldPotential", "PMEPotential"] +__all__ = ["EwaldPotential", "DirectPotential", "PMEPotential"] diff --git a/src/meshlode/metatensor/calculators.py b/src/meshlode/metatensor/base.py similarity index 63% rename from src/meshlode/metatensor/calculators.py rename to src/meshlode/metatensor/base.py index 35d70767..6690caf1 100644 --- a/src/meshlode/metatensor/calculators.py +++ b/src/meshlode/metatensor/base.py @@ -1,4 +1,5 @@ -from typing import List, Optional, Union +import warnings +from typing import List, Union import torch @@ -13,9 +14,6 @@ ) from ..calculators.base import CalculatorBase -from ..calculators.directpotential import _DirectPotentialImpl -from ..calculators.ewaldpotential import _EwaldPotentialImpl -from ..calculators.pmepotential import _PMEPotentialImpl class CalculatorBaseMetatensor(CalculatorBase): @@ -52,7 +50,12 @@ def _validate_compute_parameters( if not torch.all(has_charges): raise ValueError("`systems` do not consistently contain `charges` data") - self._n_charges_channels = systems[0].get_data("charges").values.shape[1] + # Metatensor will issue a warning because `charges` are not a default member of + # a System object + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + self._n_charges_channels = systems[0].get_data("charges").values.shape[1] + for i_system, system in enumerate(systems): n_channels = system.get_data("charges").values.shape[1] if n_channels != self._n_charges_channels: @@ -86,7 +89,9 @@ def compute(self, systems: Union[List[System], System]) -> TensorMap: potentials: List[torch.Tensor] = [] for system in systems: - charges = system.get_data("charges").values + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + charges = system.get_data("charges").values # try to extract neighbor list from system object neighbor_indices = None @@ -98,8 +103,8 @@ def compute(self, systems: Union[List[System], System]) -> TensorMap: ): neighbor_list = system.get_neighbor_list(neighbor_list_options) - neighbor_indices = neighbor_list.samples.values[:, :2] - neighbor_shifts = neighbor_list.samples.values[:, 2:] + neighbor_indices = neighbor_list.samples.values[:, :2].T + neighbor_shifts = neighbor_list.samples.values[:, 2:].T break @@ -128,70 +133,3 @@ def compute(self, systems: Union[List[System], System]) -> TensorMap: ) return TensorMap(keys=Labels.single(), blocks=[block]) - - -class DirectPotential(CalculatorBaseMetatensor, _DirectPotentialImpl): - """Specie-wise long-range potential using a direct summation over all atoms. - - Refer to :class:`meshlode.DirectPotential` for parameter documentation. - """ - - def __init__(self, exponent: float = 1.0): - _DirectPotentialImpl.__init__(self, exponent=exponent) - CalculatorBaseMetatensor.__init__(self, exponent=exponent) - - -class EwaldPotential(CalculatorBaseMetatensor, _EwaldPotentialImpl): - """Specie-wise long-range potential computed using the Ewald sum. - - Refer to :class:`meshlode.EwaldPotential` for parameter documentation. - """ - - def __init__( - self, - exponent: float = 1.0, - sr_cutoff: Optional[torch.Tensor] = None, - atomic_smearing: Optional[float] = None, - lr_wavelength: Optional[float] = None, - subtract_self: Optional[bool] = True, - subtract_interior: Optional[bool] = False, - ): - _EwaldPotentialImpl.__init__( - self, - exponent=exponent, - sr_cutoff=sr_cutoff, - atomic_smearing=atomic_smearing, - lr_wavelength=lr_wavelength, - subtract_self=subtract_self, - subtract_interior=subtract_interior, - ) - CalculatorBaseMetatensor.__init__(self, exponent=exponent) - - -class PMEPotential(CalculatorBaseMetatensor, _PMEPotentialImpl): - """Specie-wise long-range potential using a particle mesh-based Ewald (PME). - - Refer to :class:`meshlode.PMEPotential` for parameter documentation. - """ - - def __init__( - self, - exponent: float = 1.0, - sr_cutoff: Optional[torch.Tensor] = None, - atomic_smearing: Optional[float] = None, - mesh_spacing: Optional[float] = None, - interpolation_order: Optional[int] = 3, - subtract_self: Optional[bool] = True, - subtract_interior: Optional[bool] = False, - ): - _PMEPotentialImpl.__init__( - self, - exponent=exponent, - sr_cutoff=sr_cutoff, - atomic_smearing=atomic_smearing, - mesh_spacing=mesh_spacing, - interpolation_order=interpolation_order, - subtract_self=subtract_self, - subtract_interior=subtract_interior, - ) - CalculatorBaseMetatensor.__init__(self, exponent=exponent) diff --git a/src/meshlode/metatensor/directpotential.py b/src/meshlode/metatensor/directpotential.py new file mode 100644 index 00000000..7d989062 --- /dev/null +++ b/src/meshlode/metatensor/directpotential.py @@ -0,0 +1,55 @@ +from ..calculators.directpotential import _DirectPotentialImpl +from .base import CalculatorBaseMetatensor + + +class DirectPotential(CalculatorBaseMetatensor, _DirectPotentialImpl): + """Specie-wise long-range potential using a direct summation over all atoms. + + Refer to :class:`meshlode.DirectPotential` for parameter documentation. + + Example + ------- + We compute the energy of two charges which are sepearated by 2 along the z-axis. + + >>> import torch + >>> from metatensor.torch import Labels, TensorBlock + >>> from metatensor.torch.atomistic import System + + Define simple example structure + + >>> system = System( + ... types=torch.tensor([1, 1]), + ... positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 2.0]]), + ... cell=torch.zeros([3, 3]), + ... ) + + Next we attach the charges to our ``system`` + + >>> charges = torch.tensor([1.0, -1.0]).reshape(-1, 1) + >>> data = TensorBlock( + ... values=charges, + ... samples=Labels.range("atom", len(system)), + ... components=[], + ... properties=Labels("charge", torch.tensor([[0]])), + ... ) + >>> system.add_data(name="charges", data=data) + + and compute the potenial + + >>> direct = DirectPotential() + >>> potential = direct.compute(system) + + The results are stored inside the ``values`` property inside the first + :py:class:`TensorBlock ` of the ``potential``. + + >>> potential[0].values + tensor([[-0.5000], + [ 0.5000]]) + + Which is the expected potential since :math:`V \propto 1/r` where :math:`r` is the + distance between the particles. + """ + + def __init__(self, exponent: float = 1.0): + _DirectPotentialImpl.__init__(self, exponent=exponent) + CalculatorBaseMetatensor.__init__(self, exponent=exponent) diff --git a/src/meshlode/metatensor/ewaldpotential.py b/src/meshlode/metatensor/ewaldpotential.py new file mode 100644 index 00000000..a831c954 --- /dev/null +++ b/src/meshlode/metatensor/ewaldpotential.py @@ -0,0 +1,75 @@ +from typing import Optional + +import torch + +from ..calculators.ewaldpotential import _EwaldPotentialImpl +from .base import CalculatorBaseMetatensor + + +class EwaldPotential(CalculatorBaseMetatensor, _EwaldPotentialImpl): + """Specie-wise long-range potential computed using the Ewald sum. + + Refer to :class:`meshlode.EwaldPotential` for parameter documentation. + + Example + ------- + We calculate the Madelung constant of a CsCl (Cesium-Chloride) crystal. The + reference value is :math:`2 \cdot 1.7626 / \sqrt{3} \approx 2.0354`. + + >>> import torch + >>> from metatensor.torch import Labels, TensorBlock + >>> from metatensor.torch.atomistic import System + + Define simple example structure + + >>> system = System( + ... types=torch.tensor([55, 17]), + ... positions=torch.tensor([[0.0, 0.0, 0.0], [0.5, 0.5, 0.5]]), + ... cell=torch.eye(3), + ... ) + + Next we attach the charges to our ``system`` + + >>> charges = torch.tensor([1.0, -1.0]).reshape(-1, 1) + >>> data = TensorBlock( + ... values=charges, + ... samples=Labels.range("atom", len(system)), + ... components=[], + ... properties=Labels("charge", torch.tensor([[0]])), + ... ) + >>> system.add_data(name="charges", data=data) + + and compute the potenial + + >>> ewald = EwaldPotential() + >>> potential = ewald.compute(system) + + The results are stored inside the ``values`` property inside the first + :py:class:`TensorBlock ` of the ``potential``. + + >>> potential[0].values + tensor([[-2.0354], + [ 2.0354]]) + + Which is the same as the reference value given above. + """ + + def __init__( + self, + exponent: float = 1.0, + sr_cutoff: Optional[torch.Tensor] = None, + atomic_smearing: Optional[float] = None, + lr_wavelength: Optional[float] = None, + subtract_self: Optional[bool] = True, + subtract_interior: Optional[bool] = False, + ): + _EwaldPotentialImpl.__init__( + self, + exponent=exponent, + sr_cutoff=sr_cutoff, + atomic_smearing=atomic_smearing, + lr_wavelength=lr_wavelength, + subtract_self=subtract_self, + subtract_interior=subtract_interior, + ) + CalculatorBaseMetatensor.__init__(self, exponent=exponent) diff --git a/src/meshlode/metatensor/pmepotential.py b/src/meshlode/metatensor/pmepotential.py new file mode 100644 index 00000000..f4a25b4d --- /dev/null +++ b/src/meshlode/metatensor/pmepotential.py @@ -0,0 +1,77 @@ +from typing import Optional + +import torch + +from ..calculators.pmepotential import _PMEPotentialImpl +from .base import CalculatorBaseMetatensor + + +class PMEPotential(CalculatorBaseMetatensor, _PMEPotentialImpl): + """Specie-wise long-range potential using a particle mesh-based Ewald (PME). + + Refer to :class:`meshlode.PMEPotential` for parameter documentation. + + Example + ------- + We calculate the Madelung constant of a CsCl (Cesium-Chloride) crystal. The + reference value is :math:`2 \cdot 1.7626 / \sqrt{3} \approx 2.0354`. + + >>> import torch + >>> from metatensor.torch import Labels, TensorBlock + >>> from metatensor.torch.atomistic import System + + Define simple example structure + + >>> system = System( + ... types=torch.tensor([55, 17]), + ... positions=torch.tensor([[0.0, 0.0, 0.0], [0.5, 0.5, 0.5]]), + ... cell=torch.eye(3), + ... ) + + Next we attach the charges to our ``system`` + + >>> charges = torch.tensor([1.0, -1.0]).reshape(-1, 1) + >>> data = TensorBlock( + ... values=charges, + ... samples=Labels.range("atom", len(system)), + ... components=[], + ... properties=Labels("charge", torch.tensor([[0]])), + ... ) + >>> system.add_data(name="charges", data=data) + + and compute the potenial + + >>> pme = PMEPotential() + >>> potential = pme.compute(system) + + The results are stored inside the ``values`` property inside the first + :py:class:`TensorBlock ` of the ``potential``. + + >>> potential[0].values + tensor([[-2.0384], + [ 2.0384]]) + + Which is the same as the reference value given above. + """ + + def __init__( + self, + exponent: float = 1.0, + sr_cutoff: Optional[torch.Tensor] = None, + atomic_smearing: Optional[float] = None, + mesh_spacing: Optional[float] = None, + interpolation_order: Optional[int] = 3, + subtract_self: Optional[bool] = True, + subtract_interior: Optional[bool] = False, + ): + _PMEPotentialImpl.__init__( + self, + exponent=exponent, + sr_cutoff=sr_cutoff, + atomic_smearing=atomic_smearing, + mesh_spacing=mesh_spacing, + interpolation_order=interpolation_order, + subtract_self=subtract_self, + subtract_interior=subtract_interior, + ) + CalculatorBaseMetatensor.__init__(self, exponent=exponent) diff --git a/tests/calculators/test_base.py b/tests/calculators/test_base.py index 09d07fd1..7eb3641a 100644 --- a/tests/calculators/test_base.py +++ b/tests/calculators/test_base.py @@ -1,7 +1,7 @@ import pytest import torch -from meshlode.calculators.base import CalculatorBaseTorch +from meshlode.calculators.base import CalculatorBase, CalculatorBaseTorch # Define some example parameters @@ -15,6 +15,11 @@ CELL_2 = torch.arange(9, dtype=DTYPE, device=DEVICE).reshape((3, 3)) +def test_CalculatorBase(): + calculator = CalculatorBase(exponent=5.0) + assert calculator.exponent == 5.0 + + class CalculatorTest(CalculatorBaseTorch): def compute(self, positions, charges, cell, neighbor_indices, neighbor_shifts): return self._compute_impl( diff --git a/tests/init.py b/tests/test_init.py similarity index 100% rename from tests/init.py rename to tests/test_init.py diff --git a/tox.ini b/tox.ini index 6f87f465..9f533953 100644 --- a/tox.ini +++ b/tox.ini @@ -4,9 +4,6 @@ envlist = build tests - - - [testenv] passenv = * lint_folders = @@ -14,7 +11,7 @@ lint_folders = "{toxinidir}/tests" \ "{toxinidir}/docs/src/" \ "{toxinidir}/examples" -warning_options = \ +warning_options = -W error \ -W "ignore:ast.Str is deprecated and will be removed in Python 3.14:DeprecationWarning" \ -W "ignore:Attribute s is deprecated and will be removed in Python 3.14:DeprecationWarning" \ -W "ignore:ast.NameConstant is deprecated and will be removed in Python 3.14:DeprecationWarning" @@ -58,7 +55,7 @@ commands = pytest {[testenv]test_options} {posargs} # Run documentation tests - pytest --doctest-modules --pyargs meshlode {posargs} + pytest {[testenv]warning_options} --doctest-modules --pyargs meshlode {posargs} [testenv:tests-min] description = Run the minimal core tests with pytest and {basepython}. From 22fd41f0269cf1462227baf6e86c373a559d007d Mon Sep 17 00:00:00 2001 From: Philip Loche Date: Wed, 10 Jul 2024 13:39:22 +0200 Subject: [PATCH 32/35] finish tests --- README.rst | 5 +- docs/src/conf.py | 2 +- examples/neighborlist_example.py | 4 +- src/meshlode/lib/potentials.py | 3 +- src/meshlode/metatensor/directpotential.py | 6 +- src/meshlode/metatensor/ewaldpotential.py | 6 +- src/meshlode/metatensor/pmepotential.py | 6 +- tests/__init__.py | 5 - tests/metatensor/test_base_metatensor.py | 195 +++++++++++++++++++ tests/metatensor/test_workflow_metatensor.py | 99 ++++++++++ 10 files changed, 309 insertions(+), 22 deletions(-) delete mode 100644 tests/__init__.py create mode 100644 tests/metatensor/test_base_metatensor.py create mode 100644 tests/metatensor/test_workflow_metatensor.py diff --git a/README.rst b/README.rst index 3ad71d1e..035876dc 100644 --- a/README.rst +++ b/README.rst @@ -22,9 +22,8 @@ You can install *MeshLode* using pip with You can then ``import meshlode`` and use it in your projects! -We also provide bindings to `metatensor -`_ which can optionally be installed -together and used as ``meshlode.metatensor`` via +We also provide bindings to `metatensor `_ which can +optionally be installed together and used as ``meshlode.metatensor`` via .. code-block:: bash diff --git a/docs/src/conf.py b/docs/src/conf.py index 2e7071cf..1ebbdbcc 100644 --- a/docs/src/conf.py +++ b/docs/src/conf.py @@ -59,7 +59,7 @@ "python": ("https://docs.python.org/3", None), "numpy": ("https://numpy.org/doc/stable/", None), "torch": ("https://pytorch.org/docs/stable/", None), - "metatensor": ("https://lab-cosmo.github.io/metatensor/latest/", None), + "metatensor": ("https://docs.metatensor.org/latest/", None), } # -- Options for HTML output ------------------------------------------------- diff --git a/examples/neighborlist_example.py b/examples/neighborlist_example.py index e4992f01..5ca0647a 100644 --- a/examples/neighborlist_example.py +++ b/examples/neighborlist_example.py @@ -86,9 +86,9 @@ data = TensorBlock( values=charges, - samples=Labels.range("atom", len(system)), + samples=Labels.range("atom", charges.shape[0]), components=[], - properties=Labels("charge", torch.tensor([[0]])), + properties=Labels.range("charge", charges.shape[1]), ) system.add_data(name="charges", data=data) diff --git a/src/meshlode/lib/potentials.py b/src/meshlode/lib/potentials.py index 63be3e05..c9bbcb15 100644 --- a/src/meshlode/lib/potentials.py +++ b/src/meshlode/lib/potentials.py @@ -1,6 +1,5 @@ -from typing import Union - import math +from typing import Union import torch from torch.special import gammainc, gammaincc, gammaln diff --git a/src/meshlode/metatensor/directpotential.py b/src/meshlode/metatensor/directpotential.py index 7d989062..232e4c13 100644 --- a/src/meshlode/metatensor/directpotential.py +++ b/src/meshlode/metatensor/directpotential.py @@ -3,7 +3,7 @@ class DirectPotential(CalculatorBaseMetatensor, _DirectPotentialImpl): - """Specie-wise long-range potential using a direct summation over all atoms. + r"""Specie-wise long-range potential using a direct summation over all atoms. Refer to :class:`meshlode.DirectPotential` for parameter documentation. @@ -28,9 +28,9 @@ class DirectPotential(CalculatorBaseMetatensor, _DirectPotentialImpl): >>> charges = torch.tensor([1.0, -1.0]).reshape(-1, 1) >>> data = TensorBlock( ... values=charges, - ... samples=Labels.range("atom", len(system)), + ... samples=Labels.range("atom", charges.shape[0]), ... components=[], - ... properties=Labels("charge", torch.tensor([[0]])), + ... properties=Labels.range("charge", charges.shape[1]), ... ) >>> system.add_data(name="charges", data=data) diff --git a/src/meshlode/metatensor/ewaldpotential.py b/src/meshlode/metatensor/ewaldpotential.py index a831c954..0eac3b6b 100644 --- a/src/meshlode/metatensor/ewaldpotential.py +++ b/src/meshlode/metatensor/ewaldpotential.py @@ -7,7 +7,7 @@ class EwaldPotential(CalculatorBaseMetatensor, _EwaldPotentialImpl): - """Specie-wise long-range potential computed using the Ewald sum. + r"""Specie-wise long-range potential computed using the Ewald sum. Refer to :class:`meshlode.EwaldPotential` for parameter documentation. @@ -33,9 +33,9 @@ class EwaldPotential(CalculatorBaseMetatensor, _EwaldPotentialImpl): >>> charges = torch.tensor([1.0, -1.0]).reshape(-1, 1) >>> data = TensorBlock( ... values=charges, - ... samples=Labels.range("atom", len(system)), + ... samples=Labels.range("atom", charges.shape[0]), ... components=[], - ... properties=Labels("charge", torch.tensor([[0]])), + ... properties=Labels.range("charge", charges.shape[1]), ... ) >>> system.add_data(name="charges", data=data) diff --git a/src/meshlode/metatensor/pmepotential.py b/src/meshlode/metatensor/pmepotential.py index f4a25b4d..d2e14983 100644 --- a/src/meshlode/metatensor/pmepotential.py +++ b/src/meshlode/metatensor/pmepotential.py @@ -7,7 +7,7 @@ class PMEPotential(CalculatorBaseMetatensor, _PMEPotentialImpl): - """Specie-wise long-range potential using a particle mesh-based Ewald (PME). + r"""Specie-wise long-range potential using a particle mesh-based Ewald (PME). Refer to :class:`meshlode.PMEPotential` for parameter documentation. @@ -33,9 +33,9 @@ class PMEPotential(CalculatorBaseMetatensor, _PMEPotentialImpl): >>> charges = torch.tensor([1.0, -1.0]).reshape(-1, 1) >>> data = TensorBlock( ... values=charges, - ... samples=Labels.range("atom", len(system)), + ... samples=Labels.range("atom", charges.shape[0]), ... components=[], - ... properties=Labels("charge", torch.tensor([[0]])), + ... properties=Labels.range("charge", charges.shape[1]), ... ) >>> system.add_data(name="charges", data=data) diff --git a/tests/__init__.py b/tests/__init__.py deleted file mode 100644 index 1c2cd789..00000000 --- a/tests/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -import meshlode - - -def test_version_exist(): - meshlode.__version__ diff --git a/tests/metatensor/test_base_metatensor.py b/tests/metatensor/test_base_metatensor.py new file mode 100644 index 00000000..4091b144 --- /dev/null +++ b/tests/metatensor/test_base_metatensor.py @@ -0,0 +1,195 @@ +import pytest +import torch +from metatensor.torch import Labels, TensorBlock +from metatensor.torch.atomistic import System +from packaging import version + +from meshlode.metatensor.base import CalculatorBaseMetatensor + + +class CalculatorTest(CalculatorBaseMetatensor): + def _compute_single_system( + self, positions, charges, cell, neighbor_indices, neighbor_shifts + ): + return charges + + +@pytest.mark.parametrize("method_name", ["compute", "forward"]) +def test_compute_output_shapes_single(method_name): + system = System( + types=torch.tensor([1, 1]), + positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 2.0]]), + cell=torch.zeros([3, 3]), + ) + + charges = torch.tensor([1.0, -1.0]).reshape(-1, 1) + data = TensorBlock( + values=charges, + samples=Labels.range("atom", charges.shape[0]), + components=[], + properties=Labels.range("charge", charges.shape[1]), + ) + + system.add_data(name="charges", data=data) + + calculator = CalculatorTest(exponent=1.0) + method = getattr(calculator, method_name) + result = method(system) + + assert isinstance(result, torch.ScriptObject) + if version.parse(torch.__version__) >= version.parse("2.1"): + assert result._type().name() == "TensorMap" + + assert len(result) == 1 + assert result[0].samples.names == ["system", "atom"] + assert result[0].components == [] + assert result[0].properties.names == ["charges_channel"] + + assert tuple(result[0].values.shape) == (len(system), 1) + + +def test_compute_output_shapes_multiple(): + + system = System( + types=torch.tensor([1, 1]), + positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 2.0]]), + cell=torch.zeros([3, 3]), + ) + + charges = torch.tensor([1.0, -1.0]).reshape(-1, 1) + data = TensorBlock( + values=charges, + samples=Labels.range("atom", charges.shape[0]), + components=[], + properties=Labels.range("charge", charges.shape[1]), + ) + + system.add_data(name="charges", data=data) + + calculator = CalculatorTest(exponent=1.0) + result = calculator.compute([system, system]) + + assert isinstance(result, torch.ScriptObject) + if version.parse(torch.__version__) >= version.parse("2.1"): + assert result._type().name() == "TensorMap" + + assert len(result) == 1 + assert result[0].samples.names == ["system", "atom"] + assert result[0].components == [] + assert result[0].properties.names == ["charges_channel"] + + assert tuple(result[0].values.shape) == (2 * len(system), 1) + + +def test_wrong_system_dtype(): + system1 = System( + types=torch.tensor([1, 1]), + positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 2.0]]), + cell=torch.zeros([3, 3]), + ) + + system2 = System( + types=torch.tensor([1, 1]), + positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 2.0]], dtype=torch.float64), + cell=torch.zeros([3, 3], dtype=torch.float64), + ) + + calculator = CalculatorTest(exponent=1.0) + + match = r"`dtype` of all systems must be the same, got 7 and 6" + with pytest.raises(ValueError, match=match): + calculator.compute([system1, system2]) + + +def test_wrong_system_device(): + system1 = System( + types=torch.tensor([1, 1]), + positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 2.0]]), + cell=torch.zeros([3, 3]), + ) + + system2 = System( + types=torch.tensor([1, 1], device="meta"), + positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 2.0]], device="meta"), + cell=torch.zeros([3, 3], device="meta"), + ) + + calculator = CalculatorTest(exponent=1.0) + + match = r"`device` of all systems must be the same, got meta and cpu" + with pytest.raises(ValueError, match=match): + calculator.compute([system1, system2]) + + +def test_wrong_system_not_all_charges(): + system1 = System( + types=torch.tensor([1, 1]), + positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 2.0]]), + cell=torch.zeros([3, 3]), + ) + + charges = torch.tensor([1.0, -1.0]).reshape(-1, 1) + data = TensorBlock( + values=charges, + samples=Labels.range("atom", charges.shape[0]), + components=[], + properties=Labels.range("charge", charges.shape[1]), + ) + + system1.add_data(name="charges", data=data) + + system2 = System( + types=torch.tensor( + [1, 1], + ), + positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 2.0]]), + cell=torch.zeros([3, 3]), + ) + + calculator = CalculatorTest(exponent=1.0) + + match = r"`systems` do not consistently contain `charges` data" + with pytest.raises(ValueError, match=match): + calculator.compute([system1, system2]) + + +def test_different_number_charge_channles(): + system1 = System( + types=torch.tensor([1, 1]), + positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 2.0]]), + cell=torch.zeros([3, 3]), + ) + + charges1 = torch.tensor([1.0, -1.0]).reshape(-1, 1) + data1 = TensorBlock( + values=charges1, + samples=Labels.range("atom", charges1.shape[0]), + components=[], + properties=Labels.range("charge", charges1.shape[1]), + ) + + system1.add_data(name="charges", data=data1) + + system2 = System( + types=torch.tensor([1, 1]), + positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 2.0]]), + cell=torch.zeros([3, 3]), + ) + + charges2 = torch.tensor([[1.0, 2.0], [-1.0, -2.0]]) + data2 = TensorBlock( + values=charges2, + samples=Labels.range("atom", charges2.shape[0]), + components=[], + properties=Labels.range("charge", charges2.shape[1]), + ) + system2.add_data(name="charges", data=data2) + + calculator = CalculatorTest(exponent=1.0) + + match = ( + r"number of charges-channels in system index 1 \(2\) is inconsistent with " + r"first system \(1\)" + ) + with pytest.raises(ValueError, match=match): + calculator.compute([system1, system2]) diff --git a/tests/metatensor/test_workflow_metatensor.py b/tests/metatensor/test_workflow_metatensor.py new file mode 100644 index 00000000..09d95fa2 --- /dev/null +++ b/tests/metatensor/test_workflow_metatensor.py @@ -0,0 +1,99 @@ +""" +Madelung tests +""" + +import pytest +import torch +from packaging import version + + +meshlode_metatensor = pytest.importorskip("meshlode.metatensor") +mts_torch = pytest.importorskip("metatensor.torch") +mts_atomistic = pytest.importorskip("metatensor.torch.atomistic") + + +ATOMIC_SMEARING = 0.1 +LR_WAVELENGTH = ATOMIC_SMEARING / 4 +MESH_SPACING = ATOMIC_SMEARING / 4 +INTERPOLATION_ORDER = 2 +SUBTRACT_SELF = True + + +@pytest.mark.parametrize( + "CalculatorClass, params", + [ + (meshlode_metatensor.DirectPotential, {}), + ( + meshlode_metatensor.EwaldPotential, + { + "atomic_smearing": ATOMIC_SMEARING, + "lr_wavelength": LR_WAVELENGTH, + "subtract_self": SUBTRACT_SELF, + }, + ), + ( + meshlode_metatensor.PMEPotential, + { + "atomic_smearing": ATOMIC_SMEARING, + "mesh_spacing": MESH_SPACING, + "interpolation_order": INTERPOLATION_ORDER, + "subtract_self": SUBTRACT_SELF, + }, + ), + ], +) +class TestWorkflow: + def cscl_system(self): + """CsCl crystal. Same as in the madelung test""" + + system = mts_atomistic.System( + types=torch.tensor([17, 55]), + positions=torch.tensor([[0, 0, 0], [0.5, 0.5, 0.5]]), + cell=torch.eye(3), + ) + + data = mts_torch.TensorBlock( + values=torch.tensor([-1.0, 1.0]).reshape(-1, 1), + samples=mts_torch.Labels.range("atom", len(system)), + components=[], + properties=mts_torch.Labels("charge", torch.tensor([[0]])), + ) + system.add_data(name="charges", data=data) + + return system + + def calculator(self, CalculatorClass, params): + return CalculatorClass(**params) + + def test_forward(self, CalculatorClass, params): + calculator = self.calculator(CalculatorClass, params) + descriptor_compute = calculator.compute(self.cscl_system()) + descriptor_forward = calculator.forward(self.cscl_system()) + + assert isinstance(descriptor_compute, torch.ScriptObject) + assert isinstance(descriptor_forward, torch.ScriptObject) + if version.parse(torch.__version__) >= version.parse("2.1"): + assert descriptor_compute._type().name() == "TensorMap" + assert descriptor_forward._type().name() == "TensorMap" + + assert mts_torch.equal(descriptor_forward, descriptor_compute) + + # Make sure that the calculators are computing the features without raising errors, + # and returns the correct output format (TensorMap) + def check_operation(self, CalculatorClass, params): + calculator = self.calculator(CalculatorClass, params) + descriptor = calculator.compute(self.cscl_system()) + + assert isinstance(descriptor, torch.ScriptObject) + if version.parse(torch.__version__) >= version.parse("2.1"): + assert descriptor._type().name() == "TensorMap" + + # Run the above test as a normal python script + def test_operation_as_python(self, CalculatorClass, params): + self.check_operation(CalculatorClass, params) + + # Similar to the above, but also testing that the code can be compiled as a torch + # script + # def test_operation_as_torch_script(self, CalculatorClass, params): + # scripted = torch.jit.script(CalculatorClass, params) + # self.check_operation(scripted) From 9be73a6db5ce5d11ef82b7f0ae633fb8d58db2a4 Mon Sep 17 00:00:00 2001 From: Philip Loche Date: Wed, 10 Jul 2024 14:01:51 +0200 Subject: [PATCH 33/35] fix linter --- README.rst | 4 +- src/meshlode/calculators/base.py | 43 +++--------------- src/meshlode/calculators/directpotential.py | 2 +- src/meshlode/calculators/ewaldpotential.py | 23 +++++----- src/meshlode/calculators/pmepotential.py | 24 +++++----- src/meshlode/metatensor/base.py | 8 ++-- src/meshlode/metatensor/directpotential.py | 2 +- src/meshlode/metatensor/ewaldpotential.py | 6 +-- src/meshlode/metatensor/pmepotential.py | 8 ++-- tests/calculators/test_base.py | 49 +++++++++------------ tests/metatensor/test_base_metatensor.py | 12 ++--- 11 files changed, 74 insertions(+), 107 deletions(-) diff --git a/README.rst b/README.rst index 035876dc..8e9f4b4b 100644 --- a/README.rst +++ b/README.rst @@ -22,8 +22,8 @@ You can install *MeshLode* using pip with You can then ``import meshlode`` and use it in your projects! -We also provide bindings to `metatensor `_ which can -optionally be installed together and used as ``meshlode.metatensor`` via +We also provide bindings to `metatensor `_ which +can optionally be installed together and used as ``meshlode.metatensor`` via .. code-block:: bash diff --git a/src/meshlode/calculators/base.py b/src/meshlode/calculators/base.py index 830015d1..0dac2f78 100644 --- a/src/meshlode/calculators/base.py +++ b/src/meshlode/calculators/base.py @@ -7,20 +7,16 @@ from meshlode.lib import InversePowerLawPotential -class CalculatorBase(torch.nn.Module): - """Base class providing general funtionality.""" +class _ShortRange: + """Base class providing general funtionality for short range interactions.""" - def __init__( - self, - exponent: float, - ): + def __init__(self, exponent: float, subtract_interior: bool): # Attach the function handling all computations related to the # power-law potential for later convenience self.exponent = exponent + self.subtract_interior = subtract_interior self.potential = InversePowerLawPotential(exponent=exponent) - super().__init__() - def _compute_sr( self, positions: torch.Tensor, @@ -31,32 +27,6 @@ def _compute_sr( neighbor_indices: Optional[torch.Tensor] = None, neighbor_shifts: Optional[torch.Tensor] = None, ) -> torch.Tensor: - """ - Compute the short-range part of the Ewald sum in realspace - - :param positions: torch.tensor of shape (n_atoms, 3). Contains the Cartesian - coordinates of the atoms. The implementation also works if the positions - are not contained within the unit cell. - :param charges: torch.tensor of shape `(n_atoms, n_channels)`. In the simplest - case, this would be a tensor of shape (n_atoms, 1) where charges[i,0] is the - charge of atom i. More generally, the potential for the same atom positions - is computed for n_channels independent meshes, and one can specify the - "charge" of each atom on each of the meshes independently. - :param cell: torch.tensor of shape `(3, 3)`. Describes the unit cell of the - structure, where cell[i] is the i-th basis vector. - :param smearing: torch.Tensor smearing paramter determining the splitting - between the SR and LR parts. - :param sr_cutoff: Cutoff radius used for the short-range part of the Ewald sum. - :param neighbor_indices: Optional single or list of 2D tensors of shape (2, n), - where n is the number of atoms. The 2 rows correspond to the indices of - the two atoms which are considered neighbors (e.g. within a cutoff distance) - :param neighbor_shifts: Optional single or list of 2D tensors of shape (3, n), - where n is the number of atoms. The 3 rows correspond to the shift indices - for periodic images. - - :returns: torch.tensor of shape `(n_atoms, n_channels)` containing the potential - at the position of each atom for the `n_channels` independent meshes separately. - """ if neighbor_indices is None or neighbor_shifts is None: # Get list of neighbors struc = Atoms(positions=positions.detach().numpy(), cell=cell, pbc=True) @@ -94,7 +64,7 @@ def _compute_sr( return potential -class CalculatorBaseTorch(CalculatorBase): +class CalculatorBaseTorch(torch.nn.Module): """ Base calculator for the torch interface to MeshLODE. @@ -103,9 +73,8 @@ class CalculatorBaseTorch(CalculatorBase): def __init__( self, - exponent: float, ): - super().__init__(exponent=exponent) + super().__init__() def _validate_compute_parameters( self, diff --git a/src/meshlode/calculators/directpotential.py b/src/meshlode/calculators/directpotential.py index 26207910..802f795d 100644 --- a/src/meshlode/calculators/directpotential.py +++ b/src/meshlode/calculators/directpotential.py @@ -78,7 +78,7 @@ class DirectPotential(CalculatorBaseTorch, _DirectPotentialImpl): def __init__(self, exponent: float = 1.0): _DirectPotentialImpl.__init__(self, exponent=exponent) - CalculatorBaseTorch.__init__(self, exponent=exponent) + CalculatorBaseTorch.__init__(self) def compute( self, diff --git a/src/meshlode/calculators/ewaldpotential.py b/src/meshlode/calculators/ewaldpotential.py index 20bf877e..ef6e6640 100644 --- a/src/meshlode/calculators/ewaldpotential.py +++ b/src/meshlode/calculators/ewaldpotential.py @@ -3,33 +3,35 @@ import torch from ..lib import generate_kvectors_squeezed -from .base import CalculatorBaseTorch +from .base import CalculatorBaseTorch, _ShortRange -class _EwaldPotentialImpl: +class _EwaldPotentialImpl(_ShortRange): def __init__( self, exponent: float, sr_cutoff: Union[None, torch.Tensor], atomic_smearing: Union[None, float], lr_wavelength: Union[None, float], - subtract_self: Union[None, bool], - subtract_interior: Union[None, bool], + subtract_self: bool, + subtract_interior: bool, ): if exponent < 0.0 or exponent > 3.0: raise ValueError(f"`exponent` p={exponent} has to satisfy 0 < p < 3") if atomic_smearing is not None and atomic_smearing <= 0: raise ValueError(f"`atomic_smearing` {atomic_smearing} has to be positive") + _ShortRange.__init__( + self, exponent=exponent, subtract_interior=subtract_interior + ) self.atomic_smearing = atomic_smearing self.sr_cutoff = sr_cutoff self.lr_wavelength = lr_wavelength # If interior contributions are to be subtracted, also do so for self term - if subtract_interior: + if self.subtract_interior: subtract_self = True self.subtract_self = subtract_self - self.subtract_interior = subtract_interior def _compute_single_system( self, @@ -154,7 +156,8 @@ def _compute_lr( # TODO: modify to expression for general p if subtract_self: self_contrib = ( - torch.sqrt(torch.tensor(2.0 / torch.pi, device=self._device)) / smearing + torch.sqrt(torch.tensor(2.0 / torch.pi, device=positions.device)) + / smearing ) energy -= charges * self_contrib @@ -218,8 +221,8 @@ def __init__( sr_cutoff: Optional[torch.Tensor] = None, atomic_smearing: Optional[float] = None, lr_wavelength: Optional[float] = None, - subtract_self: Optional[bool] = True, - subtract_interior: Optional[bool] = False, + subtract_self: bool = True, + subtract_interior: bool = False, ): _EwaldPotentialImpl.__init__( self, @@ -230,7 +233,7 @@ def __init__( subtract_self=subtract_self, subtract_interior=subtract_interior, ) - CalculatorBaseTorch.__init__(self, exponent=exponent) + CalculatorBaseTorch.__init__(self) def compute( self, diff --git a/src/meshlode/calculators/pmepotential.py b/src/meshlode/calculators/pmepotential.py index 22d7a06b..4de06b47 100644 --- a/src/meshlode/calculators/pmepotential.py +++ b/src/meshlode/calculators/pmepotential.py @@ -4,19 +4,19 @@ from ..lib import generate_kvectors_for_mesh from ..lib.mesh_interpolator import MeshInterpolator -from .base import CalculatorBaseTorch +from .base import CalculatorBaseTorch, _ShortRange -class _PMEPotentialImpl: +class _PMEPotentialImpl(_ShortRange): def __init__( self, exponent: float, sr_cutoff: Union[None, torch.Tensor], atomic_smearing: Union[None, float], mesh_spacing: Union[None, float], - interpolation_order: Union[None, int], - subtract_self: Union[None, bool], - subtract_interior: Union[None, bool], + interpolation_order: int, + subtract_self: bool, + subtract_interior: bool, ): # Check that all provided values are correct if exponent < 0.0 or exponent > 3.0: @@ -26,16 +26,18 @@ def __init__( if atomic_smearing is not None and atomic_smearing <= 0: raise ValueError(f"`atomic_smearing` {atomic_smearing} has to be positive") + _ShortRange.__init__( + self, exponent=exponent, subtract_interior=subtract_interior + ) self.atomic_smearing = atomic_smearing self.mesh_spacing = mesh_spacing self.interpolation_order = interpolation_order self.sr_cutoff = sr_cutoff # If interior contributions are to be subtracted, also do so for self term - if subtract_interior: + if self.subtract_interior: subtract_self = True self.subtract_self = subtract_self - self.subtract_interior = subtract_interior self.atomic_smearing = atomic_smearing self.mesh_spacing = mesh_spacing @@ -225,9 +227,9 @@ def __init__( sr_cutoff: Optional[torch.Tensor] = None, atomic_smearing: Optional[float] = None, mesh_spacing: Optional[float] = None, - interpolation_order: Optional[int] = 3, - subtract_self: Optional[bool] = True, - subtract_interior: Optional[bool] = False, + interpolation_order: int = 3, + subtract_self: bool = True, + subtract_interior: bool = False, ): _PMEPotentialImpl.__init__( self, @@ -239,7 +241,7 @@ def __init__( subtract_self=subtract_self, subtract_interior=subtract_interior, ) - CalculatorBaseTorch.__init__(self, exponent=exponent) + CalculatorBaseTorch.__init__(self) def compute( self, diff --git a/src/meshlode/metatensor/base.py b/src/meshlode/metatensor/base.py index 6690caf1..0bd3280d 100644 --- a/src/meshlode/metatensor/base.py +++ b/src/meshlode/metatensor/base.py @@ -13,12 +13,10 @@ "Try installing it with:\npip install metatensor[torch]" ) -from ..calculators.base import CalculatorBase - -class CalculatorBaseMetatensor(CalculatorBase): - def __init__(self, exponent: float): - super().__init__(exponent) +class CalculatorBaseMetatensor(torch.nn.Module): + def __init__(self): + super().__init__() def forward(self, systems: Union[List[System], System]) -> TensorMap: """Forward just calls :py:meth:`compute`.""" diff --git a/src/meshlode/metatensor/directpotential.py b/src/meshlode/metatensor/directpotential.py index 232e4c13..368fa116 100644 --- a/src/meshlode/metatensor/directpotential.py +++ b/src/meshlode/metatensor/directpotential.py @@ -52,4 +52,4 @@ class DirectPotential(CalculatorBaseMetatensor, _DirectPotentialImpl): def __init__(self, exponent: float = 1.0): _DirectPotentialImpl.__init__(self, exponent=exponent) - CalculatorBaseMetatensor.__init__(self, exponent=exponent) + CalculatorBaseMetatensor.__init__(self) diff --git a/src/meshlode/metatensor/ewaldpotential.py b/src/meshlode/metatensor/ewaldpotential.py index 0eac3b6b..f16efcee 100644 --- a/src/meshlode/metatensor/ewaldpotential.py +++ b/src/meshlode/metatensor/ewaldpotential.py @@ -60,8 +60,8 @@ def __init__( sr_cutoff: Optional[torch.Tensor] = None, atomic_smearing: Optional[float] = None, lr_wavelength: Optional[float] = None, - subtract_self: Optional[bool] = True, - subtract_interior: Optional[bool] = False, + subtract_self: bool = True, + subtract_interior: bool = False, ): _EwaldPotentialImpl.__init__( self, @@ -72,4 +72,4 @@ def __init__( subtract_self=subtract_self, subtract_interior=subtract_interior, ) - CalculatorBaseMetatensor.__init__(self, exponent=exponent) + CalculatorBaseMetatensor.__init__(self) diff --git a/src/meshlode/metatensor/pmepotential.py b/src/meshlode/metatensor/pmepotential.py index d2e14983..af4da3ad 100644 --- a/src/meshlode/metatensor/pmepotential.py +++ b/src/meshlode/metatensor/pmepotential.py @@ -60,9 +60,9 @@ def __init__( sr_cutoff: Optional[torch.Tensor] = None, atomic_smearing: Optional[float] = None, mesh_spacing: Optional[float] = None, - interpolation_order: Optional[int] = 3, - subtract_self: Optional[bool] = True, - subtract_interior: Optional[bool] = False, + interpolation_order: int = 3, + subtract_self: bool = True, + subtract_interior: bool = False, ): _PMEPotentialImpl.__init__( self, @@ -74,4 +74,4 @@ def __init__( subtract_self=subtract_self, subtract_interior=subtract_interior, ) - CalculatorBaseMetatensor.__init__(self, exponent=exponent) + CalculatorBaseMetatensor.__init__(self) diff --git a/tests/calculators/test_base.py b/tests/calculators/test_base.py index 7eb3641a..efdbc4aa 100644 --- a/tests/calculators/test_base.py +++ b/tests/calculators/test_base.py @@ -1,7 +1,7 @@ import pytest import torch -from meshlode.calculators.base import CalculatorBase, CalculatorBaseTorch +from meshlode.calculators.base import CalculatorBaseTorch # Define some example parameters @@ -15,11 +15,6 @@ CELL_2 = torch.arange(9, dtype=DTYPE, device=DEVICE).reshape((3, 3)) -def test_CalculatorBase(): - calculator = CalculatorBase(exponent=5.0) - assert calculator.exponent == 5.0 - - class CalculatorTest(CalculatorBaseTorch): def compute(self, positions, charges, cell, neighbor_indices, neighbor_shifts): return self._compute_impl( @@ -58,7 +53,7 @@ def _compute_single_system( ], ) def test_compute_output_shapes(method_name, positions, charges): - calculator = CalculatorTest(exponent=1.0) + calculator = CalculatorTest() method = getattr(calculator, method_name) result = method( @@ -79,7 +74,7 @@ def test_compute_output_shapes(method_name, positions, charges): # Tests for a mismatch in the number of provided inputs for different variables def test_mismatched_numbers_cell(): - calculator = CalculatorTest(exponent=1.0) + calculator = CalculatorTest() match = r"Got inconsistent numbers of positions \(2\) and cell \(3\)" with pytest.raises(ValueError, match=match): calculator.compute( @@ -92,7 +87,7 @@ def test_mismatched_numbers_cell(): def test_mismatched_numbers_charges(): - calculator = CalculatorTest(exponent=1.0) + calculator = CalculatorTest() match = r"Got inconsistent numbers of positions \(2\) and charges \(3\)" with pytest.raises(ValueError, match=match): calculator.compute( @@ -105,7 +100,7 @@ def test_mismatched_numbers_charges(): def test_mismatched_numbers_neighbor_indices(): - calculator = CalculatorTest(exponent=1.0) + calculator = CalculatorTest() match = r"Got inconsistent numbers of positions \(2\) and neighbor_indices \(3\)" with pytest.raises(ValueError, match=match): calculator.compute( @@ -118,7 +113,7 @@ def test_mismatched_numbers_neighbor_indices(): def test_mismatched_numbers_neighbor_shiftss(): - calculator = CalculatorTest(exponent=1.0) + calculator = CalculatorTest() match = r"Got inconsistent numbers of positions \(2\) and neighbor_shifts \(3\)" with pytest.raises(ValueError, match=match): calculator.compute( @@ -132,7 +127,7 @@ def test_mismatched_numbers_neighbor_shiftss(): # Tests for invalid shape, dtype and device of positions def test_invalid_shape_positions(): - calculator = CalculatorTest(exponent=1.0) + calculator = CalculatorTest() match = ( r"each `positions` must be a \(n_atoms x 3\) tensor, got at least " r"one tensor with shape \(4, 5\)" @@ -148,7 +143,7 @@ def test_invalid_shape_positions(): def test_invalid_dtype_positions(): - calculator = CalculatorTest(exponent=1.0) + calculator = CalculatorTest() match = ( r"each `positions` must have the same type torch.float32 as the " r"first provided one. Got at least one tensor of type " @@ -166,7 +161,7 @@ def test_invalid_dtype_positions(): def test_invalid_device_positions(): - calculator = CalculatorTest(exponent=1.0) + calculator = CalculatorTest() match = ( r"each `positions` must be on the same device cpu as the " r"first provided one. Got at least one tensor on device " @@ -185,7 +180,7 @@ def test_invalid_device_positions(): # Tests for invalid shape, dtype and device of cell def test_invalid_shape_cell(): - calculator = CalculatorTest(exponent=1.0) + calculator = CalculatorTest() match = ( r"each `cell` must be a \(3 x 3\) tensor, got at least one tensor with " r"shape \(2, 2\)" @@ -201,7 +196,7 @@ def test_invalid_shape_cell(): def test_invalid_dtype_cell(): - calculator = CalculatorTest(exponent=1.0) + calculator = CalculatorTest() match = ( r"each `cell` must have the same type torch.float32 as positions, " r"got at least one tensor of type torch.float64" @@ -217,7 +212,7 @@ def test_invalid_dtype_cell(): def test_invalid_device_cell(): - calculator = CalculatorTest(exponent=1.0) + calculator = CalculatorTest() match = ( r"each `cell` must be on the same device cpu as positions, " r"got at least one tensor with device meta" @@ -234,7 +229,7 @@ def test_invalid_device_cell(): # Tests for invalid shape, dtype and device of charges def test_invalid_dim_charges(): - calculator = CalculatorTest(exponent=1.0) + calculator = CalculatorTest() match = ( r"each `charges` needs to be a 2-dimensional tensor, got at least " r"one tensor with 1 dimension\(s\) and shape " @@ -251,7 +246,7 @@ def test_invalid_dim_charges(): def test_invalid_shape_charges(): - calculator = CalculatorTest(exponent=1.0) + calculator = CalculatorTest() match = ( r"each `charges` must be a \(n_atoms x n_channels\) tensor, with" r"`n_atoms` being the same as the variable `positions`. Got at " @@ -269,7 +264,7 @@ def test_invalid_shape_charges(): def test_invalid_dtype_charges(): - calculator = CalculatorTest(exponent=1.0) + calculator = CalculatorTest() match = ( r"each `charges` must have the same type torch.float32 as positions, " r"got at least one tensor of type torch.float64" @@ -285,7 +280,7 @@ def test_invalid_dtype_charges(): def test_invalid_device_charges(): - calculator = CalculatorTest(exponent=1.0) + calculator = CalculatorTest() match = ( r"each `charges` must be on the same device cpu as positions, " r"got at least one tensor with device meta" @@ -302,7 +297,7 @@ def test_invalid_device_charges(): # Tests for invalid shape, dtype and device of neighbor_indices and neighbor_shifts def test_need_both_neighbor_indices_and_shifts(): - calculator = CalculatorTest(exponent=1.0) + calculator = CalculatorTest() match = r"Need to provide both `neighbor_indices` and `neighbor_shifts` together." with pytest.raises(ValueError, match=match): calculator.compute( @@ -315,7 +310,7 @@ def test_need_both_neighbor_indices_and_shifts(): def test_invalid_shape_neighbor_indices(): - calculator = CalculatorTest(exponent=1.0) + calculator = CalculatorTest() match = ( r"neighbor_indices is expected to have shape \(2, num_neighbors\)" r", but got \(4, 10\) for one structure" @@ -331,7 +326,7 @@ def test_invalid_shape_neighbor_indices(): def test_invalid_shape_neighbor_shifts(): - calculator = CalculatorTest(exponent=1.0) + calculator = CalculatorTest() match = ( r"neighbor_shifts is expected to have shape \(num_neighbors, 3\)" r", but got \(10, 2\) for one structure" @@ -347,7 +342,7 @@ def test_invalid_shape_neighbor_shifts(): def test_invalid_shape_neighbor_indices_neighbor_shifts(): - calculator = CalculatorTest(exponent=1.0) + calculator = CalculatorTest() match = ( r"`neighbor_indices` and `neighbor_shifts` need to have shapes " r"\(2, num_neighbors\) and \(num_neighbors, 3\). For at least one" @@ -365,7 +360,7 @@ def test_invalid_shape_neighbor_indices_neighbor_shifts(): def test_invalid_device_neighbor_indices(): - calculator = CalculatorTest(exponent=1.0) + calculator = CalculatorTest() match = ( r"each `neighbor_indices` must be on the same device cpu as positions, " r"got at least one tensor with device meta" @@ -381,7 +376,7 @@ def test_invalid_device_neighbor_indices(): def test_invalid_device_neighbor_shifts(): - calculator = CalculatorTest(exponent=1.0) + calculator = CalculatorTest() match = ( r"each `neighbor_shifts` must be on the same device cpu as positions, " r"got at least one tensor with device meta" diff --git a/tests/metatensor/test_base_metatensor.py b/tests/metatensor/test_base_metatensor.py index 4091b144..239bec22 100644 --- a/tests/metatensor/test_base_metatensor.py +++ b/tests/metatensor/test_base_metatensor.py @@ -32,7 +32,7 @@ def test_compute_output_shapes_single(method_name): system.add_data(name="charges", data=data) - calculator = CalculatorTest(exponent=1.0) + calculator = CalculatorTest() method = getattr(calculator, method_name) result = method(system) @@ -66,7 +66,7 @@ def test_compute_output_shapes_multiple(): system.add_data(name="charges", data=data) - calculator = CalculatorTest(exponent=1.0) + calculator = CalculatorTest() result = calculator.compute([system, system]) assert isinstance(result, torch.ScriptObject) @@ -94,7 +94,7 @@ def test_wrong_system_dtype(): cell=torch.zeros([3, 3], dtype=torch.float64), ) - calculator = CalculatorTest(exponent=1.0) + calculator = CalculatorTest() match = r"`dtype` of all systems must be the same, got 7 and 6" with pytest.raises(ValueError, match=match): @@ -114,7 +114,7 @@ def test_wrong_system_device(): cell=torch.zeros([3, 3], device="meta"), ) - calculator = CalculatorTest(exponent=1.0) + calculator = CalculatorTest() match = r"`device` of all systems must be the same, got meta and cpu" with pytest.raises(ValueError, match=match): @@ -146,7 +146,7 @@ def test_wrong_system_not_all_charges(): cell=torch.zeros([3, 3]), ) - calculator = CalculatorTest(exponent=1.0) + calculator = CalculatorTest() match = r"`systems` do not consistently contain `charges` data" with pytest.raises(ValueError, match=match): @@ -185,7 +185,7 @@ def test_different_number_charge_channles(): ) system2.add_data(name="charges", data=data2) - calculator = CalculatorTest(exponent=1.0) + calculator = CalculatorTest() match = ( r"number of charges-channels in system index 1 \(2\) is inconsistent with " From 19134e548676f93328eecaf4dd097d909e266dea Mon Sep 17 00:00:00 2001 From: Philip Loche Date: Wed, 10 Jul 2024 14:19:19 +0200 Subject: [PATCH 34/35] fix tests-min --- tests/metatensor/test_base_metatensor.py | 59 ++++++++++---------- tests/metatensor/test_calculators.py | 2 +- tests/metatensor/test_workflow_metatensor.py | 2 +- 3 files changed, 32 insertions(+), 31 deletions(-) diff --git a/tests/metatensor/test_base_metatensor.py b/tests/metatensor/test_base_metatensor.py index 239bec22..dbaa985b 100644 --- a/tests/metatensor/test_base_metatensor.py +++ b/tests/metatensor/test_base_metatensor.py @@ -1,13 +1,14 @@ import pytest import torch -from metatensor.torch import Labels, TensorBlock -from metatensor.torch.atomistic import System from packaging import version -from meshlode.metatensor.base import CalculatorBaseMetatensor +mts_torch = pytest.importorskip("metatensor.torch") +mts_atomistic = pytest.importorskip("metatensor.torch.atomistic") +meshlode_metatensor = pytest.importorskip("meshlode.metatensor") -class CalculatorTest(CalculatorBaseMetatensor): + +class CalculatorTest(meshlode_metatensor.base.CalculatorBaseMetatensor): def _compute_single_system( self, positions, charges, cell, neighbor_indices, neighbor_shifts ): @@ -16,18 +17,18 @@ def _compute_single_system( @pytest.mark.parametrize("method_name", ["compute", "forward"]) def test_compute_output_shapes_single(method_name): - system = System( + system = mts_atomistic.System( types=torch.tensor([1, 1]), positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 2.0]]), cell=torch.zeros([3, 3]), ) charges = torch.tensor([1.0, -1.0]).reshape(-1, 1) - data = TensorBlock( + data = mts_torch.TensorBlock( values=charges, - samples=Labels.range("atom", charges.shape[0]), + samples=mts_torch.Labels.range("atom", charges.shape[0]), components=[], - properties=Labels.range("charge", charges.shape[1]), + properties=mts_torch.Labels.range("charge", charges.shape[1]), ) system.add_data(name="charges", data=data) @@ -50,18 +51,18 @@ def test_compute_output_shapes_single(method_name): def test_compute_output_shapes_multiple(): - system = System( + system = mts_atomistic.System( types=torch.tensor([1, 1]), positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 2.0]]), cell=torch.zeros([3, 3]), ) charges = torch.tensor([1.0, -1.0]).reshape(-1, 1) - data = TensorBlock( + data = mts_torch.TensorBlock( values=charges, - samples=Labels.range("atom", charges.shape[0]), + samples=mts_torch.Labels.range("atom", charges.shape[0]), components=[], - properties=Labels.range("charge", charges.shape[1]), + properties=mts_torch.Labels.range("charge", charges.shape[1]), ) system.add_data(name="charges", data=data) @@ -82,13 +83,13 @@ def test_compute_output_shapes_multiple(): def test_wrong_system_dtype(): - system1 = System( + system1 = mts_atomistic.System( types=torch.tensor([1, 1]), positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 2.0]]), cell=torch.zeros([3, 3]), ) - system2 = System( + system2 = mts_atomistic.System( types=torch.tensor([1, 1]), positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 2.0]], dtype=torch.float64), cell=torch.zeros([3, 3], dtype=torch.float64), @@ -102,13 +103,13 @@ def test_wrong_system_dtype(): def test_wrong_system_device(): - system1 = System( + system1 = mts_atomistic.System( types=torch.tensor([1, 1]), positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 2.0]]), cell=torch.zeros([3, 3]), ) - system2 = System( + system2 = mts_atomistic.System( types=torch.tensor([1, 1], device="meta"), positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 2.0]], device="meta"), cell=torch.zeros([3, 3], device="meta"), @@ -122,23 +123,23 @@ def test_wrong_system_device(): def test_wrong_system_not_all_charges(): - system1 = System( + system1 = mts_atomistic.System( types=torch.tensor([1, 1]), positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 2.0]]), cell=torch.zeros([3, 3]), ) charges = torch.tensor([1.0, -1.0]).reshape(-1, 1) - data = TensorBlock( + data = mts_torch.TensorBlock( values=charges, - samples=Labels.range("atom", charges.shape[0]), + samples=mts_torch.Labels.range("atom", charges.shape[0]), components=[], - properties=Labels.range("charge", charges.shape[1]), + properties=mts_torch.Labels.range("charge", charges.shape[1]), ) system1.add_data(name="charges", data=data) - system2 = System( + system2 = mts_atomistic.System( types=torch.tensor( [1, 1], ), @@ -154,34 +155,34 @@ def test_wrong_system_not_all_charges(): def test_different_number_charge_channles(): - system1 = System( + system1 = mts_atomistic.System( types=torch.tensor([1, 1]), positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 2.0]]), cell=torch.zeros([3, 3]), ) charges1 = torch.tensor([1.0, -1.0]).reshape(-1, 1) - data1 = TensorBlock( + data1 = mts_torch.TensorBlock( values=charges1, - samples=Labels.range("atom", charges1.shape[0]), + samples=mts_torch.Labels.range("atom", charges1.shape[0]), components=[], - properties=Labels.range("charge", charges1.shape[1]), + properties=mts_torch.Labels.range("charge", charges1.shape[1]), ) system1.add_data(name="charges", data=data1) - system2 = System( + system2 = mts_atomistic.System( types=torch.tensor([1, 1]), positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 2.0]]), cell=torch.zeros([3, 3]), ) charges2 = torch.tensor([[1.0, 2.0], [-1.0, -2.0]]) - data2 = TensorBlock( + data2 = mts_torch.TensorBlock( values=charges2, - samples=Labels.range("atom", charges2.shape[0]), + samples=mts_torch.Labels.range("atom", charges2.shape[0]), components=[], - properties=Labels.range("charge", charges2.shape[1]), + properties=mts_torch.Labels.range("charge", charges2.shape[1]), ) system2.add_data(name="charges", data=data2) diff --git a/tests/metatensor/test_calculators.py b/tests/metatensor/test_calculators.py index 09d95fa2..137d158f 100644 --- a/tests/metatensor/test_calculators.py +++ b/tests/metatensor/test_calculators.py @@ -7,9 +7,9 @@ from packaging import version -meshlode_metatensor = pytest.importorskip("meshlode.metatensor") mts_torch = pytest.importorskip("metatensor.torch") mts_atomistic = pytest.importorskip("metatensor.torch.atomistic") +meshlode_metatensor = pytest.importorskip("meshlode.metatensor") ATOMIC_SMEARING = 0.1 diff --git a/tests/metatensor/test_workflow_metatensor.py b/tests/metatensor/test_workflow_metatensor.py index 09d95fa2..137d158f 100644 --- a/tests/metatensor/test_workflow_metatensor.py +++ b/tests/metatensor/test_workflow_metatensor.py @@ -7,9 +7,9 @@ from packaging import version -meshlode_metatensor = pytest.importorskip("meshlode.metatensor") mts_torch = pytest.importorskip("metatensor.torch") mts_atomistic = pytest.importorskip("metatensor.torch.atomistic") +meshlode_metatensor = pytest.importorskip("meshlode.metatensor") ATOMIC_SMEARING = 0.1 From a237de3e47d1f122da4fa9fc6d7c3e6c05be8ec9 Mon Sep 17 00:00:00 2001 From: Philip Loche Date: Wed, 10 Jul 2024 14:29:43 +0200 Subject: [PATCH 35/35] more linting --- tests/metatensor/test_base_metatensor.py | 5 +- tests/metatensor/test_calculators.py | 99 -------------------- tests/metatensor/test_workflow_metatensor.py | 9 +- 3 files changed, 8 insertions(+), 105 deletions(-) delete mode 100644 tests/metatensor/test_calculators.py diff --git a/tests/metatensor/test_base_metatensor.py b/tests/metatensor/test_base_metatensor.py index dbaa985b..cca6dd43 100644 --- a/tests/metatensor/test_base_metatensor.py +++ b/tests/metatensor/test_base_metatensor.py @@ -2,13 +2,14 @@ import torch from packaging import version +import meshlode + mts_torch = pytest.importorskip("metatensor.torch") mts_atomistic = pytest.importorskip("metatensor.torch.atomistic") -meshlode_metatensor = pytest.importorskip("meshlode.metatensor") -class CalculatorTest(meshlode_metatensor.base.CalculatorBaseMetatensor): +class CalculatorTest(meshlode.metatensor.base.CalculatorBaseMetatensor): def _compute_single_system( self, positions, charges, cell, neighbor_indices, neighbor_shifts ): diff --git a/tests/metatensor/test_calculators.py b/tests/metatensor/test_calculators.py deleted file mode 100644 index 137d158f..00000000 --- a/tests/metatensor/test_calculators.py +++ /dev/null @@ -1,99 +0,0 @@ -""" -Madelung tests -""" - -import pytest -import torch -from packaging import version - - -mts_torch = pytest.importorskip("metatensor.torch") -mts_atomistic = pytest.importorskip("metatensor.torch.atomistic") -meshlode_metatensor = pytest.importorskip("meshlode.metatensor") - - -ATOMIC_SMEARING = 0.1 -LR_WAVELENGTH = ATOMIC_SMEARING / 4 -MESH_SPACING = ATOMIC_SMEARING / 4 -INTERPOLATION_ORDER = 2 -SUBTRACT_SELF = True - - -@pytest.mark.parametrize( - "CalculatorClass, params", - [ - (meshlode_metatensor.DirectPotential, {}), - ( - meshlode_metatensor.EwaldPotential, - { - "atomic_smearing": ATOMIC_SMEARING, - "lr_wavelength": LR_WAVELENGTH, - "subtract_self": SUBTRACT_SELF, - }, - ), - ( - meshlode_metatensor.PMEPotential, - { - "atomic_smearing": ATOMIC_SMEARING, - "mesh_spacing": MESH_SPACING, - "interpolation_order": INTERPOLATION_ORDER, - "subtract_self": SUBTRACT_SELF, - }, - ), - ], -) -class TestWorkflow: - def cscl_system(self): - """CsCl crystal. Same as in the madelung test""" - - system = mts_atomistic.System( - types=torch.tensor([17, 55]), - positions=torch.tensor([[0, 0, 0], [0.5, 0.5, 0.5]]), - cell=torch.eye(3), - ) - - data = mts_torch.TensorBlock( - values=torch.tensor([-1.0, 1.0]).reshape(-1, 1), - samples=mts_torch.Labels.range("atom", len(system)), - components=[], - properties=mts_torch.Labels("charge", torch.tensor([[0]])), - ) - system.add_data(name="charges", data=data) - - return system - - def calculator(self, CalculatorClass, params): - return CalculatorClass(**params) - - def test_forward(self, CalculatorClass, params): - calculator = self.calculator(CalculatorClass, params) - descriptor_compute = calculator.compute(self.cscl_system()) - descriptor_forward = calculator.forward(self.cscl_system()) - - assert isinstance(descriptor_compute, torch.ScriptObject) - assert isinstance(descriptor_forward, torch.ScriptObject) - if version.parse(torch.__version__) >= version.parse("2.1"): - assert descriptor_compute._type().name() == "TensorMap" - assert descriptor_forward._type().name() == "TensorMap" - - assert mts_torch.equal(descriptor_forward, descriptor_compute) - - # Make sure that the calculators are computing the features without raising errors, - # and returns the correct output format (TensorMap) - def check_operation(self, CalculatorClass, params): - calculator = self.calculator(CalculatorClass, params) - descriptor = calculator.compute(self.cscl_system()) - - assert isinstance(descriptor, torch.ScriptObject) - if version.parse(torch.__version__) >= version.parse("2.1"): - assert descriptor._type().name() == "TensorMap" - - # Run the above test as a normal python script - def test_operation_as_python(self, CalculatorClass, params): - self.check_operation(CalculatorClass, params) - - # Similar to the above, but also testing that the code can be compiled as a torch - # script - # def test_operation_as_torch_script(self, CalculatorClass, params): - # scripted = torch.jit.script(CalculatorClass, params) - # self.check_operation(scripted) diff --git a/tests/metatensor/test_workflow_metatensor.py b/tests/metatensor/test_workflow_metatensor.py index 137d158f..f5a28680 100644 --- a/tests/metatensor/test_workflow_metatensor.py +++ b/tests/metatensor/test_workflow_metatensor.py @@ -6,10 +6,11 @@ import torch from packaging import version +import meshlode + mts_torch = pytest.importorskip("metatensor.torch") mts_atomistic = pytest.importorskip("metatensor.torch.atomistic") -meshlode_metatensor = pytest.importorskip("meshlode.metatensor") ATOMIC_SMEARING = 0.1 @@ -22,9 +23,9 @@ @pytest.mark.parametrize( "CalculatorClass, params", [ - (meshlode_metatensor.DirectPotential, {}), + (meshlode.metatensor.DirectPotential, {}), ( - meshlode_metatensor.EwaldPotential, + meshlode.metatensor.EwaldPotential, { "atomic_smearing": ATOMIC_SMEARING, "lr_wavelength": LR_WAVELENGTH, @@ -32,7 +33,7 @@ }, ), ( - meshlode_metatensor.PMEPotential, + meshlode.metatensor.PMEPotential, { "atomic_smearing": ATOMIC_SMEARING, "mesh_spacing": MESH_SPACING,