From 549599bc8079a6ef67b7f3a6a0f694a19fb2130e Mon Sep 17 00:00:00 2001 From: John Gardner Date: Wed, 6 Mar 2024 08:59:17 +0000 Subject: [PATCH] prettify pair potentials --- src/graph_pes/core.py | 4 +- src/graph_pes/models/pairwise.py | 54 +++++++++++-- src/graph_pes/nn.py | 131 ++++--------------------------- src/graph_pes/util.py | 49 ++++++++++++ tests/test_nn.py | 14 ---- 5 files changed, 114 insertions(+), 138 deletions(-) diff --git a/src/graph_pes/core.py b/src/graph_pes/core.py index 89887b71..fd373c39 100644 --- a/src/graph_pes/core.py +++ b/src/graph_pes/core.py @@ -14,7 +14,7 @@ keys, sum_per_structure, ) -from graph_pes.transform import Identity, PerAtomStandardScaler, Transform +from graph_pes.transform import PerAtomStandardScaler, Transform from graph_pes.util import differentiate, require_grad @@ -61,7 +61,7 @@ def __init__(self): super().__init__() # assigned here to appease torchscript and mypy: # this gets overridden in pre_fit below - self.energy_transform: Transform = Identity() + self.energy_transform: Transform = PerAtomStandardScaler() def pre_fit(self, graphs: AtomicGraphBatch): """ diff --git a/src/graph_pes/models/pairwise.py b/src/graph_pes/models/pairwise.py index 3c418d88..2d776701 100644 --- a/src/graph_pes/models/pairwise.py +++ b/src/graph_pes/models/pairwise.py @@ -12,6 +12,7 @@ sum_over_neighbours, ) from graph_pes.transform import PerAtomShift +from graph_pes.util import pytorch_repr, to_significant_figures from jaxtyping import Float from torch import Tensor @@ -112,6 +113,17 @@ def __init__(self, epsilon: float = 0.1, sigma: float = 1.0): self._log_epsilon = torch.nn.Parameter(torch.tensor(epsilon).log()) self._log_sigma = torch.nn.Parameter(torch.tensor(sigma).log()) + # epsilon is a scaling term, so only need to learn a shift + self.energy_transform = PerAtomShift() + + @property + def epsilon(self): + return self._log_epsilon.exp() + + @property + def sigma(self): + return self._log_sigma.exp() + # don't use Z_i and Z_j, but include them for consistency with the # abstract method def interaction(self, r: torch.Tensor, Z_i=None, Z_j=None): @@ -123,11 +135,9 @@ def interaction(self, r: torch.Tensor, Z_i=None, Z_j=None): r The pair-wise distances between the atoms. """ - epsilon = self._log_epsilon.exp() - sigma = self._log_sigma.exp() - x = sigma / r - return 4 * epsilon * (x**12 - x**6) + x = self.sigma / r + return 4 * self.epsilon * (x**12 - x**6) def pre_fit(self, graph: AtomicGraphBatch): assert "energy" in graph @@ -147,6 +157,16 @@ def pre_fit(self, graph: AtomicGraphBatch): d = torch.quantile(neighbour_distances(graph), 0.01) self._log_sigma = torch.nn.Parameter(d.log()) + def __repr__(self): + return pytorch_repr( + "LennardJones", + _modules={ + "epsilon": to_significant_figures(self.epsilon.item(), 3), + "sigma": to_significant_figures(self.sigma.item(), 3), + "energy_transform": self.energy_transform, + }, + ) + class Morse(PairPotential): r""" @@ -170,6 +190,18 @@ def __init__(self, D: float = 0.1, a: float = 3.0, r0: float = 1.0): # parameter (rather than a shift and scale) self.energy_transform = PerAtomShift() + @property + def D(self): + return self._log_D.exp() + + @property + def a(self): + return self._log_a.exp() + + @property + def r0(self): + return self._log_r0.exp() + def interaction( self, r: torch.Tensor, Z_i: torch.Tensor, Z_j: torch.Tensor ): @@ -185,8 +217,7 @@ def interaction( Z_j : torch.Tensor The atomic numbers of the neighbours. (unused) """ - D, a, r0 = self._log_D.exp(), self._log_a.exp(), self._log_r0.exp() - return D * (1 - torch.exp(-a * (r - r0))) ** 2 + return self.D * (1 - torch.exp(-self.a * (r - self.r0))) ** 2 def pre_fit(self, graph: AtomicGraphBatch): assert "energy" in graph @@ -203,3 +234,14 @@ def pre_fit(self, graph: AtomicGraphBatch): # set the width to be "reasonable" self._log_a = torch.nn.Parameter(torch.tensor(3.0).log()) + + def __repr__(self): + return pytorch_repr( + "Morse", + _modules={ + "D": to_significant_figures(self.D.item(), 3), + "a": to_significant_figures(self.a.item(), 3), + "r0": to_significant_figures(self.r0.item(), 3), + "energy_transform": self.energy_transform, + }, + ) diff --git a/src/graph_pes/nn.py b/src/graph_pes/nn.py index be12edd4..1fc9175c 100644 --- a/src/graph_pes/nn.py +++ b/src/graph_pes/nn.py @@ -1,6 +1,5 @@ from __future__ import annotations -from abc import ABC, abstractmethod from typing import Callable import torch @@ -10,9 +9,6 @@ from .util import MAX_Z, pairs -# TODO support access to .data via property and setter on ConstrainedParameter -# / cleanup the ConstrainedParameter class - class MLP(nn.Module): """ @@ -250,10 +246,10 @@ def numel(self) -> int: def __repr__(self) -> str: if len(self._accessed_Zs) == 0: - return ( - f"PerSpeciesParameter(dim={tuple(self.shape[1:])}, " - f"requires_grad={self.requires_grad})" - ) + dim = tuple(self.shape[1:]) + if dim == (1,): + return "PerSpeciesParameter()" + return f"PerSpeciesParameter(dim={dim})" torch.set_printoptions(threshold=3) Zs = sorted(self._accessed_Zs) @@ -271,15 +267,22 @@ def __repr__(self) -> str: torch.set_printoptions(profile="default") + dim = tuple(self.shape[1:]) + if dim == (1,): + return f"""\ +PerSpeciesParameter( + {lines}, +)""" return f"""\ -PerSpeciesParameter({{ - {lines} -}}, dim={tuple(self.shape[1:])}, requires_grad={self.requires_grad})""" +PerSpeciesParameter( + {lines}, + dim={dim}, +)""" class PerSpeciesEmbedding(torch.nn.Module): """ - A per-speices equivalent of `torch.nn.Embedding`. + A per-species equivalent of `torch.nn.Embedding`. Parameters ---------- @@ -304,110 +307,6 @@ def __repr__(self) -> str: return f"PerSpeciesEmbedding(dim={self._embeddings.shape[1]})" -class ConstrainedParameter(nn.Module, ABC): - """ - Abstract base class for constrained parameters. - - Implementations should override the `_constrained_value` property. - - Parameters - ---------- - x - The initial value of the parameter. - requires_grad - Whether the parameter should be trainable. - """ - - def __init__(self, x: torch.Tensor, requires_grad: bool = True): - super().__init__() - self._parameter = nn.Parameter(x, requires_grad) - - @property - @abstractmethod - def constrained_value(self) -> torch.Tensor: - """generate the constrained value""" - - def _do_math(self, other, function, rev=False): - if isinstance(other, ConstrainedParameter): - other_value = other.constrained_value - else: - other_value = other - - if rev: - return function(other_value, self.constrained_value) - else: - return function(self.constrained_value, other_value) - - def __add__(self, other): - return self._do_math(other, torch.add) - - def __radd__(self, other): - return self._do_math(other, torch.add, rev=True) - - def __sub__(self, other): - return self._do_math(other, torch.sub) - - def __rsub__(self, other): - return self._do_math(other, torch.sub, rev=True) - - def __mul__(self, other): - return self._do_math(other, torch.mul) - - def __rmul__(self, other): - return self._do_math(other, torch.mul, rev=True) - - def __truediv__(self, other): - return self._do_math(other, torch.true_divide) - - def __rtruediv__(self, other): - return self._do_math(other, torch.true_divide, rev=True) - - def __pow__(self, other): - return self._do_math(other, torch.pow) - - def __rpow__(self, other): - return self._do_math(other, torch.pow, rev=True) - - def log(self): - return torch.log(self.constrained_value) - - def sqrt(self): - return torch.sqrt(self.constrained_value) - - def __repr__(self): - t = self.constrained_value - if t.numel() == 1: - return f"{self.__class__.__name__}({t.item():.4f})" - return f"{self.__class__.__name__}({t})" - - def __neg__(self): - return self._do_math(0, torch.sub, rev=True) - - -# TODO: make this work with torchscript -class PositiveParameter(ConstrainedParameter): - """ - Drop-in replacement for :class:`torch.nn.Parameter`. An internal - exponentiation ensures that the parameter is always positive. - - Parameters - ---------- - x - The initial value of the parameter. Must be positive. - requires_grad - Whether the parameter should be trainable. - """ - - def __init__(self, x: torch.Tensor | float, requires_grad: bool = True): - if not isinstance(x, torch.Tensor): - x = torch.tensor(x) - super().__init__(torch.log(x), requires_grad) - - @property - def constrained_value(self): - return torch.exp(self._parameter) - - class HaddamardProduct(nn.Module): def __init__(self, *components: nn.Module, left_aligned: bool = False): super().__init__() diff --git a/src/graph_pes/util.py b/src/graph_pes/util.py index 557d11f5..5b61b675 100644 --- a/src/graph_pes/util.py +++ b/src/graph_pes/util.py @@ -95,3 +95,52 @@ def require_grad(tensor: torch.Tensor): tensor.requires_grad_(True) yield tensor.requires_grad_(req_grad) + + +def to_significant_figures(x: float | int, sf: int = 3) -> float: + """ + Get a string representation of a float, rounded to + `sf` significant figures. + """ + + # do the actual rounding + possibly_scientific = f"{x:.{sf}g}" + + # this might be in e.g. 1.23e+02 format, + # so convert to float and back to string + return float(possibly_scientific) + + +def pytorch_repr( + name: str, + _modules: dict | None = None, + extra_repr: str = "", +) -> str: + # lifted from torch.nn.Module.__repr__ + from torch.nn.modules.module import _addindent + + if _modules is None: + _modules = {} + + # We treat the extra repr like the sub-module, one item per line + extra_lines = [] + # empty string will be split into list [''] + if extra_repr: + extra_lines = extra_repr.split("\n") + child_lines = [] + for key, module in _modules.items(): + mod_str = repr(module) + mod_str = _addindent(mod_str, 2) + child_lines.append("(" + key + "): " + mod_str) + lines = extra_lines + child_lines + + main_str = name + "(" + if lines: + # simple one-liner info, which most builtin Modules will use + if len(extra_lines) == 1 and not child_lines: + main_str += extra_lines[0] + else: + main_str += "\n " + "\n ".join(lines) + "\n" + + main_str += ")" + return main_str diff --git a/tests/test_nn.py b/tests/test_nn.py index 713ed81a..60abdb65 100644 --- a/tests/test_nn.py +++ b/tests/test_nn.py @@ -5,7 +5,6 @@ MLP, PerSpeciesEmbedding, PerSpeciesParameter, - PositiveParameter, ) from graph_pes.util import MAX_Z @@ -50,16 +49,3 @@ def test_mlp(): # test nice repr assert "MLP(10 → 20 → 1" in str(mlp) - - -def test_positive_parameter(): - x = torch.tensor([1, 2, 3]).float() - positive_x = PositiveParameter(x) - - a = torch.tensor([-1, 0, 1]).float() - - assert torch.allclose(positive_x + a, x + a) - assert torch.allclose(positive_x - a, x - a) - assert torch.allclose(positive_x * a, x * a) - assert torch.allclose(positive_x / a, x / a) - assert torch.allclose(positive_x.log(), x.log())