MMOCR / mmocr /datasets /base_dataset.py
tomofi's picture
Add application file
2366e36
raw
history blame
No virus
5.38 kB
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
from mmcv.utils import print_log
from mmdet.datasets.builder import DATASETS
from mmdet.datasets.pipelines import Compose
from torch.utils.data import Dataset
from mmocr.datasets.builder import build_loader
@DATASETS.register_module()
class BaseDataset(Dataset):
"""Custom dataset for text detection, text recognition, and their
downstream tasks.
1. The text detection annotation format is as follows:
The `annotations` field is optional for testing
(this is one line of anno_file, with line-json-str
converted to dict for visualizing only).
{
"file_name": "sample.jpg",
"height": 1080,
"width": 960,
"annotations":
[
{
"iscrowd": 0,
"category_id": 1,
"bbox": [357.0, 667.0, 804.0, 100.0],
"segmentation": [[361, 667, 710, 670,
72, 767, 357, 763]]
}
]
}
2. The two text recognition annotation formats are as follows:
The `x1,y1,x2,y2,x3,y3,x4,y4` field is used for online crop
augmentation during training.
format1: sample.jpg hello
format2: sample.jpg 20 20 100 20 100 40 20 40 hello
Args:
ann_file (str): Annotation file path.
pipeline (list[dict]): Processing pipeline.
loader (dict): Dictionary to construct loader
to load annotation infos.
img_prefix (str, optional): Image prefix to generate full
image path.
test_mode (bool, optional): If set True, try...except will
be turned off in __getitem__.
"""
def __init__(self,
ann_file,
loader,
pipeline,
img_prefix='',
test_mode=False):
super().__init__()
self.test_mode = test_mode
self.img_prefix = img_prefix
self.ann_file = ann_file
# load annotations
loader.update(ann_file=ann_file)
self.data_infos = build_loader(loader)
# processing pipeline
self.pipeline = Compose(pipeline)
# set group flag and class, no meaning
# for text detect and recognize
self._set_group_flag()
self.CLASSES = 0
def __len__(self):
return len(self.data_infos)
def _set_group_flag(self):
"""Set flag."""
self.flag = np.zeros(len(self), dtype=np.uint8)
def pre_pipeline(self, results):
"""Prepare results dict for pipeline."""
results['img_prefix'] = self.img_prefix
def prepare_train_img(self, index):
"""Get training data and annotations from pipeline.
Args:
index (int): Index of data.
Returns:
dict: Training data and annotation after pipeline with new keys
introduced by pipeline.
"""
img_info = self.data_infos[index]
results = dict(img_info=img_info)
self.pre_pipeline(results)
return self.pipeline(results)
def prepare_test_img(self, img_info):
"""Get testing data from pipeline.
Args:
idx (int): Index of data.
Returns:
dict: Testing data after pipeline with new keys introduced by
pipeline.
"""
return self.prepare_train_img(img_info)
def _log_error_index(self, index):
"""Logging data info of bad index."""
try:
data_info = self.data_infos[index]
img_prefix = self.img_prefix
print_log(f'Warning: skip broken file {data_info} '
f'with img_prefix {img_prefix}')
except Exception as e:
print_log(f'load index {index} with error {e}')
def _get_next_index(self, index):
"""Get next index from dataset."""
self._log_error_index(index)
index = (index + 1) % len(self)
return index
def __getitem__(self, index):
"""Get training/test data from pipeline.
Args:
index (int): Index of data.
Returns:
dict: Training/test data.
"""
if self.test_mode:
return self.prepare_test_img(index)
while True:
try:
data = self.prepare_train_img(index)
if data is None:
raise Exception('prepared train data empty')
break
except Exception as e:
print_log(f'prepare index {index} with error {e}')
index = self._get_next_index(index)
return data
def format_results(self, results, **kwargs):
"""Placeholder to format result to dataset-specific output."""
pass
def evaluate(self, results, metric=None, logger=None, **kwargs):
"""Evaluate the dataset.
Args:
results (list): Testing results of the dataset.
metric (str | list[str]): Metrics to be evaluated.
logger (logging.Logger | str | None): Logger used for printing
related information during evaluation. Default: None.
Returns:
dict[str: float]
"""
raise NotImplementedError