forked from oyxhust/CNN-LSTM-CTC-text-recognition
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcrnn_predictor.py
122 lines (110 loc) · 4.67 KB
/
crnn_predictor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
#!/usr/bin/env python2.7
# coding=utf-8
from __future__ import print_function
import sys, os
mxnet_root = os.path.expanduser("~/github/mxnet")
sys.path.append(os.path.join(mxnet_root, "amalgamation/python/"))
sys.path.append(os.path.join(mxnet_root, "python/"))
import argparse
from mxnet_predict import Predictor
import mxnet as mx
sys.path.append(os.path.dirname(os.path.abspath(__file__)) + '/symbol')
from crnn import crnn
import numpy as np
import cv2
import os
import matplotlib.pyplot as plt
class lstm_ocr_model(object):
# Keep Zero index for blank. (CTC request it)
def __init__(self, path_of_json, path_of_params, classes, data_shape, batch_size, num_label, num_hidden, num_lstm_layer):
super(lstm_ocr_model, self).__init__()
self.path_of_json = path_of_json
self.path_of_params = path_of_params
self.classes = classes
self.batch_size = batch_size
self.data_shape = data_shape
self.num_label = num_label
self.num_hidden = num_hidden
self.num_lstm_layer = num_lstm_layer
self.predictor = None
self.__init_ocr()
def __init_ocr(self):
init_c = [('l%d_init_c'%l, (self.batch_size, self.num_hidden)) for l in range(self.num_lstm_layer*2)]
init_h = [('l%d_init_h'%l, (self.batch_size, self.num_hidden)) for l in range(self.num_lstm_layer*2)]
init_states = init_c + init_h
all_shapes = [('data', (self.batch_size, 1, self.data_shape[1], self.data_shape[0]))] + init_states + [('label', (self.batch_size, self.num_label))]
all_shapes_dict = {}
for _shape in all_shapes:
all_shapes_dict[_shape[0]] = _shape[1]
self.predictor = Predictor(open(self.path_of_json).read(),
open(self.path_of_params).read(),
all_shapes_dict,dev_type="gpu", dev_id=0)
def forward_ocr(self, img):
img = cv2.cvtColor(img,cv2.COLOR_BGR2GRAY)
img = cv2.resize(img, self.data_shape)
img = img.reshape((1, self.data_shape[1], self.data_shape[0]))
img = np.multiply(img, 1/255.0)
inputs = dict(data=img)
nd_zero = np.ones((self.batch_size, self.num_hidden), dtype=np.float32)
for l in range(self.num_lstm_layer*2):
inputs['l%d_init_c'%l] = nd_zero
inputs['l%d_init_h'%l] = nd_zero
self.predictor.forward(**inputs)
prob = self.predictor.get_output(0)
print(prob.shape) ###
print(prob[:4,:4]) ###
label_list = []
for p in prob:
max_index = np.argsort(p)[::-1][0]
label_list.append(max_index)
return self.__get_string(label_list)
def __get_string(self, label_list):
# Do CTC label rule
# CTC cannot emit a repeated symbol on consecutive timesteps
ret = []
label_list2 = [0] + list(label_list)
for i in range(len(label_list)):
c1 = label_list2[i]
c2 = label_list2[i+1]
if c2 == 0 or c2 == c1:
continue
ret.append(c2)
# change to ascii
s = ''
for l in ret:
if l > 0 and l < (len(self.classes)+1):
c = self.classes[l-1]
else:
c = ''
s += c
return s
def parse_args():
parser = argparse.ArgumentParser(description='predictor')
parser.add_argument('--img', dest='img', help='which image to use',
default=os.path.join(os.getcwd(), 'data', 'demo', '20150105_14543723_Z.jpg'), type=str)
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
json_path = os.path.join(os.getcwd(), 'model', 'crnn_ctc-symbol.json')
param_path = os.path.join(os.getcwd(), 'model', 'crnn_ctc-0100.params')
num_label = 9 # Set your max length of label, add one more for blank
batch_size = 1
num_hidden = 256
num_lstm_layer = 2
data_shape = (100, 32)
classes = ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "A", "B", "C", "D", "E", "F", "G",
"H", "I", "J", "K", "L", "M", "N", "O", "P", "Q", "R", "S", "T", "U", "V", "W", "X", "Y", "Z"]
demo_img = args.img
_lstm_ocr_model = lstm_ocr_model(json_path, param_path, classes, data_shape, batch_size,
num_label, num_hidden, num_lstm_layer)
img = cv2.imread(demo_img)
#img = cv2.bitwise_not(img)
_str = _lstm_ocr_model.forward_ocr(img)
print('Result: ', _str)
plt.imshow(img)
plt.gca().text(0, 6.8,
'{:s} {:s}'.format("prediction", _str),
#bbox=dict(facecolor=colors[cls_id], alpha=0.5),
fontsize=12, color='red')
plt.show()