Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

您好,可以帮我看看这个Nature-DQN的代码吗?运行特别慢 #11

Open
miss-fang opened this issue Oct 21, 2019 · 0 comments

Comments

@miss-fang
Copy link

miss-fang commented Oct 21, 2019

`import tensorflow as tf
import numpy as np
import gym
import random
from collections import deque
from keras.utils.np_utils import to_categorical

import tensorflow.keras.backend as K

class QNetwork(tf.keras.Model):

def __init__(self):
    super().__init__()
    self.dense1=tf.keras.layers.Dense(24,activation='relu')
    self.dense2=tf.keras.layers.Dense(2)
    self.dense3=tf.keras.layers.Dense(24,activation='relu')
    self.dense4=tf.keras.layers.Dense(2)

def call(self,inputs):
    x=self.dense1(inputs)
    x=self.dense2(x)
    return x

def tarNet_Q(self,inputs):
    x=self.dense3(inputs)
    x=self.dense4(x)
    return x

def get_action(self,inputs):
    q_values=self(inputs)
    return K.eval(tf.argmax(q_values,axis=-1))[0]

env=gym.make('CartPole-v0')

num_episodes=300
num_exploration=200
max_len=400
batch_size=32
lr=1e-3
gamma=0.9
initial_epsilon=0.5
final_epsilon=0.01
replay_buffer=deque(maxlen=10000)
tarNet_update_frequence=10
optimizer=tf.train.AdamOptimizer(learning_rate=lr)
qNet=QNetwork()
for i in range(1,num_episodes+1):
state=env.reset()
epsilon=max(initial_epsilon*(num_exploration-i)/num_exploration,final_epsilon)
for t in range(max_len):#设置最大得分1000
if random.random()<epsilon:
action=env.action_space.sample()
else:
action=qNet.get_action(tf.constant(np.expand_dims(state,axis=0),dtype=tf.float32))
next_state,reward,done,info=env.step(action)
reward=-1.if done else reward
replay_buffer.append((state,action,reward,next_state,done))
state=next_state
if done:
print('episode %d,epsilon %f,score %d'%(i,epsilon,t))
break
if len(replay_buffer)>=batch_size:
batch_state,batch_action,batch_reward,batch_next_state,batch_done=
[np.array(a,dtype=np.float32) for a in zip(random.sample(replay_buffer,batch_size))]
q_value=qNet.tarNet_Q(tf.constant(batch_next_state,dtype=tf.float32))
y=batch_reward+(gamma
tf.reduce_max(q_value,axis=1))*(1-batch_done)
with tf.GradientTape() as tape:
loss=tf.losses.mean_squared_error(y,tf.reduce_max(
qNet(tf.constant(batch_state))*to_categorical(batch_action,num_classes=2),axis=1))
grads=tape.gradient(loss,qNet.variables[:4])
optimizer.apply_gradients(grads_and_vars=zip(grads,qNet.variables[:4]))
if i%tarNet_update_frequence==0:
for j in range(2):
tf.assign(qNet.variables[4+j],qNet.dense1.get_weights()[j])
tf.assign(qNet.variables[6+j],qNet.dense2.get_weights()[j])
env.close()
`
我觉得运行慢是因为复制网络参数的方式不对,请看到的兄弟姐妹给个建议。

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant