-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathutil.py
23 lines (21 loc) · 834 Bytes
/
util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import numpy as np
from rl.core import Processor
from rl.util import WhiteningNormalizer
from sklearn.preprocessing import MinMaxScaler, StandardScaler
ADDITIONAL_STATE = 5
class Normalizerprocessor(Processor):
def __init__(self):
self.scaler = StandardScaler()
self.normalizer = None
def state_batch_process(self, batch):
batch_len = batch.shape[0]
kernel = []
for i in range(batch_len):
observe = batch[i][..., :-ADDITIONAL_STATE]
observe = self.scaler.fit_transform(observe)
agent_state = batch[i][..., ADDITIONAL_STATE:]
temp = np.concatenate((observe, agent_state), axis=1)
temp = temp.reshape((1,) + temp.shape)
kernel.append(temp)
batch = np.concatenate(tuple(kernel))
return batch