Skip to content

Commit

Permalink
more
Browse files Browse the repository at this point in the history
  • Loading branch information
jla-gardner committed Jan 12, 2024
1 parent b337eba commit 564b6b3
Show file tree
Hide file tree
Showing 7 changed files with 217 additions and 125 deletions.
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
Home <self>
data
nn
transforms

########
GraphPES
Expand Down
13 changes: 13 additions & 0 deletions docs/source/transforms.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
###########
Transforms
###########

.. autoclass :: graph_pes.transform.Transform
:members:
:private-members:
.. autoclass :: graph_pes.transform.Identity()
.. autoclass :: graph_pes.transform.Chain
.. autoclass :: graph_pes.transform.PerSpeciesOffset
11 changes: 4 additions & 7 deletions src/graph_pes/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
Identity,
PerSpeciesOffset,
PerSpeciesScale,
Transform,
)
from torch import nn

Expand Down Expand Up @@ -74,7 +75,7 @@ def forward(self, graph: AtomicGraph | AtomicGraphBatch):

def __init__(self):
super().__init__()
self._energy_transforms = nn.ModuleDict(
self._energy_transforms: dict[str, Transform] = nn.ModuleDict( # type: ignore
{
"local": Chain(
[PerSpeciesScale(), PerSpeciesOffset()], trainable=True
Expand Down Expand Up @@ -176,8 +177,8 @@ def get_predictions(pes: GraphPESModel, structure: AtomicGraph) -> Prediction:
# an infinitesimal change in the cell parameters
actual_cell = structure.cell
change_to_cell = torch.zeros_like(actual_cell, requires_grad=True)
# symmetric_change = 0.5 * (change_to_cell + change_to_cell.transpose(-1, -2))
structure.cell = actual_cell + change_to_cell
symmetric_change = 0.5 * (change_to_cell + change_to_cell.transpose(-1, -2))
structure.cell = actual_cell + symmetric_change
with require_grad(structure._positions):
total_energy = pes(structure)
(dE_dR, dCell_dR) = get_gradient(
Expand Down Expand Up @@ -209,10 +210,6 @@ def energy_and_forces(pes: GraphPESModel, structure: AtomicGraph):
The energy of the structure and forces on each atom.
"""

# TODO handle the case where isolated atoms are present
# such that the gradient of energy wrt their positions
# is zero.

# use the autograd machinery to auto-magically
# calculate forces for (almost) free
structure._positions.requires_grad_(True)
Expand Down
25 changes: 12 additions & 13 deletions src/graph_pes/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,10 @@ def of_dim(
cls,
dim: int,
requires_grad: bool = True,
generator: Callable[[tuple[int, int]], Tensor] | None = None,
generator: Callable[[tuple[int, int]], Tensor]
| int
| float
| None = None,
):
"""
Create a `PerSpeciesParameter` of the given dimension.
Expand All @@ -193,9 +196,12 @@ def of_dim(
requires_grad
Whether the parameter should be trainable.
"""
if generator is None:
generator = torch.randn
data = generator((MAX_Z, dim))
if isinstance(generator, (int, float)):
data = torch.full((MAX_Z, dim), generator).float()
elif generator is None:
data = torch.randn(MAX_Z, dim)
else:
data = generator((MAX_Z, dim))
return PerSpeciesParameter(data=data, requires_grad=requires_grad)

def __getitem__(self, Z: int | Tensor) -> Tensor:
Expand All @@ -206,7 +212,6 @@ def __getitem__(self, Z: int | Tensor) -> Tensor:
----------
Z
The atomic number/s of the parameter to get.
"""

if isinstance(Z, int):
Expand All @@ -217,13 +222,7 @@ def __getitem__(self, Z: int | Tensor) -> Tensor:
return super().__getitem__(Z)

def numel(self) -> int:
"""
Get the number of trainable parameters.
Returns
-------
The number of trainable parameters.
"""
"""Get the number of trainable parameters."""

return sum(self[Z].numel() for Z in self._accessed_Zs)

Expand Down Expand Up @@ -367,5 +366,5 @@ def __init__(self, x: torch.Tensor | float, requires_grad: bool = True):
super().__init__(torch.log(x), requires_grad)

@property
def _constrained_value(self):
def constrained_value(self):
return torch.exp(self._parameter)
Loading

0 comments on commit 564b6b3

Please sign in to comment.