Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import numpy as np | |
from mmdet.datasets.builder import DATASETS | |
from mmocr.core.evaluation.hmean import eval_hmean | |
from mmocr.datasets.base_dataset import BaseDataset | |
class TextDetDataset(BaseDataset): | |
def _parse_anno_info(self, annotations): | |
"""Parse bbox and mask annotation. | |
Args: | |
annotations (dict): Annotations of one image. | |
Returns: | |
dict: A dict containing the following keys: bboxes, bboxes_ignore, | |
labels, masks, masks_ignore. "masks" and | |
"masks_ignore" are represented by polygon boundary | |
point sequences. | |
""" | |
gt_bboxes, gt_bboxes_ignore = [], [] | |
gt_masks, gt_masks_ignore = [], [] | |
gt_labels = [] | |
for ann in annotations: | |
if ann.get('iscrowd', False): | |
gt_bboxes_ignore.append(ann['bbox']) | |
gt_masks_ignore.append(ann.get('segmentation', None)) | |
else: | |
gt_bboxes.append(ann['bbox']) | |
gt_labels.append(ann['category_id']) | |
gt_masks.append(ann.get('segmentation', None)) | |
if gt_bboxes: | |
gt_bboxes = np.array(gt_bboxes, dtype=np.float32) | |
gt_labels = np.array(gt_labels, dtype=np.int64) | |
else: | |
gt_bboxes = np.zeros((0, 4), dtype=np.float32) | |
gt_labels = np.array([], dtype=np.int64) | |
if gt_bboxes_ignore: | |
gt_bboxes_ignore = np.array(gt_bboxes_ignore, dtype=np.float32) | |
else: | |
gt_bboxes_ignore = np.zeros((0, 4), dtype=np.float32) | |
ann = dict( | |
bboxes=gt_bboxes, | |
labels=gt_labels, | |
bboxes_ignore=gt_bboxes_ignore, | |
masks_ignore=gt_masks_ignore, | |
masks=gt_masks) | |
return ann | |
def prepare_train_img(self, index): | |
"""Get training data and annotations from pipeline. | |
Args: | |
index (int): Index of data. | |
Returns: | |
dict: Training data and annotation after pipeline with new keys | |
introduced by pipeline. | |
""" | |
img_ann_info = self.data_infos[index] | |
img_info = { | |
'filename': img_ann_info['file_name'], | |
'height': img_ann_info['height'], | |
'width': img_ann_info['width'] | |
} | |
ann_info = self._parse_anno_info(img_ann_info['annotations']) | |
results = dict(img_info=img_info, ann_info=ann_info) | |
results['bbox_fields'] = [] | |
results['mask_fields'] = [] | |
results['seg_fields'] = [] | |
self.pre_pipeline(results) | |
return self.pipeline(results) | |
def evaluate(self, | |
results, | |
metric='hmean-iou', | |
score_thr=0.3, | |
rank_list=None, | |
logger=None, | |
**kwargs): | |
"""Evaluate the dataset. | |
Args: | |
results (list): Testing results of the dataset. | |
metric (str | list[str]): Metrics to be evaluated. | |
score_thr (float): Score threshold for prediction map. | |
logger (logging.Logger | str | None): Logger used for printing | |
related information during evaluation. Default: None. | |
rank_list (str): json file used to save eval result | |
of each image after ranking. | |
Returns: | |
dict[str: float] | |
""" | |
metrics = metric if isinstance(metric, list) else [metric] | |
allowed_metrics = ['hmean-iou', 'hmean-ic13'] | |
metrics = set(metrics) & set(allowed_metrics) | |
img_infos = [] | |
ann_infos = [] | |
for i in range(len(self)): | |
img_ann_info = self.data_infos[i] | |
img_info = {'filename': img_ann_info['file_name']} | |
ann_info = self._parse_anno_info(img_ann_info['annotations']) | |
img_infos.append(img_info) | |
ann_infos.append(ann_info) | |
eval_results = eval_hmean( | |
results, | |
img_infos, | |
ann_infos, | |
metrics=metrics, | |
score_thr=score_thr, | |
logger=logger, | |
rank_list=rank_list) | |
return eval_results | |