|
import copy |
|
import warnings |
|
|
|
from mmcv.cnn import VGG |
|
from mmcv.runner.hooks import HOOKS, Hook |
|
|
|
from mmdet.datasets.builder import PIPELINES |
|
from mmdet.datasets.pipelines import LoadAnnotations, LoadImageFromFile |
|
from mmdet.models.dense_heads import GARPNHead, RPNHead |
|
from mmdet.models.roi_heads.mask_heads import FusedSemanticHead |
|
|
|
|
|
def replace_ImageToTensor(pipelines): |
|
"""Replace the ImageToTensor transform in a data pipeline to |
|
DefaultFormatBundle, which is normally useful in batch inference. |
|
|
|
Args: |
|
pipelines (list[dict]): Data pipeline configs. |
|
|
|
Returns: |
|
list: The new pipeline list with all ImageToTensor replaced by |
|
DefaultFormatBundle. |
|
|
|
Examples: |
|
>>> pipelines = [ |
|
... dict(type='LoadImageFromFile'), |
|
... dict( |
|
... type='MultiScaleFlipAug', |
|
... img_scale=(1333, 800), |
|
... flip=False, |
|
... transforms=[ |
|
... dict(type='Resize', keep_ratio=True), |
|
... dict(type='RandomFlip'), |
|
... dict(type='Normalize', mean=[0, 0, 0], std=[1, 1, 1]), |
|
... dict(type='Pad', size_divisor=32), |
|
... dict(type='ImageToTensor', keys=['img']), |
|
... dict(type='Collect', keys=['img']), |
|
... ]) |
|
... ] |
|
>>> expected_pipelines = [ |
|
... dict(type='LoadImageFromFile'), |
|
... dict( |
|
... type='MultiScaleFlipAug', |
|
... img_scale=(1333, 800), |
|
... flip=False, |
|
... transforms=[ |
|
... dict(type='Resize', keep_ratio=True), |
|
... dict(type='RandomFlip'), |
|
... dict(type='Normalize', mean=[0, 0, 0], std=[1, 1, 1]), |
|
... dict(type='Pad', size_divisor=32), |
|
... dict(type='DefaultFormatBundle'), |
|
... dict(type='Collect', keys=['img']), |
|
... ]) |
|
... ] |
|
>>> assert expected_pipelines == replace_ImageToTensor(pipelines) |
|
""" |
|
pipelines = copy.deepcopy(pipelines) |
|
for i, pipeline in enumerate(pipelines): |
|
if pipeline['type'] == 'MultiScaleFlipAug': |
|
assert 'transforms' in pipeline |
|
pipeline['transforms'] = replace_ImageToTensor( |
|
pipeline['transforms']) |
|
elif pipeline['type'] == 'ImageToTensor': |
|
warnings.warn( |
|
'"ImageToTensor" pipeline is replaced by ' |
|
'"DefaultFormatBundle" for batch inference. It is ' |
|
'recommended to manually replace it in the test ' |
|
'data pipeline in your config file.', UserWarning) |
|
pipelines[i] = {'type': 'DefaultFormatBundle'} |
|
return pipelines |
|
|
|
|
|
def get_loading_pipeline(pipeline): |
|
"""Only keep loading image and annotations related configuration. |
|
|
|
Args: |
|
pipeline (list[dict]): Data pipeline configs. |
|
|
|
Returns: |
|
list[dict]: The new pipeline list with only keep |
|
loading image and annotations related configuration. |
|
|
|
Examples: |
|
>>> pipelines = [ |
|
... dict(type='LoadImageFromFile'), |
|
... dict(type='LoadAnnotations', with_bbox=True), |
|
... dict(type='Resize', img_scale=(1333, 800), keep_ratio=True), |
|
... dict(type='RandomFlip', flip_ratio=0.5), |
|
... dict(type='Normalize', **img_norm_cfg), |
|
... dict(type='Pad', size_divisor=32), |
|
... dict(type='DefaultFormatBundle'), |
|
... dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']) |
|
... ] |
|
>>> expected_pipelines = [ |
|
... dict(type='LoadImageFromFile'), |
|
... dict(type='LoadAnnotations', with_bbox=True) |
|
... ] |
|
>>> assert expected_pipelines ==\ |
|
... get_loading_pipeline(pipelines) |
|
""" |
|
loading_pipeline_cfg = [] |
|
for cfg in pipeline: |
|
obj_cls = PIPELINES.get(cfg['type']) |
|
|
|
if obj_cls is not None and obj_cls in (LoadImageFromFile, |
|
LoadAnnotations): |
|
loading_pipeline_cfg.append(cfg) |
|
assert len(loading_pipeline_cfg) == 2, \ |
|
'The data pipeline in your config file must include ' \ |
|
'loading image and annotations related pipeline.' |
|
return loading_pipeline_cfg |
|
|
|
|
|
@HOOKS.register_module() |
|
class NumClassCheckHook(Hook): |
|
|
|
def _check_head(self, runner): |
|
"""Check whether the `num_classes` in head matches the length of |
|
`CLASSSES` in `dataset`. |
|
|
|
Args: |
|
runner (obj:`EpochBasedRunner`): Epoch based Runner. |
|
""" |
|
model = runner.model |
|
dataset = runner.data_loader.dataset |
|
if dataset.CLASSES is None: |
|
runner.logger.warning( |
|
f'Please set `CLASSES` ' |
|
f'in the {dataset.__class__.__name__} and' |
|
f'check if it is consistent with the `num_classes` ' |
|
f'of head') |
|
else: |
|
assert type(dataset.CLASSES) is not str, \ |
|
(f'`CLASSES` in {dataset.__class__.__name__}' |
|
f'should be a tuple of str.' |
|
f'Add comma if number of classes is 1 as ' |
|
f'CLASSES = ({dataset.CLASSES},)') |
|
for name, module in model.named_modules(): |
|
if hasattr(module, 'num_classes') and not isinstance( |
|
module, (RPNHead, VGG, FusedSemanticHead, GARPNHead)): |
|
assert module.num_classes == len(dataset.CLASSES), \ |
|
(f'The `num_classes` ({module.num_classes}) in ' |
|
f'{module.__class__.__name__} of ' |
|
f'{model.__class__.__name__} does not matches ' |
|
f'the length of `CLASSES` ' |
|
f'{len(dataset.CLASSES)}) in ' |
|
f'{dataset.__class__.__name__}') |
|
|
|
def before_train_epoch(self, runner): |
|
"""Check whether the training dataset is compatible with head. |
|
|
|
Args: |
|
runner (obj:`EpochBasedRunner`): Epoch based Runner. |
|
""" |
|
self._check_head(runner) |
|
|
|
def before_val_epoch(self, runner): |
|
"""Check whether the dataset in val epoch is compatible with head. |
|
|
|
Args: |
|
runner (obj:`EpochBasedRunner`): Epoch based Runner. |
|
""" |
|
self._check_head(runner) |
|
|