-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathml.py
59 lines (53 loc) · 1.64 KB
/
ml.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
'''
使用sklearn中的逻辑回归模型识别数字和运算符号,并保存模型
'''
import os
import cv2
import numpy as np
import pickle
from sklearn.linear_model import LogisticRegression
import shutil
def load_train_data():
"""加载训练数据"""
res = []
c = []
for root,dir,file in os.walk('TrainChar'):
if len(file) != 0:
_class = root.split(os.path.sep)[-1]
if _class.isdigit():
__class = int(_class)
elif _class == '+':
__class = 10
elif _class == '-':
__class = 11
elif _class == '=':
__class = 12
for f in file:
img = cv2.imread(os.path.join(root, f), 0)
if img is None or img.shape != (60,30):
continue
res.append(np.array(img).reshape(1,-1).tolist()[0])
c.append(__class)
res = np.array(res)
res[res == 255] = 1
return res,c
def dumpModel():
"""保存模型到lr.pickle文件中"""
train_data, train_target = load_train_data()
l = LogisticRegression(class_weight='balanced')
l.fit(train_data,train_target)
#保存模型
with open('lr.pickle', 'wb') as fw:
pickle.dump(l, fw)
print('保存模型完毕')
#清空 TrainChar文件夹以便重新导入训练字符
def cleanTrainChar():
shutil.rmtree('TrainChar')
os.mkdir('TrainChar')
for num in range(10):
os.mkdir(os.path.join("TrainChar", num))
for op in ['+', '-', '=']:
os.mkdir(os.path.join("TrainChar", op))
print('done')
if __name__ == '__main__':
dumpModel()