Skip to content

Commit

Permalink
prettify pair potentials
Browse files Browse the repository at this point in the history
  • Loading branch information
jla-gardner committed Mar 6, 2024
1 parent 18e9f98 commit 549599b
Show file tree
Hide file tree
Showing 5 changed files with 114 additions and 138 deletions.
4 changes: 2 additions & 2 deletions src/graph_pes/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
keys,
sum_per_structure,
)
from graph_pes.transform import Identity, PerAtomStandardScaler, Transform
from graph_pes.transform import PerAtomStandardScaler, Transform
from graph_pes.util import differentiate, require_grad


Expand Down Expand Up @@ -61,7 +61,7 @@ def __init__(self):
super().__init__()
# assigned here to appease torchscript and mypy:
# this gets overridden in pre_fit below
self.energy_transform: Transform = Identity()
self.energy_transform: Transform = PerAtomStandardScaler()

def pre_fit(self, graphs: AtomicGraphBatch):
"""
Expand Down
54 changes: 48 additions & 6 deletions src/graph_pes/models/pairwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
sum_over_neighbours,
)
from graph_pes.transform import PerAtomShift
from graph_pes.util import pytorch_repr, to_significant_figures
from jaxtyping import Float
from torch import Tensor

Expand Down Expand Up @@ -112,6 +113,17 @@ def __init__(self, epsilon: float = 0.1, sigma: float = 1.0):
self._log_epsilon = torch.nn.Parameter(torch.tensor(epsilon).log())
self._log_sigma = torch.nn.Parameter(torch.tensor(sigma).log())

# epsilon is a scaling term, so only need to learn a shift
self.energy_transform = PerAtomShift()

@property
def epsilon(self):
return self._log_epsilon.exp()

@property
def sigma(self):
return self._log_sigma.exp()

# don't use Z_i and Z_j, but include them for consistency with the
# abstract method
def interaction(self, r: torch.Tensor, Z_i=None, Z_j=None):
Expand All @@ -123,11 +135,9 @@ def interaction(self, r: torch.Tensor, Z_i=None, Z_j=None):
r
The pair-wise distances between the atoms.
"""
epsilon = self._log_epsilon.exp()
sigma = self._log_sigma.exp()

x = sigma / r
return 4 * epsilon * (x**12 - x**6)
x = self.sigma / r
return 4 * self.epsilon * (x**12 - x**6)

def pre_fit(self, graph: AtomicGraphBatch):
assert "energy" in graph
Expand All @@ -147,6 +157,16 @@ def pre_fit(self, graph: AtomicGraphBatch):
d = torch.quantile(neighbour_distances(graph), 0.01)
self._log_sigma = torch.nn.Parameter(d.log())

def __repr__(self):
return pytorch_repr(
"LennardJones",
_modules={
"epsilon": to_significant_figures(self.epsilon.item(), 3),
"sigma": to_significant_figures(self.sigma.item(), 3),
"energy_transform": self.energy_transform,
},
)


class Morse(PairPotential):
r"""
Expand All @@ -170,6 +190,18 @@ def __init__(self, D: float = 0.1, a: float = 3.0, r0: float = 1.0):
# parameter (rather than a shift and scale)
self.energy_transform = PerAtomShift()

@property
def D(self):
return self._log_D.exp()

@property
def a(self):
return self._log_a.exp()

@property
def r0(self):
return self._log_r0.exp()

def interaction(
self, r: torch.Tensor, Z_i: torch.Tensor, Z_j: torch.Tensor
):
Expand All @@ -185,8 +217,7 @@ def interaction(
Z_j : torch.Tensor
The atomic numbers of the neighbours. (unused)
"""
D, a, r0 = self._log_D.exp(), self._log_a.exp(), self._log_r0.exp()
return D * (1 - torch.exp(-a * (r - r0))) ** 2
return self.D * (1 - torch.exp(-self.a * (r - self.r0))) ** 2

def pre_fit(self, graph: AtomicGraphBatch):
assert "energy" in graph
Expand All @@ -203,3 +234,14 @@ def pre_fit(self, graph: AtomicGraphBatch):

# set the width to be "reasonable"
self._log_a = torch.nn.Parameter(torch.tensor(3.0).log())

def __repr__(self):
return pytorch_repr(
"Morse",
_modules={
"D": to_significant_figures(self.D.item(), 3),
"a": to_significant_figures(self.a.item(), 3),
"r0": to_significant_figures(self.r0.item(), 3),
"energy_transform": self.energy_transform,
},
)
131 changes: 15 additions & 116 deletions src/graph_pes/nn.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import Callable

import torch
Expand All @@ -10,9 +9,6 @@

from .util import MAX_Z, pairs

# TODO support access to .data via property and setter on ConstrainedParameter
# / cleanup the ConstrainedParameter class


class MLP(nn.Module):
"""
Expand Down Expand Up @@ -250,10 +246,10 @@ def numel(self) -> int:

def __repr__(self) -> str:
if len(self._accessed_Zs) == 0:
return (
f"PerSpeciesParameter(dim={tuple(self.shape[1:])}, "
f"requires_grad={self.requires_grad})"
)
dim = tuple(self.shape[1:])
if dim == (1,):
return "PerSpeciesParameter()"
return f"PerSpeciesParameter(dim={dim})"

torch.set_printoptions(threshold=3)
Zs = sorted(self._accessed_Zs)
Expand All @@ -271,15 +267,22 @@ def __repr__(self) -> str:

torch.set_printoptions(profile="default")

dim = tuple(self.shape[1:])
if dim == (1,):
return f"""\
PerSpeciesParameter(
{lines},
)"""
return f"""\
PerSpeciesParameter({{
{lines}
}}, dim={tuple(self.shape[1:])}, requires_grad={self.requires_grad})"""
PerSpeciesParameter(
{lines},
dim={dim},
)"""


