Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
# Copyright (c) 2019 Western Digital Corporation or its affiliates. | |
import torch | |
from ..builder import DETECTORS | |
from .single_stage import SingleStageDetector | |
class YOLOV3(SingleStageDetector): | |
def __init__(self, | |
backbone, | |
neck, | |
bbox_head, | |
train_cfg=None, | |
test_cfg=None, | |
pretrained=None, | |
init_cfg=None): | |
super(YOLOV3, self).__init__(backbone, neck, bbox_head, train_cfg, | |
test_cfg, pretrained, init_cfg) | |
def onnx_export(self, img, img_metas): | |
"""Test function for exporting to ONNX, without test time augmentation. | |
Args: | |
img (torch.Tensor): input images. | |
img_metas (list[dict]): List of image information. | |
Returns: | |
tuple[Tensor, Tensor]: dets of shape [N, num_det, 5] | |
and class labels of shape [N, num_det]. | |
""" | |
x = self.extract_feat(img) | |
outs = self.bbox_head.forward(x) | |
# get shape as tensor | |
img_shape = torch._shape_as_tensor(img)[2:] | |
img_metas[0]['img_shape_for_onnx'] = img_shape | |
det_bboxes, det_labels = self.bbox_head.onnx_export(*outs, img_metas) | |
return det_bboxes, det_labels | |