# Copyright (c) OpenMMLab. All rights reserved. import numpy as np from mmcv.utils import print_log from mmdet.datasets.builder import DATASETS from mmdet.datasets.pipelines import Compose from torch.utils.data import Dataset from mmocr.datasets.builder import build_loader @DATASETS.register_module() class BaseDataset(Dataset): """Custom dataset for text detection, text recognition, and their downstream tasks. 1. The text detection annotation format is as follows: The `annotations` field is optional for testing (this is one line of anno_file, with line-json-str converted to dict for visualizing only). { "file_name": "sample.jpg", "height": 1080, "width": 960, "annotations": [ { "iscrowd": 0, "category_id": 1, "bbox": [357.0, 667.0, 804.0, 100.0], "segmentation": [[361, 667, 710, 670, 72, 767, 357, 763]] } ] } 2. The two text recognition annotation formats are as follows: The `x1,y1,x2,y2,x3,y3,x4,y4` field is used for online crop augmentation during training. format1: sample.jpg hello format2: sample.jpg 20 20 100 20 100 40 20 40 hello Args: ann_file (str): Annotation file path. pipeline (list[dict]): Processing pipeline. loader (dict): Dictionary to construct loader to load annotation infos. img_prefix (str, optional): Image prefix to generate full image path. test_mode (bool, optional): If set True, try...except will be turned off in __getitem__. """ def __init__(self, ann_file, loader, pipeline, img_prefix='', test_mode=False): super().__init__() self.test_mode = test_mode self.img_prefix = img_prefix self.ann_file = ann_file # load annotations loader.update(ann_file=ann_file) self.data_infos = build_loader(loader) # processing pipeline self.pipeline = Compose(pipeline) # set group flag and class, no meaning # for text detect and recognize self._set_group_flag() self.CLASSES = 0 def __len__(self): return len(self.data_infos) def _set_group_flag(self): """Set flag.""" self.flag = np.zeros(len(self), dtype=np.uint8) def pre_pipeline(self, results): """Prepare results dict for pipeline.""" results['img_prefix'] = self.img_prefix def prepare_train_img(self, index): """Get training data and annotations from pipeline. Args: index (int): Index of data. Returns: dict: Training data and annotation after pipeline with new keys introduced by pipeline. """ img_info = self.data_infos[index] results = dict(img_info=img_info) self.pre_pipeline(results) return self.pipeline(results) def prepare_test_img(self, img_info): """Get testing data from pipeline. Args: idx (int): Index of data. Returns: dict: Testing data after pipeline with new keys introduced by pipeline. """ return self.prepare_train_img(img_info) def _log_error_index(self, index): """Logging data info of bad index.""" try: data_info = self.data_infos[index] img_prefix = self.img_prefix print_log(f'Warning: skip broken file {data_info} ' f'with img_prefix {img_prefix}') except Exception as e: print_log(f'load index {index} with error {e}') def _get_next_index(self, index): """Get next index from dataset.""" self._log_error_index(index) index = (index + 1) % len(self) return index def __getitem__(self, index): """Get training/test data from pipeline. Args: index (int): Index of data. Returns: dict: Training/test data. """ if self.test_mode: return self.prepare_test_img(index) while True: try: data = self.prepare_train_img(index) if data is None: raise Exception('prepared train data empty') break except Exception as e: print_log(f'prepare index {index} with error {e}') index = self._get_next_index(index) return data def format_results(self, results, **kwargs): """Placeholder to format result to dataset-specific output.""" pass def evaluate(self, results, metric=None, logger=None, **kwargs): """Evaluate the dataset. Args: results (list): Testing results of the dataset. metric (str | list[str]): Metrics to be evaluated. logger (logging.Logger | str | None): Logger used for printing related information during evaluation. Default: None. Returns: dict[str: float] """ raise NotImplementedError