-
Notifications
You must be signed in to change notification settings - Fork 40
/
Copy pathinit_weights.py
43 lines (33 loc) · 1.06 KB
/
init_weights.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
#!/usr/bin/env python3
# -*- coding: UTF-8 -*-
"""
Run this code to build and save tensorflow model with corresponding weight values for VNect
"""
import os
import tensorflow as tf
from src.caffe2pkl import caffe2pkl
from src.vnect_model import VNect
def init_tf_weights(pfile, spath, model):
# configurations
PARAMSFILE = pfile
SAVERPATH = spath
if not tf.gfile.Exists(SAVERPATH):
tf.gfile.MakeDirs(SAVERPATH)
with tf.Session() as sess:
saver = tf.train.Saver()
model.load_weights(sess, PARAMSFILE)
saver.save(sess, os.path.join(SAVERPATH, 'vnect_tf'))
# caffe model basepath
caffe_bpath = './models/caffe_model'
# caffe model files
prototxt_name = 'vnect_net.prototxt'
caffemodel_name = 'vnect_model.caffemodel'
# pickle file name
pkl_name = 'params.pkl'
pkl_file = os.path.join(caffe_bpath, pkl_name)
# tensorflow model path
tf_save_path = './models/tf_model'
if not os.path.exists(pkl_file):
caffe2pkl(caffe_bpath, prototxt_name, caffemodel_name, pkl_name)
model = VNect()
init_tf_weights(pkl_file, tf_save_path, model)