-
Notifications
You must be signed in to change notification settings - Fork 1
/
dqnn_test.py
80 lines (71 loc) · 1.89 KB
/
dqnn_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import torch
from copy import deepcopy
import sys
from utilities.DQN_Agent import DQN_Agent
from game import game
import time
import matplotlib.pyplot as plt
def mean(x):
return sum(x)/len(x)
playouts = 500
model_name = sys.argv[1]
surv = []
X_reward = []
avg_x,avg_det = [],[]
det_reward = []
win_rate = []
iter = 1
X_agent = DQN_Agent()
X_agent.load_model("Model/"+(model_name))
total_steps = 0
while (iter <= playouts):
G = game()
step_rew, rew = 0, 0
while (not G.finish()):
print("Move No ", G.move, "\n")
# print(G.list_of_action_x())
(act,_) = X_agent.best_action(G)
state_action = G.f_x_action(act)
if (act[0] == 4):
mode = [act[0], act[1]]
target = [act[2]]
else:
mode = [act[0]]
target = [act[1]]
print(act)
G.take_action(target, "x", mode, 0)
total_steps += 1
print("\n")
for i in range(G.no_of_players):
G.take_action(None, "detective", [], i, "random")
#print("\n")
G.update_fv()
step_rew = G.X_reward - rew
rew = G.X_reward
next_state = deepcopy(G)
X_agent.add_to_memory(state_action, next_state, step_rew)
print(rew, step_rew)
iter += 1
if (G.move >= 20):
surv.append(1)
else:
surv.append(0)
win_rate.append(sum(surv) / len(surv))
X_reward.append(G.X_reward)
det_reward.append(G.D_reward/4)
avg_x.append(mean(X_reward))
avg_det.append(mean(det_reward))
plot1 = plt.figure(1)
plt.title("Win rate vs Episodes")
plt.xlabel("Episode")
plt.ylabel("Win rate ")
plt.plot(win_rate)
plt.savefig("Result/win_rate_"+sys.argv[1]+".png")
plot2 = plt.figure(2)
plt.title("X Reward vs Episodes")
plt.xlabel("Episode")
plt.ylabel("Reward ")
plt.plot(avg_x)
plt.plot(avg_det)
plt.legend(["X","Detectives"])
plt.savefig("Result/reward_"+sys.argv[1]+".png")