Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import warnings | |
import torch | |
from ..builder import DETECTORS | |
from .single_stage import SingleStageDetector | |
class DETR(SingleStageDetector): | |
r"""Implementation of `DETR: End-to-End Object Detection with | |
Transformers <https://arxiv.org/pdf/2005.12872>`_""" | |
def __init__(self, | |
backbone, | |
bbox_head, | |
train_cfg=None, | |
test_cfg=None, | |
pretrained=None, | |
init_cfg=None): | |
super(DETR, self).__init__(backbone, None, bbox_head, train_cfg, | |
test_cfg, pretrained, init_cfg) | |
# over-write `forward_dummy` because: | |
# the forward of bbox_head requires img_metas | |
def forward_dummy(self, img): | |
"""Used for computing network flops. | |
See `mmdetection/tools/analysis_tools/get_flops.py` | |
""" | |
warnings.warn('Warning! MultiheadAttention in DETR does not ' | |
'support flops computation! Do not use the ' | |
'results in your papers!') | |
batch_size, _, height, width = img.shape | |
dummy_img_metas = [ | |
dict( | |
batch_input_shape=(height, width), | |
img_shape=(height, width, 3)) for _ in range(batch_size) | |
] | |
x = self.extract_feat(img) | |
outs = self.bbox_head(x, dummy_img_metas) | |
return outs | |
# over-write `onnx_export` because: | |
# (1) the forward of bbox_head requires img_metas | |
# (2) the different behavior (e.g. construction of `masks`) between | |
# torch and ONNX model, during the forward of bbox_head | |
def onnx_export(self, img, img_metas): | |
"""Test function for exporting to ONNX, without test time augmentation. | |
Args: | |
img (torch.Tensor): input images. | |
img_metas (list[dict]): List of image information. | |
Returns: | |
tuple[Tensor, Tensor]: dets of shape [N, num_det, 5] | |
and class labels of shape [N, num_det]. | |
""" | |
x = self.extract_feat(img) | |
# forward of this head requires img_metas | |
outs = self.bbox_head.forward_onnx(x, img_metas) | |
# get shape as tensor | |
img_shape = torch._shape_as_tensor(img)[2:] | |
img_metas[0]['img_shape_for_onnx'] = img_shape | |
det_bboxes, det_labels = self.bbox_head.onnx_export(*outs, img_metas) | |
return det_bboxes, det_labels | |