Skip to content

Commit

Permalink
unify filterig
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexandraVolokhova committed Jul 13, 2024
1 parent 4adc1a5 commit 361abca
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 78 deletions.
49 changes: 32 additions & 17 deletions gflownet/envs/crystals/crystal.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,16 @@

from typing import Dict, List, Optional, Tuple

import numpy as np
import pandas as pd
import torch
from torchtyping import TensorType
from tqdm import tqdm

from gflownet.envs.crystals.composition import Composition
from gflownet.envs.crystals.lattice_parameters import PARAMETER_NAMES, LatticeParameters
from gflownet.envs.crystals.spacegroup import SpaceGroup
from gflownet.envs.stack import Stack
from gflownet.utils.crystals.constants import TRICLINIC
from torchtyping import TensorType
from tqdm import tqdm


class Crystal(Stack):
Expand Down Expand Up @@ -181,22 +181,37 @@ def process_data_set(self, df: pd.DataFrame, progress=False) -> List[List]:
for row in tqdm(df.iterrows(), total=len(df), disable=not progress):
# Index 0 is the row index; index 1 is the remaining columns
row = row[1]
state = {}
state[self.stage_composition] = self.subenvs[
self.stage_composition
].readable2state(row["Formulae"])
state[self.stage_spacegroup] = self.subenvs[
self.stage_spacegroup
]._set_constrained_properties([0, 0, row["Space Group"]])
state[self.stage_latticeparameters] = self.subenvs[
self.stage_latticeparameters
].parameters2state(tuple(row[list(PARAMETER_NAMES)]))
is_valid_subenvs = [
subenv.is_valid(state[stage]) for stage, subenv in self.subenvs.items()
]
if all(is_valid_subenvs):
if self._is_valid_datarow(row):
# TODO: Consider making stack state a dict which would avoid having to
# do this, among other advantages
state = self._state_from_datarow(row)
state_stack = [2] + [state[stage] for stage in self.subenvs]
data_valid.append(state_stack)
return data_valid

def _state_from_datarow(self, row):
state = {}
state[self.stage_composition] = self.subenvs[
self.stage_composition
].readable2state(row["Formulae"])
state[self.stage_spacegroup] = self.subenvs[
self.stage_spacegroup
]._set_constrained_properties([0, 0, row["Space Group"]])
state[self.stage_latticeparameters] = self.subenvs[
self.stage_latticeparameters
].parameters2state(tuple(row[list(PARAMETER_NAMES)]))
return state

def _is_valid_datarow(self, row):
state = self._state_from_datarow(row)
is_valid_subenvs = [
subenv.is_valid(state[stage]) for stage, subenv in self.subenvs.items()
]
return all(is_valid_subenvs)

def filter_dataset(self, df: pd.DataFrame) -> pd.DataFrame:
is_valid = []
for row in df.iterrows():
row = row[1]
is_valid.append(self._is_valid_datarow(row))
return df[np.array(is_valid)]
121 changes: 60 additions & 61 deletions scripts/crystal/plots_iclm24.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import argparse
import datetime
import pickle
import re
import sys
import warnings
from pathlib import Path
Expand All @@ -12,9 +13,11 @@
import seaborn as sns
import torch
import yaml
from gflownet.envs.crystals.crystal import Crystal
from gflownet.utils.common import load_gflow_net_from_run_path
from gflownet.utils.crystals.constants import ELEMENT_NAMES
from mendeleev.fetch import fetch_table
from omegaconf import OmegaConf
from tqdm import tqdm

warnings.filterwarnings("ignore")
Expand Down Expand Up @@ -45,31 +48,62 @@
sys.path.append(str(ROOT))


def load_mb_eform(energy_key="energy"):
"""Load the materials project eform dataset and returns train and val df."""
train_df_path = "/network/scratch/s/schmidtv/crystals-proxys/data/materials_dataset_v3/data/matbench_mp_e_form/train_data.csv"
val_df_path = "/network/scratch/s/schmidtv/crystals-proxys/data/materials_dataset_v3/data/matbench_mp_e_form/val_data.csv"
tdf = pd.read_csv(train_df_path)
vdf = pd.read_csv(val_df_path)
tdf[energy_key] = tdf["Eform"]
vdf[energy_key] = vdf["Eform"]
tdf = tdf.drop(columns=["Eform", "cif"])
vdf = vdf.drop(columns=["Eform", "cif"])
return tdf, vdf
def parse_formula(x):
element_pattern = r"([A-Z][a-z]?)(\d*)"
matches = re.findall(element_pattern, x["Formulae"])
for element, count in matches:
x[element] = int(count)
return x


