File size: 3,367 Bytes
2366e36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
from mmcv import rescale_size
from mmcv.parallel import DataContainer as DC
from mmdet.datasets.builder import PIPELINES
from mmdet.datasets.pipelines.formating import DefaultFormatBundle, to_tensor


@PIPELINES.register_module()
class ResizeNoImg:
    """Image resizing without img.

    Used for KIE.
    """

    def __init__(self, img_scale, keep_ratio=True):
        self.img_scale = img_scale
        self.keep_ratio = keep_ratio

    def __call__(self, results):
        w, h = results['img_info']['width'], results['img_info']['height']
        if self.keep_ratio:
            (new_w, new_h) = rescale_size((w, h),
                                          self.img_scale,
                                          return_scale=False)
            w_scale = new_w / w
            h_scale = new_h / h
        else:
            (new_w, new_h) = self.img_scale

        w_scale = new_w / w
        h_scale = new_h / h
        scale_factor = np.array([w_scale, h_scale, w_scale, h_scale],
                                dtype=np.float32)
        results['img_shape'] = (new_h, new_w, 1)
        results['scale_factor'] = scale_factor
        results['keep_ratio'] = True

        return results


@PIPELINES.register_module()
class KIEFormatBundle(DefaultFormatBundle):
    """Key information extraction formatting bundle.

    Based on the DefaultFormatBundle, itt simplifies the pipeline of formatting
    common fields, including "img", "proposals", "gt_bboxes", "gt_labels",
    "gt_masks", "gt_semantic_seg", "relations" and "texts".
    These fields are formatted as follows.

    - img: (1) transpose, (2) to tensor, (3) to DataContainer (stack=True)
    - proposals: (1) to tensor, (2) to DataContainer
    - gt_bboxes: (1) to tensor, (2) to DataContainer
    - gt_bboxes_ignore: (1) to tensor, (2) to DataContainer
    - gt_labels: (1) to tensor, (2) to DataContainer
    - gt_masks: (1) to tensor, (2) to DataContainer (cpu_only=True)
    - gt_semantic_seg: (1) unsqueeze dim-0 (2) to tensor,
                       (3) to DataContainer (stack=True)
    - relations: (1) scale, (2) to tensor, (3) to DataContainer
    - texts: (1) to tensor, (2) to DataContainer
    """

    def __call__(self, results):
        """Call function to transform and format common fields in results.

        Args:
            results (dict): Result dict contains the data to convert.

        Returns:
            dict: The result dict contains the data that is formatted with
                default bundle.
        """
        super().__call__(results)
        if 'ann_info' in results:
            for key in ['relations', 'texts']:
                value = results['ann_info'][key]
                if key == 'relations' and 'scale_factor' in results:
                    scale_factor = results['scale_factor']
                    if isinstance(scale_factor, float):
                        sx = sy = scale_factor
                    else:
                        sx, sy = results['scale_factor'][:2]
                    r = sx / sy
                    factor = np.array([sx, sy, r, 1, r]).astype(np.float32)
                    value = value * factor[None, None]
                results[key] = DC(to_tensor(value))
        return results

    def __repr__(self):
        return self.__class__.__name__