Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import copy | |
from mmdet.datasets import DATASETS, ConcatDataset, build_dataset | |
from mmocr.utils import is_2dlist, is_type_list | |
class UniformConcatDataset(ConcatDataset): | |
"""A wrapper of ConcatDataset which support dataset pipeline assignment and | |
replacement. | |
Args: | |
datasets (list[dict] | list[list[dict]]): A list of datasets cfgs. | |
separate_eval (bool): Whether to evaluate the results | |
separately if it is used as validation dataset. | |
Defaults to True. | |
pipeline (None | list[dict] | list[list[dict]]): If ``None``, | |
each dataset in datasets use its own pipeline; | |
If ``list[dict]``, it will be assigned to the dataset whose | |
pipeline is None in datasets; | |
If ``list[list[dict]]``, pipeline of dataset which is None | |
in datasets will be replaced by the corresponding pipeline | |
in the list. | |
force_apply (bool): If True, apply pipeline above to each dataset | |
even if it have its own pipeline. Default: False. | |
""" | |
def __init__(self, | |
datasets, | |
separate_eval=True, | |
pipeline=None, | |
force_apply=False, | |
**kwargs): | |
new_datasets = [] | |
if pipeline is not None: | |
assert isinstance( | |
pipeline, | |
list), 'pipeline must be list[dict] or list[list[dict]].' | |
if is_type_list(pipeline, dict): | |
self._apply_pipeline(datasets, pipeline, force_apply) | |
new_datasets = datasets | |
elif is_2dlist(pipeline): | |
assert is_2dlist(datasets) | |
assert len(datasets) == len(pipeline) | |
for sub_datasets, tmp_pipeline in zip(datasets, pipeline): | |
self._apply_pipeline(sub_datasets, tmp_pipeline, | |
force_apply) | |
new_datasets.extend(sub_datasets) | |
else: | |
if is_2dlist(datasets): | |
for sub_datasets in datasets: | |
new_datasets.extend(sub_datasets) | |
else: | |
new_datasets = datasets | |
datasets = [build_dataset(c, kwargs) for c in new_datasets] | |
super().__init__(datasets, separate_eval) | |
def _apply_pipeline(datasets, pipeline, force_apply=False): | |
from_cfg = all(isinstance(x, dict) for x in datasets) | |
assert from_cfg, 'datasets should be config dicts' | |
assert all(isinstance(x, dict) for x in pipeline) | |
for dataset in datasets: | |
if dataset['pipeline'] is None or force_apply: | |
dataset['pipeline'] = copy.deepcopy(pipeline) | |