-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcheck.py
55 lines (51 loc) · 2.17 KB
/
check.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import logging
from time import perf_counter
from ex3 import DroneAgent, ids
from inputs import inputs_list
from drone_env import DroneEnv
from trainer import DroneTrainer
AGENT_INIT_TIME_LIMIT = 1.
EPISODE_TIME_LIMIT = 4e-3
NR_TRAIN_EPISODES = int(200e3)
NR_TEST_EPISODES = int(10e3)
logging.getLogger().setLevel(logging.INFO)
if __name__ == '__main__':
logging.info(f"IDS:: {ids}")
test_scores = []
for idx, params in enumerate(inputs_list):
try:
logging.info(f'input_id:: {idx}')
""" Initialize Environment """
drone_env = DroneEnv(params)
""" Create Agent """
n = len(params['map'])
m = len(params['map'][0])
start = perf_counter()
drone_agent = DroneAgent(n, m)
end = perf_counter()
if end - start > AGENT_INIT_TIME_LIMIT:
logging.critical(f"timed out on agent constructor, time: {round(end - start, 2)}")
raise TimeoutError
""" Run Trainer """
trainer = DroneTrainer(drone_agent, drone_env)
start = perf_counter()
average_score_train = trainer.run(nr_episodes=NR_TRAIN_EPISODES, train=True)
end = perf_counter()
if end - start > EPISODE_TIME_LIMIT * NR_TRAIN_EPISODES:
logging.critical(f"timed out on train, time: {round(end - start, 2)}")
raise TimeoutError
logging.info(f'train score: {average_score_train}, time: {round(end - start, 2)}')
""" Evaluate Agent"""
start = perf_counter()
average_score_test = trainer.run(nr_episodes=NR_TEST_EPISODES, train=False)
end = perf_counter()
if end - start > EPISODE_TIME_LIMIT * NR_TEST_EPISODES:
logging.critical(f"timed out on test, time: {round(end - start, 2)}")
raise TimeoutError
logging.info(f'test score: {average_score_test}, time: {round(end - start, 2)}')
test_scores.append(average_score_test)
except TimeoutError:
test_scores.append(-50.)
continue
logging.info(f"Done!")
logging.info(f"scores: {test_scores}")