-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrl_saveeEnv.py
79 lines (58 loc) · 2.34 KB
/
rl_saveeEnv.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import gym
import numpy as np
import pandas as pd
from Datastore import Datastore
from constants import EMOTIONS, NUM_MFCC, NO_features
from data import FeatureType
from data_versions import DataVersions
from hashing_util import get_hash
from savee_datastore import SAVEEDatastore
class SAVEEEnv(gym.Env):
metadata = {'render.modes': ['human']}
def __init__(self, data_version) -> None:
super().__init__()
self.itr = 0
self.X = []
self.Y = []
self.num_classes = len(EMOTIONS)
self.data_version = data_version
self.datastore: Datastore
if data_version == DataVersions.Vsavee:
self.datastore = SAVEEDatastore(FeatureType.MFCC)
self.set_data()
self.action_space = gym.spaces.Discrete(self.num_classes)
self.observation_space = gym.spaces.Box(-1, 1, [NUM_MFCC, NO_features])
self.data_hashes = pd.DataFrame(self.datastore.get_data_hash_list())
self.data_hashes['used'] = False
self.data_hashes.columns = ['file_hash', 'used']
def step(self, action):
assert self.action_space.contains(action)
reward = -0.1 + int(action == np.argmax(self.Y[self.itr]))
# reward = 1 if action == self.Y[self.itr] else -1
done = (len(self.X) - 2 <= self.itr)
next_state = self.X[self.itr + 1]
h = get_hash(next_state)
idx = self.data_hashes.index[self.data_hashes['file_hash'] == h]
self.data_hashes.at[idx, 'used'] = True
info = {
"ground_truth": np.argmax(self.Y[self.itr]),
"itr": self.itr,
"used_data_count": int(self.data_hashes[self.data_hashes['used'] == True].shape[0])
}
self.itr += 1
return next_state, reward, done, info
def render(self, mode='human'):
print("Not implemented \t i: {}".format(self.itr))
def reset(self):
self.itr = 0
self.set_data()
return self.X[self.itr]
def set_data(self):
self.X = []
self.Y = []
if self.data_version == DataVersions.Vsavee:
(x_train, y_train, y_gen_train), (x_test, y_emo_test, y_gen_test) = self.datastore.get_data()
# self.X = np.array([d[FeatureType.MFCC.name] for d in x_train])
assert len(x_train) == len(y_train)
self.X = x_train
self.Y = y_train