From 89d5d0603aebe5da842c18a701f6082386170b82 Mon Sep 17 00:00:00 2001 From: xiaotinghe Date: Mon, 26 Feb 2024 09:12:20 +0000 Subject: [PATCH] layout analysis --- .../layout-analysis/model/Dockerfile.lambda | 26 ++ .../model/Dockerfile.sagemaker | 27 ++ .../layout-analysis/model/aikits_utils.py | 50 +++ .../layout-analysis/model/imaug/__init__.py | 35 ++ .../layout-analysis/model/imaug/operators.py | 209 ++++++++++ .../layout-analysis/model/imaug/table_ops.py | 229 +++++++++++ .../layout-analysis/model/infer_layout_app.py | 49 +++ .../layout-analysis/model/layout.py | 68 ++++ src/containers/layout-analysis/model/main.py | 247 +++++++++++ .../layout-analysis/model/matcher.py | 298 ++++++++++++++ src/containers/layout-analysis/model/ocr.py | 383 ++++++++++++++++++ .../model/postprocess/__init__.py | 27 ++ .../model/postprocess/cls_postprocess.py | 15 + .../model/postprocess/db_postprocess.py | 139 +++++++ .../model/postprocess/rec_postprocess.py | 216 ++++++++++ .../model/postprocess/table_postprocess.py | 120 ++++++ .../layout-analysis/model/requirements.txt | 13 + .../layout-analysis/model/sm_predictor.py | 39 ++ src/containers/layout-analysis/model/table.py | 158 ++++++++ src/containers/layout-analysis/model/utils.py | 168 ++++++++ src/containers/layout-analysis/model/xycut.py | 125 ++++++ 21 files changed, 2641 insertions(+) create mode 100644 src/containers/layout-analysis/model/Dockerfile.lambda create mode 100644 src/containers/layout-analysis/model/Dockerfile.sagemaker create mode 100644 src/containers/layout-analysis/model/aikits_utils.py create mode 100644 src/containers/layout-analysis/model/imaug/__init__.py create mode 100644 src/containers/layout-analysis/model/imaug/operators.py create mode 100644 src/containers/layout-analysis/model/imaug/table_ops.py create mode 100644 src/containers/layout-analysis/model/infer_layout_app.py create mode 100644 src/containers/layout-analysis/model/layout.py create mode 100644 src/containers/layout-analysis/model/main.py create mode 100644 src/containers/layout-analysis/model/matcher.py create mode 100644 src/containers/layout-analysis/model/ocr.py create mode 100644 src/containers/layout-analysis/model/postprocess/__init__.py create mode 100644 src/containers/layout-analysis/model/postprocess/cls_postprocess.py create mode 100644 src/containers/layout-analysis/model/postprocess/db_postprocess.py create mode 100644 src/containers/layout-analysis/model/postprocess/rec_postprocess.py create mode 100644 src/containers/layout-analysis/model/postprocess/table_postprocess.py create mode 100644 src/containers/layout-analysis/model/requirements.txt create mode 100644 src/containers/layout-analysis/model/sm_predictor.py create mode 100644 src/containers/layout-analysis/model/table.py create mode 100644 src/containers/layout-analysis/model/utils.py create mode 100644 src/containers/layout-analysis/model/xycut.py diff --git a/src/containers/layout-analysis/model/Dockerfile.lambda b/src/containers/layout-analysis/model/Dockerfile.lambda new file mode 100644 index 00000000..ac530847 --- /dev/null +++ b/src/containers/layout-analysis/model/Dockerfile.lambda @@ -0,0 +1,26 @@ +FROM public.ecr.aws/lambda/python:3.9 + +ARG FUNCTION_DIR="/opt/program" +ARG MODEL_URL="https://aws-gcr-solutions-assets.s3.cn-northwest-1.amazonaws.com.cn/ai-solution-kit/layout-analysis" + +ARG MODEL_VERSION="1.4.0" + +ADD / ${FUNCTION_DIR}/ + +RUN pip3 install -r ${FUNCTION_DIR}/requirements.txt +RUN pip3 install --target ${FUNCTION_DIR} awslambdaric + +RUN mkdir -p ${FUNCTION_DIR}/model +RUN yum install -y wget unzip +RUN wget -c ${MODEL_URL}/${MODEL_VERSION}/layout_weight.zip -O ${FUNCTION_DIR}/model/layout_weight.zip +RUN unzip ${FUNCTION_DIR}/model/layout_weight.zip -d ${FUNCTION_DIR}/model/ + +WORKDIR ${FUNCTION_DIR} +ENV PYTHONUNBUFFERED=TRUE +ENV PYTHONDONTWRITEBYTECODE=TRUE +ENV PYTHONIOENCODING="utf8" +ENV MODEL_NAME="standard" +ENV MODEL_PATH="${FUNCTION_DIR}/model/" + +ENTRYPOINT [ "python3", "-m", "awslambdaric" ] +CMD [ "infer_layout_app.handler" ] \ No newline at end of file diff --git a/src/containers/layout-analysis/model/Dockerfile.sagemaker b/src/containers/layout-analysis/model/Dockerfile.sagemaker new file mode 100644 index 00000000..26599aa9 --- /dev/null +++ b/src/containers/layout-analysis/model/Dockerfile.sagemaker @@ -0,0 +1,27 @@ +FROM nvidia/cuda:11.7.1-cudnn8-runtime-ubuntu20.04 + +RUN apt update \ + && DEBIAN_FRONTEND=noninteractive apt install -y --no-install-recommends unzip build-essential wget python3 python3-pip \ + && ln -sf python3 /usr/bin/python \ + && ln -sf pip3 /usr/bin/pip \ + && pip install --upgrade pip \ + && pip install wheel setuptools + +ARG FUNCTION_DIR="/opt/ml/code/" +ARG MODEL_DIR="/opt/ml/model/" +ENV MODEL_PATH=${MODEL_DIR} + +ARG LAYOUT_MODEL_URL="https://xiaotih.seal.ac.cn" +RUN mkdir -p ${MODEL_DIR} && wget -c $LAYOUT_MODEL_URL/layout_weight.zip -O ${MODEL_DIR}/layout_weight.zip +RUN unzip ${MODEL_DIR}/layout_weight.zip -d ${MODEL_DIR} && rm -rf ${MODEL_DIR}/layout_weight.zip + +ADD / ${FUNCTION_DIR}/ + +RUN pip3 install -r ${FUNCTION_DIR}/requirements.txt +WORKDIR ${FUNCTION_DIR} +ENV PYTHONUNBUFFERED=TRUE +ENV PYTHONDONTWRITEBYTECODE=TRUE +ENV PYTHONIOENCODING="utf8" + +# Command can be overwritten by providing a different command in the template directly. +ENTRYPOINT ["python", "sm_predictor.py"] \ No newline at end of file diff --git a/src/containers/layout-analysis/model/aikits_utils.py b/src/containers/layout-analysis/model/aikits_utils.py new file mode 100644 index 00000000..5ba1f5dc --- /dev/null +++ b/src/containers/layout-analysis/model/aikits_utils.py @@ -0,0 +1,50 @@ +from io import BytesIO +import boto3 +import base64 +import numpy as np +from PIL import Image +import cv2 +try: + import urllib.request as urllib2 + from urllib.parse import urlparse +except ImportError: + import urllib2 + from urlparse import urlparse + +def readimg(body, keys=None): + if keys is None: + keys = body.keys() + inputs = dict() + for key in keys: + try: + if key.startswith('url'): # url形式 + if body[key].startswith('http'): # http url + image_string = urllib2.urlopen(body[key]).read() + elif body[key].startswith('s3'): # s3 key + o = urlparse(body[key]) + bucket = o.netloc + path = o.path.lstrip('/') + s3 = boto3.resource('s3') + img_obj = s3.Object(bucket, path) + image_string = img_obj.get()['Body'].read() + else: + raise + elif key.startswith('img'): # base64形式 + image_string = base64.b64decode(body[key]) + else: + raise + inputs[key] = np.array(Image.open(BytesIO(image_string)).convert('RGB'))[:, :, :3] + except: + inputs[key] = None + return inputs + +def lambda_return(statusCode, body): + return { + 'statusCode': statusCode, + 'headers': { + 'Access-Control-Allow-Headers': '*', + 'Access-Control-Allow-Origin': '*', + 'Access-Control-Allow-Methods': '*' + }, + 'body': body + } \ No newline at end of file diff --git a/src/containers/layout-analysis/model/imaug/__init__.py b/src/containers/layout-analysis/model/imaug/__init__.py new file mode 100644 index 00000000..461a2473 --- /dev/null +++ b/src/containers/layout-analysis/model/imaug/__init__.py @@ -0,0 +1,35 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals +from .operators import * +from .table_ops import * +def transform(data, ops=None): + """ transform """ + if ops is None: + ops = [] + for op in ops: + data = op(data) + if data is None: + return None + return data + + +def create_operators(op_param_list, global_config=None): + """ + create operators based on the config + Args: + params(list): a dict list, used to create some operators + """ + assert isinstance(op_param_list, list), ('operator config should be a list') + ops = [] + for operator in op_param_list: + assert isinstance(operator, + dict) and len(operator) == 1, "yaml format error" + op_name = list(operator)[0] + param = {} if operator[op_name] is None else operator[op_name] + if global_config is not None: + param.update(global_config) + op = eval(op_name)(**param) + ops.append(op) + return ops \ No newline at end of file diff --git a/src/containers/layout-analysis/model/imaug/operators.py b/src/containers/layout-analysis/model/imaug/operators.py new file mode 100644 index 00000000..93a8eabe --- /dev/null +++ b/src/containers/layout-analysis/model/imaug/operators.py @@ -0,0 +1,209 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import sys +import six +import cv2 +import numpy as np + + +class DecodeImage(object): + """ decode image """ + + def __init__(self, img_mode='RGB', channel_first=False, **kwargs): + self.img_mode = img_mode + self.channel_first = channel_first + + def __call__(self, data): + img = data['image'] + if six.PY2: + assert type(img) is str and len( + img) > 0, "invalid input 'img' in DecodeImage" + else: + assert type(img) is bytes and len( + img) > 0, "invalid input 'img' in DecodeImage" + img = np.frombuffer(img, dtype='uint8') + img = cv2.imdecode(img, 1) + if img is None: + return None + if self.img_mode == 'GRAY': + img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + elif self.img_mode == 'RGB': + assert img.shape[2] == 3, 'invalid shape of image[%s]' % (img.shape) + img = img[:, :, ::-1] + + if self.channel_first: + img = img.transpose((2, 0, 1)) + + data['image'] = img + return data + + +class NormalizeImage(object): + """ normalize image such as substract mean, divide std + """ + + def __init__(self, scale=None, mean=None, std=None, order='chw', **kwargs): + if isinstance(scale, str): + scale = eval(scale) + self.scale = np.float32(scale if scale is not None else 1.0 / 255.0) + mean = mean if mean is not None else [0.485, 0.456, 0.406] + std = std if std is not None else [0.229, 0.224, 0.225] + + shape = (3, 1, 1) if order == 'chw' else (1, 1, 3) + self.mean = np.array(mean).reshape(shape).astype('float32') + self.std = np.array(std).reshape(shape).astype('float32') + + def __call__(self, data): + img = data['image'] + from PIL import Image + if isinstance(img, Image.Image): + img = np.array(img) + + assert isinstance(img, + np.ndarray), "invalid input 'img' in NormalizeImage" + data['image'] = ( + img.astype('float32') * self.scale - self.mean) / self.std + return data + + +class ToCHWImage(object): + """ convert hwc image to chw image + """ + + def __init__(self, **kwargs): + pass + + def __call__(self, data): + img = data['image'] + from PIL import Image + if isinstance(img, Image.Image): + img = np.array(img) + data['image'] = img.transpose((2, 0, 1)) + return data + + +class KeepKeys(object): + def __init__(self, keep_keys, **kwargs): + self.keep_keys = keep_keys + + def __call__(self, data): + data_list = [] + for key in self.keep_keys: + data_list.append(data[key]) + return data_list + + +class DetResizeForTest(object): + def __init__(self, **kwargs): + super(DetResizeForTest, self).__init__() + self.resize_type = 0 + if 'image_shape' in kwargs: + self.image_shape = kwargs['image_shape'] + self.resize_type = 1 + elif 'limit_side_len' in kwargs: + self.limit_side_len = kwargs['limit_side_len'] + self.limit_type = kwargs.get('limit_type', 'min') + elif 'resize_long' in kwargs: + self.resize_type = 2 + self.resize_long = kwargs.get('resize_long', 960) + else: + self.limit_side_len = 736 + self.limit_type = 'min' + + def __call__(self, data): + img = data['image'] + src_h, src_w, _ = img.shape + + if self.resize_type == 0: + # img, shape = self.resize_image_type0(img) + img, [ratio_h, ratio_w] = self.resize_image_type0(img) + elif self.resize_type == 2: + img, [ratio_h, ratio_w] = self.resize_image_type2(img) + else: + # img, shape = self.resize_image_type1(img) + img, [ratio_h, ratio_w] = self.resize_image_type1(img) + data['image'] = img + data['shape'] = np.array([src_h, src_w, ratio_h, ratio_w]) + return data + + def resize_image_type1(self, img): + resize_h, resize_w = self.image_shape + ori_h, ori_w = img.shape[:2] # (h, w, c) + ratio_h = float(resize_h) / ori_h + ratio_w = float(resize_w) / ori_w + img = cv2.resize(img, (int(resize_w), int(resize_h))) + # return img, np.array([ori_h, ori_w]) + return img, [ratio_h, ratio_w] + + def resize_image_type0(self, img): + """ + resize image to a size multiple of 32 which is required by the network + args: + img(array): array with shape [h, w, c] + return(tuple): + img, (ratio_h, ratio_w) + """ + limit_side_len = self.limit_side_len + h, w, _ = img.shape + + # limit the max side + if self.limit_type == 'max': + if max(h, w) > limit_side_len: + if h > w: + ratio = float(limit_side_len) / h + else: + ratio = float(limit_side_len) / w + else: + ratio = 1. + else: + if min(h, w) < limit_side_len: + if h < w: + ratio = float(limit_side_len) / h + else: + ratio = float(limit_side_len) / w + else: + ratio = 1. + resize_h = int(h * ratio) + resize_w = int(w * ratio) + + resize_h = int(round(resize_h / 32) * 32) + resize_w = int(round(resize_w / 32) * 32) + + try: + if int(resize_w) <= 0 or int(resize_h) <= 0: + return None, (None, None) + img = cv2.resize(img, (int(resize_w), int(resize_h))) + except: + print(img.shape, resize_w, resize_h) + sys.exit(0) + ratio_h = resize_h / float(h) + ratio_w = resize_w / float(w) + # return img, np.array([h, w]) + return img, [ratio_h, ratio_w] + + def resize_image_type2(self, img): + h, w, _ = img.shape + + resize_w = w + resize_h = h + + # Fix the longer side + if resize_h > resize_w: + ratio = float(self.resize_long) / resize_h + else: + ratio = float(self.resize_long) / resize_w + + resize_h = int(resize_h * ratio) + resize_w = int(resize_w * ratio) + + max_stride = 128 + resize_h = (resize_h + max_stride - 1) // max_stride * max_stride + resize_w = (resize_w + max_stride - 1) // max_stride * max_stride + img = cv2.resize(img, (int(resize_w), int(resize_h))) + ratio_h = resize_h / float(h) + ratio_w = resize_w / float(w) + + return img, [ratio_h, ratio_w] \ No newline at end of file diff --git a/src/containers/layout-analysis/model/imaug/table_ops.py b/src/containers/layout-analysis/model/imaug/table_ops.py new file mode 100644 index 00000000..c2c2fb2b --- /dev/null +++ b/src/containers/layout-analysis/model/imaug/table_ops.py @@ -0,0 +1,229 @@ +""" +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import sys +import six +import cv2 +import numpy as np + + +class GenTableMask(object): + """ gen table mask """ + + def __init__(self, shrink_h_max, shrink_w_max, mask_type=0, **kwargs): + self.shrink_h_max = 5 + self.shrink_w_max = 5 + self.mask_type = mask_type + + def projection(self, erosion, h, w, spilt_threshold=0): + # 水平投影 + projection_map = np.ones_like(erosion) + project_val_array = [0 for _ in range(0, h)] + + for j in range(0, h): + for i in range(0, w): + if erosion[j, i] == 255: + project_val_array[j] += 1 + # 根据数组,获取切割点 + start_idx = 0 # 记录进入字符区的索引 + end_idx = 0 # 记录进入空白区域的索引 + in_text = False # 是否遍历到了字符区内 + box_list = [] + for i in range(len(project_val_array)): + if in_text == False and project_val_array[ + i] > spilt_threshold: # 进入字符区了 + in_text = True + start_idx = i + elif project_val_array[ + i] <= spilt_threshold and in_text == True: # 进入空白区了 + end_idx = i + in_text = False + if end_idx - start_idx <= 2: + continue + box_list.append((start_idx, end_idx + 1)) + + if in_text: + box_list.append((start_idx, h - 1)) + # 绘制投影直方图 + for j in range(0, h): + for i in range(0, project_val_array[j]): + projection_map[j, i] = 0 + return box_list, projection_map + + def projection_cx(self, box_img): + box_gray_img = cv2.cvtColor(box_img, cv2.COLOR_BGR2GRAY) + h, w = box_gray_img.shape + # 灰度图片进行二值化处理 + ret, thresh1 = cv2.threshold(box_gray_img, 200, 255, + cv2.THRESH_BINARY_INV) + # 纵向腐蚀 + if h < w: + kernel = np.ones((2, 1), np.uint8) + erode = cv2.erode(thresh1, kernel, iterations=1) + else: + erode = thresh1 + # 水平膨胀 + kernel = np.ones((1, 5), np.uint8) + erosion = cv2.dilate(erode, kernel, iterations=1) + # 水平投影 + projection_map = np.ones_like(erosion) + project_val_array = [0 for _ in range(0, h)] + + for j in range(0, h): + for i in range(0, w): + if erosion[j, i] == 255: + project_val_array[j] += 1 + # 根据数组,获取切割点 + start_idx = 0 # 记录进入字符区的索引 + end_idx = 0 # 记录进入空白区域的索引 + in_text = False # 是否遍历到了字符区内 + box_list = [] + spilt_threshold = 0 + for i in range(len(project_val_array)): + if in_text == False and project_val_array[ + i] > spilt_threshold: # 进入字符区了 + in_text = True + start_idx = i + elif project_val_array[ + i] <= spilt_threshold and in_text == True: # 进入空白区了 + end_idx = i + in_text = False + if end_idx - start_idx <= 2: + continue + box_list.append((start_idx, end_idx + 1)) + + if in_text: + box_list.append((start_idx, h - 1)) + # 绘制投影直方图 + for j in range(0, h): + for i in range(0, project_val_array[j]): + projection_map[j, i] = 0 + split_bbox_list = [] + if len(box_list) > 1: + for i, (h_start, h_end) in enumerate(box_list): + if i == 0: + h_start = 0 + if i == len(box_list): + h_end = h + word_img = erosion[h_start:h_end + 1, :] + word_h, word_w = word_img.shape + w_split_list, w_projection_map = self.projection(word_img.T, + word_w, word_h) + w_start, w_end = w_split_list[0][0], w_split_list[-1][1] + if h_start > 0: + h_start -= 1 + h_end += 1 + word_img = box_img[h_start:h_end + 1:, w_start:w_end + 1, :] + split_bbox_list.append([w_start, h_start, w_end, h_end]) + else: + split_bbox_list.append([0, 0, w, h]) + return split_bbox_list + + def shrink_bbox(self, bbox): + left, top, right, bottom = bbox + sh_h = min(max(int((bottom - top) * 0.1), 1), self.shrink_h_max) + sh_w = min(max(int((right - left) * 0.1), 1), self.shrink_w_max) + left_new = left + sh_w + right_new = right - sh_w + top_new = top + sh_h + bottom_new = bottom - sh_h + if left_new >= right_new: + left_new = left + right_new = right + if top_new >= bottom_new: + top_new = top + bottom_new = bottom + return [left_new, top_new, right_new, bottom_new] + + def __call__(self, data): + img = data['image'] + cells = data['cells'] + height, width = img.shape[0:2] + if self.mask_type == 1: + mask_img = np.zeros((height, width), dtype=np.float32) + else: + mask_img = np.zeros((height, width, 3), dtype=np.float32) + cell_num = len(cells) + for cno in range(cell_num): + if "bbox" in cells[cno]: + bbox = cells[cno]['bbox'] + left, top, right, bottom = bbox + box_img = img[top:bottom, left:right, :].copy() + split_bbox_list = self.projection_cx(box_img) + for sno in range(len(split_bbox_list)): + split_bbox_list[sno][0] += left + split_bbox_list[sno][1] += top + split_bbox_list[sno][2] += left + split_bbox_list[sno][3] += top + + for sno in range(len(split_bbox_list)): + left, top, right, bottom = split_bbox_list[sno] + left, top, right, bottom = self.shrink_bbox( + [left, top, right, bottom]) + if self.mask_type == 1: + mask_img[top:bottom, left:right] = 1.0 + data['mask_img'] = mask_img + else: + mask_img[top:bottom, left:right, :] = (255, 255, 255) + data['image'] = mask_img + return data + + +class ResizeTableImage(object): + def __init__(self, max_len, resize_bboxes=False, infer_mode=False, + **kwargs): + super(ResizeTableImage, self).__init__() + self.max_len = max_len + self.resize_bboxes = resize_bboxes + self.infer_mode = infer_mode + + def __call__(self, data): + img = data['image'] + height, width = img.shape[0:2] + ratio = self.max_len / (max(height, width) * 1.0) + resize_h = int(height * ratio) + resize_w = int(width * ratio) + resize_img = cv2.resize(img, (resize_w, resize_h)) + if self.resize_bboxes and not self.infer_mode: + data['bboxes'] = data['bboxes'] * ratio + data['image'] = resize_img + data['src_img'] = img + data['shape'] = np.array([height, width, ratio, ratio]) + data['max_len'] = self.max_len + return data + + +class PaddingTableImage(object): + def __init__(self, size, **kwargs): + super(PaddingTableImage, self).__init__() + self.size = size + + def __call__(self, data): + img = data['image'] + pad_h, pad_w = self.size + padding_img = np.zeros((pad_h, pad_w, 3), dtype=np.float32) + height, width = img.shape[0:2] + padding_img[0:height, 0:width, :] = img.copy() + data['image'] = padding_img + shape = data['shape'].tolist() + shape.extend([pad_h, pad_w]) + data['shape'] = np.array(shape) + return data diff --git a/src/containers/layout-analysis/model/infer_layout_app.py b/src/containers/layout-analysis/model/infer_layout_app.py new file mode 100644 index 00000000..02da649b --- /dev/null +++ b/src/containers/layout-analysis/model/infer_layout_app.py @@ -0,0 +1,49 @@ +import json +import time +from os import environ + +import cv2 +from aikits_utils import readimg, lambda_return + +from main import structure_predict + +if environ["MODEL_PATH"] is None: + environ["MODEL_PATH"] = "/opt/program/model/" + +def read_img(body): + if 'url' in body: + inputs = readimg(body, ['url']) + img = inputs['url'] + else: + inputs = readimg(body, ['img']) + img = inputs['img'] + for k, v in inputs.items(): + if v is None: + return str(k) + return img + +def handler(event, context): + start_time = time.time() + if "body" not in event: + return lambda_return(400, 'invalid param') + try: + if isinstance(event["body"], str): + body = json.loads(event["body"]) + else: + body = event["body"] + if 'url' in body and 'img' in body: + return lambda_return(400, '`url` and `img` cannot be used at the same time') + img = read_img(body) + if isinstance(img, str): + return lambda_return(400, f'`parameter `{img}` illegal') + img = img[:,:,::-1] + except: + return lambda_return(400, 'invalid param') + lang = body.get("lang", 'ch') + output_format = body.get("output_format", 'json') + table_format = body.get("table_format", 'html') + result = structure_predict(img, lang, output_format, table_format) + if 'duration' in body and body['duration']: + result.append({"duration": time.time() - start_time}) + return lambda_return(200, json.dumps(result)) + diff --git a/src/containers/layout-analysis/model/layout.py b/src/containers/layout-analysis/model/layout.py new file mode 100644 index 00000000..a223ba6b --- /dev/null +++ b/src/containers/layout-analysis/model/layout.py @@ -0,0 +1,68 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import sys +import cv2 +import numpy as np +import time + +from utils import preprocess, multiclass_nms, postprocess +import onnxruntime +import GPUtil +if len(GPUtil.getGPUs()): + provider = [("CUDAExecutionProvider", {"cudnn_conv_algo_search": "HEURISTIC"}), "CPUExecutionProvider"] + model = 'layout.onnx' +else: + provider = ["CPUExecutionProvider"] + model = 'layout_s.onnx' + +class LayoutPredictor(object): + def __init__(self): + self.ort_session = onnxruntime.InferenceSession(os.path.join(os.environ['MODEL_PATH'], model), providers=provider) + #_ = self.ort_session.run(['output'], {'images': np.zeros((1,3,640,640), dtype='float32')})[0] + self.categorys = ['text', 'title', 'figure', 'table'] + def __call__(self, img): + ori_im = img.copy() + + starttime = time.time() + + h,w,_ = img.shape + h_ori, w_ori, _ = img.shape + h, w = (640, 640) + image, ratio = preprocess(img, (h, w)) + res = self.ort_session.run(['output'], {'images': image[np.newaxis,:]})[0] + predictions = postprocess(res, (h, w), p6=False)[0] + boxes = predictions[:, :4] + + scores = predictions[:, 4, None] * predictions[:, 5:] + + boxes_xyxy = np.ones_like(boxes) + boxes_xyxy[:, 0] = boxes[:, 0] - boxes[:, 2]/2. + boxes_xyxy[:, 1] = boxes[:, 1] - boxes[:, 3]/2. + boxes_xyxy[:, 2] = boxes[:, 0] + boxes[:, 2]/2. + boxes_xyxy[:, 3] = boxes[:, 1] + boxes[:, 3]/2. + boxes_xyxy /= ratio + + dets = multiclass_nms(boxes_xyxy, scores, nms_thr=0.15, score_thr=0.3) + if dets is None: + return [], time.time() - starttime + scores = dets[:, 4] + final_cls_inds = dets[:, 5] + final_boxes = dets[:, :4]#.astype('int') + result = [] + for box_idx,box in enumerate(final_boxes): + result.append({'label': self.categorys[int(final_cls_inds[box_idx])], + 'bbox': box}) + elapse = time.time() - starttime + return result, elapse \ No newline at end of file diff --git a/src/containers/layout-analysis/model/main.py b/src/containers/layout-analysis/model/main.py new file mode 100644 index 00000000..219f0f66 --- /dev/null +++ b/src/containers/layout-analysis/model/main.py @@ -0,0 +1,247 @@ +import boto3 +import datetime +import json +import logging +import os +import re +import subprocess +from pathlib import Path + +from ocr import TextSystem +from table import TableSystem +from layout import LayoutPredictor +import numpy as np +from markdownify import markdownify as md +from utils import check_and_read +from xycut import recursive_xy_cut +import time +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +class StructureSystem(object): + def __init__(self): + self.mode = 'structure' + self.recovery = True + drop_score = 0 + # init model + self.layout_predictor = LayoutPredictor() + self.text_system = TextSystem() + self.table_system = TableSystem( + self.text_system.text_detector, + self.text_system.text_recognizer) + def __call__(self, img, return_ocr_result_in_table=False, lang='ch'): + time_dict = { + 'image_orientation': 0, + 'layout': 0, + 'table': 0, + 'table_match': 0, + 'det': 0, + 'rec': 0, + 'kie': 0, + 'all': 0 + } + start = time.time() + ori_im = img.copy() + layout_res, elapse = self.layout_predictor(img) + time_dict['layout'] += elapse + res_list = [] + for region in layout_res: + res = '' + if region['bbox'] is not None: + x1, y1, x2, y2 = region['bbox'] + x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2) + x1, y1, x2, y2 = max(x1, 0), max(y1, 0), max(x2, 0), max(y2, 0) + roi_img = ori_im[y1:y2, x1:x2, :] + else: + x1, y1, x2, y2 = 0, 0, w, h + roi_img = ori_im + if region['label'] == 'table': + res, table_time_dict = self.table_system( + roi_img, return_ocr_result_in_table, lang) + time_dict['table'] += table_time_dict['table'] + time_dict['table_match'] += table_time_dict['match'] + time_dict['det'] += table_time_dict['det'] + time_dict['rec'] += table_time_dict['rec'] + else: + wht_im = np.ones(ori_im.shape, dtype=ori_im.dtype) + wht_im[y1:y2, x1:x2, :] = roi_img + filter_boxes, filter_rec_res = self.text_system( + wht_im, lang) + + # remove style char, + # when using the recognition model trained on the PubtabNet dataset, + # it will recognize the text format in the table, such as + style_token = [ + '', '', '', '', '', + '', '', '', '', + '', '', '', '', + '' + ] + res = [] + for box, rec_res in zip(filter_boxes, filter_rec_res): + rec_str, rec_conf = rec_res + for token in style_token: + if token in rec_str: + rec_str = rec_str.replace(token, '') + if not self.recovery: + box += [x1, y1] + res.append({ + 'text': rec_str, + 'confidence': float(rec_conf), + 'text_region': box.tolist() + }) + res_list.append({ + 'type': region['label'].lower(), + 'bbox': [x1, y1, x2, y2], + 'img': roi_img, + 'res': res, + }) + end = time.time() + time_dict['all'] = end - start + return res_list, time_dict + +structure_engine = StructureSystem() + +def remove_symbols(text): + """ + Removes symbols from the given text using regular expressions. + + Args: + text (str): The input text. + + Returns: + str: The cleaned text with symbols removed. + """ + cleaned_text = re.sub(r"[^\w\s\u4e00-\u9fff]", "", text) + return cleaned_text + + +def structure_predict(img, lang, output_type=None, table_type='markdown') -> str: + + all_res = [] + result, _ = structure_engine(img, lang=lang) + if result != []: + boxes = [row["bbox"] for row in result] + res = [] + recursive_xy_cut(np.asarray(boxes).astype(int), np.arange(len(boxes)), res) + all_res = [result[idx] for idx in res] + if output_type=='json': + result = [] + for row in all_res: + if row['type'] == 'table': + if table_type == 'html': + region_text = row["res"]["html"] + else: + region_text = md( + row["res"]["html"], + strip=["b", "img"], + heading_style="ATX", + newline_style="BACKSLASH", + ) + else: + region_text = "" + for _, line in enumerate(row['res']): + region_text += line["text"] + (" " if lang == 'en' else '') + row = { + "BlockType": row['type'], + "Geometry": { + "BoundingBox": { + 'Width': row['bbox'][2]-row['bbox'][0], + 'Height': row['bbox'][3]-row['bbox'][1], + 'Left': row['bbox'][0], + 'Top': row['bbox'][1] + } + }, + "Text": region_text.strip() + } + result.append(row) + return result + doc = "" + prev_region_text = "" + + for _, region in enumerate(all_res): + if len(region["res"]) == 0: + continue + if region["type"].lower() == "figure": + region_text = "" + for _, line in enumerate(region["res"]): + region_text += line["text"] + elif region["type"].lower() == "title": + region_text = '' + for i, line in enumerate(region['res']): + region_text += line['text'] + '' + if remove_symbols(region_text) != remove_symbols(prev_region_text): + doc += '## ' + region_text + '\n\n' + prev_region_text = region_text + elif region["type"].lower() == "table": + if "" not in region["res"]["html"]: + region["res"]["html"] = ( + region["res"]["html"] + .replace("", "", 1) + .replace("", "", 1) + ) + if table_type == 'html': + doc += ( + region["res"]["html"] + "\n\n" + ) + else: + doc += ( + md( + region["res"]["html"], + strip=["b", "img"], + heading_style="ATX", + newline_style="BACKSLASH", + ) + + "\n\n" + ) + elif region["type"].lower() in ("header", "footer"): + continue + else: + region_text = "" + for _, line in enumerate(region["res"]): + region_text += line["text"] + " " + if remove_symbols(region_text) != remove_symbols(prev_region_text): + doc += region_text + prev_region_text = region_text + + doc += "\n\n" + doc = re.sub("\n{2,}", "\n\n", doc.strip()) + return {'Markdown': doc} + +def readimg(body, keys=None): + if keys is None: + keys = body.keys() + inputs = dict() + for key in keys: + try: + if key.startswith('url'): # url形式 + if body[key].startswith('http'): # http url + image_string = urllib2.urlopen(body[key]).read() + elif body[key].startswith('s3'): # s3 key + o = urlparse(body[key]) + bucket = o.netloc + path = o.path.lstrip('/') + s3 = boto3.resource('s3') + img_obj = s3.Object(bucket, path) + image_string = img_obj.get()['Body'].read() + else: + raise + elif key.startswith('img'): # base64形式 + image_string = base64.b64decode(body[key]) + else: + raise + inputs[key] = np.array(Image.open(BytesIO(image_string)).convert('RGB'))[:, :, :3] + except: + inputs[key] = None + return inputs + +if __name__ == "__main__": + body = { + "s3_bucket": "icyxu-llm-glue-assets", + "object_key": "test_data/test_glue_lib/cn_pdf/2023.ccl-2.6.pdf", + "destination_bucket": "llm-bot-document-results-icyxu", + "mode": "ppstructure", + "lang": "zh", + } + + print(process_pdf_pipeline(body)) diff --git a/src/containers/layout-analysis/model/matcher.py b/src/containers/layout-analysis/model/matcher.py new file mode 100644 index 00000000..290d03ba --- /dev/null +++ b/src/containers/layout-analysis/model/matcher.py @@ -0,0 +1,298 @@ +# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +def deal_bb(result_token): + """ + In our opinion, always occurs in text's context. + This function will find out all tokens in and insert by manual. + :param result_token: + :return: + """ + # find out parts. + thead_pattern = '(.*?)' + if re.search(thead_pattern, result_token) is None: + return result_token + thead_part = re.search(thead_pattern, result_token).group() + origin_thead_part = copy.deepcopy(thead_part) + + # check "rowspan" or "colspan" occur in parts or not . + span_pattern = "|||" + span_iter = re.finditer(span_pattern, thead_part) + span_list = [s.group() for s in span_iter] + has_span_in_head = True if len(span_list) > 0 else False + + if not has_span_in_head: + # not include "rowspan" or "colspan" branch 1. + # 1. replace to , and to + # 2. it is possible to predict text include or by Text-line recognition, + # so we replace to , and to + thead_part = thead_part.replace('', '')\ + .replace('', '')\ + .replace('', '')\ + .replace('', '') + else: + # include "rowspan" or "colspan" branch 2. + # Firstly, we deal rowspan or colspan cases. + # 1. replace > to > + # 2. replace to + # 3. it is possible to predict text include or by Text-line recognition, + # so we replace to , and to + + # Secondly, deal ordinary cases like branch 1 + + # replace ">" to "" + replaced_span_list = [] + for sp in span_list: + replaced_span_list.append(sp.replace('>', '>')) + for sp, rsp in zip(span_list, replaced_span_list): + thead_part = thead_part.replace(sp, rsp) + + # replace "" to "" + thead_part = thead_part.replace('', '') + + # remove duplicated by re.sub + mb_pattern = "()+" + single_b_string = "" + thead_part = re.sub(mb_pattern, single_b_string, thead_part) + + mgb_pattern = "()+" + single_gb_string = "" + thead_part = re.sub(mgb_pattern, single_gb_string, thead_part) + + # ordinary cases like branch 1 + thead_part = thead_part.replace('', '').replace('', + '') + + # convert back to , empty cell has no . + # but space cell( ) is suitable for + thead_part = thead_part.replace('', '') + # deal with duplicated + thead_part = deal_duplicate_bb(thead_part) + # deal with isolate span tokens, which causes by wrong predict by structure prediction. + # eg.PMC5994107_011_00.png + thead_part = deal_isolate_span(thead_part) + # replace original result with new thead part. + result_token = result_token.replace(origin_thead_part, thead_part) + return result_token +def deal_eb_token(master_token): + """ + post process with , , ... + emptyBboxTokenDict = { + "[]": '', + "[' ']": '', + "['', ' ', '']": '', + "['\\u2028', '\\u2028']": '', + "['', ' ', '']": '', + "['', '']": '', + "['', ' ', '']": '', + "['', '', '', '']": '', + "['', '', ' ', '', '']": '', + "['', '']": '', + "['', ' ', '\\u2028', ' ', '\\u2028', ' ', '']": '', + } + :param master_token: + :return: + """ + master_token = master_token.replace('', '') + master_token = master_token.replace('', ' ') + master_token = master_token.replace('', ' ') + master_token = master_token.replace('', '\u2028\u2028') + master_token = master_token.replace('', ' ') + master_token = master_token.replace('', '') + master_token = master_token.replace('', ' ') + master_token = master_token.replace('', + '') + master_token = master_token.replace('', + ' ') + master_token = master_token.replace('', '') + master_token = master_token.replace('', + ' \u2028 \u2028 ') + return master_token + + +def distance(box_1, box_2): + x1, y1, x2, y2 = box_1 + x3, y3, x4, y4 = box_2 + dis = abs(x3 - x1) + abs(y3 - y1) + abs(x4 - x2) + abs(y4 - y2) + dis_2 = abs(x3 - x1) + abs(y3 - y1) + dis_3 = abs(x4 - x2) + abs(y4 - y2) + return dis + min(dis_2, dis_3) + + +def compute_iou(rec1, rec2): + """ + computing IoU + :param rec1: (y0, x0, y1, x1), which reflects + (top, left, bottom, right) + :param rec2: (y0, x0, y1, x1) + :return: scala value of IoU + """ + # computing area of each rectangles + S_rec1 = (rec1[2] - rec1[0]) * (rec1[3] - rec1[1]) + S_rec2 = (rec2[2] - rec2[0]) * (rec2[3] - rec2[1]) + + # computing the sum_area + sum_area = S_rec1 + S_rec2 + + # find the each edge of intersect rectangle + left_line = max(rec1[1], rec2[1]) + right_line = min(rec1[3], rec2[3]) + top_line = max(rec1[0], rec2[0]) + bottom_line = min(rec1[2], rec2[2]) + + # judge if there is an intersect + if left_line >= right_line or top_line >= bottom_line: + return 0.0 + else: + intersect = (right_line - left_line) * (bottom_line - top_line) + return (intersect / (sum_area - intersect)) * 1.0 + + +class TableMatch: + def __init__(self, filter_ocr_result=False, use_master=False): + self.filter_ocr_result = filter_ocr_result + self.use_master = use_master + + def __call__(self, structure_res, dt_boxes, rec_res): + pred_structures, pred_bboxes = structure_res + if self.filter_ocr_result: + dt_boxes, rec_res = self._filter_ocr_result(pred_bboxes, dt_boxes, + rec_res) + matched_index = self.match_result(dt_boxes, pred_bboxes) + if self.use_master: + pred_html, pred = self.get_pred_html_master(pred_structures, + matched_index, rec_res) + else: + pred_html, pred = self.get_pred_html(pred_structures, matched_index, + rec_res) + return pred_html + + def match_result(self, dt_boxes, pred_bboxes): + matched = {} + for i, gt_box in enumerate(dt_boxes): + distances = [] + for j, pred_box in enumerate(pred_bboxes): + if len(pred_box) == 8: + pred_box = [ + np.min(pred_box[0::2]), np.min(pred_box[1::2]), + np.max(pred_box[0::2]), np.max(pred_box[1::2]) + ] + distances.append((distance(gt_box, pred_box), + 1. - compute_iou(gt_box, pred_box) + )) # compute iou and l1 distance + sorted_distances = distances.copy() + # select det box by iou and l1 distance + sorted_distances = sorted( + sorted_distances, key=lambda item: (item[1], item[0])) + if distances.index(sorted_distances[0]) not in matched.keys(): + matched[distances.index(sorted_distances[0])] = [i] + else: + matched[distances.index(sorted_distances[0])].append(i) + return matched + + def get_pred_html(self, pred_structures, matched_index, ocr_contents): + end_html = [] + td_index = 0 + for tag in pred_structures: + if '' in tag: + if '' == tag: + end_html.extend('') + if td_index in matched_index.keys(): + b_with = False + if '' in ocr_contents[matched_index[td_index][ + 0]] and len(matched_index[td_index]) > 1: + b_with = True + end_html.extend('') + for i, td_index_index in enumerate(matched_index[td_index]): + content = ocr_contents[td_index_index][0] + if len(matched_index[td_index]) > 1: + if len(content) == 0: + continue + if content[0] == ' ': + content = content[1:] + if '' in content: + content = content[3:] + if '' in content: + content = content[:-4] + if len(content) == 0: + continue + if i != len(matched_index[ + td_index]) - 1 and ' ' != content[-1]: + content += ' ' + end_html.extend(content) + if b_with: + end_html.extend('') + if '' == tag: + end_html.append('') + else: + end_html.append(tag) + td_index += 1 + else: + end_html.append(tag) + return ''.join(end_html), end_html + + def get_pred_html_master(self, pred_structures, matched_index, + ocr_contents): + end_html = [] + td_index = 0 + for token in pred_structures: + if '' in token: + txt = '' + b_with = False + if td_index in matched_index.keys(): + if '' in ocr_contents[matched_index[td_index][ + 0]] and len(matched_index[td_index]) > 1: + b_with = True + for i, td_index_index in enumerate(matched_index[td_index]): + content = ocr_contents[td_index_index][0] + if len(matched_index[td_index]) > 1: + if len(content) == 0: + continue + if content[0] == ' ': + content = content[1:] + if '' in content: + content = content[3:] + if '' in content: + content = content[:-4] + if len(content) == 0: + continue + if i != len(matched_index[ + td_index]) - 1 and ' ' != content[-1]: + content += ' ' + txt += content + if b_with: + txt = '{}'.format(txt) + if '' == token: + token = '{}'.format(txt) + else: + token = '{}'.format(txt) + td_index += 1 + token = deal_eb_token(token) + end_html.append(token) + html = ''.join(end_html) + html = deal_bb(html) + return html, end_html + + def _filter_ocr_result(self, pred_bboxes, dt_boxes, rec_res): + y1 = pred_bboxes[:, 1::2].min() + new_dt_boxes = [] + new_rec_res = [] + + for box, rec in zip(dt_boxes, rec_res): + if np.max(box[1::2]) < y1: + continue + new_dt_boxes.append(box) + new_rec_res.append(rec) + return new_dt_boxes, new_rec_res diff --git a/src/containers/layout-analysis/model/ocr.py b/src/containers/layout-analysis/model/ocr.py new file mode 100644 index 00000000..6463a211 --- /dev/null +++ b/src/containers/layout-analysis/model/ocr.py @@ -0,0 +1,383 @@ +import copy +import math +import time +import os + +import numpy as np +import onnxruntime +from PIL import Image, ImageDraw +import cv2 +from imaug import create_operators, transform +from postprocess import build_post_process +import GPUtil +if len(GPUtil.getGPUs()): + provider = [("CUDAExecutionProvider", {"cudnn_conv_algo_search": "HEURISTIC"}), "CPUExecutionProvider"] + rec_batch_num = 6 +else: + provider = ["CPUExecutionProvider"] + rec_batch_num = 1 + +class TextClassifier(): + def __init__(self): + self.weights_path = os.environ['MODEL_PATH'] + 'classifier.onnx' + + self.cls_image_shape = [3, 48, 192] + self.cls_batch_num = 30 + self.cls_thresh = 0.9 + self.use_zero_copy_run = False + postprocess_params = { + 'name': 'ClsPostProcess', + "label_list": ['0', '180'], + } + self.postprocess_op = build_post_process(postprocess_params) + + self.ort_session = onnxruntime.InferenceSession(self.weights_path, providers=provider) + + def resize_norm_img(self, img): + imgC, imgH, imgW = self.cls_image_shape + h = img.shape[0] + w = img.shape[1] + ratio = w / float(h) + if math.ceil(imgH * ratio) > imgW: + resized_w = imgW + else: + resized_w = int(math.ceil(imgH * ratio)) + resized_image = np.array(Image.fromarray(img).resize((resized_w, imgH))) + #resized_image = cv2.resize(img, (resized_w, imgH)) + resized_image = resized_image.astype('float32') + if self.cls_image_shape[0] == 1: + resized_image = resized_image / 255 + resized_image = resized_image[np.newaxis, :] + else: + resized_image = resized_image.transpose((2, 0, 1)) / 255 + resized_image -= 0.5 + resized_image /= 0.5 + padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32) + padding_im[:, :, 0:resized_w] = resized_image + return padding_im + + def __call__(self, img_list): + img_list = copy.deepcopy(img_list) + img_num = len(img_list) + # Calculate the aspect ratio of all text bars + width_list = [] + for img in img_list: + width_list.append(img.shape[1] / float(img.shape[0])) + # Sorting can speed up the cls process + indices = np.argsort(np.array(width_list)) + + cls_res = [['', 0.0]] * img_num + batch_num = self.cls_batch_num + for beg_img_no in range(0, img_num, batch_num): + end_img_no = min(img_num, beg_img_no + batch_num) + norm_img_batch = [] + max_wh_ratio = 0 + for ino in range(beg_img_no, end_img_no): + h, w = img_list[indices[ino]].shape[0:2] + wh_ratio = w * 1.0 / h + max_wh_ratio = max(max_wh_ratio, wh_ratio) + for ino in range(beg_img_no, end_img_no): + norm_img = self.resize_norm_img(img_list[indices[ino]]) + norm_img = norm_img[np.newaxis, :] + norm_img_batch.append(norm_img) + norm_img_batch = np.concatenate(norm_img_batch) + norm_img_batch = norm_img_batch.copy() + starttime = time.time() + ort_inputs = {self.ort_session.get_inputs()[0].name: norm_img_batch} + prob_out = self.ort_session.run(None, ort_inputs)[0] + cls_result = self.postprocess_op(prob_out) + for rno in range(len(cls_result)): + label, score = cls_result[rno] + cls_res[indices[beg_img_no + rno]] = [label, score] + if '180' in label and score > self.cls_thresh: + img_list[indices[beg_img_no + rno]] = np.array(Image.fromarray(img_list[indices[beg_img_no + rno]]).transpose(Image.ROTATE_180)) + return img_list, cls_res + +class TextDetector(): + def __init__(self): + modelName = 'det_cn.onnx' + self.weights_path = os.environ['MODEL_PATH'] + modelName + + self.det_algorithm = 'DB' + self.use_zero_copy_run = False + + pre_process_list = [{'DetResizeForTest': {'limit_side_len': 960, 'limit_type': 'max'}}, + {'NormalizeImage': {'std': [0.229, 0.224, 0.225], 'mean': [0.485, 0.456, 0.406], 'scale': '1./255.', 'order': 'hwc'}}, + {'ToCHWImage': None}, {'KeepKeys': {'keep_keys': ['image', 'shape']}}] + + postprocess_params = {'name': 'DBPostProcess', 'thresh': 0.1, 'box_thresh': 0.1, 'max_candidates': 1000, 'unclip_ratio': 1.5, 'use_dilation': False, 'score_mode': 'fast', 'box_type': 'quad'} + self.preprocess_op = create_operators(pre_process_list) + self.postprocess_op = build_post_process(postprocess_params) + self.ort_session = onnxruntime.InferenceSession(self.weights_path, providers=provider) + _ = self.ort_session.run(None, {"x": np.zeros([1, 3, 64, 64], dtype='float32')}) + + # load_pytorch_weights + + def order_points_clockwise(self, pts): + """ + reference from: https://github.com/jrosebr1/imutils/blob/master/imutils/perspective.py + # sort the points based on their x-coordinates + """ + xSorted = pts[np.argsort(pts[:, 0]), :] + + # grab the left-most and right-most points from the sorted + # x-roodinate points + leftMost = xSorted[:2, :] + rightMost = xSorted[2:, :] + + # now, sort the left-most coordinates according to their + # y-coordinates so we can grab the top-left and bottom-left + # points, respectively + leftMost = leftMost[np.argsort(leftMost[:, 1]), :] + (tl, bl) = leftMost + + rightMost = rightMost[np.argsort(rightMost[:, 1]), :] + (tr, br) = rightMost + + rect = np.array([tl, tr, br, bl], dtype="float32") + return rect + + def clip_det_res(self, points, img_height, img_width): + for pno in range(points.shape[0]): + points[pno, 0] = int(min(max(points[pno, 0], 0), img_width - 1)) + points[pno, 1] = int(min(max(points[pno, 1], 0), img_height - 1)) + return points + + def filter_tag_det_res(self, dt_boxes, image_shape): + img_height, img_width = image_shape[0:2] + dt_boxes_new = [] + for box in dt_boxes: + box = self.order_points_clockwise(box) + box = self.clip_det_res(box, img_height, img_width) + rect_width = int(np.linalg.norm(box[0] - box[1])) + rect_height = int(np.linalg.norm(box[0] - box[3])) + if rect_width <= 3 or rect_height <= 3: + continue + dt_boxes_new.append(box) + dt_boxes = np.array(dt_boxes_new) + return dt_boxes + + def filter_tag_det_res_only_clip(self, dt_boxes, image_shape): + img_height, img_width = image_shape[0:2] + dt_boxes_new = [] + for box in dt_boxes: + box = self.clip_det_res(box, img_height, img_width) + dt_boxes_new.append(box) + dt_boxes = np.array(dt_boxes_new) + return dt_boxes + + def __call__(self, img): + start = time.time() + ori_im = img.copy() + data = {'image': img} + data = transform(data, self.preprocess_op) + img, shape_list = data + if img is None: + return None, 0 + img = np.expand_dims(img, axis=0) + shape_list = np.expand_dims(shape_list, axis=0) + img = img.copy() + ort_inputs = {self.ort_session.get_inputs()[0].name: img} + preds = {} + preds['maps'] = self.ort_session.run(None, ort_inputs)[0] + post_result = self.postprocess_op(preds, shape_list) + dt_boxes = post_result[0]['points'] + dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape) + return dt_boxes + +class TextRecognizer(): + def __init__(self, lang='ch'): + if lang=='ch': + modelName = 'rec_ch.onnx' + else: + modelName = 'rec_en.onnx' + self.weights_path = os.environ['MODEL_PATH'] + modelName + + self.limited_max_width = 1280 + self.limited_min_width = 16 + + self.rec_image_shape = [3, 48, 480] + self.character_type = 'ch' + self.rec_batch_num = rec_batch_num + self.rec_algorithm = 'CRNN' + self.use_zero_copy_run = False + if lang=='ch': + postprocess_params = { + 'name': 'CTCLabelDecode', + "character_type": 'ch', + "character_dict_path": os.environ['MODEL_PATH'] + 'ppocr_keys_v1.txt', + "use_space_char": True + } + else: + postprocess_params = { + 'name': 'CTCLabelDecode', + "character_dict_path": os.environ['MODEL_PATH'] + 'en_dict.txt', + "use_space_char": True + } + self.postprocess_op = build_post_process(postprocess_params) + + self.ort_session = onnxruntime.InferenceSession(self.weights_path, providers=provider) + + def resize_norm_img(self, img, max_wh_ratio): + imgC, imgH, imgW = self.rec_image_shape + + assert imgC == img.shape[2] + imgW = int((imgH * max_wh_ratio)) + # if self.use_onnx: + # w = self.input_tensor.shape[3:][0] + # if isinstance(w, str): + # pass + # elif w is not None and w > 0: + # imgW = w + h, w = img.shape[:2] + ratio = w / float(h) + if math.ceil(imgH * ratio) > imgW: + resized_w = imgW + else: + resized_w = int(math.ceil(imgH * ratio)) + resized_image = cv2.resize(img, (resized_w, imgH)) + resized_image = resized_image.astype('float32') + resized_image = resized_image.transpose((2, 0, 1)) / 255 + resized_image -= 0.5 + resized_image /= 0.5 + padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32) + padding_im[:, :, 0:resized_w] = resized_image + return padding_im + + def __call__(self, img_list): + img_num = len(img_list) + # Calculate the aspect ratio of all text bars + width_list = [] + for img in img_list: + width_list.append(img.shape[1] / float(img.shape[0])) + # Sorting can speed up the recognition process + indices = np.argsort(np.array(width_list)) + + # rec_res = [] + rec_res = [['', 0.0]] * img_num + batch_num = self.rec_batch_num + for beg_img_no in range(0, img_num, batch_num): + end_img_no = min(img_num, beg_img_no + batch_num) + norm_img_batch = [] + max_wh_ratio = 0 + for ino in range(beg_img_no, end_img_no): + # h, w = img_list[ino].shape[0:2] + h, w = img_list[indices[ino]].shape[0:2] + wh_ratio = w * 1.0 / h + max_wh_ratio = max(max_wh_ratio, wh_ratio) + max_wh_ratio = math.ceil(max_wh_ratio) + for ino in range(beg_img_no, end_img_no): + # norm_img = self.resize_norm_img(img_list[ino], max_wh_ratio) + norm_img = self.resize_norm_img(img_list[indices[ino]], + max_wh_ratio) + norm_img = norm_img[np.newaxis, :] + norm_img_batch.append(norm_img) + norm_img_batch = np.concatenate(norm_img_batch) + norm_img_batch = norm_img_batch.copy() + norm_img_batch = np.ascontiguousarray(norm_img_batch) + ort_inputs = {self.ort_session.get_inputs()[0].name: norm_img_batch} + start = time.time() + preds = self.ort_session.run(None, ort_inputs)[0] + rec_result = self.postprocess_op(preds) + for rno in range(len(rec_result)): + rec_res[indices[beg_img_no + rno]] = rec_result[rno] + return rec_res +def sorted_boxes(dt_boxes): + """ + Sort text boxes in order from top to bottom, left to right + args: + dt_boxes(array):detected text boxes with shape [4, 2] + return: + sorted boxes(array) with shape [4, 2] + """ + num_boxes = dt_boxes.shape[0] + sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0])) + _boxes = list(sorted_boxes) + + for i in range(num_boxes - 1): + if abs(_boxes[i + 1][0][1] - _boxes[i][0][1]) < 10 and ( + _boxes[i + 1][0][0] < _boxes[i][0][0] + ): + tmp = _boxes[i] + _boxes[i] = _boxes[i + 1] + _boxes[i + 1] = tmp + return _boxes +class TextSystem: + def __init__(self): + self.text_detector = TextDetector() + + self.text_recognizer = { + 'ch': TextRecognizer('ch'), + 'en': TextRecognizer('en'), + } + + self.drop_score = 0.4 + self.text_classifier = TextClassifier() + + def get_rotate_crop_image(self, img, points): + """ + img_height, img_width = img.shape[0:2] + left = int(np.min(points[:, 0])) + right = int(np.max(points[:, 0])) + top = int(np.min(points[:, 1])) + bottom = int(np.max(points[:, 1])) + img_crop = img[top:bottom, left:right, :].copy() + points[:, 0] = points[:, 0] - left + points[:, 1] = points[:, 1] - top + """ + img_crop_width = int( + max( + np.linalg.norm(points[0] - points[1]), + np.linalg.norm(points[2] - points[3]), + ) + ) + img_crop_height = int( + max( + np.linalg.norm(points[0] - points[3]), + np.linalg.norm(points[1] - points[2]), + ) + ) + pts_std = np.float32( + [ + [0, 0], + [img_crop_width, 0], + [img_crop_width, img_crop_height], + [0, img_crop_height], + ] + ) + M = cv2.getPerspectiveTransform(points, pts_std) + dst_img = cv2.warpPerspective( + img, + M, + (img_crop_width, img_crop_height), + borderMode=cv2.BORDER_REPLICATE, + flags=cv2.INTER_CUBIC, + ) + dst_img_height, dst_img_width = dst_img.shape[0:2] + if dst_img_height * 1.0 / dst_img_width >= 1.5: + dst_img = np.rot90(dst_img) + return dst_img + + def __call__(self, img, lang='ch'): + ori_im = img.copy() + dt_boxes = self.text_detector(img) + if dt_boxes is None: + return None, None + img_crop_list = [] + + dt_boxes = sorted_boxes(dt_boxes) + + for bno in range(len(dt_boxes)): + tmp_box = copy.deepcopy(dt_boxes[bno]) + img_crop = self.get_rotate_crop_image(ori_im, tmp_box) + img_crop_list.append(img_crop) + img_crop_list, angle_list = self.text_classifier(img_crop_list) + + rec_res = self.text_recognizer[lang](img_crop_list) + filter_boxes, filter_rec_res = [], [] + for box, rec_reuslt in zip(dt_boxes, rec_res): + text, score = rec_reuslt + if score >= self.drop_score: + filter_boxes.append(box) + filter_rec_res.append(rec_reuslt) + return filter_boxes, filter_rec_res \ No newline at end of file diff --git a/src/containers/layout-analysis/model/postprocess/__init__.py b/src/containers/layout-analysis/model/postprocess/__init__.py new file mode 100644 index 00000000..935708e1 --- /dev/null +++ b/src/containers/layout-analysis/model/postprocess/__init__.py @@ -0,0 +1,27 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import copy + +__all__ = ['build_post_process'] + + +def build_post_process(config, global_config=None): + from .db_postprocess import DBPostProcess + from .rec_postprocess import CTCLabelDecode, AttnLabelDecode + from .cls_postprocess import ClsPostProcess + from .table_postprocess import TableLabelDecode + support_dict = [ + 'DBPostProcess', 'CTCLabelDecode', 'AttnLabelDecode', 'ClsPostProcess', 'TableLabelDecode' + ] + + config = copy.deepcopy(config) + module_name = config.pop('name') + if global_config is not None: + config.update(global_config) + assert module_name in support_dict, Exception( + 'post process only support {}'.format(support_dict)) + module_class = eval(module_name)(**config) + return module_class \ No newline at end of file diff --git a/src/containers/layout-analysis/model/postprocess/cls_postprocess.py b/src/containers/layout-analysis/model/postprocess/cls_postprocess.py new file mode 100644 index 00000000..f16536c3 --- /dev/null +++ b/src/containers/layout-analysis/model/postprocess/cls_postprocess.py @@ -0,0 +1,15 @@ +class ClsPostProcess(object): + """ Convert between text-label and text-index """ + + def __init__(self, label_list, **kwargs): + super(ClsPostProcess, self).__init__() + self.label_list = label_list + + def __call__(self, preds, label=None, *args, **kwargs): + pred_idxs = preds.argmax(axis=1) + decode_out = [(self.label_list[idx], preds[i, idx]) + for i, idx in enumerate(pred_idxs)] + if label is None: + return decode_out + label = [(self.label_list[idx], 1.0) for idx in label] + return decode_out, label \ No newline at end of file diff --git a/src/containers/layout-analysis/model/postprocess/db_postprocess.py b/src/containers/layout-analysis/model/postprocess/db_postprocess.py new file mode 100644 index 00000000..a4c40c7f --- /dev/null +++ b/src/containers/layout-analysis/model/postprocess/db_postprocess.py @@ -0,0 +1,139 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import cv2 +import numpy as np +import pyclipper +from shapely.geometry import Polygon + + +class DBPostProcess(object): + """ + The post process for Differentiable Binarization (DB). + """ + + def __init__(self, + thresh=0.3, + box_thresh=0.7, + max_candidates=1000, + unclip_ratio=2.0, + use_dilation=False, + **kwargs): + self.thresh = thresh + self.box_thresh = box_thresh + self.max_candidates = max_candidates + self.unclip_ratio = unclip_ratio + self.min_size = 3 + self.dilation_kernel = None if not use_dilation else np.array( + [[1, 1], [1, 1]]) + + def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height): + ''' + _bitmap: single map with shape (1, H, W), + whose values are binarized as {0, 1} + ''' + + bitmap = _bitmap + height, width = bitmap.shape + + outs = cv2.findContours((bitmap * 255).astype(np.uint8), cv2.RETR_LIST, + cv2.CHAIN_APPROX_SIMPLE) + if len(outs) == 3: + img, contours, _ = outs[0], outs[1], outs[2] + elif len(outs) == 2: + contours, _ = outs[0], outs[1] + + num_contours = min(len(contours), self.max_candidates) + + boxes = [] + scores = [] + for index in range(num_contours): + contour = contours[index] + points, sside = self.get_mini_boxes(contour) + if sside < self.min_size: + continue + points = np.array(points) + score = self.box_score_fast(pred, points.reshape(-1, 2)) + if self.box_thresh > score: + continue + + box = self.unclip(points).reshape(-1, 1, 2) + box, sside = self.get_mini_boxes(box) + if sside < self.min_size + 2: + continue + box = np.array(box) + + box[:, 0] = np.clip( + np.round(box[:, 0] / width * dest_width), 0, dest_width) + box[:, 1] = np.clip( + np.round(box[:, 1] / height * dest_height), 0, dest_height) + boxes.append(box.astype(np.int16)) + scores.append(score) + return np.array(boxes, dtype=np.int16), scores + + def unclip(self, box): + unclip_ratio = self.unclip_ratio + poly = Polygon(box) + distance = poly.area * unclip_ratio / poly.length + offset = pyclipper.PyclipperOffset() + offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON) + expanded = np.array(offset.Execute(distance)) + return expanded + + def get_mini_boxes(self, contour): + bounding_box = cv2.minAreaRect(contour) + points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0]) + + index_1, index_2, index_3, index_4 = 0, 1, 2, 3 + if points[1][1] > points[0][1]: + index_1 = 0 + index_4 = 1 + else: + index_1 = 1 + index_4 = 0 + if points[3][1] > points[2][1]: + index_2 = 2 + index_3 = 3 + else: + index_2 = 3 + index_3 = 2 + + box = [ + points[index_1], points[index_2], points[index_3], points[index_4] + ] + return box, min(bounding_box[1]) + + def box_score_fast(self, bitmap, _box): + h, w = bitmap.shape[:2] + box = _box.copy() + xmin = np.clip(np.floor(box[:, 0].min()).astype(np.int32), 0, w - 1) + xmax = np.clip(np.ceil(box[:, 0].max()).astype(np.int32), 0, w - 1) + ymin = np.clip(np.floor(box[:, 1].min()).astype(np.int32), 0, h - 1) + ymax = np.clip(np.ceil(box[:, 1].max()).astype(np.int32), 0, h - 1) + + mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8) + box[:, 0] = box[:, 0] - xmin + box[:, 1] = box[:, 1] - ymin + cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int32), 1) + return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0] + + def __call__(self, outs_dict, shape_list): + pred = outs_dict['maps'] + pred = pred[:, 0, :, :] + segmentation = pred > self.thresh + + boxes_batch = [] + for batch_index in range(pred.shape[0]): + src_h, src_w, ratio_h, ratio_w = shape_list[batch_index] + if self.dilation_kernel is not None: + mask = cv2.dilate( + np.array(segmentation[batch_index]).astype(np.uint8), + self.dilation_kernel) + else: + mask = segmentation[batch_index] + boxes, scores = self.boxes_from_bitmap(pred[batch_index], mask, + src_w, src_h) + + boxes_batch.append({'points': boxes}) + return boxes_batch \ No newline at end of file diff --git a/src/containers/layout-analysis/model/postprocess/rec_postprocess.py b/src/containers/layout-analysis/model/postprocess/rec_postprocess.py new file mode 100644 index 00000000..86433e0b --- /dev/null +++ b/src/containers/layout-analysis/model/postprocess/rec_postprocess.py @@ -0,0 +1,216 @@ +import numpy as np + + +class BaseRecLabelDecode(object): + """ Convert between text-label and text-index """ + + def __init__(self, + character_dict_path=None, + character_type='ch', + use_space_char=False): + support_character_type = [ + 'ch', 'en', 'en_sensitive', 'french', 'german', 'japan', 'korean' + ] + assert character_type in support_character_type, "Only {} are supported now but get {}".format( + support_character_type, character_type) + + if character_type == "en": + self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz" + dict_character = list(self.character_str) + elif character_type in ["ch", "french", "german", "japan", "korean"]: + self.character_str = "" + assert character_dict_path is not None, "character_dict_path should not be None when character_type is ch" + with open(character_dict_path, "rb") as fin: + lines = fin.readlines() + for line in lines: + line = line.decode('utf-8').strip("\n").strip("\r\n") + self.character_str += line + if use_space_char: + self.character_str += " " + dict_character = list(self.character_str) + elif character_type == "en_sensitive": + # same with ASTER setting (use 94 char). + import string + self.character_str = string.printable[:-6] + dict_character = list(self.character_str) + else: + raise NotImplementedError + self.character_type = character_type + dict_character = self.add_special_char(dict_character) + self.dict = {} + for i, char in enumerate(dict_character): + self.dict[char] = i + self.character = dict_character + + def add_special_char(self, dict_character): + return dict_character + + def decode(self, text_index, text_prob=None, is_remove_duplicate=True): + """ convert text-index into text-label. """ + result_list = [] + ignored_tokens = self.get_ignored_tokens() + batch_size = len(text_index) + for batch_idx in range(batch_size): + char_list = [] + conf_list = [] + for idx in range(len(text_index[batch_idx])): + if text_index[batch_idx][idx] in ignored_tokens: + continue + if is_remove_duplicate: + # only for predict + if idx > 0 and text_index[batch_idx][idx - 1] == text_index[ + batch_idx][idx]: + continue + char_list.append(self.character[int(text_index[batch_idx][ + idx])]) + if text_prob is not None: + conf_list.append(text_prob[batch_idx][idx]) + else: + conf_list.append(1) + text = ''.join(char_list) + result_list.append((text, np.mean(conf_list))) + return result_list + + def get_ignored_tokens(self): + return [0] # for ctc blank + + +class CTCLabelDecode(BaseRecLabelDecode): + """ Convert between text-label and text-index """ + + def __init__(self, + character_dict_path=None, + character_type='ch', + use_space_char=False, + **kwargs): + super(CTCLabelDecode, self).__init__(character_dict_path, + character_type, use_space_char) + + def __call__(self, preds, label=None, *args, **kwargs): + + preds_idx = preds.argmax(axis=2) + preds_prob = preds.max(axis=2) + text = self.decode(preds_idx, preds_prob) + if label is None: + return text + label = self.decode(label, is_remove_duplicate=False) + return text, label + + def add_special_char(self, dict_character): + dict_character = ['blank'] + dict_character + return dict_character + + +class AttnLabelDecode(BaseRecLabelDecode): + """ Convert between text-label and text-index """ + + def __init__(self, + character_dict_path=None, + character_type='ch', + use_space_char=False, + **kwargs): + super(AttnLabelDecode, self).__init__(character_dict_path, + character_type, use_space_char) + + def add_special_char(self, dict_character): + self.beg_str = "sos" + self.end_str = "eos" + dict_character = [self.beg_str, self.end_str] + dict_character + return dict_character + + def __call__(self, text): + text = self.decode(text) + return text + + def get_ignored_tokens(self): + beg_idx = self.get_beg_end_flag_idx("beg") + end_idx = self.get_beg_end_flag_idx("end") + return [beg_idx, end_idx] + + def get_beg_end_flag_idx(self, beg_or_end): + if beg_or_end == "beg": + idx = np.array(self.dict[self.beg_str]) + elif beg_or_end == "end": + idx = np.array(self.dict[self.end_str]) + else: + assert False, "unsupport type %s in get_beg_end_flag_idx" \ + % beg_or_end + return idx +class AttnLabelDecode(BaseRecLabelDecode): + """ Convert between text-label and text-index """ + + def __init__(self, character_dict_path=None, use_space_char=False, + **kwargs): + super(AttnLabelDecode, self).__init__(character_dict_path, + use_space_char) + + def add_special_char(self, dict_character): + self.beg_str = "sos" + self.end_str = "eos" + dict_character = dict_character + dict_character = [self.beg_str] + dict_character + [self.end_str] + return dict_character + + def decode(self, text_index, text_prob=None, is_remove_duplicate=False): + """ convert text-index into text-label. """ + result_list = [] + ignored_tokens = self.get_ignored_tokens() + [beg_idx, end_idx] = self.get_ignored_tokens() + batch_size = len(text_index) + for batch_idx in range(batch_size): + char_list = [] + conf_list = [] + for idx in range(len(text_index[batch_idx])): + if text_index[batch_idx][idx] in ignored_tokens: + continue + if int(text_index[batch_idx][idx]) == int(end_idx): + break + if is_remove_duplicate: + # only for predict + if idx > 0 and text_index[batch_idx][idx - 1] == text_index[ + batch_idx][idx]: + continue + char_list.append(self.character[int(text_index[batch_idx][ + idx])]) + if text_prob is not None: + conf_list.append(text_prob[batch_idx][idx]) + else: + conf_list.append(1) + text = ''.join(char_list) + result_list.append((text, np.mean(conf_list).tolist())) + return result_list + + def __call__(self, preds, label=None, *args, **kwargs): + """ + text = self.decode(text) + if label is None: + return text + else: + label = self.decode(label, is_remove_duplicate=False) + return text, label + """ + if isinstance(preds, paddle.Tensor): + preds = preds.numpy() + + preds_idx = preds.argmax(axis=2) + preds_prob = preds.max(axis=2) + text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False) + if label is None: + return text + label = self.decode(label, is_remove_duplicate=False) + return text, label + + def get_ignored_tokens(self): + beg_idx = self.get_beg_end_flag_idx("beg") + end_idx = self.get_beg_end_flag_idx("end") + return [beg_idx, end_idx] + + def get_beg_end_flag_idx(self, beg_or_end): + if beg_or_end == "beg": + idx = np.array(self.dict[self.beg_str]) + elif beg_or_end == "end": + idx = np.array(self.dict[self.end_str]) + else: + assert False, "unsupport type %s in get_beg_end_flag_idx" \ + % beg_or_end + return idx \ No newline at end of file diff --git a/src/containers/layout-analysis/model/postprocess/table_postprocess.py b/src/containers/layout-analysis/model/postprocess/table_postprocess.py new file mode 100644 index 00000000..b37c32a1 --- /dev/null +++ b/src/containers/layout-analysis/model/postprocess/table_postprocess.py @@ -0,0 +1,120 @@ +import numpy as np + +from .rec_postprocess import AttnLabelDecode + + +class TableLabelDecode(AttnLabelDecode): + """ """ + + def __init__(self, + character_dict_path, + merge_no_span_structure=False, + **kwargs): + dict_character = [] + with open(character_dict_path, "rb") as fin: + lines = fin.readlines() + for line in lines: + line = line.decode('utf-8').strip("\n").strip("\r\n") + dict_character.append(line) + + if merge_no_span_structure: + if "" not in dict_character: + dict_character.append("") + if "" in dict_character: + dict_character.remove("") + + dict_character = self.add_special_char(dict_character) + self.dict = {} + for i, char in enumerate(dict_character): + self.dict[char] = i + self.character = dict_character + self.td_token = ['', ''] + + def __call__(self, preds, batch=None): + structure_probs = preds['structure_probs'] + bbox_preds = preds['loc_preds'] + shape_list = batch[-1] + result = self.decode(structure_probs, bbox_preds, shape_list) + if len(batch) == 1: # only contains shape + return result + + label_decode_result = self.decode_label(batch) + return result, label_decode_result + + def decode(self, structure_probs, bbox_preds, shape_list): + """convert text-label into text-index. + """ + ignored_tokens = self.get_ignored_tokens() + end_idx = self.dict[self.end_str] + + structure_idx = structure_probs.argmax(axis=2) + structure_probs = structure_probs.max(axis=2) + + structure_batch_list = [] + bbox_batch_list = [] + batch_size = len(structure_idx) + for batch_idx in range(batch_size): + structure_list = [] + bbox_list = [] + score_list = [] + for idx in range(len(structure_idx[batch_idx])): + char_idx = int(structure_idx[batch_idx][idx]) + if idx > 0 and char_idx == end_idx: + break + if char_idx in ignored_tokens: + continue + text = self.character[char_idx] + if text in self.td_token: + bbox = bbox_preds[batch_idx, idx] + bbox = self._bbox_decode(bbox, shape_list[batch_idx]) + bbox_list.append(bbox) + structure_list.append(text) + score_list.append(structure_probs[batch_idx, idx]) + structure_batch_list.append([structure_list, np.mean(score_list)]) + bbox_batch_list.append(np.array(bbox_list)) + result = { + 'bbox_batch_list': bbox_batch_list, + 'structure_batch_list': structure_batch_list, + } + return result + + def decode_label(self, batch): + """convert text-label into text-index. + """ + structure_idx = batch[1] + gt_bbox_list = batch[2] + shape_list = batch[-1] + ignored_tokens = self.get_ignored_tokens() + end_idx = self.dict[self.end_str] + + structure_batch_list = [] + bbox_batch_list = [] + batch_size = len(structure_idx) + for batch_idx in range(batch_size): + structure_list = [] + bbox_list = [] + for idx in range(len(structure_idx[batch_idx])): + char_idx = int(structure_idx[batch_idx][idx]) + if idx > 0 and char_idx == end_idx: + break + if char_idx in ignored_tokens: + continue + structure_list.append(self.character[char_idx]) + + bbox = gt_bbox_list[batch_idx][idx] + if bbox.sum() != 0: + bbox = self._bbox_decode(bbox, shape_list[batch_idx]) + bbox_list.append(bbox) + structure_batch_list.append(structure_list) + bbox_batch_list.append(bbox_list) + result = { + 'bbox_batch_list': bbox_batch_list, + 'structure_batch_list': structure_batch_list, + } + return result + + def _bbox_decode(self, bbox, shape): + h, w, ratio_h, ratio_w, pad_h, pad_w = shape + bbox[0::2] *= w + bbox[1::2] *= h + return bbox diff --git a/src/containers/layout-analysis/model/requirements.txt b/src/containers/layout-analysis/model/requirements.txt new file mode 100644 index 00000000..e79b1ba5 --- /dev/null +++ b/src/containers/layout-analysis/model/requirements.txt @@ -0,0 +1,13 @@ +boto3==1.28.85 +#torch==2.1.0 +opencv-contrib-python-headless==4.8.1.78 +#transformers==0.1.17 +onnxruntime-gpu==1.16.0 +Pillow==8.4.0 +pyclipper==1.3.0 +Shapely==1.7.1 +PyMuPDF<1.21.0 +markdownify +flask +gevent +GPUtil \ No newline at end of file diff --git a/src/containers/layout-analysis/model/sm_predictor.py b/src/containers/layout-analysis/model/sm_predictor.py new file mode 100644 index 00000000..fdec84b6 --- /dev/null +++ b/src/containers/layout-analysis/model/sm_predictor.py @@ -0,0 +1,39 @@ +from gevent import pywsgi +import flask +import json + +import infer_layout_app + +app = flask.Flask(__name__) + +@app.route('/ping', methods=['GET']) +def ping(): + """ + Determine if the container is working and healthy. In this sample container, we declare + it healthy if we can load the model successfully. + :return: + """ + status = 200 + return flask.Response(response='Flask app is activated.', status=status, mimetype='application/json') + +@app.route('/invocations', methods=['POST']) +def transformation(): + """ + Do an inference on a single batch of data. In this sample server, we take image data as base64 formation, + decode it for internal use and then convert the predictions to json format + :return: + """ + if flask.request.content_type == 'application/json': + request_body = flask.request.data.decode('utf-8') + body = json.loads(request_body) + req = infer_layout_app.handler({'body':body}, None) + return flask.Response( + response=req['body'], + status=req['statusCode'], mimetype='application/json') + else: + return flask.Response( + response='Only supports application/json data', + status=415, mimetype='application/json') + +server = pywsgi.WSGIServer(('0.0.0.0', 8080), app) +server.serve_forever() \ No newline at end of file diff --git a/src/containers/layout-analysis/model/table.py b/src/containers/layout-analysis/model/table.py new file mode 100644 index 00000000..06f49546 --- /dev/null +++ b/src/containers/layout-analysis/model/table.py @@ -0,0 +1,158 @@ +from imaug import create_operators +from postprocess import build_post_process +from matcher import TableMatch +import time +import copy +from imaug import create_operators, transform +import numpy as np +import os +import onnxruntime as ort + +sess_options = ort.SessionOptions() + +sess_options.intra_op_num_threads = 8 +sess_options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL +sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL +def sorted_boxes(dt_boxes): + """ + Sort text boxes in order from top to bottom, left to right + args: + dt_boxes(array):detected text boxes with shape [4, 2] + return: + sorted boxes(array) with shape [4, 2] + """ + num_boxes = dt_boxes.shape[0] + sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0])) + _boxes = list(sorted_boxes) + + for i in range(num_boxes - 1): + if abs(_boxes[i + 1][0][1] - _boxes[i][0][1]) < 10 and ( + _boxes[i + 1][0][0] < _boxes[i][0][0] + ): + tmp = _boxes[i] + _boxes[i] = _boxes[i + 1] + _boxes[i + 1] = tmp + return _boxes +class TableStructurer(object): + def __init__(self): + self.use_onnx = True #args.use_onnx + pre_process_list = [{'ResizeTableImage': {'max_len': 488}}, {'NormalizeImage': {'std': [0.229, 0.224, 0.225], 'mean': [0.485, 0.456, 0.406], 'scale': '1./255.', 'order': 'hwc'}}, {'PaddingTableImage': {'size': [488, 488]}}, {'ToCHWImage': None}, {'KeepKeys': {'keep_keys': ['image', 'shape']}}] + + postprocess_params = { + 'name': 'TableLabelDecode', + "character_dict_path": os.environ['MODEL_PATH'] + 'table_structure_dict_ch.txt', + 'merge_no_span_structure': True + } + self.preprocess_op = create_operators(pre_process_list) + self.postprocess_op = build_post_process(postprocess_params) + + sess = ort.InferenceSession(os.environ['MODEL_PATH'] + 'table_sim.onnx', providers=['CPUExecutionProvider']) #, sess_options=sess_options, providers=[("CUDAExecutionProvider", {"cudnn_conv_algo_search": "DEFAULT"})] + _ = sess.run(None, {'x': np.zeros((1, 3, 488, 488), dtype='float32')}) + self.predictor, self.input_tensor, self.output_tensors, self.config = sess, sess.get_inputs()[0], None, None + + + def __call__(self, img): + starttime = time.time() + ori_im = img.copy() + data = {'image': img} + data = transform(data, self.preprocess_op) + img = data[0] + if img is None: + return None, 0 + img = np.expand_dims(img, axis=0) + img = img.copy() + input_dict = {} + input_dict[self.input_tensor.name] = img + outputs = self.predictor.run(self.output_tensors, input_dict) + preds = {} + preds['structure_probs'] = outputs[1] + preds['loc_preds'] = outputs[0] + shape_list = np.expand_dims(data[-1], axis=0) + post_result = self.postprocess_op(preds, [shape_list]) + + structure_str_list = post_result['structure_batch_list'][0] + bbox_list = post_result['bbox_batch_list'][0] + structure_str_list = structure_str_list[0] + structure_str_list = [ + '', '', '' + ] + structure_str_list + ['
', '', ''] + elapse = time.time() - starttime + return (structure_str_list, bbox_list), elapse +def expand(pix, det_box, shape): + x0, y0, x1, y1 = det_box + h, w, c = shape + tmp_x0 = x0 - pix + tmp_x1 = x1 + pix + tmp_y0 = y0 - pix + tmp_y1 = y1 + pix + x0_ = tmp_x0 if tmp_x0 >= 0 else 0 + x1_ = tmp_x1 if tmp_x1 <= w else w + y0_ = tmp_y0 if tmp_y0 >= 0 else 0 + y1_ = tmp_y1 if tmp_y1 <= h else h + return x0_, y0_, x1_, y1_ + +class TableSystem(object): + def __init__(self, text_detector=None, text_recognizer=None): + self.text_detector = text_detector + self.text_recognizer = text_recognizer + + self.table_structurer = TableStructurer() + self.match = TableMatch(filter_ocr_result=True) + + def __call__(self, img, return_ocr_result_in_table=False, lang='ch'): + result = dict() + time_dict = {'det': 0, 'rec': 0, 'table': 0, 'all': 0, 'match': 0} + start = time.time() + structure_res, elapse = self._structure(copy.deepcopy(img)) + result['cell_bbox'] = structure_res[1].tolist() + time_dict['table'] = elapse + + dt_boxes, rec_res, det_elapse, rec_elapse = self._ocr( + copy.deepcopy(img), lang) + time_dict['det'] = det_elapse + time_dict['rec'] = rec_elapse + + if return_ocr_result_in_table: + result['boxes'] = [x.tolist() for x in dt_boxes] + result['rec_res'] = rec_res + + tic = time.time() + pred_html = self.match(structure_res, dt_boxes, rec_res) + toc = time.time() + time_dict['match'] = toc - tic + result['html'] = pred_html + end = time.time() + time_dict['all'] = end - start + return result, time_dict + + def _structure(self, img): + structure_res, elapse = self.table_structurer(copy.deepcopy(img)) + return structure_res, elapse + + def _ocr(self, img, lang): + h, w = img.shape[:2] + dt_boxes = self.text_detector(copy.deepcopy(img)) + dt_boxes = sorted_boxes(dt_boxes) + + r_boxes = [] + for box in dt_boxes: + x_min = max(0, box[:, 0].min() - 1) + x_max = min(w, box[:, 0].max() + 1) + y_min = max(0, box[:, 1].min() - 1) + y_max = min(h, box[:, 1].max() + 1) + box = [x_min, y_min, x_max, y_max] + r_boxes.append(box) + dt_boxes = np.array(r_boxes) + + if dt_boxes is None: + return None, None + + img_crop_list = [] + for i in range(len(dt_boxes)): + det_box = dt_boxes[i] + x0, y0, x1, y1 = expand(2, det_box, img.shape) + text_rect = img[int(y0):int(y1), int(x0):int(x1), :] + img_crop_list.append(text_rect) + rec_res = self.text_recognizer[lang](img_crop_list) + + return dt_boxes, rec_res, 0, 0 \ No newline at end of file diff --git a/src/containers/layout-analysis/model/utils.py b/src/containers/layout-analysis/model/utils.py new file mode 100644 index 00000000..3736a165 --- /dev/null +++ b/src/containers/layout-analysis/model/utils.py @@ -0,0 +1,168 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- +# Copyright (c) Megvii Inc. All rights reserved. +import cv2 +import os + +import numpy as np + +__all__ = ["preprocess", "nms", "multiclass_nms", "postprocess"] + +def preprocess(img, input_size, swap=(2, 0, 1)): + if len(img.shape) == 3: + padded_img = np.ones((input_size[0], input_size[1], 3), dtype=np.uint8) * 114 + else: + padded_img = np.ones(input_size, dtype=np.uint8) * 114 + + r = min(input_size[0] / img.shape[0], input_size[1] / img.shape[1]) + resized_img = cv2.resize( + img, + (int(img.shape[1] * r), int(img.shape[0] * r)), + interpolation=cv2.INTER_LINEAR, + ).astype(np.uint8) + padded_img[: int(img.shape[0] * r), : int(img.shape[1] * r)] = resized_img + + padded_img = padded_img.transpose(swap) + padded_img = np.ascontiguousarray(padded_img, dtype=np.float32) + return padded_img, r + +def nms(boxes, scores, nms_thr): + """Single class NMS implemented in Numpy.""" + x1 = boxes[:, 0] + y1 = boxes[:, 1] + x2 = boxes[:, 2] + y2 = boxes[:, 3] + + areas = (x2 - x1 + 1) * (y2 - y1 + 1) + order = scores.argsort()[::-1] + + keep = [] + while order.size > 0: + i = order[0] + keep.append(i) + xx1 = np.maximum(x1[i], x1[order[1:]]) + yy1 = np.maximum(y1[i], y1[order[1:]]) + xx2 = np.minimum(x2[i], x2[order[1:]]) + yy2 = np.minimum(y2[i], y2[order[1:]]) + + w = np.maximum(0.0, xx2 - xx1 + 1) + h = np.maximum(0.0, yy2 - yy1 + 1) + inter = w * h + ovr = inter / (areas[i] + areas[order[1:]] - inter) + + inds = np.where(ovr <= nms_thr)[0] + order = order[inds + 1] + + return keep + + +def multiclass_nms(boxes, scores, nms_thr, score_thr, class_agnostic=True): + """Multiclass NMS implemented in Numpy""" + if class_agnostic: + nms_method = multiclass_nms_class_agnostic + else: + nms_method = multiclass_nms_class_aware + return nms_method(boxes, scores, nms_thr, score_thr) + + +def multiclass_nms_class_aware(boxes, scores, nms_thr, score_thr): + """Multiclass NMS implemented in Numpy. Class-aware version.""" + final_dets = [] + num_classes = scores.shape[1] + for cls_ind in range(num_classes): + cls_scores = scores[:, cls_ind] + valid_score_mask = cls_scores > score_thr + if valid_score_mask.sum() == 0: + continue + else: + valid_scores = cls_scores[valid_score_mask] + valid_boxes = boxes[valid_score_mask] + keep = nms(valid_boxes, valid_scores, nms_thr) + if len(keep) > 0: + cls_inds = np.ones((len(keep), 1)) * cls_ind + dets = np.concatenate( + [valid_boxes[keep], valid_scores[keep, None], cls_inds], 1 + ) + final_dets.append(dets) + if len(final_dets) == 0: + return None + return np.concatenate(final_dets, 0) + + +def multiclass_nms_class_agnostic(boxes, scores, nms_thr, score_thr): + """Multiclass NMS implemented in Numpy. Class-agnostic version.""" + cls_inds = scores.argmax(1) + cls_scores = scores[np.arange(len(cls_inds)), cls_inds] + + valid_score_mask = cls_scores > score_thr + if valid_score_mask.sum() == 0: + return None + valid_scores = cls_scores[valid_score_mask] + valid_boxes = boxes[valid_score_mask] + valid_cls_inds = cls_inds[valid_score_mask] + keep = nms(valid_boxes, valid_scores, nms_thr) + if keep: + dets = np.concatenate( + [valid_boxes[keep], valid_scores[keep, None], valid_cls_inds[keep, None]], 1 + ) + return dets + + +def postprocess(outputs, img_size, p6=False): + + grids = [] + expanded_strides = [] + + if not p6: + strides = [8, 16, 32] + else: + strides = [8, 16, 32, 64] + + hsizes = [img_size[0] // stride for stride in strides] + wsizes = [img_size[1] // stride for stride in strides] + + for hsize, wsize, stride in zip(hsizes, wsizes, strides): + xv, yv = np.meshgrid(np.arange(wsize), np.arange(hsize)) + grid = np.stack((xv, yv), 2).reshape(1, -1, 2) + grids.append(grid) + shape = grid.shape[:2] + expanded_strides.append(np.full((*shape, 1), stride)) + + grids = np.concatenate(grids, 1) + expanded_strides = np.concatenate(expanded_strides, 1) + outputs[..., :2] = (outputs[..., :2] + grids) * expanded_strides + outputs[..., 2:4] = np.exp(outputs[..., 2:4]) * expanded_strides + + return outputs + +def check_and_read(img_path): + if os.path.basename(img_path)[-3:].lower() == 'gif': + gif = cv2.VideoCapture(img_path) + ret, frame = gif.read() + if not ret: + logger = logging.getLogger('ppocr') + logger.info("Cannot read {}. This gif image maybe corrupted.") + return None, False + if len(frame.shape) == 2 or frame.shape[-1] == 1: + frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB) + imgvalue = frame[:, :, ::-1] + return imgvalue, True, False + elif os.path.basename(img_path)[-3:].lower() == 'pdf': + import fitz + from PIL import Image + imgs = [] + with fitz.open(img_path) as pdf: + for pg in range(0, pdf.page_count): + page = pdf[pg] + mat = fitz.Matrix(2, 2) + pm = page.get_pixmap(matrix=mat, alpha=False) + + # if width or height > 2000 pixels, don't enlarge the image + if pm.width > 2000 or pm.height > 2000: + pm = page.get_pixmap(matrix=fitz.Matrix(1, 1), alpha=False) + + img = Image.frombytes("RGB", [pm.width, pm.height], pm.samples) + img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) + imgs.append(img) + return imgs, False, True + return None, False, False diff --git a/src/containers/layout-analysis/model/xycut.py b/src/containers/layout-analysis/model/xycut.py new file mode 100644 index 00000000..2e6de809 --- /dev/null +++ b/src/containers/layout-analysis/model/xycut.py @@ -0,0 +1,125 @@ +from typing import List +import cv2 +import numpy as np + + +def projection_by_bboxes(boxes: np.array, axis: int) -> np.ndarray: + """ + 通过一组 bbox 获得投影直方图,最后以 per-pixel 形式输出 + + Args: + boxes: [N, 4] + axis: 0-x坐标向水平方向投影, 1-y坐标向垂直方向投影 + + Returns: + 1D 投影直方图,长度为投影方向坐标的最大值(我们不需要图片的实际边长,因为只是要找文本框的间隔) + + """ + assert axis in [0, 1] + length = np.max(boxes[:, axis::2]) + res = np.zeros(length, dtype=int) + # TODO: how to remove for loop? + for start, end in boxes[:, axis::2]: + res[start:end] += 1 + return res + + +# from: https://dothinking.github.io/2021-06-19-%E9%80%92%E5%BD%92%E6%8A%95%E5%BD%B1%E5%88%86%E5%89%B2%E7%AE%97%E6%B3%95/#:~:text=%E9%80%92%E5%BD%92%E6%8A%95%E5%BD%B1%E5%88%86%E5%89%B2%EF%BC%88Recursive%20XY,%EF%BC%8C%E5%8F%AF%E4%BB%A5%E5%88%92%E5%88%86%E6%AE%B5%E8%90%BD%E3%80%81%E8%A1%8C%E3%80%82 +def split_projection_profile(arr_values: np.array, min_value: float, min_gap: float): + """Split projection profile: + + ``` + ┌──┐ + arr_values │ │ ┌─┐─── + ┌──┐ │ │ │ │ | + │ │ │ │ ┌───┐ │ │min_value + │ │<- min_gap ->│ │ │ │ │ │ | + ────┴──┴─────────────┴──┴─┴───┴─┴─┴─┴─── + 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 + ``` + + Args: + arr_values (np.array): 1-d array representing the projection profile. + min_value (float): Ignore the profile if `arr_value` is less than `min_value`. + min_gap (float): Ignore the gap if less than this value. + + Returns: + tuple: Start indexes and end indexes of split groups. + """ + # all indexes with projection height exceeding the threshold + arr_index = np.where(arr_values > min_value)[0] + if not len(arr_index): + return + + # find zero intervals between adjacent projections + # | | || + # ||||<- zero-interval -> ||||| + arr_diff = arr_index[1:] - arr_index[0:-1] + arr_diff_index = np.where(arr_diff > min_gap)[0] + arr_zero_intvl_start = arr_index[arr_diff_index] + arr_zero_intvl_end = arr_index[arr_diff_index + 1] + + # convert to index of projection range: + # the start index of zero interval is the end index of projection + arr_start = np.insert(arr_zero_intvl_end, 0, arr_index[0]) + arr_end = np.append(arr_zero_intvl_start, arr_index[-1]) + arr_end += 1 # end index will be excluded as index slice + + return arr_start, arr_end + + +def recursive_xy_cut(boxes: np.ndarray, indices: List[int], res: List[int]): + """ + + Args: + boxes: (N, 4) + indices: 递归过程中始终表示 box 在原始数据中的索引 + res: 保存输出结果 + + """ + # 向 y 轴投影 + assert len(boxes) == len(indices) + + _indices = boxes[:, 1].argsort() + y_sorted_boxes = boxes[_indices] + y_sorted_indices = indices[_indices] + + # debug_vis(y_sorted_boxes, y_sorted_indices) + + y_projection = projection_by_bboxes(boxes=y_sorted_boxes, axis=1) + pos_y = split_projection_profile(y_projection, 0, 1) + if not pos_y: + return + + arr_y0, arr_y1 = pos_y + for r0, r1 in zip(arr_y0, arr_y1): + # [r0, r1] 表示按照水平切分,有 bbox 的区域,对这些区域会再进行垂直切分 + _indices = (r0 <= y_sorted_boxes[:, 1]) & (y_sorted_boxes[:, 1] < r1) + + y_sorted_boxes_chunk = y_sorted_boxes[_indices] + y_sorted_indices_chunk = y_sorted_indices[_indices] + + _indices = y_sorted_boxes_chunk[:, 0].argsort() + x_sorted_boxes_chunk = y_sorted_boxes_chunk[_indices] + x_sorted_indices_chunk = y_sorted_indices_chunk[_indices] + + # 往 x 方向投影 + x_projection = projection_by_bboxes(boxes=x_sorted_boxes_chunk, axis=0) + pos_x = split_projection_profile(x_projection, 0, 1) + if not pos_x: + continue + + arr_x0, arr_x1 = pos_x + if len(arr_x0) == 1: + # x 方向无法切分 + res.extend(x_sorted_indices_chunk) + continue + + # x 方向上能分开,继续递归调用 + for c0, c1 in zip(arr_x0, arr_x1): + _indices = (c0 <= x_sorted_boxes_chunk[:, 0]) & ( + x_sorted_boxes_chunk[:, 0] < c1 + ) + recursive_xy_cut( + x_sorted_boxes_chunk[_indices], x_sorted_indices_chunk[_indices], res + )