File size: 2,209 Bytes
2366e36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
from mmcv.parallel import DataContainer as DC
from mmdet.datasets.builder import PIPELINES
from mmdet.datasets.pipelines.formating import DefaultFormatBundle

from mmocr.core.visualize import overlay_mask_img, show_feature


@PIPELINES.register_module()
class CustomFormatBundle(DefaultFormatBundle):
    """Custom formatting bundle.

    It formats common fields such as 'img' and 'proposals' as done in
    DefaultFormatBundle, while other fields such as 'gt_kernels' and
    'gt_effective_region_mask' will be formatted to DC as follows:

    - gt_kernels: to DataContainer (cpu_only=True)
    - gt_effective_mask: to DataContainer (cpu_only=True)

    Args:
        keys (list[str]): Fields to be formatted to DC only.
        call_super (bool): If True, format common fields
            by DefaultFormatBundle, else format fields in keys above only.
        visualize (dict): If flag=True, visualize gt mask for debugging.
    """

    def __init__(self,
                 keys=[],
                 call_super=True,
                 visualize=dict(flag=False, boundary_key=None)):

        super().__init__()
        self.visualize = visualize
        self.keys = keys
        self.call_super = call_super

    def __call__(self, results):

        if self.visualize['flag']:
            img = results['img'].astype(np.uint8)
            boundary_key = self.visualize['boundary_key']
            if boundary_key is not None:
                img = overlay_mask_img(img, results[boundary_key].masks[0])

            features = [img]
            names = ['img']
            to_uint8 = [1]

            for k in results['mask_fields']:
                for iter in range(len(results[k].masks)):
                    features.append(results[k].masks[iter])
                    names.append(k + str(iter))
                    to_uint8.append(0)
            show_feature(features, names, to_uint8)

        if self.call_super:
            results = super().__call__(results)

        for k in self.keys:
            results[k] = DC(results[k], cpu_only=True)

        return results

    def __repr__(self):
        return self.__class__.__name__