# Copyright (c) OpenMMLab. All rights reserved. from ..builder import DETECTORS from .two_stage import TwoStageDetector @DETECTORS.register_module() class FastRCNN(TwoStageDetector): """Implementation of `Fast R-CNN `_""" def __init__(self, backbone, roi_head, train_cfg, test_cfg, neck=None, pretrained=None, init_cfg=None): super(FastRCNN, self).__init__( backbone=backbone, neck=neck, roi_head=roi_head, train_cfg=train_cfg, test_cfg=test_cfg, pretrained=pretrained, init_cfg=init_cfg) def forward_test(self, imgs, img_metas, proposals, **kwargs): """ Args: imgs (List[Tensor]): the outer list indicates test-time augmentations and inner Tensor should have a shape NxCxHxW, which contains all images in the batch. img_metas (List[List[dict]]): the outer list indicates test-time augs (multiscale, flip, etc.) and the inner list indicates images in a batch. proposals (List[List[Tensor]]): the outer list indicates test-time augs (multiscale, flip, etc.) and the inner list indicates images in a batch. The Tensor should have a shape Px4, where P is the number of proposals. """ for var, name in [(imgs, 'imgs'), (img_metas, 'img_metas')]: if not isinstance(var, list): raise TypeError(f'{name} must be a list, but got {type(var)}') num_augs = len(imgs) if num_augs != len(img_metas): raise ValueError(f'num of augmentations ({len(imgs)}) ' f'!= num of image meta ({len(img_metas)})') if num_augs == 1: return self.simple_test(imgs[0], img_metas[0], proposals[0], **kwargs) else: # TODO: support test-time augmentation assert NotImplementedError