forked from edward-zhu/umaru
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsolve.lua
117 lines (82 loc) · 2.09 KB
/
solve.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
require 'nn'
require 'rnn'
require 'GRU'
require 'image'
require 'optim'
require 'loader'
require 'ctc_log'
require 'utils.decoder'
require 'utils.levenshtein'
-- initialize
torch.setdefaulttensortype('torch.FloatTensor')
-- debug switch
DEBUG = false
-- timer initialize
base = 0
timer = torch.Timer()
function show_log(log)
local now = timer:time().real
local cost = now - base
base = now
-- print(string.format("[%.4f][%.4f]%s", now, cost, log))
print(string.format("%s", log))
end
-- settings
GPU_ENABLED = false
local input_size = 48
-- configuration
list_file = "wwr.txt"
using_model_file = "umaru_model_15-09-10_21:51:30_30000.uma"
using_codec = "full-train.codec"
-- GPU
if GPU_ENABLED then
require 'cutorch'
require 'cunn'
end
-- load samples
show_log("Loading samples...")
loader = Loader()
loader:load(list_file)
loader:targetHeight(input_size)
codec = loader:codec()
if using_codec then
codec = loader:loadCodec(using_codec)
end
show_log(string.format("Loading finished. Got %d samples, %d classes of characters.", #loader.samples, codec.codec_size))
local class_num = codec.codec_size
-- build network
show_log("Building networks...")
local net
if using_model_file then
net = torch.load(using_model_file)
net:evaluate()
else
error("There must be a model file.")
end
if GPU_ENABLED then
net:cuda()
end
show_log(string.format("Start solving with model file: %s", using_model_file))
-- solving
local sample = loader:pickInSequential()
begin_time = timer:time().real
local dist, tmp_dist, out = 0, 0, 0
local len, tmp_len = 0, 0
while sample do
local im = sample.img
local target = codec:encode(sample.gt)
net:forget()
outputTable = net:forward(im)
out = decoder.best_path_decode(outputTable, codec)
tmp_dist = utf8.levenshtein(out, sample.gt)
tmp_len = utf8.len(sample.gt)
dist = dist + tmp_dist
len = len + tmp_len
print("")
show_log("FILE " .. sample.src)
show_log("TARGET " .. sample.gt)
show_log("OUTPUT " .. out)
show_log("DISTANCE " .. tmp_dist)
show_log("ERROR " .. string.format("%.2f%%", dist / len * 100))
sample = loader:pickInSequential()
end