diff --git a/src/graph_pes/models/pairwise.py b/src/graph_pes/models/pairwise.py index f6490548..ecb7c55d 100644 --- a/src/graph_pes/models/pairwise.py +++ b/src/graph_pes/models/pairwise.py @@ -10,7 +10,6 @@ keys, neighbour_distances, ) -from graph_pes.nn import PositiveParameter from graph_pes.transform import PerAtomScale, PerAtomShift from jaxtyping import Float from torch import Tensor @@ -98,10 +97,9 @@ class LennardJones(PairPotential): \frac{\sigma}{r_{ij}} \right)^{12} - \left( \frac{\sigma}{r_{ij}} \right)^{6} \right] - where :math:`r_{ij}` is the distance between atoms :math:`i` and :math:`j`. - Internally, :math:`\varepsilon` and :math:`\sigma` are stored as - :class:`PositiveParameter ` instances, - which ensures that they are kept strictly positive during training. + where :math:`r_{ij}` is the distance between atoms :math:`i` and :math:`j`, + and :math:`\varepsilon` and :math:`\sigma` are strictly positive + paramters that control the depth and width of the potential well, Parameters ---------- @@ -154,29 +152,19 @@ class Morse(PairPotential): V(r_{ij}, Z_i, Z_j) = V(r_{ij}) = D (1 - e^{-a(r_{ij} - r_0)})^2 where :math:`r_{ij}` is the distance between atoms :math:`i` and :math:`j`, - and :math:`D`, :math:`a` and :math:`r_0` control the depth, width and - center of the potential well, respectively. Internally, these are stored - as :class:`PositiveParameter` instances. - - Attributes - ---------- - D: :class:`PositiveParameter ` - The depth of the potential. - a: :class:`PositiveParameter ` - The width of the potential. - r0: :class:`PositiveParameter ` - The center of the potential. + and :math:`D`, :math:`a` and :math:`r_0` are strictly positive parameters + that control the depth, width and center of the potential well respectively. """ - def __init__(self): + def __init__(self, D: float = 0.1, a: float = 3.0, r0: float = 1.0): super().__init__() - self.D = PositiveParameter(0.1) - self.a = PositiveParameter(1.0) - self.r0 = PositiveParameter(0.5) + self._log_D = torch.nn.Parameter(torch.tensor(D).log()) + self._log_a = torch.nn.Parameter(torch.tensor(a).log()) + self._log_r0 = torch.nn.Parameter(torch.tensor(r0).log()) # D is a scaling term, so only need to learn a shift # parameter (rather than a shift and scale) - self._energy_summation = EnergySummation(local_transform=PerAtomScale()) + self.energy_summation = EnergySummation(local_transform=PerAtomScale()) def interaction( self, r: torch.Tensor, Z_i: torch.Tensor, Z_j: torch.Tensor @@ -193,18 +181,19 @@ def interaction( Z_j : torch.Tensor The atomic numbers of the neighbours. (unused) """ - return self.D * (1 - torch.exp(-self.a * (r - self.r0))) ** 2 + D, a, r0 = self._log_D.exp(), self._log_a.exp(), self._log_r0.exp() + return D * (1 - torch.exp(-a * (r - r0))) ** 2 def pre_fit(self, graph: AtomicGraphBatch): super().pre_fit(graph) # set the potential depth to be shallow - self.D = PositiveParameter(0.1) + self._log_D = torch.nn.Parameter(torch.tensor(0.1).log()) # set the center of the well to be close to the minimum pair-wise # distance d = torch.quantile(neighbour_distances(graph), 0.01) - self.r0 = PositiveParameter(d) + self._log_r0 = torch.nn.Parameter(d.log()) - # set the width to be broad - self.a = PositiveParameter(0.5) + # set the width to be "reasonable" + self._log_a = torch.nn.Parameter(torch.tensor(3.0).log()) diff --git a/src/graph_pes/models/zoo.py b/src/graph_pes/models/zoo.py index c40a1133..b7a0c5ff 100644 --- a/src/graph_pes/models/zoo.py +++ b/src/graph_pes/models/zoo.py @@ -1,5 +1,5 @@ from .painn import PaiNN -from .pairwise import LennardJones +from .pairwise import LennardJones, Morse from .schnet import SchNet from .tensornet import TensorNet @@ -8,4 +8,5 @@ "LennardJones", "SchNet", "TensorNet", + "Morse", ] diff --git a/src/graph_pes/nn.py b/src/graph_pes/nn.py index fe87aba6..e8e651c4 100644 --- a/src/graph_pes/nn.py +++ b/src/graph_pes/nn.py @@ -405,14 +405,18 @@ def constrained_value(self): class HaddamardProduct(nn.Module): - def __init__(self, *components: nn.Module): + def __init__(self, *components: nn.Module, left_aligned: bool = False): super().__init__() self.components: list[nn.Module] = nn.ModuleList(components) # type: ignore + self.left_aligned = left_aligned def forward(self, x): - out = 1 + out = torch.scalar_tensor(1) for component in self.components: - out = out * component(x) + if self.left_aligned: + out = left_aligned_mul(out, component(x)) + else: + out = out * component(x) return out diff --git a/tests/test_torchscript.py b/tests/test_torchscript.py index e2e47102..5faab000 100644 --- a/tests/test_torchscript.py +++ b/tests/test_torchscript.py @@ -4,7 +4,7 @@ import torch from ase.build import molecule from graph_pes.data import batch_graphs, convert_to_atomic_graph -from graph_pes.models.zoo import LennardJones, PaiNN, SchNet, TensorNet +from graph_pes.models.zoo import LennardJones, Morse, PaiNN, SchNet, TensorNet graph = convert_to_atomic_graph(molecule("CH3CH2OH"), cutoff=1.5) batch = batch_graphs([graph, graph]) @@ -13,6 +13,7 @@ PaiNN(), SchNet(), TensorNet(), + Morse(), ]