Skip to content

Commit

Permalink
Fix ICLR2024 paths
Browse files Browse the repository at this point in the history
  • Loading branch information
bogdan-kulynych committed Jun 16, 2024
1 parent 54b3640 commit e60d62b
Show file tree
Hide file tree
Showing 15 changed files with 17 additions and 30 deletions.
1 change: 1 addition & 0 deletions research/iclr2024/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
runs
4 changes: 4 additions & 0 deletions research/iclr2024/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
Example run:
```
python scripts/run_experiment_pipeline.py --data_name=german --action_set=simple_1D
```
2 changes: 0 additions & 2 deletions research/iclr2024/scripts/compute_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
import sys
import psutil

sys.path.append(os.path.join(os.getcwd(), "iclr2024"))

import pandas as pd
import argparse

Expand Down
2 changes: 0 additions & 2 deletions research/iclr2024/scripts/generate_reachable_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
import psutil
import argparse

sys.path.append(os.path.join(os.getcwd(), "iclr2024"))

settings = {
"data_name": "german",
"action_set_name": "complex_1D",
Expand Down
2 changes: 0 additions & 2 deletions research/iclr2024/scripts/get_model_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
import psutil
import rich

sys.path.append(os.path.join(os.getcwd(), "iclr2024"))

import numpy as np
import pandas as pd
import argparse
Expand Down
4 changes: 1 addition & 3 deletions research/iclr2024/scripts/merge_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,8 @@
import sys
import itertools

sys.path.append(os.path.join(os.getcwd(), "iclr2024"))

import pandas as pd
from iclr2024.src.paths import *
from src.paths import *
from src import fileutils
from src.data import BinaryClassificationDataset

Expand Down
1 change: 0 additions & 1 deletion research/iclr2024/scripts/print_actionset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import psutil
import argparse

sys.path.append(os.path.join(os.getcwd(), "iclr2024"))
from src.paths import *
from src import fileutils
from pathlib import Path
Expand Down
4 changes: 1 addition & 3 deletions research/iclr2024/scripts/run_ar.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,13 @@

warnings.simplefilter(action="ignore", category=FutureWarning)

sys.path.append(os.path.join(os.getcwd(), "iclr2024"))

import numpy as np
import pandas as pd
from sklearn.preprocessing import StandardScaler
import recourse as rs
from tqdm.auto import tqdm

from iclr2024.src.paths import *
from src.paths import *
from src import fileutils


Expand Down
2 changes: 0 additions & 2 deletions research/iclr2024/scripts/run_dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@

warnings.simplefilter(action="ignore", category=UserWarning)

sys.path.append(os.path.join(os.getcwd(), "iclr2024"))

import numpy as np
import pandas as pd
from tqdm.auto import tqdm
Expand Down
14 changes: 7 additions & 7 deletions research/iclr2024/scripts/run_experiment_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,19 +30,19 @@ def main():

pipeline = []
if "setup" in args.stages:
pipeline.append(f"python iclr2024/scripts/setup_dataset_actionset_{args.data_name}.py")
#todo: pipeline.append(f"python scripts/setup_dataset_actionset_{args.data_name}.py")
pipeline.append(f"python scripts/setup_dataset_actionset_{args.data_name}.py")
# todo: pipeline.append(f"python scripts/setup_dataset_actionset_{args.data_name}.py")

if "db" in args.stages:
if args.action_set_name == GEN_DB_ACTION_SET:
pipeline.append(
f"python iclr2024/scripts/generate_reachable_sets.py --data_name={args.data_name} --action_set_name={args.action_set_name} {'--overwrite' if args.overwrite else ''}"
f"python scripts/generate_reachable_sets.py --data_name={args.data_name} --action_set_name={args.action_set_name} {'--overwrite' if args.overwrite else ''}"
)

if "train" in args.stages:
for model in args.models:
pipeline.append(
f"python iclr2024/scripts/train_models.py --data_name={args.data_name} --action_set_name={args.action_set_name} --model_type={model}"
f"python scripts/train_models.py --data_name={args.data_name} --action_set_name={args.action_set_name} --model_type={model}"
)

if "baselines" in args.stages:
Expand All @@ -53,7 +53,7 @@ def main():
continue
else:
pipeline.append(
f"python iclr2024/scripts/run_{method}.py --data_name={args.data_name} --action_set_name={args.action_set_name} --model_type={model}"
f"python scripts/run_{method}.py --data_name={args.data_name} --action_set_name={args.action_set_name} --model_type={model}"
)

if "audit" in args.stages:
Expand All @@ -62,7 +62,7 @@ def main():
continue
else:
pipeline.append(
f"python iclr2024/scripts/run_audit.py --data_name={args.data_name} --action_set_name={args.action_set_name} --model_type={model} --method_name={method}",
f"python scripts/run_audit.py --data_name={args.data_name} --action_set_name={args.action_set_name} --model_type={model} --method_name={method}",
)

if "stats" in args.stages:
Expand All @@ -71,7 +71,7 @@ def main():
continue
else:
pipeline.append(
f"python iclr2024/scripts/compute_stats.py --data_name={args.data_name} --action_set_name={args.action_set_name} --model_type={model} --method_name={method}",
f"python scripts/compute_stats.py --data_name={args.data_name} --action_set_name={args.action_set_name} --model_type={model} --method_name={method}",
)

# Run each command in the list
Expand Down
2 changes: 0 additions & 2 deletions research/iclr2024/scripts/setup_dataset_actionset_fico.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import os
import sys

sys.path.append(os.path.join(os.getcwd(), "iclr2024"))

import pandas as pd
import numpy as np
import itertools
Expand Down
3 changes: 1 addition & 2 deletions research/iclr2024/scripts/setup_dataset_actionset_german.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# fmt: off
import os
import sys
sys.path.append(os.path.join(os.getcwd(), "iclr2024"))

import numpy as np
import pandas as pd
Expand All @@ -10,7 +9,7 @@
pd.set_option('display.max_columns', 500)


from iclr2024.src.paths import *
from src.paths import *
from src import fileutils
from src.data import BinaryClassificationDataset
from reachml import ActionSet, ReachableSetDatabase
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@
import sys

# fmt:off
sys.path.append(os.path.join(os.getcwd(), "iclr2024"))
import numpy as np
import pandas as pd
from iclr2024.src.paths import *
from src.paths import *
from src import fileutils
from src.data import BinaryClassificationDataset
from reachml import ActionSet, ReachableSetDatabase
Expand Down
1 change: 0 additions & 1 deletion research/iclr2024/scripts/train_models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import sys
import os

sys.path.append(os.path.join(os.getcwd(), "iclr2024"))
import psutil
import argparse
from src import fileutils
Expand Down
2 changes: 1 addition & 1 deletion research/iclr2024/src/paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# Directories

# path to the GitHub repository
paper_dir = Path(os.path.join(os.getcwd(), 'iclr2024')).resolve()
paper_dir = Path(os.path.join(os.getcwd(), "")).resolve()
repo_dir = Path(os.getcwd()).resolve()

# directory where we store datasets
Expand Down

0 comments on commit e60d62b

Please sign in to comment.