Skip to content

Commit

Permalink
black formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
engmubarak48 committed Jun 21, 2024
1 parent 4ecd865 commit d50aeea
Show file tree
Hide file tree
Showing 7 changed files with 30 additions and 25 deletions.
2 changes: 1 addition & 1 deletion gflownet/policy/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def make_mlp(self, activation):
def parse_config(self, config):
if config is None:
config = OmegaConf.create()
config.type = "mlp"
config.type = "mlp"
self.checkpoint = config.get("checkpoint", None)
self.shared_weights = config.get("shared_weights", False)
self.n_hid = config.get("n_hid", 128)
Expand Down
1 change: 1 addition & 0 deletions scripts/crystal/eval_crystalgflownet.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Computes evaluation metrics and plots from a pre-trained GFlowNet model.
"""

import pickle
import shutil
import sys
Expand Down
48 changes: 24 additions & 24 deletions scripts/crystal/eval_gflownet.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,30 +229,30 @@ def main(args):
env.proxy.is_bandgap = False

# Test
# samples = [env.readable2state(readable) for readable in gflownet.buffer.test["samples"]]
# energies = env.proxy(env.states2proxy(samples))
# df = pd.DataFrame(
# {
# "readable": gflownet.buffer.test["samples"],
# "energies": energies.tolist(),
# }
# )
# df.to_csv(output_dir / f"val.csv")
# dct = {"x": samples, "energy": energies.tolist()}
# pickle.dump(dct, open(output_dir / f"val.pkl", "wb"))
#
# # Train
# samples = [env.readable2state(readable) for readable in gflownet.buffer.train["samples"]]
# energies = env.proxy(env.states2proxy(samples))
# df = pd.DataFrame(
# {
# "readable": gflownet.buffer.train["samples"],
# "energies": energies.tolist(),
# }
# )
# df.to_csv(output_dir / f"train.csv")
# dct = {"x": samples, "energy": energies.tolist()}
# pickle.dump(dct, open(output_dir / f"train.pkl", "wb"))
# samples = [env.readable2state(readable) for readable in gflownet.buffer.test["samples"]]
# energies = env.proxy(env.states2proxy(samples))
# df = pd.DataFrame(
# {
# "readable": gflownet.buffer.test["samples"],
# "energies": energies.tolist(),
# }
# )
# df.to_csv(output_dir / f"val.csv")
# dct = {"x": samples, "energy": energies.tolist()}
# pickle.dump(dct, open(output_dir / f"val.pkl", "wb"))
#
# # Train
# samples = [env.readable2state(readable) for readable in gflownet.buffer.train["samples"]]
# energies = env.proxy(env.states2proxy(samples))
# df = pd.DataFrame(
# {
# "readable": gflownet.buffer.train["samples"],
# "energies": energies.tolist(),
# }
# )
# df.to_csv(output_dir / f"train.csv")
# dct = {"x": samples, "energy": energies.tolist()}
# pickle.dump(dct, open(output_dir / f"train.pkl", "wb"))

if args.n_samples > 0 and args.n_samples <= 1e5 and not args.random_only:
print(
Expand Down
1 change: 1 addition & 0 deletions scripts/crystal/sample_uniform_with_rewards.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
should be run with the same config as main.py, e.g.
python sample_uniform_with_rewards.py +experiments=crystals/albatross_sg_first logger.do.online=False user=sasha
"""

import pickle
import sys

Expand Down
1 change: 1 addition & 0 deletions scripts/pyxtal/compatibility_sg_n_atoms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
combinations spanned by the N_SYMMETRY_GROUPS, N_SPECIES and MAX_N_ATOMS. The results
are printed to stdout.
"""

import itertools
import time

Expand Down
1 change: 1 addition & 0 deletions scripts/pyxtal/get_n_compatible_for_sg.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
spanned by the --max_n_atoms and --max_n_species. The results are written to a file in
--output_dir.
"""

import itertools
import time
from argparse import ArgumentParser
Expand Down
1 change: 1 addition & 0 deletions scripts/pyxtal/pyxtal_vs_pymatgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
A simple script to determine which space group symbols are different in pyxtal and
pymatgen.
"""

from argparse import ArgumentParser

from pymatgen.symmetry.groups import (
Expand Down

0 comments on commit d50aeea

Please sign in to comment.