-
Notifications
You must be signed in to change notification settings - Fork 0
/
cnn_ctc.py
69 lines (59 loc) · 2.53 KB
/
cnn_ctc.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
from fastapi import FastAPI
from fastapi.responses import PlainTextResponse
import tensorflow as tf
import numpy as np
import requests
app = FastAPI()
model = tf.lite.Interpreter(model_path='models/cnn_ctc/model_quantized.tflite', num_threads=4)
model_3d = tf.lite.Interpreter(model_path='models/cnn_ctc_3d/model_quantized.tflite', num_threads=4)
characters = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
char_to_num = tf.keras.layers.StringLookup(vocabulary=characters, mask_token=None)
num_to_char = tf.keras.layers.StringLookup(vocabulary=char_to_num.get_vocabulary(), mask_token=None, invert=True)
def preprocess_image(img_content, img_height, img_width):
img = tf.io.decode_png(img_content, channels=1)
img = tf.image.convert_image_dtype(img, tf.float32)
img = tf.image.resize(img, [img_height, img_width])
img = tf.transpose(img, perm=[1, 0, 2])
img = tf.expand_dims(img, axis=0)
return img
def decode_batch_predictions(pred):
input_len = np.ones(pred.shape[0]) * pred.shape[1]
results = tf.keras.backend.ctc_decode(pred, input_length=input_len, greedy=True)[0][0][:, :6]
output_text = []
for res in results:
res = tf.strings.reduce_join(num_to_char(res)).numpy().decode("utf-8")
output_text.append(res)
return output_text
def decode_batch_predictions_3d(pred):
input_len = np.ones(pred.shape[0]) * pred.shape[1]
results = tf.keras.backend.ctc_decode(pred, input_length=input_len, greedy=True)[0][0][:, :3]
output_text = []
for res in results:
res = tf.strings.reduce_join(num_to_char(res)).numpy().decode("utf-8")
output_text.append(res)
return output_text
@app.get('/', response_class=PlainTextResponse)
async def uwu(url: str, d: int | None = None):
response = requests.get(url)
imgs = None
preds = None
preds_texts = None
if d == 3:
imgs = preprocess_image(response.content, 30, 42)
input_details = model_3d.get_input_details()
output_details = model_3d.get_output_details()
model_3d.allocate_tensors()
model_3d.set_tensor(input_details[0]['index'], imgs)
model_3d.invoke()
preds = model_3d.get_tensor(output_details[0]['index'])
preds_texts = decode_batch_predictions_3d(preds)
else:
imgs = preprocess_image(response.content, 32, 104)
input_details = model.get_input_details()
output_details = model.get_output_details()
model.allocate_tensors()
model.set_tensor(input_details[0]['index'], imgs)
model.invoke()
preds = model.get_tensor(output_details[0]['index'])
preds_texts = decode_batch_predictions(preds)
return preds_texts[0]