-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathrun_ablation.py
70 lines (57 loc) · 1.73 KB
/
run_ablation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
import sys
from training.trainer import Trainer
from args.args import Args
from datasets.dataset_ethz import DatasetETHZ
from helpers.system_fcts import checkGPUMemory
def main():
hparams_file = "ethz_usstof_ablation_gpu.json"
num_trainings = 10
base_dir = "results/ETHZ/ablation_commonroom/instant_ngp"
base_seed = 21
# create base dir and count seeds already trained
if not os.path.exists(base_dir):
os.makedirs(base_dir)
num_seeds_already_trained = len(os.listdir(base_dir))
if num_seeds_already_trained >= num_trainings:
print("All seeds already trained.")
sys.exit()
# args
args = Args(
file_name=hparams_file
)
# datasets
train_dataset = DatasetETHZ(
args = args,
split="train",
).to(args.device)
test_dataset = DatasetETHZ(
args = args,
split='test',
scene=train_dataset.scene,
).to(args.device)
for i in range(num_seeds_already_trained, num_trainings):
# set random seed and directory for saving
args.setRandomSeed(
seed=base_seed+i,
)
args.save_dir = os.path.join(base_dir, f"seed_{args.seed}")
if not os.path.exists(args.save_dir):
os.makedirs(args.save_dir)
else:
print(f"Seed {args.seed} already trained.")
sys.exit()
# create trainer
trainer = Trainer(
args=args,
train_dataset=train_dataset,
test_dataset=test_dataset,
)
trainer.train()
trainer.evaluate()
# check if GPU memory is full
if checkGPUMemory():
break
if __name__ == "__main__":
main()