-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathutils.py
76 lines (64 loc) · 2.33 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import tensorflow as tf
import copy
import numpy as np
def logit(x):
return - tf.log(1. / x - 1.)
def clip_b4_exp(x):
return tf.clip_by_value(x,-10,10)
def gaussian_likelihood(x, mu, log_std):
return -0.5*(((x-mu)/(tf.exp(log_std)+1e-8))**2 + 2*log_std + np.log(2*np.pi))
def get_gaes(rs, ds, vs, next_vs, gamma, _lambda, normalize):
deltas = [r + gamma * (1 - d) * nv - v for r, d, nv, v in zip(rs, ds, next_vs, vs)]
deltas = np.stack(deltas)
gaes = copy.deepcopy(deltas)
for t in reversed(range(len(deltas) - 1)):
gaes[t] = gaes[t] + (1 - ds[t]) * gamma * _lambda * gaes[t + 1]
target = gaes + vs
if normalize:
gaes = (gaes - gaes.mean()) / (gaes.std() + 1e-8)
return gaes, target
def vec_data_augmentation(s, a, next_s):
s, a, next_s = map(np.copy, [s, a, next_s])
s_augs, a_augs, next_s_augs = [], [], []
for i in range(4):
if i>0:
quad = 1
while a[0]>=1/4 and quad<4:
a[0] -= 1/4
quad += 1
a[0] *= -1
if i%2 != quad%2 :
a[0] += 1/2
a[0] += ((quad-1)/4)
a[0] = a[0] -(a[0]>1) +(a[0]<0)
for j in range(i%2,8,2):
s[j] *= -1
for j in range(2):
if j>0:
s[4:6], s[6:8], next_s[4:6], next_s[6:8] = s[6:8], s[4:6], next_s[6:8], next_s[4:6]
s, a, next_s = map(np.copy, [s, a, next_s])
s_augs.append(s)
a_augs.append(a)
next_s_augs.append(next_s)
return s_augs, a_augs, next_s_augs
def vis_data_augmentation(s, a, next_s):
s, a, next_s = map(np.copy, [s, a, next_s])
s_augs, a_augs, next_s_augs = [], [], []
for i in range(4):
if i>0:
quad = 1
while a[0]>=1/4 and quad<4:
a[0] -= 1/4
quad += 1
a[0] *= -1
if i%2 != quad%2 :
a[0] += 1/2
a[0] += ((quad-1)/4)
a[0] = a[0] -(a[0]>1) +(a[0]<0)
s = np.flipud(s) if i%2 else np.fliplr(s)
next_s = np.flipud(next_s) if i%2 else np.fliplr(next_s)
s, a, next_s = map(np.copy, [s, a, next_s])
s_augs.append(s)
a_augs.append(a)
next_s_augs.append(next_s)
return s_augs, a_augs, next_s_augs