diff --git a/tests/test_matbench_genmetrics.py b/tests/test_matbench_genmetrics.py index 4acdaa0..230e6d4 100644 --- a/tests/test_matbench_genmetrics.py +++ b/tests/test_matbench_genmetrics.py @@ -4,6 +4,7 @@ import numpy as np import numpy.typing as npt import pytest +from mp_time_split.utils.gen import DummyGenerator from numpy.testing import assert_array_equal from pymatgen.core.lattice import Lattice from pymatgen.core.structure import Structure @@ -12,7 +13,6 @@ # from pytest_cases import fixture, parametrize, parametrize_with_cases - coords = [[0, 0, 0], [0.75, 0.5, 0.75]] lattice = Lattice.from_parameters(a=3.84, b=3.84, c=3.84, alpha=120, beta=90, gamma=60) dummy_structures = [ @@ -38,7 +38,13 @@ def dummy_gen_metrics(): # @fixture def dummy_mpts_metrics(): """Get MPTSMetrics() with dummy MPTS as train/test, dummy_structures as pred/gen.""" - return MPTSMetrics(dummy_structures, dummy_structures, dummy=True) + mptm = MPTSMetrics(dummy=True) + + fold = 0 + mptm.get_train_and_val_data(fold) + mptm.record(fold, dummy_structures) + + return mptm dummy_matcher_expected = { @@ -102,6 +108,21 @@ def test_numerical_attributes(fixture: object, checkitem: Tuple[str, npt.ArrayLi ) +def test_mpts_metrics(): + mptm = MPTSMetrics(dummy=True) + for fold in mptm.folds: + train_val_inputs = mptm.get_train_and_val_data(fold) + + np.random.seed(10) + dg = DummyGenerator() + dg.fit(train_val_inputs) + gen_structures = dg.gen(n=100) + + mptm.record(fold, gen_structures) + + print(mptm.recorded_metrics) + + # %% Code Graveyard # def flatten_params( # fixtures: List[Callable], expecteds: List[dict]