import copy import warnings from mmcv.cnn import VGG from mmcv.runner.hooks import HOOKS, Hook from mmdet.datasets.builder import PIPELINES from mmdet.datasets.pipelines import LoadAnnotations, LoadImageFromFile from mmdet.models.dense_heads import GARPNHead, RPNHead from mmdet.models.roi_heads.mask_heads import FusedSemanticHead def replace_ImageToTensor(pipelines): """Replace the ImageToTensor transform in a data pipeline to DefaultFormatBundle, which is normally useful in batch inference. Args: pipelines (list[dict]): Data pipeline configs. Returns: list: The new pipeline list with all ImageToTensor replaced by DefaultFormatBundle. Examples: >>> pipelines = [ ... dict(type='LoadImageFromFile'), ... dict( ... type='MultiScaleFlipAug', ... img_scale=(1333, 800), ... flip=False, ... transforms=[ ... dict(type='Resize', keep_ratio=True), ... dict(type='RandomFlip'), ... dict(type='Normalize', mean=[0, 0, 0], std=[1, 1, 1]), ... dict(type='Pad', size_divisor=32), ... dict(type='ImageToTensor', keys=['img']), ... dict(type='Collect', keys=['img']), ... ]) ... ] >>> expected_pipelines = [ ... dict(type='LoadImageFromFile'), ... dict( ... type='MultiScaleFlipAug', ... img_scale=(1333, 800), ... flip=False, ... transforms=[ ... dict(type='Resize', keep_ratio=True), ... dict(type='RandomFlip'), ... dict(type='Normalize', mean=[0, 0, 0], std=[1, 1, 1]), ... dict(type='Pad', size_divisor=32), ... dict(type='DefaultFormatBundle'), ... dict(type='Collect', keys=['img']), ... ]) ... ] >>> assert expected_pipelines == replace_ImageToTensor(pipelines) """ pipelines = copy.deepcopy(pipelines) for i, pipeline in enumerate(pipelines): if pipeline['type'] == 'MultiScaleFlipAug': assert 'transforms' in pipeline pipeline['transforms'] = replace_ImageToTensor( pipeline['transforms']) elif pipeline['type'] == 'ImageToTensor': warnings.warn( '"ImageToTensor" pipeline is replaced by ' '"DefaultFormatBundle" for batch inference. It is ' 'recommended to manually replace it in the test ' 'data pipeline in your config file.', UserWarning) pipelines[i] = {'type': 'DefaultFormatBundle'} return pipelines def get_loading_pipeline(pipeline): """Only keep loading image and annotations related configuration. Args: pipeline (list[dict]): Data pipeline configs. Returns: list[dict]: The new pipeline list with only keep loading image and annotations related configuration. Examples: >>> pipelines = [ ... dict(type='LoadImageFromFile'), ... dict(type='LoadAnnotations', with_bbox=True), ... dict(type='Resize', img_scale=(1333, 800), keep_ratio=True), ... dict(type='RandomFlip', flip_ratio=0.5), ... dict(type='Normalize', **img_norm_cfg), ... dict(type='Pad', size_divisor=32), ... dict(type='DefaultFormatBundle'), ... dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']) ... ] >>> expected_pipelines = [ ... dict(type='LoadImageFromFile'), ... dict(type='LoadAnnotations', with_bbox=True) ... ] >>> assert expected_pipelines ==\ ... get_loading_pipeline(pipelines) """ loading_pipeline_cfg = [] for cfg in pipeline: obj_cls = PIPELINES.get(cfg['type']) # TODOļ¼šuse more elegant way to distinguish loading modules if obj_cls is not None and obj_cls in (LoadImageFromFile, LoadAnnotations): loading_pipeline_cfg.append(cfg) assert len(loading_pipeline_cfg) == 2, \ 'The data pipeline in your config file must include ' \ 'loading image and annotations related pipeline.' return loading_pipeline_cfg @HOOKS.register_module() class NumClassCheckHook(Hook): def _check_head(self, runner): """Check whether the `num_classes` in head matches the length of `CLASSSES` in `dataset`. Args: runner (obj:`EpochBasedRunner`): Epoch based Runner. """ model = runner.model dataset = runner.data_loader.dataset if dataset.CLASSES is None: runner.logger.warning( f'Please set `CLASSES` ' f'in the {dataset.__class__.__name__} and' f'check if it is consistent with the `num_classes` ' f'of head') else: assert type(dataset.CLASSES) is not str, \ (f'`CLASSES` in {dataset.__class__.__name__}' f'should be a tuple of str.' f'Add comma if number of classes is 1 as ' f'CLASSES = ({dataset.CLASSES},)') for name, module in model.named_modules(): if hasattr(module, 'num_classes') and not isinstance( module, (RPNHead, VGG, FusedSemanticHead, GARPNHead)): assert module.num_classes == len(dataset.CLASSES), \ (f'The `num_classes` ({module.num_classes}) in ' f'{module.__class__.__name__} of ' f'{model.__class__.__name__} does not matches ' f'the length of `CLASSES` ' f'{len(dataset.CLASSES)}) in ' f'{dataset.__class__.__name__}') def before_train_epoch(self, runner): """Check whether the training dataset is compatible with head. Args: runner (obj:`EpochBasedRunner`): Epoch based Runner. """ self._check_head(runner) def before_val_epoch(self, runner): """Check whether the dataset in val epoch is compatible with head. Args: runner (obj:`EpochBasedRunner`): Epoch based Runner. """ self._check_head(runner)