-
Notifications
You must be signed in to change notification settings - Fork 86
/
Copy pathparam_util.py
60 lines (53 loc) · 1.6 KB
/
param_util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
"""
Utility for model parameter
"""
import os
try:
from cPickle import load
except ImportError:
from pickle import load
class Params(object):
pass
def load_dcnn_model_params(path, param_str = None):
"""
>>> p = load_dcnn_model_params("models/filter_widths=8,6,,batch_size=10,,ks=20,8,,fold=1,1,,conv_layer_n=2,,ebd_dm=48,,l2_regs=1e-06,1e-06,1e-06,0.0001,,dr=0.5,0.5,,nkerns=7,12.pkl")
>>> p.ks
(20, 8)
>>> len(p.W)
2
>>> type(p.logreg_W)
<type 'numpy.ndarray'>
"""
if param_str is None:
param_str = os.path.basename(path).split('.')[0]
p = parse_param_string(param_str)
stuff = load(open(path, "r"))
for name, value in stuff:
if not hasattr(p, name):
setattr(p, name, value)
else:
# if appear multiple times,
# make it a list
setattr(p, name, [getattr(p, name), value])
return p
def parse_param_string(s, desired_fields = {"ks", "fold", "conv_layer_n"}):
"""
>>> p = parse_param_string("twitter4,,filter_widths=8,6,,batch_size=10,,ks=20,8,,fold=1,1,,conv_layer_n=2,,ebd_dm=48,,l2_regs=1e-06,1e-06,1e-06,0.0001,,dr=0.5,0.5,,nkerns=7,12")
>>> p.ks
(20, 8)
>>> p.fold
(1, 1)
>>> p.conv_layer_n
2
"""
p = Params()
segs = s.split(',,')
for s in segs:
if "=" in s:
key, value = s.split('=')
if key in desired_fields:
if not ',' in value:
setattr(p, key, int(value))
else:
setattr(p, key, tuple(map(int, value.split(','))))
return p