-
Notifications
You must be signed in to change notification settings - Fork 169
/
wit.py
315 lines (284 loc) · 12 KB
/
wit.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import ipywidgets as widgets
import tensorflow as tf
from IPython.core.display import display, HTML
from ipywidgets import Layout
from traitlets import Dict
from traitlets import Int
from traitlets import List
from traitlets import observe
from traitlets import Unicode
from traitlets import Set
from witwidget.notebook import base
@widgets.register
class WitWidget(widgets.DOMWidget, base.WitWidgetBase):
"""WIT widget for Jupyter."""
_view_name = Unicode('WITView').tag(sync=True)
_view_module = Unicode('wit-widget').tag(sync=True)
_view_module_version = Unicode('^0.1.0').tag(sync=True)
# Traitlets for communicating between python and javascript.
config = Dict(dict()).tag(sync=True)
# Examples and inferences are not synced directly because large datasets cause
# websocket issues. Instead, we indirectly update them through batch updates.
examples = List([])
frontend_ready = Int(0).tag(sync=True)
examples_batch = List([]).tag(sync=True)
examples_batch_id = Int(0).tag(sync=True)
inferences = Dict(dict())
inferences_batch = Dict(dict()).tag(sync=True)
inferences_batch_id = Int(0).tag(sync=True)
infer = Int(0).tag(sync=True)
infer_counter = Int(0)
update_example = Dict(dict()).tag(sync=True)
delete_example = Dict(dict()).tag(sync=True)
duplicate_example = Dict(dict()).tag(sync=True)
updated_example_indices = Set(set())
get_eligible_features = Int(0).tag(sync=True)
sort_eligible_features = Dict(dict()).tag(sync=True)
eligible_features = List([]).tag(sync=True)
infer_mutants = Dict(dict()).tag(sync=True)
mutant_charts = Dict([]).tag(sync=True)
mutant_charts_counter = Int(0)
sprite = Unicode('').tag(sync=True)
error = Dict(dict()).tag(sync=True)
compute_custom_distance = Dict(dict()).tag(sync=True)
custom_distance_dict = Dict(dict()).tag(sync=True)
def __init__(self, config_builder, height=1000, delay_rendering=False):
"""Constructor for Jupyter notebook WitWidget.
Args:
config_builder: WitConfigBuilder object containing settings for WIT.
height: Optional height in pixels for WIT to occupy. Defaults to 1000.
delay_rendering. Optional. This argument is ignored in the Jupyter
implementation but is included for API compatibility between Colab and
Jupyter implementations. Rendering in Jupyter is always delayed until
the render method is called or the WitWidget object is directly evaluated
in a notebook cell.
"""
self.transfer_block = False
self.examples_generator = None
self.inferences_generator = None
# TODO(wit-dev) This should depend on the example size targeting less than
# 10MB per batch to avoid websocket issues.
self.batch_size = 10000
widgets.DOMWidget.__init__(self, layout=Layout(height='%ipx' % height))
base.WitWidgetBase.__init__(self, config_builder)
self.error_counter = 0
# Ensure the visualization takes all available width.
display(HTML("<style>.container { width:100% !important; }</style>"))
def render(self):
"""Render the widget to the display."""
return self
def set_examples(self, examples):
if self.transfer_block:
print('Cannot set examples while transfer is in progress.')
return
base.WitWidgetBase.set_examples(self, examples)
self.examples_generator = self.generate_next_example_batch()
# If this is called after frontend is ready this makes sure examples are
# updated.
self._start_examples_sync()
self._generate_sprite()
def generate_next_example_batch(self):
n_examples = len(self.examples)
n_batches = n_examples // self.batch_size
batch_end = n_batches * self.batch_size
num_remaining = n_batches + (batch_end!=n_examples)
for i in range(n_batches):
num_remaining -= 1
yield self.examples[i*self.batch_size:(i+1)*self.batch_size], num_remaining
if batch_end != n_examples:
num_remaining -= 1
yield self.examples[batch_end:], num_remaining
def _report_error(self, err):
self.error = {
'msg': repr(err),
'counter': self.error_counter
}
self.error_counter += 1
def _start_examples_sync(self):
if not self.frontend_ready or self.examples_generator is None or self.transfer_block:
return
# Send the first batch
next_batch, self.examples_batch_id = next(self.examples_generator, ([], -1))
self.transfer_block = True
self.examples_batch = next_batch
@observe('infer')
def _infer(self, change):
try:
self.inferences = base.WitWidgetBase.infer_impl(self)
self.inferences_generator = self.generate_next_inference_batch()
self._start_inferences_sync()
except Exception as e:
self._report_error(e)
def generate_next_inference_batch(self):
# Parse out the inferences from the returned structure and empty the
# structure of contents, keeping its nested structure.
# Chunks of the inference results will be sent to the front-end and
# re-assembled.
indices = self.inferences['inferences']['indices'][:]
self.inferences['inferences']['indices'] = []
res2 = []
extra = {}
extra2 = {}
model_inference = self.inferences['inferences']['results'][0]
if ('extra_outputs' in self.inferences and len(self.inferences['extra_outputs']) and
self.inferences['extra_outputs'][0]):
for key in self.inferences['extra_outputs'][0]:
extra[key] = self.inferences['extra_outputs'][0][key][:]
self.inferences['extra_outputs'][0][key] = []
if 'classificationResult' in model_inference:
res = model_inference['classificationResult']['classifications'][:]
model_inference['classificationResult']['classifications'] = []
else:
res = model_inference['regressionResult']['regressions'][:]
model_inference['regressionResult']['regressions'] = []
if len(self.inferences['inferences']['results']) > 1:
if ('extra_outputs' in self.inferences and
len(self.inferences['extra_outputs']) > 1 and
self.inferences['extra_outputs'][1]):
for key in self.inferences['extra_outputs'][1]:
extra2[key] = self.inferences['extra_outputs'][1][key][:]
self.inferences['extra_outputs'][1][key] = []
model_2_inference = self.inferences['inferences']['results'][1]
if 'classificationResult' in model_2_inference:
res2 = model_2_inference['classificationResult']['classifications'][:]
model_2_inference['classificationResult']['classifications'] = []
else:
res2 = model_2_inference['regressionResult']['regressions'][:]
model_2_inference['regressionResult']['regressions'] = []
num_pieces = math.ceil(len(indices) / self.batch_size)
i = 0
while num_pieces > 0:
num_pieces -= 1
piece = [res[i : i + self.batch_size]]
extra_piece = [{}]
for key in extra:
extra_piece[0][key] = extra[key][i : i + self.batch_size]
if res2:
piece.append(res2[i : i + self.batch_size])
extra_piece.append({})
for key in extra2:
extra_piece[1][key] = extra2[key][i : i + self.batch_size]
ind_piece = indices[i : i + self.batch_size]
data = {'results': piece, 'indices': ind_piece, 'extra': extra_piece,
'counter': self.infer_counter}
self.infer_counter += 1
# For the first segment to send, also send the blank inferences
# structure to be filled in. This was cleared of contents above but is
# used to maintain the nested structure of the results.
if i == 0:
data['inferences'] = self.inferences
i += self.batch_size
yield data, num_pieces
def _start_inferences_sync(self):
if self.inferences_generator is None or self.transfer_block:
return
# Send the first batch
next_batch, self.inferences_batch_id = next(
self.inferences_generator, ({}, -1))
self.transfer_block = True
self.inferences_batch = next_batch
# When frontend processes sent inferences, it updates batch id to request the
# next batch
@observe('inferences_batch_id')
def _send_inferences_batch(self, change):
if not self.transfer_block:
return
# Do not trigger at the end of a transfer.
if self.inferences_batch_id < 0 or self.inferences_generator is None:
self.transfer_block = False
return
self.inferences_batch, num_remaining = next(
self.inferences_generator, ({}, -1))
if num_remaining == 0:
self.inferences_generator = None
self.transfer_block = False
# Finish setup items that require frontend to be ready.
@observe('frontend_ready')
def _finish_setup(self, change):
# Start examples transfer
self._start_examples_sync()
# When frontend processes sent examples, it updates batch id to request the
# next batch
@observe('examples_batch_id')
def _send_example_batch(self, change):
if not self.transfer_block:
return
# Do not trigger at the end of a transfer.
if self.examples_batch_id < 0 or self.examples_generator is None:
self.transfer_block = False
return
self.examples_batch, num_remaining = next(self.examples_generator, ([], -1))
if num_remaining == 0:
self.examples_generator = None
self.transfer_block = False
# Observer callbacks for changes from javascript.
@observe('get_eligible_features')
def _get_eligible_features(self, change):
features_list = base.WitWidgetBase.get_eligible_features_impl(self)
self.eligible_features = features_list
@observe('sort_eligible_features')
def _sort_eligible_features(self, change):
info = self.sort_eligible_features
features_list = base.WitWidgetBase.sort_eligible_features_impl(self, info)
self.eligible_features = features_list
@observe('infer_mutants')
def _infer_mutants(self, change):
info = self.infer_mutants
try:
json_mapping = base.WitWidgetBase.infer_mutants_impl(self, info)
json_mapping['counter'] = self.mutant_charts_counter
self.mutant_charts_counter += 1
self.mutant_charts = json_mapping
except Exception as e:
self._report_error(e)
@observe('update_example')
def _update_example(self, change):
index = self.update_example['index']
self.updated_example_indices.add(index)
self.examples[index] = self.update_example['example']
self._generate_sprite()
@observe('duplicate_example')
def _duplicate_example(self, change):
self.examples.append(self.examples[self.duplicate_example['index']])
self.updated_example_indices.add(len(self.examples) - 1)
self._generate_sprite()
@observe('delete_example')
def _delete_example(self, change):
index = self.delete_example['index']
self.examples.pop(index)
self.updated_example_indices = set([
i if i < index else i - 1 for i in self.updated_example_indices])
self._generate_sprite()
@observe('compute_custom_distance')
def _compute_custom_distance(self, change):
info = self.compute_custom_distance
index = info['index']
params = info['params']
callback_fn = info['callback']
try:
distances = base.WitWidgetBase.compute_custom_distance_impl(self, index,
params['distanceParams'])
self.custom_distance_dict = {'distances': distances,
'exInd': index,
'funId': callback_fn,
'params': params['callbackParams']}
except Exception as e:
self._report_error(e)
def _generate_sprite(self):
sprite = base.WitWidgetBase.create_sprite(self)
if sprite is not None:
self.sprite = sprite