Skip to content

Commit

Permalink
more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jla-gardner committed Jul 11, 2024
1 parent 9865eb8 commit c81a29c
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 3 deletions.
10 changes: 8 additions & 2 deletions src/graph_pes/models/pre_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from graph_pes.graphs.graph_typing import AtomicGraphBatch
from graph_pes.graphs.operations import number_of_structures, sum_per_structure

MIN_VARIANCE = 0.01


def guess_per_element_mean_and_var(
per_structure_quantity: torch.Tensor,
Expand Down Expand Up @@ -46,7 +48,11 @@ def guess_per_element_mean_and_var(
# variances for each atom, we can estimate these variances again
# using Ridge regression
ridge.fit(N.numpy(), residuals**2)
var_Z = ridge.coef_.clip(min=0.01) # avoid negative variances
variances = {int(Z): float(var) for Z, var in zip(unique_Zs, var_Z)}
var_Z = ridge.coef_
# avoid negative variances by clipping to min value
variances = {
int(Z): max(float(var), MIN_VARIANCE)
for Z, var in zip(unique_Zs, var_Z)
}

return means, variances
19 changes: 18 additions & 1 deletion tests/test_pre_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@
import torch
from graph_pes.graphs.graph_typing import LabelledBatch, LabelledGraph
from graph_pes.graphs.operations import number_of_structures, to_batch
from graph_pes.models.pre_fit import guess_per_element_mean_and_var
from graph_pes.models.pre_fit import (
MIN_VARIANCE,
guess_per_element_mean_and_var,
)


def _create_batch(
Expand Down Expand Up @@ -70,3 +73,17 @@ def test_guess_per_element_mean_and_var():
# are variances roughly right?
for Z, actual_sigma in sigma.items():
assert np.isclose(variances[Z], actual_sigma**2, atol=0.01)


def test_clamping():
# variances can not be negative: ensure that they are clamped
mu = {1: -1.0, 2: -2.0}
sigma = {1: 0.0, 2: 1.0}
batch = _create_batch(mu=mu, sigma=sigma)

# calculate the per-element mean and variance
means, variances = guess_per_element_mean_and_var(batch["energy"], batch)

# ensure no variance is less than the value we choose to clamp to
for value in variances.values():
assert value >= MIN_VARIANCE

0 comments on commit c81a29c

Please sign in to comment.