Spaces:
Runtime error
Runtime error
File size: 6,305 Bytes
2366e36 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 |
# Copyright (c) OpenMMLab. All rights reserved.
import mmcv
import numpy as np
from mmdet.datasets.api_wrappers import COCO
from mmdet.datasets.builder import DATASETS
from mmdet.datasets.coco import CocoDataset
import mmocr.utils as utils
from mmocr import digit_version
from mmocr.core.evaluation.hmean import eval_hmean
@DATASETS.register_module()
class IcdarDataset(CocoDataset):
"""Dataset for text detection while ann_file in coco format.
Args:
ann_file_backend (str): Storage backend for annotation file,
should be one in ['disk', 'petrel', 'http']. Default to 'disk'.
"""
CLASSES = ('text')
def __init__(self,
ann_file,
pipeline,
classes=None,
data_root=None,
img_prefix='',
seg_prefix=None,
proposal_file=None,
test_mode=False,
filter_empty_gt=True,
select_first_k=-1,
ann_file_backend='disk'):
# select first k images for fast debugging.
self.select_first_k = select_first_k
assert ann_file_backend in ['disk', 'petrel', 'http']
self.ann_file_backend = ann_file_backend
super().__init__(ann_file, pipeline, classes, data_root, img_prefix,
seg_prefix, proposal_file, test_mode, filter_empty_gt)
def load_annotations(self, ann_file):
"""Load annotation from COCO style annotation file.
Args:
ann_file (str): Path of annotation file.
Returns:
list[dict]: Annotation info from COCO api.
"""
if self.ann_file_backend == 'disk':
self.coco = COCO(ann_file)
else:
mmcv_version = digit_version(mmcv.__version__)
if mmcv_version < digit_version('1.3.16'):
raise Exception('Please update mmcv to 1.3.16 or higher '
'to enable "get_local_path" of "FileClient".')
file_client = mmcv.FileClient(backend=self.ann_file_backend)
with file_client.get_local_path(ann_file) as local_path:
self.coco = COCO(local_path)
self.cat_ids = self.coco.get_cat_ids(cat_names=self.CLASSES)
self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)}
self.img_ids = self.coco.get_img_ids()
data_infos = []
count = 0
for i in self.img_ids:
info = self.coco.load_imgs([i])[0]
info['filename'] = info['file_name']
data_infos.append(info)
count = count + 1
if count > self.select_first_k and self.select_first_k > 0:
break
return data_infos
def _parse_ann_info(self, img_info, ann_info):
"""Parse bbox and mask annotation.
Args:
ann_info (list[dict]): Annotation info of an image.
Returns:
dict: A dict containing the following keys: bboxes, bboxes_ignore,
labels, masks, masks_ignore, seg_map. "masks" and
"masks_ignore" are represented by polygon boundary
point sequences.
"""
gt_bboxes = []
gt_labels = []
gt_bboxes_ignore = []
gt_masks_ignore = []
gt_masks_ann = []
for ann in ann_info:
if ann.get('ignore', False):
continue
x1, y1, w, h = ann['bbox']
if ann['area'] <= 0 or w < 1 or h < 1:
continue
if ann['category_id'] not in self.cat_ids:
continue
bbox = [x1, y1, x1 + w, y1 + h]
if ann.get('iscrowd', False):
gt_bboxes_ignore.append(bbox)
gt_masks_ignore.append(ann.get(
'segmentation', None)) # to float32 for latter processing
else:
gt_bboxes.append(bbox)
gt_labels.append(self.cat2label[ann['category_id']])
gt_masks_ann.append(ann.get('segmentation', None))
if gt_bboxes:
gt_bboxes = np.array(gt_bboxes, dtype=np.float32)
gt_labels = np.array(gt_labels, dtype=np.int64)
else:
gt_bboxes = np.zeros((0, 4), dtype=np.float32)
gt_labels = np.array([], dtype=np.int64)
if gt_bboxes_ignore:
gt_bboxes_ignore = np.array(gt_bboxes_ignore, dtype=np.float32)
else:
gt_bboxes_ignore = np.zeros((0, 4), dtype=np.float32)
seg_map = img_info['filename'].replace('jpg', 'png')
ann = dict(
bboxes=gt_bboxes,
labels=gt_labels,
bboxes_ignore=gt_bboxes_ignore,
masks_ignore=gt_masks_ignore,
masks=gt_masks_ann,
seg_map=seg_map)
return ann
def evaluate(self,
results,
metric='hmean-iou',
logger=None,
score_thr=0.3,
rank_list=None,
**kwargs):
"""Evaluate the hmean metric.
Args:
results (list[dict]): Testing results of the dataset.
metric (str | list[str]): Metrics to be evaluated.
logger (logging.Logger | str | None): Logger used for printing
related information during evaluation. Default: None.
rank_list (str): json file used to save eval result
of each image after ranking.
Returns:
dict[dict[str: float]]: The evaluation results.
"""
assert utils.is_type_list(results, dict)
metrics = metric if isinstance(metric, list) else [metric]
allowed_metrics = ['hmean-iou', 'hmean-ic13']
metrics = set(metrics) & set(allowed_metrics)
img_infos = []
ann_infos = []
for i in range(len(self)):
img_info = {'filename': self.data_infos[i]['file_name']}
img_infos.append(img_info)
ann_infos.append(self.get_ann_info(i))
eval_results = eval_hmean(
results,
img_infos,
ann_infos,
metrics=metrics,
score_thr=score_thr,
logger=logger,
rank_list=rank_list)
return eval_results
|