def load_mb_bandgap(energy_key="energy"):
"""
Load the materials project bandgap dataset and returns train and val df.
"""
train_df_path = "/network/scratch/s/schmidtv/crystals-proxys/data/materials_dataset_v3/data/matbench_mp_gap/train_data.csv"
val_df_path = "/network/scratch/s/schmidtv/crystals-proxys/data/materials_dataset_v3/data/matbench_mp_gap/val_data.csv"
tdf = pd.read_csv(train_df_path)
vdf = pd.read_csv(val_df_path)
tdf[energy_key] = tdf["Band Gap"]
vdf[energy_key] = vdf["Band Gap"]
tdf = tdf.drop(columns=["Band Gap", "cif"])
vdf = vdf.drop(columns=["Band Gap", "cif"])
def add_elements_columns(df):
for el in ELEMENT_NAMES.values():
if el not in df.columns:
df[el] = 0
df = df.apply(lambda x: parse_formula(x), axis=1)
return df


def load_mb_data(env, target, energy_key="energy"):
paths = {
"eform": {
"train": "/network/projects/crystalgfn/data/eform/train.csv",
"val": "/network/projects/crystalgfn/data/eform/val.csv",
},
"bandgap": {
"train": "/network/projects/crystalgfn/data/bandgap/train.csv",
"val": "/network/projects/crystalgfn/data/bandgap/val.csv",
},
"density": {
# TODO: incorrect "energies" here, need to change it once we have
# datasets with computed densities
"train": "/network/projects/crystalgfn/data/eform/train.csv",
"val": "/network/projects/crystalgfn/data/eform/val.csv",
},
}
names = {
"eform": "Eform",
"bandgap": "Band Gap",
"density": "Eform",
}

tdf = pd.read_csv(paths[target]["train"])
vdf = pd.read_csv(paths[target]["val"])
print("Initial full data sets:")
print(f"Train: {tdf.shape}")
print(f"Val: {vdf.shape}")
tdf = add_elements_columns(tdf)
vdf = add_elements_columns(vdf)
tdf[energy_key] = tdf[names[target]]
vdf[energy_key] = vdf[names[target]]
tdf = env.filter_dataset(tdf)
vdf = env.filter_dataset(vdf)
tdf = tdf.drop(columns=[names[target], "Formulae"])
vdf = vdf.drop(columns=[names[target], "Formulae"])
print("Filtered data sets:")
print(f"Train: {tdf.shape}")
print(f"Val: {vdf.shape}")
print()
return tdf, vdf


Expand Down Expand Up @@ -1013,24 +1047,18 @@ def make_plots(
print(f"Saving plots to {output_path}")

if args.target == "eform":
ftdf, fvdf = load_mb_eform(energy_key=args.energy_key)
config_path = ROOT / "config/experiments/crystals/starling_fe.yaml"
elif args.target == "bandgap":
ftdf, fvdf = load_mb_bandgap(energy_key=args.energy_key)
config_path = ROOT / "config/experiments/crystals/starling_bg.yaml"
elif args.target == "density":
# TODO: incorrect "energies" here, need to change it once we have
# datasets with computed densities
ftdf, fvdf = load_mb_eform(energy_key=args.energy_key)
config_path = ROOT / "config/experiments/crystals/starling_density.yaml"
else:
raise ValueError("Unknown target")

print("Initial full data sets:")
print(f"Train: {ftdf.shape}")
print(f"Val: {fvdf.shape}")
config = OmegaConf.create(yaml.safe_load(open(config_path, "r")))

config = yaml.safe_load(open(config_path, "r"))
env = Crystal(config.env)
tdf, vdf = load_mb_data(env, args.target, energy_key=args.energy_key)

# List atomic numbers of the utilised elements
elements_anums = config["env"]["composition_kwargs"]["elements"]
Expand All @@ -1056,35 +1084,6 @@ def make_plots(
assert len(sg_subset) > 0
print(f"Using {len(sg_subset)} SGs: ", ", ".join(map(str, sg_subset)))

comp_cols = [c for c in ftdf.columns if c in set(ELS_TABLE)]

tdf = filter_df(
ftdf,
elements_names,
comp_cols,
sg_subset,
min_length=args.min_length,
max_length=args.max_length,
min_angle=args.min_angle,
max_angle=args.max_angle,
)

vdf = filter_df(
fvdf,
elements_names,
comp_cols,
sg_subset,
min_length=args.min_length,
max_length=args.max_length,
min_angle=args.min_angle,
max_angle=args.max_angle,
)

print("Filtered data sets:")
print(f"Train: {tdf.shape}")
print(f"Val: {vdf.shape}")
print()

make_plots(
train_df=tdf,
val_df=vdf,
Expand Down

0 comments on commit 361abca

Please sign in to comment.