You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
num_episodes=300
num_exploration=200
max_len=400
batch_size=32
lr=1e-3
gamma=0.9
initial_epsilon=0.5
final_epsilon=0.01
replay_buffer=deque(maxlen=10000)
tarNet_update_frequence=10
optimizer=tf.train.AdamOptimizer(learning_rate=lr)
qNet=QNetwork()
for i in range(1,num_episodes+1):
state=env.reset()
epsilon=max(initial_epsilon*(num_exploration-i)/num_exploration,final_epsilon)
for t in range(max_len):#设置最大得分1000
if random.random()<epsilon:
action=env.action_space.sample()
else:
action=qNet.get_action(tf.constant(np.expand_dims(state,axis=0),dtype=tf.float32))
next_state,reward,done,info=env.step(action)
reward=-1.if done else reward
replay_buffer.append((state,action,reward,next_state,done))
state=next_state
if done:
print('episode %d,epsilon %f,score %d'%(i,epsilon,t))
break
if len(replay_buffer)>=batch_size:
batch_state,batch_action,batch_reward,batch_next_state,batch_done=
[np.array(a,dtype=np.float32) for a in zip(random.sample(replay_buffer,batch_size))]
q_value=qNet.tarNet_Q(tf.constant(batch_next_state,dtype=tf.float32))
y=batch_reward+(gammatf.reduce_max(q_value,axis=1))*(1-batch_done)
with tf.GradientTape() as tape:
loss=tf.losses.mean_squared_error(y,tf.reduce_max(
qNet(tf.constant(batch_state))*to_categorical(batch_action,num_classes=2),axis=1))
grads=tape.gradient(loss,qNet.variables[:4])
optimizer.apply_gradients(grads_and_vars=zip(grads,qNet.variables[:4]))
if i%tarNet_update_frequence==0:
for j in range(2):
tf.assign(qNet.variables[4+j],qNet.dense1.get_weights()[j])
tf.assign(qNet.variables[6+j],qNet.dense2.get_weights()[j])
env.close()
`
我觉得运行慢是因为复制网络参数的方式不对,请看到的兄弟姐妹给个建议。
The text was updated successfully, but these errors were encountered:
`import tensorflow as tf
import numpy as np
import gym
import random
from collections import deque
from keras.utils.np_utils import to_categorical
import tensorflow.keras.backend as K
class QNetwork(tf.keras.Model):
env=gym.make('CartPole-v0')
num_episodes=300
num_exploration=200
max_len=400
batch_size=32
lr=1e-3
gamma=0.9
initial_epsilon=0.5
final_epsilon=0.01
replay_buffer=deque(maxlen=10000)
tarNet_update_frequence=10
optimizer=tf.train.AdamOptimizer(learning_rate=lr)
qNet=QNetwork()
for i in range(1,num_episodes+1):
state=env.reset()
epsilon=max(initial_epsilon*(num_exploration-i)/num_exploration,final_epsilon)
for t in range(max_len):#设置最大得分1000
if random.random()<epsilon:
action=env.action_space.sample()
else:
action=qNet.get_action(tf.constant(np.expand_dims(state,axis=0),dtype=tf.float32))
next_state,reward,done,info=env.step(action)
reward=-1.if done else reward
replay_buffer.append((state,action,reward,next_state,done))
state=next_state
if done:
print('episode %d,epsilon %f,score %d'%(i,epsilon,t))
break
if len(replay_buffer)>=batch_size:
batch_state,batch_action,batch_reward,batch_next_state,batch_done=
[np.array(a,dtype=np.float32) for a in zip(random.sample(replay_buffer,batch_size))]
q_value=qNet.tarNet_Q(tf.constant(batch_next_state,dtype=tf.float32))
y=batch_reward+(gammatf.reduce_max(q_value,axis=1))*(1-batch_done)
with tf.GradientTape() as tape:
loss=tf.losses.mean_squared_error(y,tf.reduce_max(
qNet(tf.constant(batch_state))*to_categorical(batch_action,num_classes=2),axis=1))
grads=tape.gradient(loss,qNet.variables[:4])
optimizer.apply_gradients(grads_and_vars=zip(grads,qNet.variables[:4]))
if i%tarNet_update_frequence==0:
for j in range(2):
tf.assign(qNet.variables[4+j],qNet.dense1.get_weights()[j])
tf.assign(qNet.variables[6+j],qNet.dense2.get_weights()[j])
env.close()
`
我觉得运行慢是因为复制网络参数的方式不对,请看到的兄弟姐妹给个建议。
The text was updated successfully, but these errors were encountered: