Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import warnings | |
import mmcv | |
from ..builder import PIPELINES | |
from .compose import Compose | |
class MultiScaleFlipAug: | |
"""Test-time augmentation with multiple scales and flipping. | |
An example configuration is as followed: | |
.. code-block:: | |
img_scale=[(1333, 400), (1333, 800)], | |
flip=True, | |
transforms=[ | |
dict(type='Resize', keep_ratio=True), | |
dict(type='RandomFlip'), | |
dict(type='Normalize', **img_norm_cfg), | |
dict(type='Pad', size_divisor=32), | |
dict(type='ImageToTensor', keys=['img']), | |
dict(type='Collect', keys=['img']), | |
] | |
After MultiScaleFLipAug with above configuration, the results are wrapped | |
into lists of the same length as followed: | |
.. code-block:: | |
dict( | |
img=[...], | |
img_shape=[...], | |
scale=[(1333, 400), (1333, 400), (1333, 800), (1333, 800)] | |
flip=[False, True, False, True] | |
... | |
) | |
Args: | |
transforms (list[dict]): Transforms to apply in each augmentation. | |
img_scale (tuple | list[tuple] | None): Images scales for resizing. | |
scale_factor (float | list[float] | None): Scale factors for resizing. | |
flip (bool): Whether apply flip augmentation. Default: False. | |
flip_direction (str | list[str]): Flip augmentation directions, | |
options are "horizontal", "vertical" and "diagonal". If | |
flip_direction is a list, multiple flip augmentations will be | |
applied. It has no effect when flip == False. Default: | |
"horizontal". | |
""" | |
def __init__(self, | |
transforms, | |
img_scale=None, | |
scale_factor=None, | |
flip=False, | |
flip_direction='horizontal'): | |
self.transforms = Compose(transforms) | |
assert (img_scale is None) ^ (scale_factor is None), ( | |
'Must have but only one variable can be set') | |
if img_scale is not None: | |
self.img_scale = img_scale if isinstance(img_scale, | |
list) else [img_scale] | |
self.scale_key = 'scale' | |
assert mmcv.is_list_of(self.img_scale, tuple) | |
else: | |
self.img_scale = scale_factor if isinstance( | |
scale_factor, list) else [scale_factor] | |
self.scale_key = 'scale_factor' | |
self.flip = flip | |
self.flip_direction = flip_direction if isinstance( | |
flip_direction, list) else [flip_direction] | |
assert mmcv.is_list_of(self.flip_direction, str) | |
if not self.flip and self.flip_direction != ['horizontal']: | |
warnings.warn( | |
'flip_direction has no effect when flip is set to False') | |
if (self.flip | |
and not any([t['type'] == 'RandomFlip' for t in transforms])): | |
warnings.warn( | |
'flip has no effect when RandomFlip is not in transforms') | |
def __call__(self, results): | |
"""Call function to apply test time augment transforms on results. | |
Args: | |
results (dict): Result dict contains the data to transform. | |
Returns: | |
dict[str: list]: The augmented data, where each value is wrapped | |
into a list. | |
""" | |
aug_data = [] | |
flip_args = [(False, None)] | |
if self.flip: | |
flip_args += [(True, direction) | |
for direction in self.flip_direction] | |
for scale in self.img_scale: | |
for flip, direction in flip_args: | |
_results = results.copy() | |
_results[self.scale_key] = scale | |
_results['flip'] = flip | |
_results['flip_direction'] = direction | |
data = self.transforms(_results) | |
aug_data.append(data) | |
# list of dict to dict of list | |
aug_data_dict = {key: [] for key in aug_data[0]} | |
for data in aug_data: | |
for key, val in data.items(): | |
aug_data_dict[key].append(val) | |
return aug_data_dict | |
def __repr__(self): | |
repr_str = self.__class__.__name__ | |
repr_str += f'(transforms={self.transforms}, ' | |
repr_str += f'img_scale={self.img_scale}, flip={self.flip}, ' | |
repr_str += f'flip_direction={self.flip_direction})' | |
return repr_str | |