-
Notifications
You must be signed in to change notification settings - Fork 1
/
dqnx_vs_dqndet.py
93 lines (88 loc) · 2.3 KB
/
dqnx_vs_dqndet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import torch
from copy import deepcopy
import sys
import os
from utilities.DQN_Det import DQN_Det
from utilities.DQN_Agent import DQN_Agent
from game import game
import time
import matplotlib.pyplot as plt
playouts = int(sys.argv[1])
model_name = str(sys.argv[3])
lr = float(sys.argv[2])
surv = []
win_rate_X = []
iter = 1
D_agent = DQN_Det(lr)
X_agent = DQN_Agent(lr)
total_steps = 0
while(iter<=playouts):
G = game()
step_rew,rew=0,0
step_rew_X,rew_X=0,0
while(not G.finish()):
#print("Move No ", G.move ,iter)
#print(G.list_of_action_x())
act = D_agent.train_action(G)
state_action = G.f_d_action(act)
act_X = X_agent.train_action(G)
# print(act_X)
state_action_X = G.f_x_action(act_X)
if(act_X[0] == 4):
mode = [act_X[0],act_X[1]]
target = [act_X[2]]
else:
mode = [act_X[0]]
target = [act_X[1]]
# print(act)
G.take_action(target,"x",mode,0)
for i in range(4):
if act[0][i]!=None:
G.take_action([act[1][i]],"detective",[act[0][i]],i)
total_steps+=1
G.update_fv()
step_rew = G.D_reward-rew
rew = G.D_reward
step_rew_X = G.X_reward-rew_X
rew_X = G.X_reward
next_state = deepcopy(G)
D_agent.add_to_memory(state_action,next_state,step_rew)
X_agent.add_to_memory(state_action_X,next_state,step_rew_X)
#print(rew,step_rew)
if(total_steps%D_agent.batch_size == 0):
print("Replaying ... ")
D_agent.replay()
if(total_steps%X_agent.batch_size == 0):
print("Replaying ... ")
X_agent.replay()
print(iter)
iter+=1
if(G.move>=20):
surv.append(1)
else:
surv.append(0)
win_rate_X.append(sum(surv)/len(surv))
if(iter%100 == 0):
D_agent.save_model("Model/"+model_name+"_det")
X_agent.save_model("Model/"+model_name+"_X")
plot1 = plt.figure(1)
#print(D_agent.loss)
plt.plot(X_agent.loss)
plt.title("X Loss vs Episodes")
plt.xlabel("Episode")
plt.ylabel("MSE Loss for X in Q value")
plt.savefig("Result\\loss_X_adv.png")
plot2 = plt.figure(2)
plt.title("Win rate of X vs Episodes")
plt.xlabel("Episode")
plt.ylabel("Win rate of X")
plt.plot(win_rate_X)
plt.savefig("Result\\win_rate_X_adv.png")
plot3 = plt.figure(3)
plt.plot(D_agent.loss)
plt.title("Det Loss vs Episodes")
plt.xlabel("Episode")
plt.ylabel("MSE Loss for Dets in Q value")
plt.savefig("Result\\loss_dets_adv.png")
#G.print_pos()
print(len(X_agent.memory),X_agent.loss)