diff --git a/config/env/alaninedipeptide.yaml b/config/env/alaninedipeptide.yaml index 17a3877a2..479099481 100644 --- a/config/env/alaninedipeptide.yaml +++ b/config/env/alaninedipeptide.yaml @@ -26,5 +26,3 @@ buffer: test: type: grid n: 1000 - output_csv: alaninedipeptide_test.csv - output_pkl: alaninedipeptide_test.pkl diff --git a/config/env/ccube.yaml b/config/env/ccube.yaml index 714638524..7c7c764fe 100644 --- a/config/env/ccube.yaml +++ b/config/env/ccube.yaml @@ -35,5 +35,3 @@ buffer: test: type: grid n: 900 - output_csv: ccube_test.csv - output_pkl: ccube_test.pkl diff --git a/config/env/crystals/composition.yaml b/config/env/crystals/composition.yaml index acfcfba50..0a4dc5202 100644 --- a/config/env/crystals/composition.yaml +++ b/config/env/crystals/composition.yaml @@ -20,7 +20,5 @@ buffer: data_path: null train: type: all - output_csv: composition_train.csv test: type: all - output_csv: composition_test.csv diff --git a/config/env/crystals/lattice_parameters.yaml b/config/env/crystals/lattice_parameters.yaml index 1c6cfc802..71f7c1057 100644 --- a/config/env/crystals/lattice_parameters.yaml +++ b/config/env/crystals/lattice_parameters.yaml @@ -39,5 +39,3 @@ buffer: test: type: grid n: 900 - output_csv: clp_test.csv - output_pkl: clp_test.pkl diff --git a/config/env/crystals/spacegroup.yaml b/config/env/crystals/spacegroup.yaml index 76da40e9f..5c5ed8ab9 100644 --- a/config/env/crystals/spacegroup.yaml +++ b/config/env/crystals/spacegroup.yaml @@ -15,9 +15,5 @@ buffer: data_path: null train: type: all - output_csv: spacegroup_train.csv - output_pkl: spacegroup_train.pkl test: type: all - output_csv: spacegroup_test.csv - output_pkl: spacegroup_test.pkl diff --git a/config/env/ctorus.yaml b/config/env/ctorus.yaml index fa194956f..dad08727c 100644 --- a/config/env/ctorus.yaml +++ b/config/env/ctorus.yaml @@ -27,5 +27,3 @@ buffer: test: type: grid n: 1000 - output_csv: ctorus_test.csv - output_pkl: ctorus_test.pkl diff --git a/config/env/grid.yaml b/config/env/grid.yaml index 7e2df40fb..0bec812b0 100644 --- a/config/env/grid.yaml +++ b/config/env/grid.yaml @@ -22,5 +22,3 @@ buffer: train: null test: type: all - output_csv: grid_test.csv - output_pkl: grid_test.pkl diff --git a/config/env/htorus.yaml b/config/env/htorus.yaml index c53b95acc..373e47352 100644 --- a/config/env/htorus.yaml +++ b/config/env/htorus.yaml @@ -27,5 +27,3 @@ buffer: test: type: grid n: 1000 - output_csv: htorus_test.csv - output_pkl: htorus_test.pkl diff --git a/config/env/scrabble.yaml b/config/env/scrabble.yaml index 6a2393724..a6454a0d3 100644 --- a/config/env/scrabble.yaml +++ b/config/env/scrabble.yaml @@ -11,5 +11,3 @@ buffer: test: type: uniform n: 10 - output_csv: scrabble_test.csv - output_pkl: scrabble_test.pkl diff --git a/config/env/tetris.yaml b/config/env/tetris.yaml index 67168f626..95ae0503c 100644 --- a/config/env/tetris.yaml +++ b/config/env/tetris.yaml @@ -20,6 +20,3 @@ buffer: test: type: random n: 10 - output_csv: tetris_test.csv - output_pkl: tetris_test.pkl - diff --git a/config/env/torus.yaml b/config/env/torus.yaml index a5959b79a..f05187099 100644 --- a/config/env/torus.yaml +++ b/config/env/torus.yaml @@ -20,8 +20,5 @@ buffer: data_path: null train: type: all - output_csv: torus_train.csv test: type: all - output_csv: torus_test.csv - output_pkl: torus_test.pkl diff --git a/config/env/tree.yaml b/config/env/tree.yaml index af6bec2eb..7ac98b00f 100644 --- a/config/env/tree.yaml +++ b/config/env/tree.yaml @@ -35,5 +35,3 @@ buffer: test: type: random n: 1000 - output_csv: tree_test.csv - output_pkl: tree_test.pkl diff --git a/config/experiments/ccube/hyperparams_search_20230920_batch1.yaml b/config/experiments/ccube/hyperparams_search_20230920_batch1.yaml index 87e44bfb5..94bb4611c 100644 --- a/config/experiments/ccube/hyperparams_search_20230920_batch1.yaml +++ b/config/experiments/ccube/hyperparams_search_20230920_batch1.yaml @@ -24,8 +24,6 @@ shared: test: type: grid n: 1000 - output_csv: ccube_test.csv - output_pkl: ccube_test.pkl # Proxy proxy: corners # GFlowNet config diff --git a/config/experiments/ccube/hyperparams_search_20230920_batch2.yaml b/config/experiments/ccube/hyperparams_search_20230920_batch2.yaml index 93491e3e9..d5e3c92d2 100644 --- a/config/experiments/ccube/hyperparams_search_20230920_batch2.yaml +++ b/config/experiments/ccube/hyperparams_search_20230920_batch2.yaml @@ -24,8 +24,6 @@ shared: test: type: grid n: 1000 - output_csv: ccube_test.csv - output_pkl: ccube_test.pkl # Proxy proxy: corners # GFlowNet config diff --git a/config/experiments/ccube/hyperparams_search_20230920_batch3.yaml b/config/experiments/ccube/hyperparams_search_20230920_batch3.yaml index 7912af9b3..eacfcb762 100644 --- a/config/experiments/ccube/hyperparams_search_20230920_batch3.yaml +++ b/config/experiments/ccube/hyperparams_search_20230920_batch3.yaml @@ -24,8 +24,6 @@ shared: test: type: grid n: 1000 - output_csv: ccube_test.csv - output_pkl: ccube_test.pkl # Proxy proxy: corners # GFlowNet config diff --git a/config/experiments/ccube/hyperparams_search_20230920_batch4.yaml b/config/experiments/ccube/hyperparams_search_20230920_batch4.yaml index cc82e322c..8bf903397 100644 --- a/config/experiments/ccube/hyperparams_search_20230920_batch4.yaml +++ b/config/experiments/ccube/hyperparams_search_20230920_batch4.yaml @@ -24,8 +24,6 @@ shared: test: type: grid n: 1000 - output_csv: ccube_test.csv - output_pkl: ccube_test.pkl # Proxy proxy: corners # GFlowNet config diff --git a/config/experiments/crystals/starling_bg.yaml b/config/experiments/crystals/starling_bg.yaml index cce4816c4..9fa12abfe 100644 --- a/config/experiments/crystals/starling_bg.yaml +++ b/config/experiments/crystals/starling_bg.yaml @@ -55,13 +55,9 @@ env: train: type: csv path: /network/projects/crystalgfn/data/bandgap/train.csv - output_csv: crystal_train.csv - output_pkl: crystal_train.pkl test: type: csv path: /network/projects/crystalgfn/data/bandgap/val.csv - output_csv: crystal_val.csv - output_pkl: crystal_val.pkl # GFlowNet hyperparameters gflownet: diff --git a/config/experiments/crystals/starling_bg_no_constraints.yaml b/config/experiments/crystals/starling_bg_no_constraints.yaml index 39f29672a..c4d57c700 100644 --- a/config/experiments/crystals/starling_bg_no_constraints.yaml +++ b/config/experiments/crystals/starling_bg_no_constraints.yaml @@ -55,13 +55,9 @@ env: train: type: csv path: /network/projects/crystalgfn/data/bandgap/train.csv - output_csv: crystal_train.csv - output_pkl: crystal_train.pkl test: type: csv path: /network/projects/crystalgfn/data/bandgap/val.csv - output_csv: crystal_val.csv - output_pkl: crystal_val.pkl # GFlowNet hyperparameters gflownet: diff --git a/config/experiments/crystals/starling_density.yaml b/config/experiments/crystals/starling_density.yaml index 31b050bd9..8baf0b645 100644 --- a/config/experiments/crystals/starling_density.yaml +++ b/config/experiments/crystals/starling_density.yaml @@ -54,13 +54,9 @@ env: train: type: csv path: /network/projects/crystalgfn/data/eform/train.csv - output_csv: crystal_train.csv - output_pkl: crystal_train.pkl test: type: csv path: /network/projects/crystalgfn/data/eform/val.csv - output_csv: crystal_val.csv - output_pkl: crystal_val.pkl # GFlowNet hyperparameters gflownet: diff --git a/config/experiments/crystals/starling_density_no_constraints.yaml b/config/experiments/crystals/starling_density_no_constraints.yaml index c8876ceae..56f826d82 100644 --- a/config/experiments/crystals/starling_density_no_constraints.yaml +++ b/config/experiments/crystals/starling_density_no_constraints.yaml @@ -54,13 +54,9 @@ env: train: type: csv path: /network/projects/crystalgfn/data/eform/train.csv - output_csv: crystal_train.csv - output_pkl: crystal_train.pkl test: type: csv path: /network/projects/crystalgfn/data/eform/val.csv - output_csv: crystal_val.csv - output_pkl: crystal_val.pkl # GFlowNet hyperparameters gflownet: diff --git a/config/experiments/crystals/starling_fe.yaml b/config/experiments/crystals/starling_fe.yaml index 41c09bae9..4251b0f4c 100644 --- a/config/experiments/crystals/starling_fe.yaml +++ b/config/experiments/crystals/starling_fe.yaml @@ -54,13 +54,9 @@ env: train: type: csv path: /network/projects/crystalgfn/data/eform/train.csv - output_csv: crystal_train.csv - output_pkl: crystal_train.pkl test: type: csv path: /network/projects/crystalgfn/data/eform/val.csv - output_csv: crystal_val.csv - output_pkl: crystal_val.pkl # GFlowNet hyperparameters gflownet: diff --git a/config/experiments/crystals/starling_fe_no_constraints.yaml b/config/experiments/crystals/starling_fe_no_constraints.yaml index 0c428d0e5..6a059eda0 100644 --- a/config/experiments/crystals/starling_fe_no_constraints.yaml +++ b/config/experiments/crystals/starling_fe_no_constraints.yaml @@ -54,13 +54,9 @@ env: train: type: csv path: /network/projects/crystalgfn/data/eform/train.csv - output_csv: crystal_train.csv - output_pkl: crystal_train.pkl test: type: csv path: /network/projects/crystalgfn/data/eform/val.csv - output_csv: crystal_val.csv - output_pkl: crystal_val.pkl # GFlowNet hyperparameters gflownet: diff --git a/config/experiments/scrabble/jay.yaml b/config/experiments/scrabble/jay.yaml index 2bc1bec28..bdbfbea56 100644 --- a/config/experiments/scrabble/jay.yaml +++ b/config/experiments/scrabble/jay.yaml @@ -18,8 +18,6 @@ env: test: type: random n: 1000 - output_csv: scrabble_test.csv - output_pkl: scrabble_test.pkl # Proxy proxy: diff --git a/config/experiments/scrabble/penguin.yaml b/config/experiments/scrabble/penguin.yaml index 02ba27158..0221f8179 100644 --- a/config/experiments/scrabble/penguin.yaml +++ b/config/experiments/scrabble/penguin.yaml @@ -17,8 +17,6 @@ env: test: type: random n: 1000 - output_csv: scrabble_test.csv - output_pkl: scrabble_test.pkl # Proxy proxy: diff --git a/gflownet/utils/buffer.py b/gflownet/utils/buffer.py index 45f60ddd0..1593c48ee 100644 --- a/gflownet/utils/buffer.py +++ b/gflownet/utils/buffer.py @@ -3,6 +3,7 @@ """ import pickle +from pathlib import Path from typing import List import numpy as np @@ -21,14 +22,17 @@ def __init__( env, proxy, replay_capacity=0, - output_csv=None, data_path=None, train=None, test=None, logger=None, **kwargs, ): - self.logger = logger + if logger is not None: + self.datadir = logger.datadir + else: + self.datadir = Path("./logs") + self.datadir.mkdir(parents=True, exist_ok=True) self.env = env self.proxy = proxy self.replay_capacity = replay_capacity @@ -43,7 +47,7 @@ def __init__( self.replay_states = {} self.replay_trajs = {} self.replay_rewards = {} - self.replay_pkl = "replay.pkl" + self.replay_pkl = self.datadir / "replay.pkl" self.train_csv = None self.train_pkl = None @@ -58,28 +62,19 @@ def __init__( else: self.train_type = None self.train, dict_tr = self.make_data_set(train) - if ( - self.train is not None - and "output_csv" in train - and train.output_csv is not None - ): - self.train.to_csv(train.output_csv) - self.train_csv = train.output_csv - if ( - dict_tr is not None - and "output_pkl" in train - and train.output_pkl is not None - ): - with open(train.output_pkl, "wb") as f: + if self.train is not None: + self.train_csv = self.datadir / "train.csv" + self.train.to_csv(self.train_csv) + if dict_tr is not None: + self.train_pkl = self.datadir / "train.pkl" + with open(self.train_pkl, "wb") as f: pickle.dump(dict_tr, f) - self.train_pkl = train.output_pkl else: print( """ Important: offline trajectories will NOT be sampled. In order to sample offline trajectories, the train configuration of the buffer should be - complete and feasible and an output pkl file should be defined in - env.buffer.train.output_pkl. + complete and feasible. It should at least specify env.buffer.train.type. """ ) self.train_pkl = None @@ -90,24 +85,19 @@ def __init__( else: self.train_type = None self.test, dict_tt = self.make_data_set(test) - if ( - self.test is not None - and "output_csv" in test - and test.output_csv is not None - ): - self.test_csv = test.output_csv - self.test.to_csv(test.output_csv) - if dict_tt is not None and "output_pkl" in test and test.output_pkl is not None: - with open(test.output_pkl, "wb") as f: + if self.test is not None: + self.test_csv = self.datadir / "test.csv" + self.test.to_csv(self.test_csv) + if dict_tt is not None: + self.test_pkl = self.datadir / "test.pkl" + with open(self.test_pkl, "wb") as f: pickle.dump(dict_tt, f) - self.test_pkl = test.output_pkl else: print( """ Important: test metrics will NOT be computed. In order to compute test metrics the test configuration of the buffer should be complete and - feasible and an output pkl file should be defined in - env.buffer.test.output_pkl. + feasible. It should at least specify env.buffer.test.type. """ ) self.test_pkl = None diff --git a/gflownet/utils/logger.py b/gflownet/utils/logger.py index 86e554c09..9fe982b30 100644 --- a/gflownet/utils/logger.py +++ b/gflownet/utils/logger.py @@ -71,14 +71,21 @@ def __init__( self.lightweight = lightweight self.debug = debug # Log directory - self.logdir = Path(logdir.root) + if "path" in logdir: + self.logdir = Path(logdir.path) + else: + self.logdir = Path(logdir.root) if not self.logdir.exists() or logdir.overwrite: self.logdir.mkdir(parents=True, exist_ok=True) else: print(f"logdir {logdir} already exists! - Ending run...") sys.exit(1) + # Checkpoints directory self.ckpts_dir = self.logdir / logdir.ckpts self.ckpts_dir.mkdir(parents=True, exist_ok=True) + # Data directory + self.datadir = self.logdir / "data" + self.datadir.mkdir(parents=True, exist_ok=True) # Write wandb URL self.write_url_file() diff --git a/main.py b/main.py index 36607c276..c0036e745 100644 --- a/main.py +++ b/main.py @@ -9,6 +9,7 @@ import hydra import pandas as pd +from omegaconf import open_dict from gflownet.utils.policy import parse_policy_config @@ -16,12 +17,13 @@ @hydra.main(config_path="./config", config_name="main", version_base="1.1") def main(config): - # Print working and logging directory + # Set and print working and logging directory + with open_dict(config): + config.logger.logdir.path = ( + hydra.core.hydra_config.HydraConfig.get().runtime.output_dir + ) print(f"\nWorking directory of this run: {os.getcwd()}") - print( - "Logging directory of this run: " - f"{hydra.core.hydra_config.HydraConfig.get().runtime.output_dir}" - ) + print(f"Logging directory of this run: {config.logger.logdir.path}\n") # Reset seed for job-name generation in multirun jobs random.seed(None) @@ -48,7 +50,7 @@ def main(config): _partial_=True, ) env = env_maker() - + # The evaluator is used to compute metrics and plots evaluator = hydra.utils.instantiate(config.evaluator) diff --git a/tests/gflownet/evaluator/test_base.py b/tests/gflownet/evaluator/test_base.py index f9616aa82..158b5b85b 100644 --- a/tests/gflownet/evaluator/test_base.py +++ b/tests/gflownet/evaluator/test_base.py @@ -221,7 +221,7 @@ def test__should_eval_top_k(constant_evaluator, period, step, target, first_it, ], ) def test__eval(gflownet_for_tests, parameterization): - assert Path("./replay.pkl").exists() + assert gflownet_for_tests.buffer.replay_pkl.exists() # results: {"metrics": dict[str, float], "figs": list[plt.Figure]} results = gflownet_for_tests.evaluator.eval() figs = gflownet_for_tests.evaluator.plot(**results["data"])