Skip to content

Commit

Permalink
Merge pull request #322 from alexhernandezgarcia/logdir
Browse files Browse the repository at this point in the history
[Small PR] Logging dir and buffer outputs: fix issue and improvements
  • Loading branch information
alexhernandezgarcia authored Jun 11, 2024
2 parents 0a256bf + f8c235c commit d741e08
Show file tree
Hide file tree
Showing 28 changed files with 38 additions and 103 deletions.
2 changes: 0 additions & 2 deletions config/env/alaninedipeptide.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,3 @@ buffer:
test:
type: grid
n: 1000
output_csv: alaninedipeptide_test.csv
output_pkl: alaninedipeptide_test.pkl
2 changes: 0 additions & 2 deletions config/env/ccube.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,5 +34,3 @@ buffer:
test:
type: grid
n: 900
output_csv: ccube_test.csv
output_pkl: ccube_test.pkl
2 changes: 0 additions & 2 deletions config/env/crystals/composition.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,5 @@ buffer:
data_path: null
train:
type: all
output_csv: composition_train.csv
test:
type: all
output_csv: composition_test.csv
2 changes: 0 additions & 2 deletions config/env/crystals/lattice_parameters.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,5 +39,3 @@ buffer:
test:
type: grid
n: 900
output_csv: clp_test.csv
output_pkl: clp_test.pkl
4 changes: 0 additions & 4 deletions config/env/crystals/spacegroup.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,5 @@ buffer:
data_path: null
train:
type: all
output_csv: spacegroup_train.csv
output_pkl: spacegroup_train.pkl
test:
type: all
output_csv: spacegroup_test.csv
output_pkl: spacegroup_test.pkl
2 changes: 0 additions & 2 deletions config/env/ctorus.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,3 @@ buffer:
test:
type: grid
n: 1000
output_csv: ctorus_test.csv
output_pkl: ctorus_test.pkl
2 changes: 0 additions & 2 deletions config/env/grid.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,3 @@ buffer:
train: null
test:
type: all
output_csv: grid_test.csv
output_pkl: grid_test.pkl
2 changes: 0 additions & 2 deletions config/env/htorus.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,3 @@ buffer:
test:
type: grid
n: 1000
output_csv: htorus_test.csv
output_pkl: htorus_test.pkl
2 changes: 0 additions & 2 deletions config/env/scrabble.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,3 @@ buffer:
test:
type: uniform
n: 10
output_csv: scrabble_test.csv
output_pkl: scrabble_test.pkl
3 changes: 0 additions & 3 deletions config/env/tetris.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,3 @@ buffer:
test:
type: random
n: 10
output_csv: tetris_test.csv
output_pkl: tetris_test.pkl

3 changes: 0 additions & 3 deletions config/env/torus.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,5 @@ buffer:
data_path: null
train:
type: all
output_csv: torus_train.csv
test:
type: all
output_csv: torus_test.csv
output_pkl: torus_test.pkl
2 changes: 0 additions & 2 deletions config/env/tree.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,5 +35,3 @@ buffer:
test:
type: random
n: 1000
output_csv: tree_test.csv
output_pkl: tree_test.pkl
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@ shared:
test:
type: grid
n: 1000
output_csv: ccube_test.csv
output_pkl: ccube_test.pkl
# Proxy
proxy: box/corners
# GFlowNet config
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@ shared:
test:
type: grid
n: 1000
output_csv: ccube_test.csv
output_pkl: ccube_test.pkl
# Proxy
proxy: box/corners
# GFlowNet config
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@ shared:
test:
type: grid
n: 1000
output_csv: ccube_test.csv
output_pkl: ccube_test.pkl
# Proxy
proxy: box/corners
# GFlowNet config
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@ shared:
test:
type: grid
n: 1000
output_csv: ccube_test.csv
output_pkl: ccube_test.pkl
# Proxy
proxy: box/corners
# GFlowNet config
Expand Down
4 changes: 0 additions & 4 deletions config/experiments/crystals/starling_bg.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,9 @@ env:
train:
type: csv
path: /network/projects/crystalgfn/data/bandgap/train.csv
output_csv: crystal_train.csv
output_pkl: crystal_train.pkl
test:
type: csv
path: /network/projects/crystalgfn/data/bandgap/val.csv
output_csv: crystal_val.csv
output_pkl: crystal_val.pkl

# GFlowNet hyperparameters
gflownet:
Expand Down
4 changes: 0 additions & 4 deletions config/experiments/crystals/starling_bg_no_constraints.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,9 @@ env:
train:
type: csv
path: /network/projects/crystalgfn/data/bandgap/train.csv
output_csv: crystal_train.csv
output_pkl: crystal_train.pkl
test:
type: csv
path: /network/projects/crystalgfn/data/bandgap/val.csv
output_csv: crystal_val.csv
output_pkl: crystal_val.pkl

# GFlowNet hyperparameters
gflownet:
Expand Down
4 changes: 0 additions & 4 deletions config/experiments/crystals/starling_density.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,9 @@ env:
train:
type: csv
path: /network/projects/crystalgfn/data/eform/train.csv
output_csv: crystal_train.csv
output_pkl: crystal_train.pkl
test:
type: csv
path: /network/projects/crystalgfn/data/eform/val.csv
output_csv: crystal_val.csv
output_pkl: crystal_val.pkl

# GFlowNet hyperparameters
gflownet:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,9 @@ env:
train:
type: csv
path: /network/projects/crystalgfn/data/eform/train.csv
output_csv: crystal_train.csv
output_pkl: crystal_train.pkl
test:
type: csv
path: /network/projects/crystalgfn/data/eform/val.csv
output_csv: crystal_val.csv
output_pkl: crystal_val.pkl

# GFlowNet hyperparameters
gflownet:
Expand Down
4 changes: 0 additions & 4 deletions config/experiments/crystals/starling_fe.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,9 @@ env:
train:
type: csv
path: /network/projects/crystalgfn/data/eform/train.csv
output_csv: crystal_train.csv
output_pkl: crystal_train.pkl
test:
type: csv
path: /network/projects/crystalgfn/data/eform/val.csv
output_csv: crystal_val.csv
output_pkl: crystal_val.pkl

# GFlowNet hyperparameters
gflownet:
Expand Down
4 changes: 0 additions & 4 deletions config/experiments/crystals/starling_fe_no_constraints.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,9 @@ env:
train:
type: csv
path: /network/projects/crystalgfn/data/eform/train.csv
output_csv: crystal_train.csv
output_pkl: crystal_train.pkl
test:
type: csv
path: /network/projects/crystalgfn/data/eform/val.csv
output_csv: crystal_val.csv
output_pkl: crystal_val.pkl

# GFlowNet hyperparameters
gflownet:
Expand Down
2 changes: 0 additions & 2 deletions config/experiments/scrabble/jay.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@ env:
test:
type: random
n: 1000
output_csv: scrabble_test.csv
output_pkl: scrabble_test.pkl

# Proxy
proxy:
Expand Down
2 changes: 0 additions & 2 deletions config/experiments/scrabble/penguin.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@ env:
test:
type: random
n: 1000
output_csv: scrabble_test.csv
output_pkl: scrabble_test.pkl

# Proxy
proxy:
Expand Down
52 changes: 21 additions & 31 deletions gflownet/utils/buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""

import pickle
from pathlib import Path
from typing import List

