Skip to content

Commit

Permalink
test coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
jla-gardner committed Mar 4, 2024
1 parent e8ae3e6 commit 907ceff
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 32 deletions.
43 changes: 16 additions & 27 deletions src/graph_pes/models/pairwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
keys,
neighbour_distances,
)
from graph_pes.nn import PositiveParameter
from graph_pes.transform import PerAtomScale, PerAtomShift
from jaxtyping import Float
from torch import Tensor
Expand Down Expand Up @@ -98,10 +97,9 @@ class LennardJones(PairPotential):
\frac{\sigma}{r_{ij}} \right)^{12} - \left( \frac{\sigma}{r_{ij}}
\right)^{6} \right]
where :math:`r_{ij}` is the distance between atoms :math:`i` and :math:`j`.
Internally, :math:`\varepsilon` and :math:`\sigma` are stored as
:class:`PositiveParameter <graph_pes.nn.PositiveParamerer>` instances,
which ensures that they are kept strictly positive during training.
where :math:`r_{ij}` is the distance between atoms :math:`i` and :math:`j`,
and :math:`\varepsilon` and :math:`\sigma` are strictly positive
paramters that control the depth and width of the potential well,
Parameters
----------
Expand Down Expand Up @@ -154,29 +152,19 @@ class Morse(PairPotential):
V(r_{ij}, Z_i, Z_j) = V(r_{ij}) = D (1 - e^{-a(r_{ij} - r_0)})^2
where :math:`r_{ij}` is the distance between atoms :math:`i` and :math:`j`,
and :math:`D`, :math:`a` and :math:`r_0` control the depth, width and
center of the potential well, respectively. Internally, these are stored
as :class:`PositiveParameter` instances.
Attributes
----------
D: :class:`PositiveParameter <graph_pes.nn.PositiveParameter>`
The depth of the potential.
a: :class:`PositiveParameter <graph_pes.nn.PositiveParameter>`
The width of the potential.
r0: :class:`PositiveParameter <graph_pes.nn.PositiveParameter>`
The center of the potential.
and :math:`D`, :math:`a` and :math:`r_0` are strictly positive parameters
that control the depth, width and center of the potential well respectively.
"""

def __init__(self):
def __init__(self, D: float = 0.1, a: float = 3.0, r0: float = 1.0):
super().__init__()
self.D = PositiveParameter(0.1)
self.a = PositiveParameter(1.0)
self.r0 = PositiveParameter(0.5)
self._log_D = torch.nn.Parameter(torch.tensor(D).log())
self._log_a = torch.nn.Parameter(torch.tensor(a).log())
self._log_r0 = torch.nn.Parameter(torch.tensor(r0).log())

# D is a scaling term, so only need to learn a shift
# parameter (rather than a shift and scale)
self._energy_summation = EnergySummation(local_transform=PerAtomScale())
self.energy_summation = EnergySummation(local_transform=PerAtomScale())

def interaction(
self, r: torch.Tensor, Z_i: torch.Tensor, Z_j: torch.Tensor
Expand All @@ -193,18 +181,19 @@ def interaction(
Z_j : torch.Tensor
The atomic numbers of the neighbours. (unused)
"""
return self.D * (1 - torch.exp(-self.a * (r - self.r0))) ** 2
D, a, r0 = self._log_D.exp(), self._log_a.exp(), self._log_r0.exp()
return D * (1 - torch.exp(-a * (r - r0))) ** 2

def pre_fit(self, graph: AtomicGraphBatch):
super().pre_fit(graph)

# set the potential depth to be shallow
self.D = PositiveParameter(0.1)
self._log_D = torch.nn.Parameter(torch.tensor(0.1).log())

Check warning on line 191 in src/graph_pes/models/pairwise.py

View check run for this annotation

Codecov / codecov/patch

src/graph_pes/models/pairwise.py#L191

Added line #L191 was not covered by tests

# set the center of the well to be close to the minimum pair-wise
# distance
d = torch.quantile(neighbour_distances(graph), 0.01)
self.r0 = PositiveParameter(d)
self._log_r0 = torch.nn.Parameter(d.log())

Check warning on line 196 in src/graph_pes/models/pairwise.py

View check run for this annotation

Codecov / codecov/patch

src/graph_pes/models/pairwise.py#L195-L196

Added lines #L195 - L196 were not covered by tests

# set the width to be broad
self.a = PositiveParameter(0.5)
# set the width to be "reasonable"
self._log_a = torch.nn.Parameter(torch.tensor(3.0).log())

Check warning on line 199 in src/graph_pes/models/pairwise.py

View check run for this annotation

Codecov / codecov/patch

src/graph_pes/models/pairwise.py#L199

Added line #L199 was not covered by tests
3 changes: 2 additions & 1 deletion src/graph_pes/models/zoo.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .painn import PaiNN
from .pairwise import LennardJones
from .pairwise import LennardJones, Morse
from .schnet import SchNet
from .tensornet import TensorNet

Expand All @@ -8,4 +8,5 @@
"LennardJones",
"SchNet",
"TensorNet",
"Morse",
]
10 changes: 7 additions & 3 deletions src/graph_pes/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,14 +405,18 @@ def constrained_value(self):


class HaddamardProduct(nn.Module):
def __init__(self, *components: nn.Module):
def __init__(self, *components: nn.Module, left_aligned: bool = False):
super().__init__()
self.components: list[nn.Module] = nn.ModuleList(components) # type: ignore
self.left_aligned = left_aligned

def forward(self, x):
out = 1
out = torch.scalar_tensor(1)
for component in self.components:
out = out * component(x)
if self.left_aligned:
out = left_aligned_mul(out, component(x))
else:
out = out * component(x)
return out


Expand Down
3 changes: 2 additions & 1 deletion tests/test_torchscript.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch
from ase.build import molecule
from graph_pes.data import batch_graphs, convert_to_atomic_graph
from graph_pes.models.zoo import LennardJones, PaiNN, SchNet, TensorNet
from graph_pes.models.zoo import LennardJones, Morse, PaiNN, SchNet, TensorNet

graph = convert_to_atomic_graph(molecule("CH3CH2OH"), cutoff=1.5)
batch = batch_graphs([graph, graph])
Expand All @@ -13,6 +13,7 @@
PaiNN(),
SchNet(),
TensorNet(),
Morse(),
]


Expand Down

0 comments on commit 907ceff

Please sign in to comment.