File size: 5,963 Bytes
2366e36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
# Copyright (c) OpenMMLab. All rights reserved.
import warnings

import mmcv
from mmdet.core import bbox2roi
from torch import nn
from torch.nn import functional as F

from mmocr.core import imshow_edge, imshow_node
from mmocr.models.builder import DETECTORS, build_roi_extractor
from mmocr.models.common.detectors import SingleStageDetector
from mmocr.utils import list_from_file


@DETECTORS.register_module()
class SDMGR(SingleStageDetector):
    """The implementation of the paper: Spatial Dual-Modality Graph Reasoning
    for Key Information Extraction. https://arxiv.org/abs/2103.14470.

    Args:
        visual_modality (bool): Whether use the visual modality.
        class_list (None | str): Mapping file of class index to
            class name. If None, class index will be shown in
            `show_results`, else class name.
    """

    def __init__(self,
                 backbone,
                 neck=None,
                 bbox_head=None,
                 extractor=dict(
                     type='mmdet.SingleRoIExtractor',
                     roi_layer=dict(type='RoIAlign', output_size=7),
                     featmap_strides=[1]),
                 visual_modality=False,
                 train_cfg=None,
                 test_cfg=None,
                 class_list=None,
                 init_cfg=None,
                 openset=False):
        super().__init__(
            backbone, neck, bbox_head, train_cfg, test_cfg, init_cfg=init_cfg)
        self.visual_modality = visual_modality
        if visual_modality:
            self.extractor = build_roi_extractor({
                **extractor, 'out_channels':
                self.backbone.base_channels
            })
            self.maxpool = nn.MaxPool2d(extractor['roi_layer']['output_size'])
        else:
            self.extractor = None
        self.class_list = class_list
        self.openset = openset

    def forward_train(self, img, img_metas, relations, texts, gt_bboxes,
                      gt_labels):
        """
        Args:
            img (tensor): Input images of shape (N, C, H, W).
                Typically these should be mean centered and std scaled.
            img_metas (list[dict]): A list of image info dict where each dict
                contains: 'img_shape', 'scale_factor', 'flip', and may also
                contain 'filename', 'ori_shape', 'pad_shape', and
                'img_norm_cfg'. For details of the values of these keys,
                please see :class:`mmdet.datasets.pipelines.Collect`.
            relations (list[tensor]): Relations between bboxes.
            texts (list[tensor]): Texts in bboxes.
            gt_bboxes (list[tensor]): Each item is the truth boxes for each
                image in [tl_x, tl_y, br_x, br_y] format.
            gt_labels (list[tensor]): Class indices corresponding to each box.

        Returns:
            dict[str, tensor]: A dictionary of loss components.
        """
        x = self.extract_feat(img, gt_bboxes)
        node_preds, edge_preds = self.bbox_head.forward(relations, texts, x)
        return self.bbox_head.loss(node_preds, edge_preds, gt_labels)

    def forward_test(self,
                     img,
                     img_metas,
                     relations,
                     texts,
                     gt_bboxes,
                     rescale=False):
        x = self.extract_feat(img, gt_bboxes)
        node_preds, edge_preds = self.bbox_head.forward(relations, texts, x)
        return [
            dict(
                img_metas=img_metas,
                nodes=F.softmax(node_preds, -1),
                edges=F.softmax(edge_preds, -1))
        ]

    def extract_feat(self, img, gt_bboxes):
        if self.visual_modality:
            x = super().extract_feat(img)[-1]
            feats = self.maxpool(self.extractor([x], bbox2roi(gt_bboxes)))
            return feats.view(feats.size(0), -1)
        return None

    def show_result(self,
                    img,
                    result,
                    boxes,
                    win_name='',
                    show=False,
                    wait_time=0,
                    out_file=None,
                    **kwargs):
        """Draw `result` on `img`.

        Args:
            img (str or tensor): The image to be displayed.
            result (dict): The results to draw on `img`.
            boxes (list): Bbox of img.
            win_name (str): The window name.
            wait_time (int): Value of waitKey param.
                Default: 0.
            show (bool): Whether to show the image.
                Default: False.
            out_file (str or None): The output filename.
                Default: None.

        Returns:
            img (tensor): Only if not `show` or `out_file`.
        """
        img = mmcv.imread(img)
        img = img.copy()

        idx_to_cls = {}
        if self.class_list is not None:
            for line in list_from_file(self.class_list):
                class_idx, class_label = line.strip().split()
                idx_to_cls[class_idx] = class_label

        # if out_file specified, do not show image in window
        if out_file is not None:
            show = False

        if self.openset:
            img = imshow_edge(
                img,
                result,
                boxes,
                show=show,
                win_name=win_name,
                wait_time=wait_time,
                out_file=out_file)
        else:
            img = imshow_node(
                img,
                result,
                boxes,
                idx_to_cls=idx_to_cls,
                show=show,
                win_name=win_name,
                wait_time=wait_time,
                out_file=out_file)

        if not (show or out_file):
            warnings.warn('show==False and out_file is not specified, only '
                          'result image will be returned')
            return img

        return img