File size: 15,708 Bytes
2366e36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
# Copyright (c) OpenMMLab. All rights reserved.
import math

import mmcv
import numpy as np
import torch
import torchvision.transforms.functional as TF
from mmcv.runner.dist_utils import get_dist_info
from mmdet.datasets.builder import PIPELINES
from PIL import Image
from shapely.geometry import Polygon
from shapely.geometry import box as shapely_box

import mmocr.utils as utils
from mmocr.datasets.pipelines.crop import warp_img


@PIPELINES.register_module()
class ResizeOCR:
    """Image resizing and padding for OCR.

    Args:
        height (int | tuple(int)): Image height after resizing.
        min_width (none | int | tuple(int)): Image minimum width
            after resizing.
        max_width (none | int | tuple(int)): Image maximum width
            after resizing.
        keep_aspect_ratio (bool): Keep image aspect ratio if True
            during resizing, Otherwise resize to the size height *
            max_width.
        img_pad_value (int): Scalar to fill padding area.
        width_downsample_ratio (float): Downsample ratio in horizontal
            direction from input image to output feature.
        backend (str | None): The image resize backend type. Options are `cv2`,
            `pillow`, `None`. If backend is None, the global imread_backend
            specified by ``mmcv.use_backend()`` will be used. Default: None.
    """

    def __init__(self,
                 height,
                 min_width=None,
                 max_width=None,
                 keep_aspect_ratio=True,
                 img_pad_value=0,
                 width_downsample_ratio=1.0 / 16,
                 backend=None):
        assert isinstance(height, (int, tuple))
        assert utils.is_none_or_type(min_width, (int, tuple))
        assert utils.is_none_or_type(max_width, (int, tuple))
        if not keep_aspect_ratio:
            assert max_width is not None, ('"max_width" must assigned '
                                           'if "keep_aspect_ratio" is False')
        assert isinstance(img_pad_value, int)
        if isinstance(height, tuple):
            assert isinstance(min_width, tuple)
            assert isinstance(max_width, tuple)
            assert len(height) == len(min_width) == len(max_width)

        self.height = height
        self.min_width = min_width
        self.max_width = max_width
        self.keep_aspect_ratio = keep_aspect_ratio
        self.img_pad_value = img_pad_value
        self.width_downsample_ratio = width_downsample_ratio
        self.backend = backend

    def __call__(self, results):
        rank, _ = get_dist_info()
        if isinstance(self.height, int):
            dst_height = self.height
            dst_min_width = self.min_width
            dst_max_width = self.max_width
        else:
            # Multi-scale resize used in distributed training.
            # Choose one (height, width) pair for one rank id.

            idx = rank % len(self.height)
            dst_height = self.height[idx]
            dst_min_width = self.min_width[idx]
            dst_max_width = self.max_width[idx]

        img_shape = results['img_shape']
        ori_height, ori_width = img_shape[:2]
        valid_ratio = 1.0
        resize_shape = list(img_shape)
        pad_shape = list(img_shape)

        if self.keep_aspect_ratio:
            new_width = math.ceil(float(dst_height) / ori_height * ori_width)
            width_divisor = int(1 / self.width_downsample_ratio)
            # make sure new_width is an integral multiple of width_divisor.
            if new_width % width_divisor != 0:
                new_width = round(new_width / width_divisor) * width_divisor
            if dst_min_width is not None:
                new_width = max(dst_min_width, new_width)
            if dst_max_width is not None:
                valid_ratio = min(1.0, 1.0 * new_width / dst_max_width)
                resize_width = min(dst_max_width, new_width)
                img_resize = mmcv.imresize(
                    results['img'], (resize_width, dst_height),
                    backend=self.backend)
                resize_shape = img_resize.shape
                pad_shape = img_resize.shape
                if new_width < dst_max_width:
                    img_resize = mmcv.impad(
                        img_resize,
                        shape=(dst_height, dst_max_width),
                        pad_val=self.img_pad_value)
                    pad_shape = img_resize.shape
            else:
                img_resize = mmcv.imresize(
                    results['img'], (new_width, dst_height),
                    backend=self.backend)
                resize_shape = img_resize.shape
                pad_shape = img_resize.shape
        else:
            img_resize = mmcv.imresize(
                results['img'], (dst_max_width, dst_height),
                backend=self.backend)
            resize_shape = img_resize.shape
            pad_shape = img_resize.shape

        results['img'] = img_resize
        results['img_shape'] = resize_shape
        results['resize_shape'] = resize_shape
        results['pad_shape'] = pad_shape
        results['valid_ratio'] = valid_ratio

        return results


@PIPELINES.register_module()
class ToTensorOCR:
    """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor."""

    def __init__(self):
        pass

    def __call__(self, results):
        results['img'] = TF.to_tensor(results['img'].copy())

        return results


