Skip to content

Commit

Permalink
test_mpts_metrics and dummy_mpts_metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
sgbaird committed Jul 7, 2022
1 parent 075602e commit ce78161
Showing 1 changed file with 23 additions and 2 deletions.
25 changes: 23 additions & 2 deletions tests/test_matbench_genmetrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy as np
import numpy.typing as npt
import pytest
from mp_time_split.utils.gen import DummyGenerator
from numpy.testing import assert_array_equal
from pymatgen.core.lattice import Lattice
from pymatgen.core.structure import Structure
Expand All @@ -12,7 +13,6 @@

# from pytest_cases import fixture, parametrize, parametrize_with_cases


coords = [[0, 0, 0], [0.75, 0.5, 0.75]]
lattice = Lattice.from_parameters(a=3.84, b=3.84, c=3.84, alpha=120, beta=90, gamma=60)
dummy_structures = [
Expand All @@ -38,7 +38,13 @@ def dummy_gen_metrics():
# @fixture
def dummy_mpts_metrics():
"""Get MPTSMetrics() with dummy MPTS as train/test, dummy_structures as pred/gen."""
return MPTSMetrics(dummy_structures, dummy_structures, dummy=True)
mptm = MPTSMetrics(dummy=True)

fold = 0
mptm.get_train_and_val_data(fold)
mptm.record(fold, dummy_structures)

return mptm


dummy_matcher_expected = {
Expand Down Expand Up @@ -102,6 +108,21 @@ def test_numerical_attributes(fixture: object, checkitem: Tuple[str, npt.ArrayLi
)


def test_mpts_metrics():
mptm = MPTSMetrics(dummy=True)
for fold in mptm.folds:
train_val_inputs = mptm.get_train_and_val_data(fold)

np.random.seed(10)
dg = DummyGenerator()
dg.fit(train_val_inputs)
gen_structures = dg.gen(n=100)

mptm.record(fold, gen_structures)

print(mptm.recorded_metrics)


# %% Code Graveyard
# def flatten_params(
# fixtures: List[Callable], expecteds: List[dict]
Expand Down

0 comments on commit ce78161

Please sign in to comment.