File size: 7,760 Bytes
2366e36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
# Copyright (c) OpenMMLab. All rights reserved.
import cv2
import numpy as np
from mmdet.core import BitmapMasks
from mmdet.datasets.builder import PIPELINES

import mmocr.utils.check_argument as check_argument
from mmocr.models.builder import build_convertor


@PIPELINES.register_module()
class OCRSegTargets:
    """Generate gt shrunk kernels for segmentation based OCR framework.

    Args:
        label_convertor (dict): Dictionary to construct label_convertor
            to convert char to index.
        attn_shrink_ratio (float): The area shrunk ratio
            between attention kernels and gt text masks.
        seg_shrink_ratio (float): The area shrunk ratio
            between segmentation kernels and gt text masks.
        box_type (str): Character box type, should be either
            'char_rects' or 'char_quads', with 'char_rects'
            for rectangle with ``xyxy`` style and 'char_quads'
            for quadrangle with ``x1y1x2y2x3y3x4y4`` style.
    """

    def __init__(self,
                 label_convertor=None,
                 attn_shrink_ratio=0.5,
                 seg_shrink_ratio=0.25,
                 box_type='char_rects',
                 pad_val=255):

        assert isinstance(attn_shrink_ratio, float)
        assert isinstance(seg_shrink_ratio, float)
        assert 0. < attn_shrink_ratio < 1.0
        assert 0. < seg_shrink_ratio < 1.0
        assert label_convertor is not None
        assert box_type in ('char_rects', 'char_quads')

        self.attn_shrink_ratio = attn_shrink_ratio
        self.seg_shrink_ratio = seg_shrink_ratio
        self.label_convertor = build_convertor(label_convertor)
        self.box_type = box_type
        self.pad_val = pad_val

    def shrink_char_quad(self, char_quad, shrink_ratio):
        """Shrink char box in style of quadrangle.

        Args:
            char_quad (list[float]): Char box with format
                [x1, y1, x2, y2, x3, y3, x4, y4].
            shrink_ratio (float): The area shrunk ratio
                between gt kernels and gt text masks.
        """
        points = [[char_quad[0], char_quad[1]], [char_quad[2], char_quad[3]],
                  [char_quad[4], char_quad[5]], [char_quad[6], char_quad[7]]]
        shrink_points = []
        for p_idx, point in enumerate(points):
            p1 = points[(p_idx + 3) % 4]
            p2 = points[(p_idx + 1) % 4]

            dist1 = self.l2_dist_two_points(p1, point)
            dist2 = self.l2_dist_two_points(p2, point)
            min_dist = min(dist1, dist2)

            v1 = [p1[0] - point[0], p1[1] - point[1]]
            v2 = [p2[0] - point[0], p2[1] - point[1]]

            temp_dist1 = (shrink_ratio * min_dist /
                          dist1) if min_dist != 0 else 0.
            temp_dist2 = (shrink_ratio * min_dist /
                          dist2) if min_dist != 0 else 0.

            v1 = [temp * temp_dist1 for temp in v1]
            v2 = [temp * temp_dist2 for temp in v2]

            shrink_point = [
                round(point[0] + v1[0] + v2[0]),
                round(point[1] + v1[1] + v2[1])
            ]
            shrink_points.append(shrink_point)

        poly = np.array(shrink_points)

        return poly

    def shrink_char_rect(self, char_rect, shrink_ratio):
        """Shrink char box in style of rectangle.

        Args:
            char_rect (list[float]): Char box with format
                [x_min, y_min, x_max, y_max].
            shrink_ratio (float): The area shrunk ratio
                between gt kernels and gt text masks.
        """
        x_min, y_min, x_max, y_max = char_rect
        w = x_max - x_min
        h = y_max - y_min
        x_min_s = round((x_min + x_max - w * shrink_ratio) / 2)
        y_min_s = round((y_min + y_max - h * shrink_ratio) / 2)
        x_max_s = round((x_min + x_max + w * shrink_ratio) / 2)
        y_max_s = round((y_min + y_max + h * shrink_ratio) / 2)
        poly = np.array([[x_min_s, y_min_s], [x_max_s, y_min_s],
                         [x_max_s, y_max_s], [x_min_s, y_max_s]])

        return poly

    def generate_kernels(self,
                         resize_shape,
                         pad_shape,
                         char_boxes,
                         char_inds,
                         shrink_ratio=0.5,
                         binary=True):
        """Generate char instance kernels for one shrink ratio.

        Args:
            resize_shape (tuple(int, int)): Image size (height, width)
                after resizing.
            pad_shape (tuple(int, int)):  Image size (height, width)
                after padding.
            char_boxes (list[list[float]]): The list of char polygons.
            char_inds (list[int]): List of char indexes.
            shrink_ratio (float): The shrink ratio of kernel.
            binary (bool): If True, return binary ndarray
                containing 0 & 1 only.
        Returns:
            char_kernel (ndarray): The text kernel mask of (height, width).
        """
        assert isinstance(resize_shape, tuple)
        assert isinstance(pad_shape, tuple)
        assert check_argument.is_2dlist(char_boxes)
        assert check_argument.is_type_list(char_inds, int)
        assert isinstance(shrink_ratio, float)
        assert isinstance(binary, bool)

        char_kernel = np.zeros(pad_shape, dtype=np.int32)
        char_kernel[:resize_shape[0], resize_shape[1]:] = self.pad_val

        for i, char_box in enumerate(char_boxes):
            if self.box_type == 'char_rects':
                poly = self.shrink_char_rect(char_box, shrink_ratio)
            elif self.box_type == 'char_quads':
                poly = self.shrink_char_quad(char_box, shrink_ratio)

            fill_value = 1 if binary else char_inds[i]
            cv2.fillConvexPoly(char_kernel, poly.astype(np.int32),
                               (fill_value))

        return char_kernel

    def l2_dist_two_points(self, p1, p2):
        return ((p1[0] - p2[0])**2 + (p1[1] - p2[1])**2)**0.5

    def __call__(self, results):
        img_shape = results['img_shape']
        resize_shape = results['resize_shape']

        h_scale = 1.0 * resize_shape[0] / img_shape[0]
        w_scale = 1.0 * resize_shape[1] / img_shape[1]

        char_boxes, char_inds = [], []
        char_num = len(results['ann_info'][self.box_type])
        for i in range(char_num):
            char_box = results['ann_info'][self.box_type][i]
            num_points = 2 if self.box_type == 'char_rects' else 4
            for j in range(num_points):
                char_box[j * 2] = round(char_box[j * 2] * w_scale)
                char_box[j * 2 + 1] = round(char_box[j * 2 + 1] * h_scale)
            char_boxes.append(char_box)
            char = results['ann_info']['chars'][i]
            char_ind = self.label_convertor.str2idx([char])[0][0]
            char_inds.append(char_ind)

        resize_shape = tuple(results['resize_shape'][:2])
        pad_shape = tuple(results['pad_shape'][:2])
        binary_target = self.generate_kernels(
            resize_shape,
            pad_shape,
            char_boxes,
            char_inds,
            shrink_ratio=self.attn_shrink_ratio,
            binary=True)

        seg_target = self.generate_kernels(
            resize_shape,
            pad_shape,
            char_boxes,
            char_inds,
            shrink_ratio=self.seg_shrink_ratio,
            binary=False)

        mask = np.ones(pad_shape, dtype=np.int32)
        mask[:resize_shape[0], resize_shape[1]:] = 0

        results['gt_kernels'] = BitmapMasks([binary_target, seg_target, mask],
                                            pad_shape[0], pad_shape[1])
        results['mask_fields'] = ['gt_kernels']

        return results