@PIPELINES.register_module()
class NormalizeOCR:
    """Normalize a tensor image with mean and standard deviation."""

    def __init__(self, mean, std):
        self.mean = mean
        self.std = std

    def __call__(self, results):
        results['img'] = TF.normalize(results['img'], self.mean, self.std)
        results['img_norm_cfg'] = dict(mean=self.mean, std=self.std)
        return results


@PIPELINES.register_module()
class OnlineCropOCR:
    """Crop text areas from whole image with bounding box jitter. If no bbox is
    given, return directly.

    Args:
        box_keys (list[str]): Keys in results which correspond to RoI bbox.
        jitter_prob (float): The probability of box jitter.
        max_jitter_ratio_x (float): Maximum horizontal jitter ratio
            relative to height.
        max_jitter_ratio_y (float): Maximum vertical jitter ratio
            relative to height.
    """

    def __init__(self,
                 box_keys=['x1', 'y1', 'x2', 'y2', 'x3', 'y3', 'x4', 'y4'],
                 jitter_prob=0.5,
                 max_jitter_ratio_x=0.05,
                 max_jitter_ratio_y=0.02):
        assert utils.is_type_list(box_keys, str)
        assert 0 <= jitter_prob <= 1
        assert 0 <= max_jitter_ratio_x <= 1
        assert 0 <= max_jitter_ratio_y <= 1

        self.box_keys = box_keys
        self.jitter_prob = jitter_prob
        self.max_jitter_ratio_x = max_jitter_ratio_x
        self.max_jitter_ratio_y = max_jitter_ratio_y

    def __call__(self, results):

        if 'img_info' not in results:
            return results

        crop_flag = True
        box = []
        for key in self.box_keys:
            if key not in results['img_info']:
                crop_flag = False
                break

            box.append(float(results['img_info'][key]))

        if not crop_flag:
            return results

        jitter_flag = np.random.random() > self.jitter_prob

        kwargs = dict(
            jitter_flag=jitter_flag,
            jitter_ratio_x=self.max_jitter_ratio_x,
            jitter_ratio_y=self.max_jitter_ratio_y)
        crop_img = warp_img(results['img'], box, **kwargs)

        results['img'] = crop_img
        results['img_shape'] = crop_img.shape

        return results


@PIPELINES.register_module()
class FancyPCA:
    """Implementation of PCA based image augmentation, proposed in the paper
    ``Imagenet Classification With Deep Convolutional Neural Networks``.

    It alters the intensities of RGB values along the principal components of
    ImageNet dataset.
    """

    def __init__(self, eig_vec=None, eig_val=None):
        if eig_vec is None:
            eig_vec = torch.Tensor([
                [-0.5675, +0.7192, +0.4009],
                [-0.5808, -0.0045, -0.8140],
                [-0.5836, -0.6948, +0.4203],
            ]).t()
        if eig_val is None:
            eig_val = torch.Tensor([[0.2175, 0.0188, 0.0045]])
        self.eig_val = eig_val  # 1*3
        self.eig_vec = eig_vec  # 3*3

    def pca(self, tensor):
        assert tensor.size(0) == 3
        alpha = torch.normal(mean=torch.zeros_like(self.eig_val)) * 0.1
        reconst = torch.mm(self.eig_val * alpha, self.eig_vec)
        tensor = tensor + reconst.view(3, 1, 1)

        return tensor

    def __call__(self, results):
        img = results['img']
        tensor = self.pca(img)
        results['img'] = tensor

        return results

    def __repr__(self):
        repr_str = self.__class__.__name__
        return repr_str


@PIPELINES.register_module()
class RandomPaddingOCR:
    """Pad the given image on all sides, as well as modify the coordinates of
    character bounding box in image.

    Args:
        max_ratio (list[int]): [left, top, right, bottom].
        box_type (None|str): Character box type. If not none,
            should be either 'char_rects' or 'char_quads', with
            'char_rects' for rectangle with ``xyxy`` style and
            'char_quads' for quadrangle with ``x1y1x2y2x3y3x4y4`` style.
    """

    def __init__(self, max_ratio=None, box_type=None):
        if max_ratio is None:
            max_ratio = [0.1, 0.2, 0.1, 0.2]
        else:
            assert utils.is_type_list(max_ratio, float)
            assert len(max_ratio) == 4
        assert box_type is None or box_type in ('char_rects', 'char_quads')

        self.max_ratio = max_ratio
        self.box_type = box_type

    def __call__(self, results):

        img_shape = results['img_shape']
        ori_height, ori_width = img_shape[:2]

        random_padding_left = round(
            np.random.uniform(0, self.max_ratio[0]) * ori_width)
        random_padding_top = round(
            np.random.uniform(0, self.max_ratio[1]) * ori_height)
        random_padding_right = round(
            np.random.uniform(0, self.max_ratio[2]) * ori_width)
        random_padding_bottom = round(
            np.random.uniform(0, self.max_ratio[3]) * ori_height)

        padding = (random_padding_left, random_padding_top,
                   random_padding_right, random_padding_bottom)
        img = mmcv.impad(results['img'], padding=padding, padding_mode='edge')

        results['img'] = img
        results['img_shape'] = img.shape

        if self.box_type is not None:
            num_points = 2 if self.box_type == 'char_rects' else 4
            char_num = len(results['ann_info'][self.box_type])
            for i in range(char_num):
                for j in range(num_points):
                    results['ann_info'][self.box_type][i][
                        j * 2] += random_padding_left
                    results['ann_info'][self.box_type][i][
                        j * 2 + 1] += random_padding_top

        return results

    def __repr__(self):
        repr_str = self.__class__.__name__
        return repr_str


