Skip to content

Commit

Permalink
Merge pull request #8 from sparks-baird/genmetrics-class
Browse files Browse the repository at this point in the history
Top-level class for interacting with metrics and test functions
  • Loading branch information
sgbaird authored Jul 7, 2022
2 parents 481932c + ce78161 commit 5f26ff3
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 14 deletions.
44 changes: 42 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,49 @@

> Generative materials benchmarking metrics, inspired by CDVAE.
A longer description of your project goes here...
This repository provides standardized benchmarks for benchmarking generative models for
crystal structure. Each benchmark has a fixed dataset, a predefined split, and a notion
of best (i.e. metric) associated with it.

## Installation
## Getting Started

### Installation

Create a conda environment with the `matbench-genmetrics` package installed from the
`conda-forge` channel. Then activate the environment.

```bash
conda create --name matbench-genmetrics --channel conda-forge python==3.9.* matbench-genmetrics
conda activate matbench-genmetrics
```

> NOTE: It doesn't have to be Python 3.9; you can remove `python==3.9.*` altogether or
change this to e.g. `python==3.8.*`. See [Advanced Installation](##Advanced-Installation)

### Basic Usage

```python
from mp_time_split.utils.gen import DummyGenerator
from matbench_genmetrics.core import MPTSMetrics

mptm = MPTSMetrics(dummy=False)
for fold in mptm.folds:
train_val_inputs = mptm.get_train_and_val_data(fold)

dg = DummyGenerator()
dg.fit(train_val_inputs)
gen_structures = dg.gen(n=10000)

mptm.record(fold, gen_structures)

print(mptm.recorded_metrics)
```

> ```python
>
> ```
## Advanced Installation
In order to set up the necessary environment:
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ install_requires =
hydra-core
torch
torch-geometric
mp-time-split
mp-time-split[pyxtal]


[options.packages.find]
Expand Down
38 changes: 29 additions & 9 deletions src/matbench_genmetrics/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def __init__(
train_structures,
test_structures,
gen_structures,
test_pred_structures,
test_pred_structures=None,
):
self.train_structures = train_structures
self.test_structures = test_structures
Expand Down Expand Up @@ -186,17 +186,37 @@ def metrics(self):


class MPTSMetrics(GenMetrics):
def __init__(self, gen_structures, test_pred_structures, dummy=False):
mpt = MPTimeSplit(target="energy_above_hull")
mpt.load(dummy=dummy)
train_inputs, val_inputs, _, _ = mpt.get_train_and_val_data(0)
super().__init__(
train_inputs.tolist(),
val_inputs.tolist(),
def __init__(self, dummy=False):
self.dummy = dummy
self.mpt = MPTimeSplit(target="energy_above_hull")
self.folds = self.mpt.folds
self.recorded_metrics = [None] * len(self.folds)

def get_train_and_val_data(self, fold, include_val=False):
self.mpt.load(dummy=self.dummy)
(
self.train_inputs,
self.val_inputs,
self.train_outputs,
self.val_outputs,
) = self.mpt.get_train_and_val_data(fold)

if include_val:
return self.train_inputs, self.val_inputs

return self.train_inputs

def record(self, fold, gen_structures, test_pred_structures=None):
GenMetrics.__init__(
self,
self.train_inputs.tolist(),
self.val_inputs.tolist(),
gen_structures,
test_pred_structures,
test_pred_structures=test_pred_structures,
)

self.recorded_metrics[fold] = self.metrics


# def get_rms_dist(gen_structures, test_structures):
# rms_dist = np.zeros((len(gen_structures), len(test_structures)))
Expand Down
25 changes: 23 additions & 2 deletions tests/test_matbench_genmetrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy as np
import numpy.typing as npt
import pytest
from mp_time_split.utils.gen import DummyGenerator
from numpy.testing import assert_array_equal
from pymatgen.core.lattice import Lattice
from pymatgen.core.structure import Structure
Expand All @@ -12,7 +13,6 @@

# from pytest_cases import fixture, parametrize, parametrize_with_cases


coords = [[0, 0, 0], [0.75, 0.5, 0.75]]
lattice = Lattice.from_parameters(a=3.84, b=3.84, c=3.84, alpha=120, beta=90, gamma=60)
dummy_structures = [
Expand All @@ -38,7 +38,13 @@ def dummy_gen_metrics():
# @fixture
def dummy_mpts_metrics():
"""Get MPTSMetrics() with dummy MPTS as train/test, dummy_structures as pred/gen."""
return MPTSMetrics(dummy_structures, dummy_structures, dummy=True)
mptm = MPTSMetrics(dummy=True)

fold = 0
mptm.get_train_and_val_data(fold)
mptm.record(fold, dummy_structures)

return mptm


dummy_matcher_expected = {
Expand Down Expand Up @@ -102,6 +108,21 @@ def test_numerical_attributes(fixture: object, checkitem: Tuple[str, npt.ArrayLi
)


def test_mpts_metrics():
mptm = MPTSMetrics(dummy=True)
for fold in mptm.folds:
train_val_inputs = mptm.get_train_and_val_data(fold)

np.random.seed(10)
dg = DummyGenerator()
dg.fit(train_val_inputs)
gen_structures = dg.gen(n=100)

mptm.record(fold, gen_structures)

print(mptm.recorded_metrics)


# %% Code Graveyard
# def flatten_params(
# fixtures: List[Callable], expecteds: List[dict]
Expand Down

0 comments on commit 5f26ff3

Please sign in to comment.