Skip to content

Commit

Permalink
clean up plotting, modified uniform sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexandraVolokhova committed Jun 27, 2024
1 parent 5281051 commit 565d1bf
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 35 deletions.
6 changes: 3 additions & 3 deletions scripts/crystal/crystalrandom.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,11 @@ def generate_random_crystals_uniform(
# Atoms per element
done = False
while not done:
composition = [0] * len(elements)
composition = dict()
for el in elements_selected:
n_atoms_el = np.random.randint(low=1, high=max_atom_i + 1)
composition[elements.index(el)] = n_atoms_el
if sum(composition) <= max_atoms:
composition[el] = n_atoms_el
if sum(composition.values()) <= max_atoms:
done = True

# Space group
Expand Down
28 changes: 0 additions & 28 deletions scripts/crystal/plots_iclm24.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,15 +82,6 @@ def now_to_str():
return now.strftime("%Y-%m-%d/%H-%M-%S")


# def get_top_els(df, comp_cols, n=10):
# """Get the top n elements in the dataset."""
# sums = df[comp_cols].sum(axis=0)
# sums = sums.sort_values(ascending=False, inplace=False)
# if n is None or n < 0:
# return sums.index.tolist()
# return sums.index[:n].tolist()


def plot_sg_dist(
tdf, vdf, sdf, udf=None, sg_key="Space Group", output_path=None, target=None
):
Expand Down Expand Up @@ -925,18 +916,9 @@ def make_plots(

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--gfn_path",
type=str,
default=None,
help="Path to a gfn checkpoint date folder (.../$SLURM_JOB_ID/$DATE_$TIME)",
)
parser.add_argument(
"--pkl_path", type=str, default=None, help="gflownet samples path"
)
parser.add_argument(
"--random_gfn_path", type=str, default=None, help="random (init only) gfn path (checkpoint)"
)
parser.add_argument(
"--random_pkl_path",
type=str,
Expand All @@ -946,12 +928,10 @@ def make_plots(
parser.add_argument(
"--uniform_pkl_path", type=str, default=None, help="uniform samples path"
)
parser.add_argument("--n_samples", type=int, default=1e3)
# target: either eform or bandgap:
parser.add_argument(
"--target", type=str, default="eform", choices=["eform", "bandgap", "density"]
)
parser.add_argument("--batch_size", type=int, default=10)
parser.add_argument("--output_path", type=str, default=None)
parser.add_argument("--sg_key", type=str, default="Space Group")
parser.add_argument("--energy_key", type=str, default="energy")
Expand All @@ -968,17 +948,9 @@ def make_plots(

USE_SUPTITLES = not args.no_suptitles

gfn_path = (
args.gfn_path
or "/network/scratch/s/schmidtv/crystals/logs/icml24/crystalgfn/4085128/2024-01-30_04-38-24"
)

print("Arguments:")
print("\n".join(f"{k:15}: {v}" for k, v in vars(args).items()))

if args.pkl_path and args.gfn_path:
raise ValueError("Only one of pkl_path and gfn_path can be given.")

now = now_to_str()
output_path = ROOT / "external" / "plots" / "icml24" / now
output_path.mkdir(parents=True, exist_ok=True)
Expand Down
6 changes: 2 additions & 4 deletions scripts/crystal/sample_uniform_with_rewards.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@

import hydra
import pandas as pd
from gflownet.utils.common import chdir_random_subdir
from gflownet.utils.policy import parse_policy_config

from crystalrandom import generate_random_crystals_uniform

Expand Down Expand Up @@ -46,8 +44,8 @@ def main(config):
)
env.reset()

energies = env.proxy(env.states2proxy(x_sampled))
rewards = env.proxy2reward(energies)
energies = proxy(env.states2proxy(x_sampled))
rewards = proxy.proxy2reward(energies)
readable = [env.state2readable(x) for x in x_sampled]
result = pd.DataFrame(
{"readable": readable, "rewards": rewards, "energies": energies, "x": x_sampled}
Expand Down

0 comments on commit 565d1bf

Please sign in to comment.