Skip to content

Commit

Permalink
Initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
bjliaa committed Apr 7, 2020
1 parent 52e11c6 commit b827e9d
Show file tree
Hide file tree
Showing 44 changed files with 12,695 additions and 0 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
# ecc
Distributional Reinforcement Learning with Ensembles
## 40M Evaluation of 4 Atari 2600 Games
![atari](images/atari.png)
Empty file added accord/__init__.py
Empty file.
Empty file added accord/agents/__init__.py
Empty file.
166 changes: 166 additions & 0 deletions accord/agents/ampdist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
import tensorflow as tf
from accord.agents.models import Probnet, Ampnet


class AmpDistAgent(tf.Module):
def __init__(self,
action_len,
vsize=5,
dense=512,
supportsize=51,
vmin=-10.0,
vmax=10.0,
starteps=1.0,
lr=1e-4,
adameps=1.5e-4,
name="distagent"):
super(AmpDistAgent, self).__init__(name=name)

self.action_len = action_len
self.optimizer = tf.keras.optimizers.Adam(learning_rate=lr,
epsilon=adameps)
self.losses = tf.keras.losses.KLDivergence(
reduction=tf.keras.losses.Reduction.NONE)
self.kldloss = tf.keras.losses.KLDivergence()

with tf.name_scope("probnet"):
self.probnet = Probnet(action_len=action_len,
dense=dense,
supportsize=supportsize)

with tf.name_scope("tnet"):
self.ampnet = Ampnet(action_len=action_len,
vsize=vsize,
dense=dense)

with tf.name_scope("selfnet"):
self.selfnet = Probnet(action_len=action_len,
dense=dense,
supportsize=supportsize)

self.supp = tf.constant(tf.linspace(vmin, vmax, supportsize),
shape=(supportsize, 1))
self.dz = tf.constant((vmax - vmin) / (supportsize - 1))
self.vmin = tf.constant(vmin)
self.vmax = tf.constant(vmax)
self.supportsize = tf.constant(supportsize, dtype=tf.int32)
self.eps = tf.Variable(starteps, trainable=False, name="epsilon")

@tf.function
def eps_greedy_action(self, state, epsval):
self.eps.assign(epsval)
dice = (tf.random.uniform([1], minval=0, maxval=1, dtype=tf.float32) <
self.eps)
raction = tf.random.uniform([1],
minval=0,
maxval=self.action_len,
dtype=tf.int64)
qaction = tf.argmax(self.qvalues(state))
return tf.where(dice, raction, qaction)

@tf.function
def amp_action(self, state, epsval):
self.eps.assign(epsval)
dice = (tf.random.uniform([1], minval=0, maxval=1, dtype=tf.float32) <
self.eps)
raction = tf.random.uniform([1],
minval=0,
maxval=self.action_len,
dtype=tf.int64)
qaction = tf.argmax(self.t_qvalues(state))
return tf.where(dice, raction, qaction)

@tf.function
def probvalues(self, states):
return tf.squeeze(self.probnet(states))

@tf.function
def qvalues(self, states):
ds = self.probnet(states)
return tf.squeeze(tf.matmul(ds, self.supp))

@tf.function
def t_probvalues(self, states):
return self.ampnet(states)

@tf.function
def t_qvalues(self, states):
ds = self.ampnet(states)
return tf.squeeze(tf.matmul(ds, self.supp))

@tf.function
def s_probvalues(self, states):
return self.selfnet(states)

@tf.function
def s_qvalues(self, states):
ds = self.selfnet(states)
return tf.squeeze(tf.matmul(ds, self.supp))

# @tf.function
def update_target(self, wlst):
self.ampnet.update(wlst)

@tf.function
def update_self(self):
q_vars = self.probnet.trainable_variables
t_vars = self.selfnet.trainable_variables
for var_q, var_t in zip(q_vars, t_vars):
var_t.assign(var_q)

@tf.function
def train(self, states, actions, drews, gexps, endstates, dones):
with tf.GradientTape() as tape:
batch_size = tf.shape(states)[0]
brange = tf.range(0, batch_size)
indices = tf.stack([brange, actions], axis=1)
chosen_dists = tf.gather_nd(self.probvalues(states), indices)

end_actions = tf.cast(tf.argmax(self.t_qvalues(endstates), axis=1),
dtype=tf.int32)

indices = tf.stack([brange, end_actions], axis=1)
chosen_end_dists = tf.gather_nd(self.t_probvalues(endstates),
indices)

dmask = (1.0 - dones) * gexps
Tzs = tf.clip_by_value(drews + dmask * self.supp, self.vmin,
self.vmax)
Tzs = tf.transpose(Tzs)
bs = (Tzs - self.vmin) / self.dz

ls = tf.cast(tf.floor(bs), tf.int32)
us = tf.cast(tf.math.ceil(bs), tf.int32)
condl = tf.cast(
tf.cast((us > 0), tf.float32) * tf.cast(
(us == ls), tf.float32), tf.bool)
condu = tf.cast(
tf.cast((ls < self.supportsize - 1), tf.float32) * tf.cast(
(us == ls), tf.float32), tf.bool)
ls = tf.where(condl, ls - 1, ls)
us = tf.where(condu, us + 1, us)

luprob = (tf.cast(us, tf.float32) - bs) * chosen_end_dists
lshot = tf.one_hot(ls, self.supportsize)
ml = tf.einsum('aj,ajk->ak', luprob, lshot)
ulprob = (bs - tf.cast(ls, tf.float32)) * chosen_end_dists
ushot = tf.one_hot(us, self.supportsize)
mu = tf.einsum('aj,ajk->ak', ulprob, ushot)

