Skip to content

Commit

Permalink
fix adsorbates filtering
Browse files Browse the repository at this point in the history
  • Loading branch information
vict0rsch committed Jan 16, 2024
1 parent 3f025ac commit 4bb9aac
Show file tree
Hide file tree
Showing 6 changed files with 108 additions and 25 deletions.
4 changes: 2 additions & 2 deletions configs/exps/alvaro/gflownet.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ default:
eval_every: 0.4

runs:

- config: faenet-is2re-all
note: baseline faenet

Expand All @@ -51,7 +51,7 @@ runs:

- config: depfaenet-is2re-all
note: depfaenet per-adsorbate
adsorbates: {'*O', '*OH', '*OH2', '*H'}
adsorbates: '*O, *OH, *OH2, *H'

- config: depfaenet-is2re-all
note: To be used for continue from dir
Expand Down
32 changes: 23 additions & 9 deletions mila/launch_exp.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
import copy
import os
import re
import subprocess
import sys
from pathlib import Path

from minydra import resolved_args
from yaml import safe_load, dump

from sbatch import now
import copy
from yaml import dump, safe_load

ROOT = Path(__file__).resolve().parent.parent

Expand Down Expand Up @@ -143,14 +142,16 @@ def cli_arg(args, key=""):
s += cli_arg(v, key=f"{parent}{k}")
else:
if " " in str(v) or "," in str(v) or isinstance(v, str):
if "'" in str(v) and '"' in str(v):
v = str(v).replace("'", "\\'")
if '"' in str(v):
v = str(v).replace('"', '\\"')
v = f"'{v}'"
elif "'" in str(v):
v = f'"{v}"'
v = f'\\"{v}\\"'
else:
v = f"'{v}'"
s += f" --{parent}{k}={v}"
if "ads" in k:
print(s.split(" --")[-1])
return s


Expand All @@ -175,10 +176,15 @@ def get_args_or_exp(key, args, exp):
n_jobs = None
args = resolved_args()
assert "exp" in args
regex = args.get("match", ".*")

regex = args.pop("match", ".*")
exp_name = args.pop("exp").replace(".yml", "").replace(".yaml", "")
no_confirm = args.pop("no_confirm", False)

sbatch_overrides = args.to_dict()

ts = now()

exp_name = args.exp.replace(".yml", "").replace(".yaml", "")
exp_file = find_exp(exp_name)

exp = safe_load(exp_file.open("r"))
Expand Down Expand Up @@ -231,6 +237,8 @@ def get_args_or_exp(key, args, exp):
else:
params["wandb_tags"] = exp_name

job = merge_dicts(job, sbatch_overrides)

py_args = f'py_args="{cli_arg(params).strip()}"'

sbatch_args = " ".join(
Expand All @@ -253,7 +261,7 @@ def get_args_or_exp(key, args, exp):
text += "\n<><><> Experiment config:\n\n-----" + exp_file.read_text() + "-----"
text += "\n<><><> Experiment runs:\n\n • " + "\n\n • ".join(commands) + separator

confirm = args.no_confirm or "y" in input("\n🚦 Confirm? [y/n] : ")
confirm = no_confirm or "y" in input("\n🚦 Confirm? [y/n] : ")

if confirm:
try:
Expand All @@ -267,6 +275,10 @@ def get_args_or_exp(key, args, exp):
for c, command in enumerate(commands):
print(f"Launching job {c+1:3}", end="\r")
outputs.append(os.popen(command).read().strip())
if "Aborting" in outputs[-1]:
print("\nError submitting job", c + 1, ":", command)
print(outputs[-1].replace("Error while launching job:\n", ""))
print("\n")
if " verbose=true" in command.lower():
print(outputs[-1])
except KeyboardInterrupt:
Expand All @@ -283,6 +295,8 @@ def get_args_or_exp(key, args, exp):

if is_interrupted:
print("\n💀 Interrupted. Kill jobs with:\n$ scancel" + " ".join(jobs))
elif not jobs:
print("\n❌ No jobs launched")
else:
text += f"{separator}All jobs launched: {' '.join(jobs)}"
with outfile.open("w") as f:
Expand Down
28 changes: 19 additions & 9 deletions mila/sbatch.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from minydra import resolved_args, MinyDict
from pathlib import Path
from datetime import datetime
import os
import re
import subprocess
from shutil import copyfile
import sys
import re
from datetime import datetime
from pathlib import Path
from shutil import copyfile

import yaml
from minydra import MinyDict, resolved_args

IS_DRAC = (
"narval.calcul.quebec" in os.environ.get("HOSTNAME", "")
Expand All @@ -24,13 +25,13 @@
# git commit: {git_commit}
# cwd: {cwd}
{git_checkout}
{sbatch_py_vars}
export MASTER_PORT=$(expr 10000 + $(echo -n $SLURM_JOBID | tail -c 4))
echo "Master port $MASTER_PORT"
cd {code_loc}
{git_checkout}
{modules}
Expand All @@ -41,7 +42,7 @@
conda activate {env}
fi
{wandb_offline}
srun --output={output} {python_command}
srun --gpus-per-task=1 --output={output} {python_command}
"""


Expand Down Expand Up @@ -247,7 +248,6 @@ def load_sbatch_args_from_dir(dir):
"cpus": int(sbatch_args["cpus-per-task"]),
"mem": sbatch_args["mem"],
"gres": sbatch_args["gres"],
"output": sbatch_args["output"],
}
return args

Expand Down Expand Up @@ -417,7 +417,17 @@ def load_sbatch_args_from_dir(dir):
print("\nDev mode: not actually executing the command 🤓\n")
else:
# not dev mode: run the command, make directories
out = subprocess.check_output(command.split(" ")).decode("utf-8").strip()
try:
out = (
subprocess.check_output(command.split(" "), stderr=subprocess.STDOUT)
.decode("utf-8")
.strip()
)
except subprocess.CalledProcessError as error:
print("Error while launching job:\n```")
print(error.output.decode("utf-8").strip())
print("```\nAborting...")
sys.exit(1)
jobid = out.split(" job ")[-1].strip()
success = out.startswith("Submitted batch job")

Expand Down
32 changes: 32 additions & 0 deletions ocpmodels/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -901,6 +901,37 @@ def set_cpus_to_workers(config, silent=False):
return config


def set_dataset_split(config):
"""
Set the split for all datasets in the config to the one specified in the
config's name.
Resulting dict:
{
"dataset": {
"train": {
"split": "all"
...
},
...
}
}
Args:
config (dict): The full trainer config dict
Returns:
dict: The updated config dict
"""
split = config["config"].split("-")[-1]
for d, dataset in config["dataset"].items():
if d == "default_val":
continue
assert isinstance(dataset, dict)
config["dataset"][d]["split"] = split
return config


def check_regress_forces(config):
if "regress_forces" in config["model"]:
if config["model"]["regress_forces"] == "":
Expand Down Expand Up @@ -1182,6 +1213,7 @@ def build_config(args, args_override=[], silent=None):
config = override_drac_paths(config)
config = continue_from_slurm_job_id(config)
config = read_slurm_env(config)
config = set_dataset_split(config)
config["optim"]["eval_batch_size"] = config["optim"]["batch_size"]
dist_utils.setup(config)

Expand Down
34 changes: 30 additions & 4 deletions ocpmodels/datasets/lmdb_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,13 @@ def __init__(
lmdb_glob=None,
adsorbates=None,
adsorbates_ref_dir=None,
silent=False,
):
super().__init__()
self.config = config
self.adsorbates = adsorbates
self.adsorbates_ref_dir = adsorbates_ref_dir
self.silent = silent

self.path = Path(self.config["src"])
if not self.path.is_file():
Expand Down Expand Up @@ -128,10 +130,23 @@ def filter_per_adsorbates(self):
if not ref_path.is_dir():
print(f"Adsorbate reference directory {ref_path} does not exist.")
return
pattern = "-".join(self.path.parts[-3:])
pattern = f"{self.config['split']}-{self.path.parts[-1]}"
candidates = list(ref_path.glob(f"*{pattern}*.json"))
if not candidates:
print(f"No adsorbate reference files found for {self.path.name}.")
print(
f"No adsorbate reference files found for {self.path.name}.:"
+ "\n".join(
[
str(p)
for p in [
ref_path,
pattern,
list(ref_path.glob(f"*{pattern}*.json")),
list(ref_path.glob("*")),
]
]
)
)
return
if len(candidates) > 1:
print(
Expand All @@ -147,6 +162,8 @@ def filter_per_adsorbates(self):
if a in ads
)

previous_samples = self.num_samples

# filter the dataset indices
if isinstance(self._keys[0], bytes):
self._keys = [i for i in self._keys if i in allowed_idxs]
Expand All @@ -158,6 +175,12 @@ def filter_per_adsorbates(self):
self._keylen_cumulative = np.cumsum(keylens).tolist()
self.num_samples = sum(keylens)

if not self.silent:
print(
f"Filtered dataset {pattern} from {previous_samples} to",
f"{self.num_samples} samples. (adsorbates: {ads})",
)

assert self.num_samples > 0, f"No samples found for adsorbates {ads}."

def __len__(self):
Expand Down Expand Up @@ -229,14 +252,17 @@ def close_db(self):

@registry.register_dataset("deup_lmdb")
class DeupDataset(LmdbDataset):
def __init__(self, all_datasets_configs, deup_split, transform=None):
def __init__(self, all_datasets_configs, deup_split, transform=None, silent=False):
# ! WARNING: this does not (yet?) handle adsorbate filtering
super().__init__(
all_datasets_configs[deup_split],
lmdb_glob=deup_split.replace("deup-", "").split("-"),
silent=silent,
)
ocp_splits = deup_split.split("-")[1:]
self.ocp_datasets = {
d: LmdbDataset(all_datasets_configs[d], transform) for d in ocp_splits
d: LmdbDataset(all_datasets_configs[d], transform, silent=silent)
for d in ocp_splits
}

def __getitem__(self, idx):
Expand Down
3 changes: 2 additions & 1 deletion ocpmodels/trainers/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,7 @@ def load_datasets(self):
transform=transform,
adsorbates=self.config.get("adsorbates"),
adsorbates_ref_dir=self.config.get("adsorbates_ref_dir"),
silent=self.silent,
)

elif self.data_mode == "heterogeneous":
Expand Down Expand Up @@ -1134,7 +1135,7 @@ def measure_inference_time(self, loops=1):
self.config["model"].get("regress_forces") == "from_energy"
)
self.model.eval()
timer = Times(gpu=True)
timer = Times(gpu=torch.cuda.is_available())

# average inference over multiple loops
for _ in range(loops):
Expand Down

0 comments on commit 4bb9aac

Please sign in to comment.