-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathmodel_utils.py
59 lines (41 loc) · 1.65 KB
/
model_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
"""Model utility file."""
import tensorflow as tf
slim = tf.contrib.slim
def get_init_fn(scopes, init_model):
"""Initialize assigment operator function used while training."""
if not init_model:
return None
for var in tf.trainable_variables():
if not (var in tf.model_variables()):
tf.contrib.framework.add_model_variable(var)
is_trainable = lambda x: x in tf.trainable_variables()
var_list = []
for scope in scopes:
var_list.extend(
filter(is_trainable, tf.contrib.framework.get_model_variables(scope)))
init_assign_op, init_feed_dict = slim.assign_from_checkpoint(
init_model, var_list)
def init_assign_function(sess):
sess.run(init_assign_op, init_feed_dict)
return init_assign_function
def get_train_op_for_scope(loss, optimizer, scopes, clip_gradient_norm):
"""Train operation function for the given scope used for training."""
for var in tf.trainable_variables():
if not (var in tf.model_variables()):
tf.contrib.framework.add_model_variable(var)
is_trainable = lambda x: x in tf.trainable_variables()
var_list = []
update_ops = []
for scope in scopes:
var_list.extend(
filter(is_trainable, tf.contrib.framework.get_model_variables(scope)))
update_ops.extend(tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope))
for var in tf.contrib.framework.get_model_variables(scope):
print('%s\t%s' % (scope, var))
#print('Trainable parameters %s' % tf.contrib.framework.get_model_variables(scope))
return slim.learning.create_train_op(
loss,
optimizer,
update_ops=update_ops,
variables_to_train=var_list,
clip_gradient_norm=clip_gradient_norm)