target = ml + mu
losses = self.losses(target, chosen_dists)

# Kullback–Leibler divergence
loss = self.kldloss(tf.stop_gradient(target), chosen_dists)

gradients = tape.gradient(loss, self.probnet.trainable_variables)
gradients, _ = tf.clip_by_global_norm(gradients, 10.0)
self.optimizer.apply_gradients(
zip(gradients, self.probnet.trainable_variables))
return losses

def save(self, filestr):
self.probnet.save_weights(filestr)

def load(self, filestr):
self.probnet.load_weights(filestr)
130 changes: 130 additions & 0 deletions accord/agents/distributional.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
import tensorflow as tf
from accord.agents.models import Probnet


class DistAgent(tf.Module):
def __init__(self,
action_len,
dense=512,
supportsize=51,
vmin=-10.0,
vmax=10.0,
starteps=1.0,
lr=1e-4,
adameps=1.5e-4,
name="distagent"):
super(DistAgent, self).__init__(name=name)

self.action_len = action_len
self.optimizer = tf.keras.optimizers.Adam(learning_rate=lr,
epsilon=adameps)
self.losses = tf.keras.losses.KLDivergence(
reduction=tf.keras.losses.Reduction.NONE)
self.kldloss = tf.keras.losses.KLDivergence()

with tf.name_scope("probnet"):
self.probnet = Probnet(action_len, dense, supportsize)
with tf.name_scope("target_probnet"):
self.targetnet = Probnet(action_len, dense, supportsize)

self.supp = tf.constant(tf.linspace(vmin, vmax, supportsize),
shape=(supportsize, 1))
self.dz = tf.constant((vmax - vmin) / (supportsize - 1))
self.vmin = tf.constant(vmin)
self.vmax = tf.constant(vmax)
self.supportsize = tf.constant(supportsize, dtype=tf.int32)
self.eps = tf.Variable(starteps, trainable=False, name="epsilon")

@tf.function
def eps_greedy_action(self, state, epsval):
self.eps.assign(epsval)
dice = (tf.random.uniform([1], minval=0, maxval=1, dtype=tf.float32) <
self.eps)
raction = tf.random.uniform([1],
minval=0,
maxval=self.action_len,
dtype=tf.int64)
qaction = tf.argmax(self.qvalues(state))
return tf.where(dice, raction, qaction)

@tf.function
def probvalues(self, states):
return tf.squeeze(self.probnet(states))

@tf.function
def qvalues(self, states):
ds = self.probnet(states)
return tf.squeeze(tf.matmul(ds, self.supp))

@tf.function
def t_probvalues(self, states):
return tf.squeeze(self.targetnet(states))

@tf.function
def t_qvalues(self, states):
ds = self.targetnet(states)
return tf.squeeze(tf.matmul(ds, self.supp))

@tf.function
def update_target(self):
q_vars = self.probnet.trainable_variables
t_vars = self.targetnet.trainable_variables
for var_q, var_t in zip(q_vars, t_vars):
var_t.assign(var_q)

@tf.function
def train(self, states, actions, drews, gexps, endstates, dones):
with tf.GradientTape() as tape:
batch_size = tf.shape(states)[0]
brange = tf.range(0, batch_size)
indices = tf.stack([brange, actions], axis=1)
chosen_dists = tf.gather_nd(self.probvalues(states), indices)

end_actions = tf.cast(tf.argmax(self.t_qvalues(endstates), axis=1),
dtype=tf.int32)

indices = tf.stack([brange, end_actions], axis=1)
chosen_end_dists = tf.gather_nd(self.t_probvalues(endstates),
indices)

dmask = (1.0 - dones) * gexps
Tzs = tf.clip_by_value(drews + dmask * self.supp, self.vmin,
self.vmax)
Tzs = tf.transpose(Tzs)
bs = (Tzs - self.vmin) / self.dz

ls = tf.cast(tf.floor(bs), tf.int32)
us = tf.cast(tf.math.ceil(bs), tf.int32)
condl = tf.cast(
tf.cast((us > 0), tf.float32) * tf.cast(
(us == ls), tf.float32), tf.bool)
condu = tf.cast(
tf.cast((ls < self.supportsize - 1), tf.float32) * tf.cast(
(us == ls), tf.float32), tf.bool)
ls = tf.where(condl, ls - 1, ls)
us = tf.where(condu, us + 1, us)

luprob = (tf.cast(us, tf.float32) - bs) * chosen_end_dists
lshot = tf.one_hot(ls, self.supportsize)
ml = tf.einsum('aj,ajk->ak', luprob, lshot)
ulprob = (bs - tf.cast(ls, tf.float32)) * chosen_end_dists
ushot = tf.one_hot(us, self.supportsize)
mu = tf.einsum('aj,ajk->ak', ulprob, ushot)

target = ml + mu
losses = self.losses(target, chosen_dists)

# Kullback–Leibler divergence
loss = self.kldloss(tf.stop_gradient(target), chosen_dists)

gradients = tape.gradient(loss, self.probnet.trainable_variables)
gradients, _ = tf.clip_by_global_norm(gradients, 10.0)
self.optimizer.apply_gradients(
zip(gradients, self.probnet.trainable_variables))
return losses

def save(self, filestr):
self.probnet.save_weights(filestr)

def load(self, filestr):
self.probnet.load_weights(filestr)
Loading

0 comments on commit b827e9d

Please sign in to comment.