Skip to content

Commit

Permalink
black, isort
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexandraVolokhova committed Jun 27, 2024
1 parent 565d1bf commit 2d6fba5
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 47 deletions.
3 changes: 2 additions & 1 deletion scripts/crystal/eval_crystalgflownet.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@

sys.path.append(str(Path(__file__).resolve().parent.parent))

from crystalrandom import generate_random_crystals
from gflownet.gflownet import GFlowNetAgent
from gflownet.utils.common import load_gflow_net_from_run_path
from gflownet.utils.policy import parse_policy_config

from crystalrandom import generate_random_crystals


def add_args(parser):
"""
Expand Down
56 changes: 28 additions & 28 deletions scripts/crystal/eval_gflownet.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,19 @@
from argparse import ArgumentParser
from pathlib import Path

import pandas as pd
import numpy as np
import pandas as pd
import torch
from tqdm import tqdm

sys.path.append(str(Path(__file__).resolve().parent.parent.parent))

from crystalrandom import generate_random_crystals_uniform
from hydra.utils import instantiate

from gflownet.gflownet import GFlowNetAgent
from gflownet.utils.common import load_gflow_net_from_run_path, read_hydra_config
from gflownet.utils.policy import parse_policy_config
from hydra.utils import instantiate

from crystalrandom import generate_random_crystals_uniform


def add_args(parser):
Expand Down Expand Up @@ -229,30 +229,30 @@ def main(args):
env.proxy.is_bandgap = False

# Test
# samples = [env.readable2state(readable) for readable in gflownet.buffer.test["samples"]]
# energies = env.proxy(env.states2proxy(samples))
# df = pd.DataFrame(
# {
# "readable": gflownet.buffer.test["samples"],
# "energies": energies.tolist(),
# }
# )
# df.to_csv(output_dir / f"val.csv")
# dct = {"x": samples, "energy": energies.tolist()}
# pickle.dump(dct, open(output_dir / f"val.pkl", "wb"))
#
# # Train
# samples = [env.readable2state(readable) for readable in gflownet.buffer.train["samples"]]
# energies = env.proxy(env.states2proxy(samples))
# df = pd.DataFrame(
# {
# "readable": gflownet.buffer.train["samples"],
# "energies": energies.tolist(),
# }
# )
# df.to_csv(output_dir / f"train.csv")
# dct = {"x": samples, "energy": energies.tolist()}
# pickle.dump(dct, open(output_dir / f"train.pkl", "wb"))
# samples = [env.readable2state(readable) for readable in gflownet.buffer.test["samples"]]
# energies = env.proxy(env.states2proxy(samples))
# df = pd.DataFrame(
# {
# "readable": gflownet.buffer.test["samples"],
# "energies": energies.tolist(),
# }
# )
# df.to_csv(output_dir / f"val.csv")
# dct = {"x": samples, "energy": energies.tolist()}
# pickle.dump(dct, open(output_dir / f"val.pkl", "wb"))
#
# # Train
# samples = [env.readable2state(readable) for readable in gflownet.buffer.train["samples"]]
# energies = env.proxy(env.states2proxy(samples))
# df = pd.DataFrame(
# {
# "readable": gflownet.buffer.train["samples"],
# "energies": energies.tolist(),
# }
# )
# df.to_csv(output_dir / f"train.csv")
# dct = {"x": samples, "energy": energies.tolist()}
# pickle.dump(dct, open(output_dir / f"train.pkl", "wb"))

if args.n_samples > 0 and args.n_samples <= 1e5 and not args.random_only:
print(
Expand Down
2 changes: 1 addition & 1 deletion scripts/crystal/plots_conditional_icml24.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,4 +127,4 @@ def load_energies_only(pkl_path, energy_key="energy"):
cdfs = {k: load_energies_only(pkl_path=v) for k, v in cond_paths.items()}
dfs.update(cdfs)
plot_gfn_energies_violins(dfs, output_path, target=args.target)
plt.close("all")
plt.close("all")
27 changes: 10 additions & 17 deletions scripts/crystal/plots_iclm24.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,10 @@
import seaborn as sns
import torch
import yaml
from mendeleev.fetch import fetch_table
from tqdm import tqdm

from gflownet.utils.common import load_gflow_net_from_run_path
from gflownet.utils.crystals.constants import ELEMENT_NAMES
from mendeleev.fetch import fetch_table
from tqdm import tqdm

warnings.filterwarnings("ignore")

Expand Down Expand Up @@ -634,7 +633,9 @@ def sort_names_for_z(element_names):
return sorted(element_names, key=lambda x: ELS_TABLE.tolist().index(x))


def pkl_samples_to_df(samples, elements_names, sg_key="Space Group", energy_key="energy"):
def pkl_samples_to_df(
samples, elements_names, sg_key="Space Group", energy_key="energy"
):
"""
Convert samples from a pickled file to a DataFrame.
Expand Down Expand Up @@ -687,12 +688,11 @@ def pkl_samples_to_df(samples, elements_names, sg_key="Space Group", energy_key=
)
df = df[cols]
# set zeros for elements that are not present in the samples
df[elements_names] = df[elements_names].fillna(0)
df[elements_names] = df[elements_names].fillna(0)
return df


def load_gfn_samples(
element_names, pkl_path):
def load_gfn_samples(element_names, pkl_path):
"""
Load samples from pickled data and convert them to a DataFrame.
Expand Down Expand Up @@ -978,26 +978,20 @@ def make_plots(
config = yaml.safe_load(open(config_path, "r"))

# List atomic numbers of the utilised elements
elements_anums = config['env']['composition_kwargs']['elements']
elements_anums = config["env"]["composition_kwargs"]["elements"]
elements_names = [ELEMENT_NAMES[anum] for anum in elements_anums]

sdf = load_gfn_samples(
elements_names,
pkl_path=args.pkl_path
)
sdf = load_gfn_samples(elements_names, pkl_path=args.pkl_path)
print("Loaded gfn samples: ", sdf.shape)



if args.uniform_pkl_path:
udf = load_uniform_samples(
elements_names, pkl_path=args.uniform_pkl_path, config=config
)
print("Loaded uniform samples: ", udf.shape)
if args.random_pkl_path or args.random_gfn_path:
rdf = load_gfn_samples( # random init
elements_names,
pkl_path=args.random_pkl_path
elements_names, pkl_path=args.random_pkl_path
)
print("Loaded random samples: ", rdf.shape)

Expand All @@ -1007,7 +1001,6 @@ def make_plots(
assert len(sg_subset) > 0
print(f"Using {len(sg_subset)} SGs: ", ", ".join(map(str, sg_subset)))


comp_cols = [c for c in ftdf.columns if c in set(ELS_TABLE)]

tdf = filter_df(
Expand Down

0 comments on commit 2d6fba5

Please sign in to comment.