-
Notifications
You must be signed in to change notification settings - Fork 755
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Refactor] Refactor TextRecogVisualizer
- Loading branch information
1 parent
7e7a526
commit c78be99
Showing
4 changed files
with
213 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from .textrecog_visualizer import TextRecogLocalVisualizer | ||
|
||
__all__ = ['TextRecogLocalVisualizer'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from typing import Dict, Optional, Tuple, Union | ||
|
||
import cv2 | ||
import mmcv | ||
import numpy as np | ||
from mmengine import Visualizer | ||
|
||
from mmocr.core import TextRecogDataSample | ||
from mmocr.registry import VISUALIZERS | ||
|
||
|
||
@VISUALIZERS.register_module() | ||
class TextRecogLocalVisualizer(Visualizer): | ||
"""MMOCR Text Detection Local Visualizer. | ||
Args: | ||
name (str): Name of the instance. Defaults to 'visualizer'. | ||
image (np.ndarray, optional): The origin image to draw. The format | ||
should be RGB. Defaults to None. | ||
vis_backends (list, optional): Visual backend config list. | ||
Defaults to None. | ||
save_dir (str, optional): Save file dir for all storage backends. | ||
If it is None, the backend storage will not save any data. | ||
gt_color (str or tuple[int, int, int]): Colors of GT text. The tuple of | ||
color should be in RGB order. Or using an abbreviation of color, | ||
such as `'g'` for `'green'`. Defaults to 'g'. | ||
pred_color (str or tuple[int, int, int]): Colors of Predicted text. | ||
The tuple of color should be in RGB order. Or using an abbreviation | ||
of color, such as `'r'` for `'red'`. Defaults to 'r'. | ||
""" | ||
|
||
def __init__(self, | ||
name: str = 'visualizer', | ||
image: Optional[np.ndarray] = None, | ||
vis_backends: Optional[Dict] = None, | ||
save_dir: Optional[str] = None, | ||
gt_color: Optional[Union[str, Tuple[int, int, int]]] = 'g', | ||
pred_color: Optional[Union[str, Tuple[int, int, | ||
int]]] = 'r') -> None: | ||
super().__init__( | ||
name=name, | ||
image=image, | ||
vis_backends=vis_backends, | ||
save_dir=save_dir) | ||
self.gt_color = gt_color | ||
self.pred_color = pred_color | ||
|
||
def add_datasample(self, | ||
name: str, | ||
image: np.ndarray, | ||
gt_sample: Optional['TextRecogDataSample'] = None, | ||
pred_sample: Optional['TextRecogDataSample'] = None, | ||
draw_gt: bool = True, | ||
draw_pred: bool = True, | ||
show: bool = False, | ||
wait_time: int = 0, | ||
out_file: Optional[str] = None, | ||
step=0) -> None: | ||
"""Visualize datasample and save to all backends. | ||
- If GT and prediction are plotted at the same time, they are | ||
displayed in a stitched image where the left image is the | ||
ground truth and the right image is the prediction. | ||
- If ``show`` is True, all storage backends are ignored, and | ||
the images will be displayed in a local window. | ||
- If ``out_file`` is specified, the drawn image will be | ||
saved to ``out_file``. This is usually used when the display | ||
is not available. | ||
Args: | ||
name (str): The image title. Defaults to 'image'. | ||
image (np.ndarray): The image to draw. | ||
gt_sample (:obj:`TextRecogDataSample`, optional): GT | ||
TextRecogDataSample. Defaults to None. | ||
pred_sample (:obj:`TextRecogDataSample`, optional): Predicted | ||
TextRecogDataSample. Defaults to None. | ||
draw_gt (bool): Whether to draw GT TextRecogDataSample. | ||
Defaults to True. | ||
draw_pred (bool): Whether to draw Predicted TextRecogDataSample. | ||
Defaults to True. | ||
show (bool): Whether to display the drawn image. Defaults to False. | ||
wait_time (float): The interval of show (s). Defaults to 0. | ||
out_file (str): Path to output file. Defaults to None. | ||
step (int): Global step value to record. Defaults to 0. | ||
""" | ||
gt_img_data = None | ||
pred_img_data = None | ||
height, width = image.shape[:2] | ||
resize_height = 64 | ||
resize_width = int(1.0 * width / height * resize_height) | ||
image = cv2.resize(image, (resize_width, resize_height)) | ||
if image.ndim == 2: | ||
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) | ||
|
||
if draw_gt and gt_sample is not None and 'gt_text' in gt_sample: | ||
gt_text = gt_sample.gt_text.item | ||
empty_img = np.full_like(image, 255) | ||
self.set_image(empty_img) | ||
font_size = 0.5 * resize_width / len(gt_text) | ||
self.draw_texts( | ||
gt_text, | ||
np.array([resize_width / 2, resize_height / 2]), | ||
colors=self.gt_color, | ||
font_sizes=font_size, | ||
vertical_alignments='center', | ||
horizontal_alignments='center') | ||
gt_text_image = self.get_image() | ||
gt_img_data = np.concatenate((image, gt_text_image), axis=0) | ||
|
||
if (draw_pred and pred_sample is not None | ||
and 'pred_text' in pred_sample): | ||
pred_text = pred_sample.pred_text.item | ||
empty_img = np.full_like(image, 255) | ||
self.set_image(empty_img) | ||
font_size = 0.5 * resize_width / len(pred_text) | ||
self.draw_texts( | ||
pred_text, | ||
np.array([resize_width / 2, resize_height / 2]), | ||
colors=self.pred_color, | ||
font_sizes=font_size, | ||
vertical_alignments='center', | ||
horizontal_alignments='center') | ||
pred_text_image = self.get_image() | ||
pred_img_data = np.concatenate((image, pred_text_image), axis=0) | ||
|
||
if gt_img_data is not None and pred_img_data is not None: | ||
drawn_img = np.concatenate((gt_img_data, pred_text_image), axis=0) | ||
elif gt_img_data is not None: | ||
drawn_img = gt_img_data | ||
else: | ||
drawn_img = pred_img_data | ||
|
||
if show: | ||
self.show(drawn_img, win_name=name, wait_time=wait_time) | ||
else: | ||
self.add_image(name, drawn_img, step) | ||
|
||
if out_file is not None: | ||
mmcv.imwrite(drawn_img[..., ::-1], out_file) |
68 changes: 68 additions & 0 deletions
68
tests/test_core/test_visualization/test_textrecog_visualizer.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
import os.path as osp | ||
import tempfile | ||
import unittest | ||
|
||
import cv2 | ||
import numpy as np | ||
from mmengine.data import LabelData | ||
|
||
from mmocr.core import TextRecogDataSample | ||
from mmocr.core.visualization import TextRecogLocalVisualizer | ||
|
||
|
||
class TestTextDetLocalVisualizer(unittest.TestCase): | ||
|
||
def test_add_datasample(self): | ||
h, w = 64, 128 | ||
image = np.random.randint(0, 256, size=(h, w, 3)).astype('uint8') | ||
|
||
# test gt_text | ||
gt_recog_data_sample = TextRecogDataSample() | ||
img_meta = dict(img_shape=(12, 10, 3)) | ||
gt_text = LabelData(metainfo=img_meta) | ||
gt_text.item = 'mmocr' | ||
gt_recog_data_sample.gt_text = gt_text | ||
|
||
recog_local_visualizer = TextRecogLocalVisualizer() | ||
recog_local_visualizer.add_datasample('image', image, | ||
gt_recog_data_sample) | ||
|
||
# test gt_text and pred_text | ||
pred_recog_data_sample = TextRecogDataSample() | ||
pred_text = LabelData(metainfo=img_meta) | ||
pred_text.item = 'MMOCR' | ||
pred_recog_data_sample.pred_text = pred_text | ||
|
||
with tempfile.TemporaryDirectory() as tmp_dir: | ||
# test out | ||
out_file = osp.join(tmp_dir, 'out_file.jpg') | ||
|
||
# draw_gt = True + gt_sample | ||
recog_local_visualizer.add_datasample( | ||
'image', image, gt_recog_data_sample, out_file=out_file) | ||
self._assert_image_and_shape(out_file, (h * 2, w, 3)) | ||
|
||
# draw_gt = True + gt_sample + pred_sample | ||
recog_local_visualizer.add_datasample( | ||
'image', | ||
image, | ||
gt_recog_data_sample, | ||
pred_recog_data_sample, | ||
out_file=out_file) | ||
self._assert_image_and_shape(out_file, (h * 3, w, 3)) | ||
|
||
# draw_gt = False + gt_sample + pred_sample | ||
recog_local_visualizer.add_datasample( | ||
'image', | ||
image, | ||
gt_recog_data_sample, | ||
pred_recog_data_sample, | ||
draw_gt=False, | ||
out_file=out_file) | ||
self._assert_image_and_shape(out_file, (h * 2, w, 3)) | ||
|
||
def _assert_image_and_shape(self, out_file, out_shape): | ||
self.assertTrue(osp.exists(out_file)) | ||
drawn_img = cv2.imread(out_file) | ||
self.assertTrue(drawn_img.shape == out_shape) |