@PIPELINES.register_module()
class RandomRotateImageBox:
    """Rotate augmentation for segmentation based text recognition.

    Args:
        min_angle (int): Minimum rotation angle for image and box.
        max_angle (int): Maximum rotation angle for image and box.
        box_type (str): Character box type, should be either
            'char_rects' or 'char_quads', with 'char_rects'
            for rectangle with ``xyxy`` style and 'char_quads'
            for quadrangle with ``x1y1x2y2x3y3x4y4`` style.
    """

    def __init__(self, min_angle=-10, max_angle=10, box_type='char_quads'):
        assert box_type in ('char_rects', 'char_quads')

        self.min_angle = min_angle
        self.max_angle = max_angle
        self.box_type = box_type

    def __call__(self, results):
        in_img = results['img']
        in_chars = results['ann_info']['chars']
        in_boxes = results['ann_info'][self.box_type]

        img_width, img_height = in_img.size
        rotate_center = [img_width / 2., img_height / 2.]

        tan_temp_max_angle = rotate_center[1] / rotate_center[0]
        temp_max_angle = np.arctan(tan_temp_max_angle) * 180. / np.pi

        random_angle = np.random.uniform(
            max(self.min_angle, -temp_max_angle),
            min(self.max_angle, temp_max_angle))
        random_angle_radian = random_angle * np.pi / 180.

        img_box = shapely_box(0, 0, img_width, img_height)

        out_img = TF.rotate(
            in_img,
            random_angle,
            resample=False,
            expand=False,
            center=rotate_center)

        out_boxes, out_chars = self.rotate_bbox(in_boxes, in_chars,
                                                random_angle_radian,
                                                rotate_center, img_box)

        results['img'] = out_img
        results['ann_info']['chars'] = out_chars
        results['ann_info'][self.box_type] = out_boxes

        return results

    @staticmethod
    def rotate_bbox(boxes, chars, angle, center, img_box):
        out_boxes = []
        out_chars = []
        for idx, bbox in enumerate(boxes):
            temp_bbox = []
            for i in range(len(bbox) // 2):
                point = [bbox[2 * i], bbox[2 * i + 1]]
                temp_bbox.append(
                    RandomRotateImageBox.rotate_point(point, angle, center))
            poly_temp_bbox = Polygon(temp_bbox).buffer(0)
            if poly_temp_bbox.is_valid:
                if img_box.intersects(poly_temp_bbox) and (
                        not img_box.touches(poly_temp_bbox)):
                    temp_bbox_area = poly_temp_bbox.area

                    intersect_area = img_box.intersection(poly_temp_bbox).area
                    intersect_ratio = intersect_area / temp_bbox_area

                    if intersect_ratio >= 0.7:
                        out_box = []
                        for p in temp_bbox:
                            out_box.extend(p)
                        out_boxes.append(out_box)
                        out_chars.append(chars[idx])

        return out_boxes, out_chars

    @staticmethod
    def rotate_point(point, angle, center):
        cos_theta = math.cos(-angle)
        sin_theta = math.sin(-angle)
        c_x = center[0]
        c_y = center[1]
        new_x = (point[0] - c_x) * cos_theta - (point[1] -
                                                c_y) * sin_theta + c_x
        new_y = (point[0] - c_x) * sin_theta + (point[1] -
                                                c_y) * cos_theta + c_y

        return [new_x, new_y]


@PIPELINES.register_module()
class OpencvToPil:
    """Convert ``numpy.ndarray`` (bgr) to ``PIL Image`` (rgb)."""

    def __init__(self, **kwargs):
        pass

    def __call__(self, results):
        img = results['img'][..., ::-1]
        img = Image.fromarray(img)
        results['img'] = img

        return results

    def __repr__(self):
        repr_str = self.__class__.__name__
        return repr_str


@PIPELINES.register_module()
class PilToOpencv:
    """Convert ``PIL Image`` (rgb) to ``numpy.ndarray`` (bgr)."""

    def __init__(self, **kwargs):
        pass

    def __call__(self, results):
        img = np.asarray(results['img'])
        img = img[..., ::-1]
        results['img'] = img

        return results

    def __repr__(self):
        repr_str = self.__class__.__name__
        return repr_str