Skip to content

Commit

Permalink
clean
Browse files Browse the repository at this point in the history
skip-checks: true
  • Loading branch information
jla-gardner committed Jul 11, 2024
1 parent 9792d46 commit 48473fe
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 6 deletions.
20 changes: 20 additions & 0 deletions cgap17.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import ase
from graph_pes.data.dataset import AseDataset
from graph_pes.data.module import FixedDatasets
from graph_pes.data.utils import random_split
from load_atoms import load_dataset


def load_data(batch_size: int = 32, n_train: int = 1_000) -> FixedDatasets:
dataset: list[ase.Atoms] = load_dataset("C-GAP-17").filter_by(
lambda x: len(x) > 2
) # type: ignore
train, val, test = random_split(dataset, [n_train, 10, 10], seed=42)

return FixedDatasets(
AseDataset(train, cutoff=3.7, pre_transform=True),
AseDataset(val, cutoff=3.7, pre_transform=True),
{"test": AseDataset(test, cutoff=3.7, pre_transform=True)},
batch_size=batch_size,
num_workers=4,
)
40 changes: 40 additions & 0 deletions config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
model:
- graph_pes.models.LearnableOffset()
- graph_pes.models.LennardJones()
# - graph_pes.models.SchNet()
# - graph_pes.models.e3nn.nequip.NequIP:
# n_elements: 1

data:
cgap17.load_data:
batch_size: 32
n_train: 10

loss:
- component: graph_pes.training.loss.PerAtomEnergyLoss()
- component:
graph_pes.training.loss.Loss:
property_key: forces
metric: graph_pes.training.loss.RMSE()

# TODO: ladder fit etc.
fitting:
pre_fit_model: True

optimizer:
graph_pes.training.opt.Optimizer:
name: AdamW
lr: 0.001
weight_decay: 0.0

scheduler:
graph_pes.training.opt.LRScheduler:
name: ReduceLROnPlateau
factor: 0.8
patience: 10

trainer_kwargs:
max_epochs: 1
accelerator: cpu

# TODO: lightning scheduler configuration
12 changes: 6 additions & 6 deletions src/graph_pes/models/scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def predict_unscaled_energies(self, graph: AtomicGraph) -> torch.Tensor:
The unscaled, local energies with shape ``(n_atoms,)``.
"""

@torch.no_grad()
def pre_fit(
self,
graphs: LabelledGraphDataset | Sequence[LabelledGraph] | LabelledBatch,
Expand All @@ -77,12 +78,11 @@ def pre_fit(
# use Ridge regression to calculate standard deviations in the
# per-element contributions to the total energy
if "energy" in graph_batch:
with torch.no_grad():
_, variances = guess_per_element_mean_and_var(
graph_batch["energy"], graph_batch
)
for Z, var in variances.items():
self._per_element_scaling[Z] = var**0.5
_, variances = guess_per_element_mean_and_var(
graph_batch["energy"], graph_batch
)
for Z, var in variances.items():
self._per_element_scaling[Z] = var**0.5

else:
model_name = self.__class__.__name__
Expand Down

0 comments on commit 48473fe

Please sign in to comment.