From 564b6b3dcda14dae07ddf956fb1c3a0945331881 Mon Sep 17 00:00:00 2001 From: John Gardner Date: Fri, 12 Jan 2024 11:32:15 +0000 Subject: [PATCH] more --- docs/source/index.rst | 1 + docs/source/transforms.rst | 13 ++ src/graph_pes/core.py | 11 +- src/graph_pes/nn.py | 25 ++-- src/graph_pes/transform.py | 250 ++++++++++++++++++++++--------------- tests/test_nn.py | 25 +++- tests/test_transform.py | 17 +++ 7 files changed, 217 insertions(+), 125 deletions(-) create mode 100644 docs/source/transforms.rst create mode 100644 tests/test_transform.py diff --git a/docs/source/index.rst b/docs/source/index.rst index 9ba2c397..079b6522 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -7,6 +7,7 @@ Home data nn + transforms ######## GraphPES diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst new file mode 100644 index 00000000..78317c7d --- /dev/null +++ b/docs/source/transforms.rst @@ -0,0 +1,13 @@ +########### +Transforms +########### + +.. autoclass :: graph_pes.transform.Transform + :members: + :private-members: + +.. autoclass :: graph_pes.transform.Identity() + +.. autoclass :: graph_pes.transform.Chain + +.. autoclass :: graph_pes.transform.PerSpeciesOffset diff --git a/src/graph_pes/core.py b/src/graph_pes/core.py index 500e5489..e071ec52 100644 --- a/src/graph_pes/core.py +++ b/src/graph_pes/core.py @@ -12,6 +12,7 @@ Identity, PerSpeciesOffset, PerSpeciesScale, + Transform, ) from torch import nn @@ -74,7 +75,7 @@ def forward(self, graph: AtomicGraph | AtomicGraphBatch): def __init__(self): super().__init__() - self._energy_transforms = nn.ModuleDict( + self._energy_transforms: dict[str, Transform] = nn.ModuleDict( # type: ignore { "local": Chain( [PerSpeciesScale(), PerSpeciesOffset()], trainable=True @@ -176,8 +177,8 @@ def get_predictions(pes: GraphPESModel, structure: AtomicGraph) -> Prediction: # an infinitesimal change in the cell parameters actual_cell = structure.cell change_to_cell = torch.zeros_like(actual_cell, requires_grad=True) - # symmetric_change = 0.5 * (change_to_cell + change_to_cell.transpose(-1, -2)) - structure.cell = actual_cell + change_to_cell + symmetric_change = 0.5 * (change_to_cell + change_to_cell.transpose(-1, -2)) + structure.cell = actual_cell + symmetric_change with require_grad(structure._positions): total_energy = pes(structure) (dE_dR, dCell_dR) = get_gradient( @@ -209,10 +210,6 @@ def energy_and_forces(pes: GraphPESModel, structure: AtomicGraph): The energy of the structure and forces on each atom. """ - # TODO handle the case where isolated atoms are present - # such that the gradient of energy wrt their positions - # is zero. - # use the autograd machinery to auto-magically # calculate forces for (almost) free structure._positions.requires_grad_(True) diff --git a/src/graph_pes/nn.py b/src/graph_pes/nn.py index 79bea295..ebd1b069 100644 --- a/src/graph_pes/nn.py +++ b/src/graph_pes/nn.py @@ -178,7 +178,10 @@ def of_dim( cls, dim: int, requires_grad: bool = True, - generator: Callable[[tuple[int, int]], Tensor] | None = None, + generator: Callable[[tuple[int, int]], Tensor] + | int + | float + | None = None, ): """ Create a `PerSpeciesParameter` of the given dimension. @@ -193,9 +196,12 @@ def of_dim( requires_grad Whether the parameter should be trainable. """ - if generator is None: - generator = torch.randn - data = generator((MAX_Z, dim)) + if isinstance(generator, (int, float)): + data = torch.full((MAX_Z, dim), generator).float() + elif generator is None: + data = torch.randn(MAX_Z, dim) + else: + data = generator((MAX_Z, dim)) return PerSpeciesParameter(data=data, requires_grad=requires_grad) def __getitem__(self, Z: int | Tensor) -> Tensor: @@ -206,7 +212,6 @@ def __getitem__(self, Z: int | Tensor) -> Tensor: ---------- Z The atomic number/s of the parameter to get. - """ if isinstance(Z, int): @@ -217,13 +222,7 @@ def __getitem__(self, Z: int | Tensor) -> Tensor: return super().__getitem__(Z) def numel(self) -> int: - """ - Get the number of trainable parameters. - - Returns - ------- - The number of trainable parameters. - """ + """Get the number of trainable parameters.""" return sum(self[Z].numel() for Z in self._accessed_Zs) @@ -367,5 +366,5 @@ def __init__(self, x: torch.Tensor | float, requires_grad: bool = True): super().__init__(torch.log(x), requires_grad) @property - def _constrained_value(self): + def constrained_value(self): return torch.exp(self._parameter) diff --git a/src/graph_pes/transform.py b/src/graph_pes/transform.py index 15668e3d..e201158e 100644 --- a/src/graph_pes/transform.py +++ b/src/graph_pes/transform.py @@ -6,52 +6,115 @@ import torch from graph_pes.data import AtomicGraph, AtomicGraphBatch, sum_per_structure from graph_pes.nn import PerSpeciesParameter -from torch import nn +from torch import Tensor, nn class Transform(nn.Module, ABC): r""" - Transforms data. + :math:`T: \mathbb{R}^n \rightarrow_{\mathcal{G}} \mathbb{R}^n` - .. math:: - \mathbf{x}^\prime = T(\mathbf{x}) + Abstract base class for shape-preserving transformations of + data, conditioned on an :class:`AtomicGraph `, + :math:`\mathcal{G}`. + + Subclasses should implement :meth:`forward`, :meth:`inverse`, + and optionally :meth:`fit_to_source` and :meth:`fit_to_target`. + + :meth:`_parameter` and :meth:`_per_species_parameter` are provided + as convenience methods for creating parameters that respect the + `trainable` flag. + + Parameters + ---------- + trainable + Whether the transform should be trainable. """ def __init__(self, trainable: bool = True): super().__init__() self.trainable = trainable - def _parameter(self, x) -> nn.Parameter: - """get an optionally trainable parameter""" + @abstractmethod + def forward(self, x: Tensor, graph: AtomicGraph) -> Tensor: + r""" + Implements the forward transformation, :math:`y = T(x, \mathcal{G})`. + + Parameters + ---------- + x + The input data. + graph + The graph to condition the transformation on. + + Returns + ------- + y: Tensor + The transformed data. + """ + + @abstractmethod + def inverse(self, x: Tensor, graph: AtomicGraph) -> Tensor: + r""" + Implements the inverse transformation, + :math:`x = T^{-1}(y, \mathcal{G})`. + + Parameters + ---------- + x + The input data. + graph + The graph to condition the inverse transformation on. + + Returns + ------- + x: Tensor + The inversely-transformed data. + """ + + def fit_to_source(self, data: Tensor, graphs: AtomicGraphBatch): + """ + Fits the transform to data in the source space, :math:`x`. + + Parameters + ---------- + data + The data, :math:`x`, to fit to. + graphs + The graphs to condition the transformation on. + """ + + def fit_to_target(self, data: Tensor, graphs: AtomicGraphBatch): + """ + Fits the transform to data in the target space, :math:`y`. + + Parameters + ---------- + data + The data, :math:`y`, to fit to. + graphs + The graphs to condition the inverse transformation on. + """ + + def _parameter(self, x: Tensor) -> nn.Parameter: + """Wrap `x` in an optionally trainable parameter.""" return nn.Parameter(x, requires_grad=self.trainable) def _per_species_parameter( self, - zs: torch.Tensor | None = None, - values: torch.Tensor | None = None, - default: float = 0.0, + generator: Callable[[tuple[int, int]], Tensor] | float = 0.0, ) -> PerSpeciesParameter: - """get an optionally trainable per-species parameter""" - return PerSpeciesParameter(zs, values, default, self.trainable) - - @abstractmethod - def forward(self, x: torch.Tensor, graph: AtomicGraph) -> torch.Tensor: - """implements the forward transformation""" - - @abstractmethod - def inverse(self, x: torch.Tensor, graph: AtomicGraph) -> torch.Tensor: - """implements the inverse transformation""" + """Generate an (optionally trainable) per-species parameter.""" + return PerSpeciesParameter.of_dim( + 1, requires_grad=self.trainable, generator=generator + ) - def fit_to_source(self, data: torch.Tensor, graphs: AtomicGraphBatch): - """fits the transform to data in the source space""" - pass - def fit_to_target(self, data: torch.Tensor, graphs: AtomicGraphBatch): - """fits the transform to data in the target space""" - pass +class Identity(Transform): + """The identity transform, provided for convenience.""" + def __init__(self): + super().__init__(trainable=False) -class Identity(Transform): def forward(self, x, graph): return x @@ -60,11 +123,29 @@ def inverse(self, x, graph): class Chain(Transform): + r""" + A chain of transformations, :math:`T_n \circ \dots \circ T_2 \circ T_1`. + + The forward transformation is applied sequentially from left to right, + :math:`y = T_n \circ \dots \circ T_2 \circ T_1(x, \mathcal{G})`. + + The inverse transformation is applied sequentially from right to left, + :math:`x = T_1^{-1} \circ T_2^{-1} \circ \dots + \circ T_n^{-1}(y, \mathcal{G})`. + + Parameters + ---------- + transforms + The transformations to chain together. + trainable + Whether the chain should be trainable. + """ + def __init__(self, transforms: list[Transform], trainable: bool = True): super().__init__(trainable) for t in transforms: t.trainable = trainable - self.transforms = nn.ModuleList(transforms) + self.transforms: list[Transform] = nn.ModuleList(transforms) # type: ignore def forward(self, x, graph): for t in self.transforms: @@ -73,19 +154,16 @@ def forward(self, x, graph): def inverse(self, x, graph): for t in reversed(self.transforms): - t: Transform x = t.inverse(x, graph) return x - def fit_to_source(self, data: torch.Tensor, graphs: AtomicGraphBatch): - for t in self.transforms: # type: ignore - t: Transform + def fit_to_source(self, data: Tensor, graphs: AtomicGraphBatch): + for t in self.transforms: t.fit_to_source(data, graphs) data = t(data, graphs) - def fit_to_target(self, data: torch.Tensor, graphs: AtomicGraphBatch): - for t in reversed(self.transforms): # type: ignore - t: Transform + def fit_to_target(self, data: Tensor, graphs: AtomicGraphBatch): + for t in reversed(self.transforms): t.fit_to_target(data, graphs) data = t.inverse(data, graphs) @@ -94,52 +172,17 @@ def is_local_property(x, graph): return len(x.shape) and x.shape[0] == graph.n_atoms -# class PerSpeciesTransform(Transform, ABC): -# r""" -# Uses a per-species parameter to transform a per-structure/per-atom property. -# """ - -# def __init__( -# self, -# trainable: bool = True, -# values: PerSpeciesParameter | None = None, -# default: float = 0.0, -# ): -# super().__init__(trainable=trainable) - -# if values is not None: -# values.values.requires_grad = trainable -# else: -# values = self._per_species_parameter(default=default) -# self.values = values - -# @staticmethod -# def is_local_property(x, graph): -# return len(x.shape) and x.shape[0] == graph.n_atoms - -# def forward(self, x: torch.Tensor, graph: AtomicGraph) -> torch.Tensor: -# values = self.values[graph.Z] -# if self.is_local_property(x, graph): -# return self.per_species_op(x, values) -# else: -# return self.per_structure_op(x, values, graph) - -# @abstractmethod -# def per_species_op( -# self, x: torch.Tensor, values: torch.Tensor -# ) -> torch.Tensor: -# """implements the per-species forward transformation""" - -# @abstractmethod -# def per_structure_op( -# self, x: torch.Tensor, values: torch.Tensor, graph: AtomicGraph -# ) -> torch.Tensor: -# """implements the per-structure forward transformation""" - - class PerSpeciesOffset(Transform): r""" - adds a per-species offset to the input + Adds a per-species offset to a tensor of per-atom, or per-structure + properties. + + Parameters + ---------- + trainable + Whether the offset should be trainable. + offsets + The offsets to use. If `None`, the offsets are initialized to zero. """ def __init__( @@ -147,25 +190,26 @@ def __init__( ): super().__init__(trainable=trainable) if offsets is not None: - offsets.values.requires_grad = trainable + offsets.requires_grad = trainable else: - offsets = self._per_species_parameter(default=0.0) + offsets = self._per_species_parameter(0.0) + self.offsets = offsets self.op = torch.add def _perform_op( - self, x: torch.Tensor, graph: AtomicGraph, op: Callable - ) -> torch.Tensor: + self, x: Tensor, graph: AtomicGraph, op: Callable + ) -> Tensor: offsets = self.offsets[graph.Z] # if we have a total property, we need to sum offsets over the structure if not is_local_property(x, graph): offsets = sum_per_structure(offsets, graph) return op(x, offsets) - def forward(self, x: torch.Tensor, graph: AtomicGraph) -> torch.Tensor: + def forward(self, x: Tensor, graph: AtomicGraph) -> Tensor: return self._perform_op(x, graph, self.op) - def inverse(self, x: torch.Tensor, graph: AtomicGraph) -> torch.Tensor: + def inverse(self, x: Tensor, graph: AtomicGraph) -> Tensor: return self._perform_op(x, graph, self.inverse_op) @property @@ -178,7 +222,7 @@ def inverse_op(self): raise NotImplementedError def guess_offsets( - self, x: torch.Tensor, graphs: AtomicGraphBatch + self, x: Tensor, graphs: AtomicGraphBatch ) -> PerSpeciesParameter: """guesses the offsets from the data""" if is_local_property(x, graphs): @@ -199,11 +243,11 @@ def guess_offsets( return self._per_species_parameter(zs, offsets) - def fit_to_source(self, data: torch.Tensor, graphs: AtomicGraphBatch): + def fit_to_source(self, data: Tensor, graphs: AtomicGraphBatch): self.offsets = self.guess_offsets(data, graphs) self.op = torch.sub - def fit_to_target(self, data: torch.Tensor, graphs: AtomicGraphBatch): + def fit_to_target(self, data: Tensor, graphs: AtomicGraphBatch): self.offsets = self.guess_offsets(data, graphs) self.op = torch.add @@ -213,9 +257,7 @@ def __repr__(self): ) -def sum_scale_per_structure( - scale: torch.Tensor, graph: AtomicGraph -) -> torch.Tensor: +def sum_scale_per_structure(scale: Tensor, graph: AtomicGraph) -> Tensor: sum = sum_per_structure(scale**2, graph) return torch.sqrt(sum) @@ -237,8 +279,8 @@ def __init__( self.op = torch.mul def _perform_op( - self, x: torch.Tensor, graph: AtomicGraph, op: Callable - ) -> torch.Tensor: + self, x: Tensor, graph: AtomicGraph, op: Callable + ) -> Tensor: scales = self.scales[graph.Z] # if we have a total property, we need to sum scales over the structure @@ -251,10 +293,10 @@ def _perform_op( scales = scales.view(-1, 1) return op(x, scales) - def forward(self, x: torch.Tensor, graph: AtomicGraph) -> torch.Tensor: + def forward(self, x: Tensor, graph: AtomicGraph) -> Tensor: return self._perform_op(x, graph, self.op) - def inverse(self, x: torch.Tensor, graph: AtomicGraph) -> torch.Tensor: + def inverse(self, x: Tensor, graph: AtomicGraph) -> Tensor: return self._perform_op(x, graph, self.inverse_op) @property @@ -266,7 +308,7 @@ def inverse_op(self): else: raise NotImplementedError - def guess_scales(self, x: torch.Tensor, graphs: AtomicGraphBatch): + def guess_scales(self, x: Tensor, graphs: AtomicGraphBatch): """guesses the scales from the data""" if is_local_property(x, graphs): # fit to mean per species @@ -285,11 +327,11 @@ def guess_scales(self, x: torch.Tensor, graphs: AtomicGraphBatch): return self._per_species_parameter(zs, scales, default=1.0) - def fit_to_source(self, data: torch.Tensor, graphs: AtomicGraphBatch): + def fit_to_source(self, data: Tensor, graphs: AtomicGraphBatch): self.scales = self.guess_scales(data, graphs) self.op = torch.div - def fit_to_target(self, data: torch.Tensor, graphs: AtomicGraphBatch): + def fit_to_target(self, data: Tensor, graphs: AtomicGraphBatch): self.scales = self.guess_scales(data, graphs) self.op = torch.mul @@ -302,18 +344,18 @@ def __repr__(self): class Scale(Transform): def __init__(self, trainable: bool = True, scale: float = 1.0): super().__init__(trainable=trainable) - self.scale = self._parameter(torch.tensor(scale)) + self.scale = self._parameter(Tensor(scale)) - def forward(self, x: torch.Tensor, graph: AtomicGraph) -> torch.Tensor: + def forward(self, x: Tensor, graph: AtomicGraph) -> Tensor: return x * self.scale - def inverse(self, x: torch.Tensor, graph: AtomicGraph) -> torch.Tensor: + def inverse(self, x: Tensor, graph: AtomicGraph) -> Tensor: return x / self.scale - def fit_to_source(self, data: torch.Tensor, graphs: AtomicGraphBatch): + def fit_to_source(self, data: Tensor, graphs: AtomicGraphBatch): self.scale = self._parameter(1 / data.std()) - def fit_to_target(self, data: torch.Tensor, graphs: AtomicGraphBatch): + def fit_to_target(self, data: Tensor, graphs: AtomicGraphBatch): self.scale = self._parameter(data.std()) @@ -322,8 +364,8 @@ def __init__(self, scale: float): super().__init__(trainable=False) self.scale = scale - def forward(self, x: torch.Tensor, graph: AtomicGraph) -> torch.Tensor: + def forward(self, x: Tensor, graph: AtomicGraph) -> Tensor: return x * self.scale - def inverse(self, x: torch.Tensor, graph: AtomicGraph) -> torch.Tensor: + def inverse(self, x: Tensor, graph: AtomicGraph) -> Tensor: return x / self.scale diff --git a/tests/test_nn.py b/tests/test_nn.py index cca338ed..06bc36a4 100644 --- a/tests/test_nn.py +++ b/tests/test_nn.py @@ -1,5 +1,10 @@ import torch -from graph_pes.nn import MLP, PerSpeciesEmbedding, PerSpeciesParameter +from graph_pes.nn import ( + MLP, + PerSpeciesEmbedding, + PerSpeciesParameter, + PositiveParameter, +) from graph_pes.util import MAX_Z @@ -22,6 +27,11 @@ def test_per_species_parameter(): assert embedding(Z).shape == (5, 10) assert embedding.parameters().__next__().numel() == 50 + # test default value init + assert PerSpeciesParameter.of_dim(1, generator=1.0).data.allclose( + torch.ones(MAX_Z) + ) + def test_mlp(): mlp = MLP([10, 20, 1]) @@ -38,3 +48,16 @@ def test_mlp(): # test nice repr assert "MLP(10 → 20 → 1" in str(mlp) + + +def test_positive_parameter(): + x = torch.tensor([1, 2, 3]).float() + positive_x = PositiveParameter(x) + + a = torch.tensor([-1, 0, 1]).float() + + assert torch.allclose(positive_x + a, x + a) + assert torch.allclose(positive_x - a, x - a) + assert torch.allclose(positive_x * a, x * a) + assert torch.allclose(positive_x / a, x / a) + assert torch.allclose(positive_x.log(), x.log()) diff --git a/tests/test_transform.py b/tests/test_transform.py new file mode 100644 index 00000000..8b0b43a5 --- /dev/null +++ b/tests/test_transform.py @@ -0,0 +1,17 @@ +import torch +from ase import Atoms +from graph_pes.data import convert_to_atomic_graph +from graph_pes.transform import Identity + +graph = convert_to_atomic_graph( + Atoms("H2", positions=[(0, 0, 0), (0, 0, 1)]), + cutoff=1.5, +) + + +def test_identity(): + transform = Identity() + x = torch.arange(10).float() + + assert transform.forward(x, graph).equal(x) + assert transform.inverse(x, graph).equal(x)