-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain_dqn_basic.py
62 lines (55 loc) · 2.5 KB
/
main_dqn_basic.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
from matplotlib import pyplot as plt
from env.doom_enviroment import *
from networks.doom_network_dqn import *
from torch import optim
from agents.doom_agent_dqn import *
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
case = 'testing'
scenario = 'basic'
episodes = 200
update = 5
starting_eps = 0.9
ending_eps = 0.05
total_steps = 7000
batch_size = 256
capacity = 10000
gamma = 0.95
lr = 1e-3
frames = 4
update_target = 5
update_policy = 1
update_eval = 10
update_save = 50
if case == 'training':
doom_env = VizDoomEnv(render=True)
policy = DoomDQN(doom_env.get_num_actions(),device).to(device)
target = DoomDQN(doom_env.get_num_actions(),device).to(device)
target.load_state_dict(policy.state_dict())
optimizer = optim.RMSprop(policy.parameters(),lr=lr)
doom_agent = DoomAgentDQN(starting_eps=starting_eps, ending_eps=ending_eps, env=doom_env, policy=policy, target=target,
total_steps=total_steps, batch_size= batch_size,capacity=capacity,device=device,gamma=gamma,
optimizer=optimizer,update_target=update_target,episodes=episodes,frames=frames,
update_policy=update_policy,update_eval=update_eval,update_save=update_save)
rewards = doom_agent.learn()
torch.save(doom_agent.policy.state_dict(), "./checkpoint/policy_dqn_{}_{}_{}_{}_ft.pth".format(lr,episodes,batch_size,scenario))
plt.plot(range(update_eval,episodes+1,update_eval),rewards)
plt.title("Reward")
plt.xlabel("Episodes")
plt.ylabel("Reward")
plt.savefig("Reward_basic.png")
plt.show()
else:
doom_env = VizDoomEnv(render=True)
policy = DoomDQN(doom_env.get_num_actions(), device).to(device)
policy.load_state_dict(torch.load("./final_models/basic.pth"))
target = DoomDQN(doom_env.get_num_actions(), device).to(device)
target.load_state_dict(policy.state_dict())
optimizer = optim.RMSprop(policy.parameters(), lr=lr)
doom_agent = DoomAgentDQN(starting_eps=starting_eps, ending_eps=ending_eps, env=doom_env, policy=policy,
target=target,
total_steps=total_steps, batch_size=batch_size, capacity=capacity, device=device,
gamma=gamma,
optimizer=optimizer, update_target=update_target, episodes=episodes, frames=frames,
update_policy=update_policy, update_eval=update_eval, update_save=update_save)
reward = doom_agent.eval()
print(reward)