class PerSpeciesEmbedding(torch.nn.Module):
"""
A per-speices equivalent of `torch.nn.Embedding`.
A per-species equivalent of `torch.nn.Embedding`.
Parameters
----------
Expand All @@ -304,110 +307,6 @@ def __repr__(self) -> str:
return f"PerSpeciesEmbedding(dim={self._embeddings.shape[1]})"


class ConstrainedParameter(nn.Module, ABC):
"""
Abstract base class for constrained parameters.
Implementations should override the `_constrained_value` property.
Parameters
----------
x
The initial value of the parameter.
requires_grad
Whether the parameter should be trainable.
"""

def __init__(self, x: torch.Tensor, requires_grad: bool = True):
super().__init__()
self._parameter = nn.Parameter(x, requires_grad)

@property
@abstractmethod
def constrained_value(self) -> torch.Tensor:
"""generate the constrained value"""

def _do_math(self, other, function, rev=False):
if isinstance(other, ConstrainedParameter):
other_value = other.constrained_value
else:
other_value = other

if rev:
return function(other_value, self.constrained_value)
else:
return function(self.constrained_value, other_value)

def __add__(self, other):
return self._do_math(other, torch.add)

def __radd__(self, other):
return self._do_math(other, torch.add, rev=True)

def __sub__(self, other):
return self._do_math(other, torch.sub)

def __rsub__(self, other):
return self._do_math(other, torch.sub, rev=True)

def __mul__(self, other):
return self._do_math(other, torch.mul)

def __rmul__(self, other):
return self._do_math(other, torch.mul, rev=True)

def __truediv__(self, other):
return self._do_math(other, torch.true_divide)

def __rtruediv__(self, other):
return self._do_math(other, torch.true_divide, rev=True)

def __pow__(self, other):
return self._do_math(other, torch.pow)

def __rpow__(self, other):
return self._do_math(other, torch.pow, rev=True)

def log(self):
return torch.log(self.constrained_value)

def sqrt(self):
return torch.sqrt(self.constrained_value)

def __repr__(self):
t = self.constrained_value
if t.numel() == 1:
return f"{self.__class__.__name__}({t.item():.4f})"
return f"{self.__class__.__name__}({t})"

def __neg__(self):
return self._do_math(0, torch.sub, rev=True)


# TODO: make this work with torchscript
class PositiveParameter(ConstrainedParameter):
"""
Drop-in replacement for :class:`torch.nn.Parameter`. An internal
exponentiation ensures that the parameter is always positive.
Parameters
----------
x
The initial value of the parameter. Must be positive.
requires_grad
Whether the parameter should be trainable.
"""

def __init__(self, x: torch.Tensor | float, requires_grad: bool = True):
if not isinstance(x, torch.Tensor):
x = torch.tensor(x)
super().__init__(torch.log(x), requires_grad)

@property
def constrained_value(self):
return torch.exp(self._parameter)


class HaddamardProduct(nn.Module):
def __init__(self, *components: nn.Module, left_aligned: bool = False):
super().__init__()
Expand Down
49 changes: 49 additions & 0 deletions src/graph_pes/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,52 @@ def require_grad(tensor: torch.Tensor):
tensor.requires_grad_(True)
yield
tensor.requires_grad_(req_grad)


def to_significant_figures(x: float | int, sf: int = 3) -> float:
"""
Get a string representation of a float, rounded to
`sf` significant figures.
"""

# do the actual rounding
possibly_scientific = f"{x:.{sf}g}"

# this might be in e.g. 1.23e+02 format,
# so convert to float and back to string
return float(possibly_scientific)


def pytorch_repr(
name: str,
_modules: dict | None = None,
extra_repr: str = "",
) -> str:
# lifted from torch.nn.Module.__repr__
from torch.nn.modules.module import _addindent

if _modules is None:
_modules = {}

# We treat the extra repr like the sub-module, one item per line
extra_lines = []
# empty string will be split into list ['']
if extra_repr:
extra_lines = extra_repr.split("\n")
child_lines = []
for key, module in _modules.items():
mod_str = repr(module)
mod_str = _addindent(mod_str, 2)
child_lines.append("(" + key + "): " + mod_str)
lines = extra_lines + child_lines

main_str = name + "("
if lines:
# simple one-liner info, which most builtin Modules will use
if len(extra_lines) == 1 and not child_lines:
main_str += extra_lines[0]
else:
main_str += "\n " + "\n ".join(lines) + "\n"

main_str += ")"
return main_str
14 changes: 0 additions & 14 deletions tests/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
MLP,
PerSpeciesEmbedding,
PerSpeciesParameter,
PositiveParameter,
)
from graph_pes.util import MAX_Z

Expand Down Expand Up @@ -50,16 +49,3 @@ def test_mlp():

# test nice repr
assert "MLP(10 → 20 → 1" in str(mlp)


def test_positive_parameter():
x = torch.tensor([1, 2, 3]).float()
positive_x = PositiveParameter(x)

a = torch.tensor([-1, 0, 1]).float()

assert torch.allclose(positive_x + a, x + a)
assert torch.allclose(positive_x - a, x - a)
assert torch.allclose(positive_x * a, x * a)
assert torch.allclose(positive_x / a, x / a)
assert torch.allclose(positive_x.log(), x.log())

0 comments on commit 549599b

Please sign in to comment.