-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathrun_optimization.py
143 lines (122 loc) · 4.42 KB
/
run_optimization.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
import numpy as np
import time
from icecream import ic
from optimization.particle_swarm_optimization_wrapper import ParticleSwarmOptimizationWrapper
from args.args import Args
from datasets.dataset_rh import DatasetRH
from datasets.dataset_ethz import DatasetETHZ
from training.trainer import Trainer
from helpers.system_fcts import checkGPUMemory
def main():
# define paraeters
T = 36000 # if termination_by_time: T is time in seconds, else T is number of iterations
termination_by_time = True # whether to terminate by time or iterations
hparams_file = "ethz_usstof_gpu.json"
hparams_lims_file = "optimization/hparams_lims.json"
save_dir = "results/pso/opt32_2"
# get hyper-parameters and other variables
args = Args(
file_name=hparams_file
)
args.model.save = False
args.training.debug_mode = False
args.eval.eval_every_n_steps = args.training.max_steps + 1
args.eval.plot_results = False
args.eval.sensors = ["GT", "NeRF"]
args.eval.num_color_pts = 0
args.eval.batch_size = 8192
args.training.batch_size = 4096
args.seed = np.random.randint(0, 2**8-1)
# datasets
if args.dataset.name == 'RH2':
dataset = DatasetRH
elif args.dataset.name == 'ETHZ':
dataset = DatasetETHZ
else:
args.logger.error("Invalid dataset name.")
train_dataset = dataset(
args = args,
split="train",
).to(args.device)
test_dataset = dataset(
args = args,
split='test',
scene=train_dataset.scene,
).to(args.device)
# pso
pso = ParticleSwarmOptimizationWrapper(
hparams_lims_file=hparams_lims_file,
save_dir=save_dir,
T=T,
termination_by_time=termination_by_time,
rng=np.random.default_rng(args.seed),
)
# run optimization
terminate = False
iter = 0
while not terminate:
iter += 1
# get hparams to evaluate
hparams_dict = pso.getNextHparams(
group_dict_layout=True,
name_dict_layout=False,
) # np.array (M,)
# set hparams
args.setRandomSeed(
seed=args.seed+iter,
)
sampling_pix_sum = (hparams_dict["training"]["pixs_valid_uss"] + hparams_dict["training"]["pixs_valid_tof"])
if sampling_pix_sum > 1.0:
sampling_pix_sum = np.ceil(100*sampling_pix_sum) / 100 # round to 2 decimals
hparams_dict["training"]["pixs_valid_uss"] /= sampling_pix_sum
hparams_dict["training"]["pixs_valid_tof"] /= sampling_pix_sum
sampling_strategy = {
"imgs": "all",
"pixs": {
"valid_uss": hparams_dict["training"]["pixs_valid_uss"],
"valid_tof": hparams_dict["training"]["pixs_valid_tof"],
},
}
for key, value in hparams_dict["training"].items():
if (key == "pixs_valid_uss") or (key == "pixs_valid_tof"):
setattr(args.training, "sampling_strategy", sampling_strategy)
continue
setattr(args.training, key, value)
for key, value in hparams_dict["occ_grid"].items():
if (key == "update_interval") or (key == "decay_warmup_steps"):
setattr(args.occ_grid, key, int(np.round(value)))
continue
setattr(args.occ_grid, key, value)
setattr(args.tof, "tof_pix_size", int(np.round(hparams_dict["ToF"]["tof_pix_size"])))
print("\n\n----- NEW PARAMETERS -----")
print(f"Time: {time.time()-pso.time_start+pso.time_offset:.1f}s, particle: {pso.n}")
ic(hparams_dict)
print(f"Current best mnn: {np.min(pso.best_score):.3f}, best particle: {np.argmin(pso.best_score)}")
# load trainer
trainer = Trainer(
args=args,
train_dataset=train_dataset,
test_dataset=test_dataset,
)
# train and evaluate model
trainer.train()
metrics_dict = trainer.evaluate()
# get score
score = metrics_dict['NeRF']["nn_mean"]['zone3']
if score == np.nan:
score = np.inf
# update particle swarm
terminate = pso.update(
score=score,
) # bool
# save state
pso.saveState(
score=score,
)
del trainer
if checkGPUMemory():
terminate = True
if __name__ == "__main__":
main()