tomofi's picture
Add application file
2366e36
raw
history blame
No virus
5.96 kB
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
import mmcv
from mmdet.core import bbox2roi
from torch import nn
from torch.nn import functional as F
from mmocr.core import imshow_edge, imshow_node
from mmocr.models.builder import DETECTORS, build_roi_extractor
from mmocr.models.common.detectors import SingleStageDetector
from mmocr.utils import list_from_file
@DETECTORS.register_module()
class SDMGR(SingleStageDetector):
"""The implementation of the paper: Spatial Dual-Modality Graph Reasoning
for Key Information Extraction. https://arxiv.org/abs/2103.14470.
Args:
visual_modality (bool): Whether use the visual modality.
class_list (None | str): Mapping file of class index to
class name. If None, class index will be shown in
`show_results`, else class name.
"""
def __init__(self,
backbone,
neck=None,
bbox_head=None,
extractor=dict(
type='mmdet.SingleRoIExtractor',
roi_layer=dict(type='RoIAlign', output_size=7),
featmap_strides=[1]),
visual_modality=False,
train_cfg=None,
test_cfg=None,
class_list=None,
init_cfg=None,
openset=False):
super().__init__(
backbone, neck, bbox_head, train_cfg, test_cfg, init_cfg=init_cfg)
self.visual_modality = visual_modality
if visual_modality:
self.extractor = build_roi_extractor({
**extractor, 'out_channels':
self.backbone.base_channels
})
self.maxpool = nn.MaxPool2d(extractor['roi_layer']['output_size'])
else:
self.extractor = None
self.class_list = class_list
self.openset = openset
def forward_train(self, img, img_metas, relations, texts, gt_bboxes,
gt_labels):
"""
Args:
img (tensor): Input images of shape (N, C, H, W).
Typically these should be mean centered and std scaled.
img_metas (list[dict]): A list of image info dict where each dict
contains: 'img_shape', 'scale_factor', 'flip', and may also
contain 'filename', 'ori_shape', 'pad_shape', and
'img_norm_cfg'. For details of the values of these keys,
please see :class:`mmdet.datasets.pipelines.Collect`.
relations (list[tensor]): Relations between bboxes.
texts (list[tensor]): Texts in bboxes.
gt_bboxes (list[tensor]): Each item is the truth boxes for each
image in [tl_x, tl_y, br_x, br_y] format.
gt_labels (list[tensor]): Class indices corresponding to each box.
Returns:
dict[str, tensor]: A dictionary of loss components.
"""
x = self.extract_feat(img, gt_bboxes)
node_preds, edge_preds = self.bbox_head.forward(relations, texts, x)
return self.bbox_head.loss(node_preds, edge_preds, gt_labels)
def forward_test(self,
img,
img_metas,
relations,
texts,
gt_bboxes,
rescale=False):
x = self.extract_feat(img, gt_bboxes)
node_preds, edge_preds = self.bbox_head.forward(relations, texts, x)
return [
dict(
img_metas=img_metas,
nodes=F.softmax(node_preds, -1),
edges=F.softmax(edge_preds, -1))
]
def extract_feat(self, img, gt_bboxes):
if self.visual_modality:
x = super().extract_feat(img)[-1]
feats = self.maxpool(self.extractor([x], bbox2roi(gt_bboxes)))
return feats.view(feats.size(0), -1)
return None
def show_result(self,
img,
result,
boxes,
win_name='',
show=False,
wait_time=0,
out_file=None,
**kwargs):
"""Draw `result` on `img`.
Args:
img (str or tensor): The image to be displayed.
result (dict): The results to draw on `img`.
boxes (list): Bbox of img.
win_name (str): The window name.
wait_time (int): Value of waitKey param.
Default: 0.
show (bool): Whether to show the image.
Default: False.
out_file (str or None): The output filename.
Default: None.
Returns:
img (tensor): Only if not `show` or `out_file`.
"""
img = mmcv.imread(img)
img = img.copy()
idx_to_cls = {}
if self.class_list is not None:
for line in list_from_file(self.class_list):
class_idx, class_label = line.strip().split()
idx_to_cls[class_idx] = class_label
# if out_file specified, do not show image in window
if out_file is not None:
show = False
if self.openset:
img = imshow_edge(
img,
result,
boxes,
show=show,
win_name=win_name,
wait_time=wait_time,
out_file=out_file)
else:
img = imshow_node(
img,
result,
boxes,
idx_to_cls=idx_to_cls,
show=show,
win_name=win_name,
wait_time=wait_time,
out_file=out_file)
if not (show or out_file):
warnings.warn('show==False and out_file is not specified, only '
'result image will be returned')
return img
return img