Skip to content

Commit

Permalink
Update pre_fit method to include energy_label parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
jla-gardner committed Jan 15, 2024
1 parent cd0346f commit 73b17f2
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 27 deletions.
8 changes: 4 additions & 4 deletions src/graph_pes/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def __add__(self, other: GraphPESModel) -> Ensemble:

return Ensemble([self, other], mean=False)

def pre_fit(self, graphs: AtomicGraphBatch):
def pre_fit(self, graphs: AtomicGraphBatch, energy_label: str = "energy"):
"""
Perform optional pre-processing of the training data.
Expand All @@ -133,7 +133,7 @@ def pre_fit(self, graphs: AtomicGraphBatch):
graphs
The training data.
"""
self._energy_summation.fit_to_graphs(graphs)
self._energy_summation.fit_to_graphs(graphs, energy_label)


class EnergySummation(nn.Module):
Expand Down Expand Up @@ -161,13 +161,13 @@ def forward(self, local_energies: torch.Tensor, graph: AtomicGraphBatch):
def fit_to_graphs(
self,
graphs: AtomicGraphBatch | list[AtomicGraph],
energy_key: str = "energy",
energy_label: str = "energy",
):
if not isinstance(graphs, AtomicGraphBatch):
graphs = AtomicGraphBatch.from_graphs(graphs)

for transform in [self.local_transform, self.total_transform]:
transform.fit(graphs[energy_key], graphs)
transform.fit(graphs[energy_label], graphs)


class Ensemble(GraphPESModel):
Expand Down
8 changes: 4 additions & 4 deletions src/graph_pes/models/pairwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from graph_pes.data import AtomicGraph
from graph_pes.data.batching import AtomicGraphBatch
from graph_pes.nn import MLP, PositiveParameter
from graph_pes.transform import PerAtomScale
from graph_pes.transform import PerAtomScale, PerAtomShift
from jaxtyping import Float
from torch import Tensor, nn
from torch_geometric.utils import scatter
Expand Down Expand Up @@ -104,7 +104,7 @@ def __init__(self):

# epsilon is a scaling term, so only need to learn a shift
# parameter (rather than a shift and scale)
self._energy_summation = EnergySummation(local_transform=PerAtomScale())
self._energy_summation = EnergySummation(local_transform=PerAtomShift())

def interaction(
self, r: torch.Tensor, Z_i: torch.Tensor, Z_j: torch.Tensor
Expand All @@ -124,8 +124,8 @@ def interaction(
x = self.sigma / r
return 4 * self.epsilon * (x**12 - x**6)

def pre_fit(self, graph: AtomicGraphBatch):
super().pre_fit(graph)
def pre_fit(self, graph: AtomicGraphBatch, energy_label: str = "energy"):
super().pre_fit(graph, energy_label)

# set the potential depth to be shallow
self.epsilon = PositiveParameter(0.01)
Expand Down
3 changes: 3 additions & 0 deletions src/graph_pes/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
from graph_pes.util import MAX_Z, pairs
from torch import Tensor

# TODO support access to .data via property and setter on ConstrainedParameter
# / cleanup the ConstrainedParameter class


class MLP(nn.Module):
"""
Expand Down
24 changes: 13 additions & 11 deletions src/graph_pes/training.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from __future__ import annotations

from pathlib import Path
from typing import Callable, TypeVar

import pytorch_lightning as pl
import torch
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from torch import optim

from .core import GraphPESModel, get_predictions
from .data import AtomicGraph
Expand All @@ -14,12 +14,14 @@
from .transform import Chain, PerAtomScale, Scale
from .util import Keys

Model = TypeVar("Model", bound=GraphPESModel)


def train_model(
model: GraphPESModel,
model: Model,
train_data: list[AtomicGraph],
val_data: list[AtomicGraph] | None = None,
optimizer: optim.Optimizer | None = None,
optimizer: Callable[[Model], torch.optim.Optimizer] | None = None,
loss: WeightedLoss | Loss | None = None,
property_labels: dict[Keys, str] | None = None,
*,
Expand Down Expand Up @@ -73,18 +75,21 @@ def train_model(
)

# deal with fitting transforms
if pre_fit_model:
model.pre_fit(batch)
# TODO: what if not training on energy?
if pre_fit_model and Keys.ENERGY in property_labels:
model.pre_fit(batch, property_labels[Keys.ENERGY])

actual_loss = get_loss(loss, property_labels)
actual_loss.fit_transform(batch)

# deal with the optimizer
if optimizer is None:
optimizer = optim.Adam(model.parameters())
opt = torch.optim.Adam(model.parameters(), lr=3e-4)
else:
opt = optimizer(model)

# create the task (a pytorch lightning module)
task = LearnThePES(model, optimizer, actual_loss, property_labels)
task = LearnThePES(model, opt, actual_loss, property_labels)

# create the trainer
kwargs = default_trainer_kwargs()
Expand All @@ -99,9 +104,6 @@ def train_model(


def get_existing_keys(batch: AtomicGraphBatch) -> dict[Keys, str]:
# return {
# value: key for key, value in Keys.__members__.items() if key in batch
# }
return {
key: key.value
for key in Keys.__members__.values()
Expand All @@ -113,7 +115,7 @@ class LearnThePES(pl.LightningModule):
def __init__(
self,
model: GraphPESModel,
optimizer: optim.Optimizer,
optimizer: torch.optim.Optimizer,
loss: WeightedLoss,
property_labels: dict[Keys, str],
):
Expand Down
10 changes: 2 additions & 8 deletions src/graph_pes/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,9 +208,6 @@ def fit(self, x: LocalProperty | GlobalProperty, graphs: AtomicGraphBatch):
The atomic graphs that x originates from.
"""
# reset the shift
self.shift = PerSpeciesParameter.of_dim(
1, requires_grad=self.trainable, generator=0
)
zs = torch.unique(graphs.Z)

if graphs.is_local_property(x):
Expand Down Expand Up @@ -327,7 +324,7 @@ class PerAtomScale(Transform):
scale is fixed.
"""

def __init__(self, trainable: bool = True, act_on_norms: bool = True):
def __init__(self, trainable: bool = True, act_on_norms: bool = False):
super().__init__(trainable=trainable)
self.scales = PerSpeciesParameter.of_dim(
dim=1, requires_grad=trainable, generator=1
Expand Down Expand Up @@ -356,9 +353,6 @@ def fit(self, x: LocalProperty | GlobalProperty, graphs: AtomicGraphBatch):
The atomic graphs that x originates from.
"""
# reset the scale
self.scales = PerSpeciesParameter.of_dim(
1, requires_grad=self.trainable, generator=1
)
zs = torch.unique(graphs.Z)

if self.act_on_norms:
Expand Down Expand Up @@ -466,5 +460,5 @@ def inverse(self, x: Tensor, graph: AtomicGraph) -> Tensor:

@torch.no_grad()
def fit(self, x: Tensor, graphs: AtomicGraphBatch) -> Transform:
self.scale = nn.Parameter(x.var(), requires_grad=self.trainable)
self.scale.data = x.var()
return self

0 comments on commit 73b17f2

Please sign in to comment.