import numpy as np
Expand All @@ -21,14 +22,17 @@ def __init__(
env,
proxy,
replay_capacity=0,
output_csv=None,
data_path=None,
train=None,
test=None,
logger=None,
**kwargs,
):
self.logger = logger
if logger is not None:
self.datadir = logger.datadir
else:
self.datadir = Path("./logs")
self.datadir.mkdir(parents=True, exist_ok=True)
self.env = env
self.proxy = proxy
self.replay_capacity = replay_capacity
Expand All @@ -43,7 +47,7 @@ def __init__(
self.replay_states = {}
self.replay_trajs = {}
self.replay_rewards = {}
self.replay_pkl = "replay.pkl"
self.replay_pkl = self.datadir / "replay.pkl"

self.train_csv = None
self.train_pkl = None
Expand All @@ -58,28 +62,19 @@ def __init__(
else:
self.train_type = None
self.train, dict_tr = self.make_data_set(train)
if (
self.train is not None
and "output_csv" in train
and train.output_csv is not None
):
self.train.to_csv(train.output_csv)
self.train_csv = train.output_csv
if (
dict_tr is not None
and "output_pkl" in train
and train.output_pkl is not None
):
with open(train.output_pkl, "wb") as f:
if self.train is not None:
self.train_csv = self.datadir / "train.csv"
self.train.to_csv(self.train_csv)
if dict_tr is not None:
self.train_pkl = self.datadir / "train.pkl"
with open(self.train_pkl, "wb") as f:
pickle.dump(dict_tr, f)
self.train_pkl = train.output_pkl
else:
print(
"""
Important: offline trajectories will NOT be sampled. In order to sample
offline trajectories, the train configuration of the buffer should be
complete and feasible and an output pkl file should be defined in
env.buffer.train.output_pkl.
complete and feasible. It should at least specify env.buffer.train.type.
"""
)
self.train_pkl = None
Expand All @@ -90,24 +85,19 @@ def __init__(
else:
self.train_type = None
self.test, dict_tt = self.make_data_set(test)
if (
self.test is not None
and "output_csv" in test
and test.output_csv is not None
):
self.test_csv = test.output_csv
self.test.to_csv(test.output_csv)
if dict_tt is not None and "output_pkl" in test and test.output_pkl is not None:
with open(test.output_pkl, "wb") as f:
if self.test is not None:
self.test_csv = self.datadir / "test.csv"
self.test.to_csv(self.test_csv)
if dict_tt is not None:
self.test_pkl = self.datadir / "test.pkl"
with open(self.test_pkl, "wb") as f:
pickle.dump(dict_tt, f)
self.test_pkl = test.output_pkl
else:
print(
"""
Important: test metrics will NOT be computed. In order to compute
test metrics the test configuration of the buffer should be complete and
feasible and an output pkl file should be defined in
env.buffer.test.output_pkl.
feasible. It should at least specify env.buffer.test.type.
"""
)
self.test_pkl = None
Expand Down
9 changes: 8 additions & 1 deletion gflownet/utils/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,14 +71,21 @@ def __init__(
self.lightweight = lightweight
self.debug = debug
# Log directory
self.logdir = Path(logdir.root)
if "path" in logdir:
self.logdir = Path(logdir.path)
else:
self.logdir = Path(logdir.root)
if not self.logdir.exists() or logdir.overwrite:
self.logdir.mkdir(parents=True, exist_ok=True)
else:
print(f"logdir {logdir} already exists! - Ending run...")
sys.exit(1)
# Checkpoints directory
self.ckpts_dir = self.logdir / logdir.ckpts
self.ckpts_dir.mkdir(parents=True, exist_ok=True)
# Data directory
self.datadir = self.logdir / "data"
self.datadir.mkdir(parents=True, exist_ok=True)
# Write wandb URL
self.write_url_file()

Expand Down
14 changes: 8 additions & 6 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,21 @@

import hydra
import pandas as pd
from omegaconf import open_dict

from gflownet.utils.policy import parse_policy_config


@hydra.main(config_path="./config", config_name="main", version_base="1.1")
def main(config):

# Print working and logging directory
# Set and print working and logging directory
with open_dict(config):
config.logger.logdir.path = (
hydra.core.hydra_config.HydraConfig.get().runtime.output_dir
)
print(f"\nWorking directory of this run: {os.getcwd()}")
print(
"Logging directory of this run: "
f"{hydra.core.hydra_config.HydraConfig.get().runtime.output_dir}"
)
print(f"Logging directory of this run: {config.logger.logdir.path}\n")

# Reset seed for job-name generation in multirun jobs
random.seed(None)
Expand All @@ -48,7 +50,7 @@ def main(config):
_partial_=True,
)
env = env_maker()

# The evaluator is used to compute metrics and plots
evaluator = hydra.utils.instantiate(config.evaluator)

Expand Down
2 changes: 1 addition & 1 deletion tests/gflownet/evaluator/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def test__should_eval_top_k(constant_evaluator, period, step, target, first_it,
],
)
def test__eval(gflownet_for_tests, parameterization):
assert Path("./replay.pkl").exists()
assert gflownet_for_tests.buffer.replay_pkl.exists()
# results: {"metrics": dict[str, float], "figs": list[plt.Figure]}
results = gflownet_for_tests.evaluator.eval()
figs = gflownet_for_tests.evaluator.plot(**results["data"])
Expand Down

0 comments on commit d741e08

Please sign in to comment.