-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
71 lines (49 loc) · 1.8 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import os
import time
from environment import create_env
from callback import Callbacks
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3 import PPO
# ------------------------------------------------------------------------------------
CHECKPOINT_DIR = './train/'
OPT_DIR = './opt/'
LOG_DIR = './logs/'
# ------------------------------------------------------------------------------------
def main():
print('... Training ...\n')
env = create_env(LOG_DIR=LOG_DIR)
model = PPO.load(os.path.join(OPT_DIR, f'best_model'))
model.set_env(env)
#creating callback instance
callback = Callbacks(check_freq=100000, save_path=CHECKPOINT_DIR)
model.learn(total_timesteps=5500000, callback=callback)
print('\nDone!\n\nEvaluating the model...\n')
model = PPO.load(os.path.join(CHECKPOINT_DIR, f'model_100000'))
env = create_env(LOG_DIR=LOG_DIR)
mean_reward, _ = evaluate_policy(model, env, n_eval_episodes=1, render=True)
print(f'\n\n\nThe mean_reward is: {mean_reward}\n\n\n')
env.close()
print('\nDone!\n\nTesting the model...\n')
env = create_env(LOG_DIR=LOG_DIR)
obs = env.reset()
done = False
total_reward = 0
info = {'matches_won': 0}
while True:
if info.get('matches_won') == 2:
pass
else:
action, _ = model.predict(obs)
obs, reward, done, info = env.step(action)
info = dict(info[0])
print(info)
env.render()
time.sleep(0.01)
total_reward += reward
print(reward)
if done:
env.reset()
# ------------------------------------------------------------------------------------
if __name__ == '__main__':
main()
# ------------------------------------------------------------------------------------