-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathadder_env.py
166 lines (135 loc) · 6.04 KB
/
adder_env.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
from constants import *
import nengo
import numpy as np
from random import shuffle
import logging
logging.basicConfig(filename='env.log', level=logging.DEBUG)
def create_adder_env(q_list, q_norm_list, ans_list, op_val, num_vocab, ans_dur=0.3):
with nengo.Network(label="env") as env:
env.env_cls = AdderEnv(q_list, q_norm_list, ans_list, op_val, num_vocab, ans_dur)
env.get_ans = nengo.Node(env.env_cls.get_answer)
env.set_ans = nengo.Node(env.env_cls.set_answer, size_in=D)
env.env_keys = nengo.Node(env.env_cls.input_func)
env.env_norm_keys = nengo.Node(env.env_cls.input_func_normed)
env.op_in = nengo.Node(env.env_cls.op_state_input)
env.q_in = nengo.Node(env.env_cls.q_inputs)
env.learning = nengo.Node(lambda t: env.env_cls.learning)
env.count_reset = nengo.Node(lambda t: -env.env_cls.learning - 1)
return env
class AdderEnv(object):
def __init__(self, q_list, q_norm_list, ans_list, op_val, num_vocab, ans_dur, filename="paper5_reactions.txt"):
## Bunch of time constants
self.rest = 0.05
self.ans_duration = ans_dur
self.q_duration = 0.08
self.op_duration = 0.05
## Value variables
self.list_index = 0
self.q_list = q_list
self.q_norm_list = q_norm_list
self.ans_list = ans_list
self.op_val = op_val
self.num_items = len(q_list)
self.indices = range(self.num_items)
self.num_vocab = num_vocab
## Timing variables
self.learning = -1
self.ans_arrive = 0.0
self.time = 0.0
self.train = False
self.reset = False
# For measuring progress
self.questions_answered = 0
# For detecting a crash
self.time_since_last_answer = 0.0
# Logging for reaction times
self.fi = open("data/%s" %filename, "w")
self.react_time = open("data/paper5_react_time.txt", "w")
def sp_text(self, x):
return self.num_vocab.text(x).split(';')[0].split(".")[1][2:]
# TODO: These functions should be combined as a closure
def input_func(self, t):
if self.time > self.rest and not self.reset:
return self.q_list[self.indices[self.list_index]]
else:
return np.zeros(2*D)
def input_func_normed(self, t):
if self.time > self.rest and not self.reset:
return self.q_norm_list[self.indices[self.list_index]]
else:
return np.zeros(2*D)
def q_inputs(self, t):
if self.time > self.rest and self.time < (self.q_duration + self.rest):
return self.q_list[self.indices[self.list_index]]
else:
return np.zeros(2*D)
def op_state_input(self, t):
if self.time > self.rest and self.time < (self.op_duration + self.rest):
return self.op_val
else:
return np.zeros(less_D)
def get_answer(self, t):
if t < (self.ans_arrive + self.ans_duration) and self.ans_arrive != 0.0:
return self.ans_list[self.indices[self.list_index]]
else:
return np.zeros(D)
def set_answer(self, t, x):
"""Time keeping function.
if there's some sort of answer coming from the basal-ganglia,
detected by the norm not being (effectively) zero, give feedback for
a certain amount of time before resetting the answer and starting the
system again
this is basically a temporally sensitive state machine, however
I don't know of any state machine libraries for Python, so this is
what you get instead...
WHY DO I HAVE SUCH A HARD TIME WRITING STATE MACHINES
"""
self.time += dt
self.time_since_last_answer += dt
max_sim = np.max(np.dot(self.num_vocab.vectors, x))
# when an answer arrives, note it's time of arrival and turn on learning
if max_sim > 0.45 and self.ans_arrive == 0.0 and not self.train:
self.ans_arrive = t
self.learning = 0
self.train = True
# check the answer is correct
correct_text = self.sp_text(self.ans_list[self.indices[self.list_index]])
ans_text = self.sp_text(x)
self.react_time.write("%s\n" % self.time_since_last_answer)
self.time_since_last_answer = 0.0
self.questions_answered += 1
q_ans = self.q_list[self.indices[self.list_index]]
addend_1 = self.sp_text(q_ans[:D])
addend_2 = self.sp_text(q_ans[D:])
print("Answered %s+%s" % (addend_1, addend_2))
print("Answered %s questions at %s\n" % (self.questions_answered, t))
self.fi.write("Question answered %s at %s\n" % (self.questions_answered, t))
self.fi.write("max_sim: %s\n" % max_sim)
if correct_text != ans_text:
logging.debug("%s != %s" %(correct_text, ans_text))
print("%s != %s\n" %(correct_text, ans_text))
self.fi.write("Error: %s\n" % t)
# sustain the answer for training purposes
# after we're done sustaining the answer
# turn of the learning and the answer arrival time
# wait until the similarity goes down before asking for a new question
if t > (self.ans_arrive + self.ans_duration) and self.train:
if not self.reset:
self.ans_arrive = 0.0
self.learning = -1
self.reset = True
self.fi.write("Turning off: %s\n" % t)
if max_sim < 0.1:
self.fi.write("Next question: %s\n" % t)
self.fi.write("max_sim: %s\n" % max_sim)
self.time = 0.0
self.train = False
self.reset = False
if self.list_index < self.num_items - 1:
self.list_index += 1
print("Increment:: %s" %self.list_index)
else:
print("Shuffling\n")
self.fi.write("Shuffling\n")
shuffle(self.indices)
self.list_index = 0