|
from collections import OrderedDict |
|
|
|
from mmcv.utils import print_log |
|
|
|
from mmdet.core import eval_map, eval_recalls |
|
from .builder import DATASETS |
|
from .xml_style import XMLDataset |
|
|
|
|
|
@DATASETS.register_module() |
|
class VOCDataset(XMLDataset): |
|
|
|
CLASSES = ('aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', |
|
'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', |
|
'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', |
|
'tvmonitor') |
|
|
|
def __init__(self, **kwargs): |
|
super(VOCDataset, self).__init__(**kwargs) |
|
if 'VOC2007' in self.img_prefix: |
|
self.year = 2007 |
|
elif 'VOC2012' in self.img_prefix: |
|
self.year = 2012 |
|
else: |
|
raise ValueError('Cannot infer dataset year from img_prefix') |
|
|
|
def evaluate(self, |
|
results, |
|
metric='mAP', |
|
logger=None, |
|
proposal_nums=(100, 300, 1000), |
|
iou_thr=0.5, |
|
scale_ranges=None): |
|
"""Evaluate in VOC protocol. |
|
|
|
Args: |
|
results (list[list | tuple]): Testing results of the dataset. |
|
metric (str | list[str]): Metrics to be evaluated. Options are |
|
'mAP', 'recall'. |
|
logger (logging.Logger | str, optional): Logger used for printing |
|
related information during evaluation. Default: None. |
|
proposal_nums (Sequence[int]): Proposal number used for evaluating |
|
recalls, such as recall@100, recall@1000. |
|
Default: (100, 300, 1000). |
|
iou_thr (float | list[float]): IoU threshold. Default: 0.5. |
|
scale_ranges (list[tuple], optional): Scale ranges for evaluating |
|
mAP. If not specified, all bounding boxes would be included in |
|
evaluation. Default: None. |
|
|
|
Returns: |
|
dict[str, float]: AP/recall metrics. |
|
""" |
|
|
|
if not isinstance(metric, str): |
|
assert len(metric) == 1 |
|
metric = metric[0] |
|
allowed_metrics = ['mAP', 'recall'] |
|
if metric not in allowed_metrics: |
|
raise KeyError(f'metric {metric} is not supported') |
|
annotations = [self.get_ann_info(i) for i in range(len(self))] |
|
eval_results = OrderedDict() |
|
iou_thrs = [iou_thr] if isinstance(iou_thr, float) else iou_thr |
|
if metric == 'mAP': |
|
assert isinstance(iou_thrs, list) |
|
if self.year == 2007: |
|
ds_name = 'voc07' |
|
else: |
|
ds_name = self.CLASSES |
|
mean_aps = [] |
|
for iou_thr in iou_thrs: |
|
print_log(f'\n{"-" * 15}iou_thr: {iou_thr}{"-" * 15}') |
|
mean_ap, _ = eval_map( |
|
results, |
|
annotations, |
|
scale_ranges=None, |
|
iou_thr=iou_thr, |
|
dataset=ds_name, |
|
logger=logger) |
|
mean_aps.append(mean_ap) |
|
eval_results[f'AP{int(iou_thr * 100):02d}'] = round(mean_ap, 3) |
|
eval_results['mAP'] = sum(mean_aps) / len(mean_aps) |
|
elif metric == 'recall': |
|
gt_bboxes = [ann['bboxes'] for ann in annotations] |
|
recalls = eval_recalls( |
|
gt_bboxes, results, proposal_nums, iou_thr, logger=logger) |
|
for i, num in enumerate(proposal_nums): |
|
for j, iou in enumerate(iou_thr): |
|
eval_results[f'recall@{num}@{iou}'] = recalls[i, j] |
|
if recalls.shape[1] > 1: |
|
ar = recalls.mean(axis=1) |
|
for i, num in enumerate(proposal_nums): |
|
eval_results[f'AR@{num}'] = ar[i] |
|
return eval_results |
|
|