|
import os.path as osp |
|
import warnings |
|
from collections import OrderedDict |
|
|
|
import mmcv |
|
import numpy as np |
|
from mmcv.utils import print_log |
|
from terminaltables import AsciiTable |
|
from torch.utils.data import Dataset |
|
|
|
from mmdet.core import eval_map, eval_recalls |
|
from .builder import DATASETS |
|
from .pipelines import Compose |
|
|
|
|
|
@DATASETS.register_module() |
|
class CustomDataset(Dataset): |
|
"""Custom dataset for detection. |
|
|
|
The annotation format is shown as follows. The `ann` field is optional for |
|
testing. |
|
|
|
.. code-block:: none |
|
|
|
[ |
|
{ |
|
'filename': 'a.jpg', |
|
'width': 1280, |
|
'height': 720, |
|
'ann': { |
|
'bboxes': <np.ndarray> (n, 4) in (x1, y1, x2, y2) order. |
|
'labels': <np.ndarray> (n, ), |
|
'bboxes_ignore': <np.ndarray> (k, 4), (optional field) |
|
'labels_ignore': <np.ndarray> (k, 4) (optional field) |
|
} |
|
}, |
|
... |
|
] |
|
|
|
Args: |
|
ann_file (str): Annotation file path. |
|
pipeline (list[dict]): Processing pipeline. |
|
classes (str | Sequence[str], optional): Specify classes to load. |
|
If is None, ``cls.CLASSES`` will be used. Default: None. |
|
data_root (str, optional): Data root for ``ann_file``, |
|
``img_prefix``, ``seg_prefix``, ``proposal_file`` if specified. |
|
test_mode (bool, optional): If set True, annotation will not be loaded. |
|
filter_empty_gt (bool, optional): If set true, images without bounding |
|
boxes of the dataset's classes will be filtered out. This option |
|
only works when `test_mode=False`, i.e., we never filter images |
|
during tests. |
|
""" |
|
|
|
CLASSES = None |
|
|
|
def __init__(self, |
|
ann_file, |
|
pipeline, |
|
classes=None, |
|
data_root=None, |
|
img_prefix='', |
|
seg_prefix=None, |
|
proposal_file=None, |
|
test_mode=False, |
|
filter_empty_gt=True): |
|
self.ann_file = ann_file |
|
self.data_root = data_root |
|
self.img_prefix = img_prefix |
|
self.seg_prefix = seg_prefix |
|
self.proposal_file = proposal_file |
|
self.test_mode = test_mode |
|
self.filter_empty_gt = filter_empty_gt |
|
self.CLASSES = self.get_classes(classes) |
|
|
|
|
|
if self.data_root is not None: |
|
if not osp.isabs(self.ann_file): |
|
self.ann_file = osp.join(self.data_root, self.ann_file) |
|
if not (self.img_prefix is None or osp.isabs(self.img_prefix)): |
|
self.img_prefix = osp.join(self.data_root, self.img_prefix) |
|
if not (self.seg_prefix is None or osp.isabs(self.seg_prefix)): |
|
self.seg_prefix = osp.join(self.data_root, self.seg_prefix) |
|
if not (self.proposal_file is None |
|
or osp.isabs(self.proposal_file)): |
|
self.proposal_file = osp.join(self.data_root, |
|
self.proposal_file) |
|
|
|
self.data_infos = self.load_annotations(self.ann_file) |
|
|
|
if self.proposal_file is not None: |
|
self.proposals = self.load_proposals(self.proposal_file) |
|
else: |
|
self.proposals = None |
|
|
|
|
|
if not test_mode: |
|
valid_inds = self._filter_imgs() |
|
self.data_infos = [self.data_infos[i] for i in valid_inds] |
|
if self.proposals is not None: |
|
self.proposals = [self.proposals[i] for i in valid_inds] |
|
|
|
self._set_group_flag() |
|
|
|
|
|
self.pipeline = Compose(pipeline) |
|
|
|
def __len__(self): |
|
"""Total number of samples of data.""" |
|
return len(self.data_infos) |
|
|
|
def load_annotations(self, ann_file): |
|
"""Load annotation from annotation file.""" |
|
return mmcv.load(ann_file) |
|
|
|
def load_proposals(self, proposal_file): |
|
"""Load proposal from proposal file.""" |
|
return mmcv.load(proposal_file) |
|
|
|
def get_ann_info(self, idx): |
|
"""Get annotation by index. |
|
|
|
Args: |
|
idx (int): Index of data. |
|
|
|
Returns: |
|
dict: Annotation info of specified index. |
|
""" |
|
|
|
return self.data_infos[idx]['ann'] |
|
|
|
def get_cat_ids(self, idx): |
|
"""Get category ids by index. |
|
|
|
Args: |
|
idx (int): Index of data. |
|
|
|
Returns: |
|
list[int]: All categories in the image of specified index. |
|
""" |
|
|
|
return self.data_infos[idx]['ann']['labels'].astype(np.int).tolist() |
|
|
|
def pre_pipeline(self, results): |
|
"""Prepare results dict for pipeline.""" |
|
results['img_prefix'] = self.img_prefix |
|
results['seg_prefix'] = self.seg_prefix |
|
results['proposal_file'] = self.proposal_file |
|
results['bbox_fields'] = [] |
|
results['mask_fields'] = [] |
|
results['seg_fields'] = [] |
|
|
|
def _filter_imgs(self, min_size=32): |
|
"""Filter images too small.""" |
|
if self.filter_empty_gt: |
|
warnings.warn( |
|
'CustomDataset does not support filtering empty gt images.') |
|
valid_inds = [] |
|
for i, img_info in enumerate(self.data_infos): |
|
if min(img_info['width'], img_info['height']) >= min_size: |
|
valid_inds.append(i) |
|
return valid_inds |
|
|
|
def _set_group_flag(self): |
|
"""Set flag according to image aspect ratio. |
|
|
|
Images with aspect ratio greater than 1 will be set as group 1, |
|
otherwise group 0. |
|
""" |
|
self.flag = np.zeros(len(self), dtype=np.uint8) |
|
for i in range(len(self)): |
|
img_info = self.data_infos[i] |
|
if img_info['width'] / img_info['height'] > 1: |
|
self.flag[i] = 1 |
|
|
|
def _rand_another(self, idx): |
|
"""Get another random index from the same group as the given index.""" |
|
pool = np.where(self.flag == self.flag[idx])[0] |
|
return np.random.choice(pool) |
|
|
|
def __getitem__(self, idx): |
|
"""Get training/test data after pipeline. |
|
|
|
Args: |
|
idx (int): Index of data. |
|
|
|
Returns: |
|
dict: Training/test data (with annotation if `test_mode` is set \ |
|
True). |
|
""" |
|
|
|
if self.test_mode: |
|
return self.prepare_test_img(idx) |
|
while True: |
|
data = self.prepare_train_img(idx) |
|
if data is None: |
|
idx = self._rand_another(idx) |
|
continue |
|
return data |
|
|
|
def prepare_train_img(self, idx): |
|
"""Get training data and annotations after pipeline. |
|
|
|
Args: |
|
idx (int): Index of data. |
|
|
|
Returns: |
|
dict: Training data and annotation after pipeline with new keys \ |
|
introduced by pipeline. |
|
""" |
|
|
|
img_info = self.data_infos[idx] |
|
ann_info = self.get_ann_info(idx) |
|
results = dict(img_info=img_info, ann_info=ann_info) |
|
if self.proposals is not None: |
|
results['proposals'] = self.proposals[idx] |
|
self.pre_pipeline(results) |
|
return self.pipeline(results) |
|
|
|
def prepare_test_img(self, idx): |
|
"""Get testing data after pipeline. |
|
|
|
Args: |
|
idx (int): Index of data. |
|
|
|
Returns: |
|
dict: Testing data after pipeline with new keys introduced by \ |
|
pipeline. |
|
""" |
|
|
|
img_info = self.data_infos[idx] |
|
results = dict(img_info=img_info) |
|
if self.proposals is not None: |
|
results['proposals'] = self.proposals[idx] |
|
self.pre_pipeline(results) |
|
return self.pipeline(results) |
|
|
|
@classmethod |
|
def get_classes(cls, classes=None): |
|
"""Get class names of current dataset. |
|
|
|
Args: |
|
classes (Sequence[str] | str | None): If classes is None, use |
|
default CLASSES defined by builtin dataset. If classes is a |
|
string, take it as a file name. The file contains the name of |
|
classes where each line contains one class name. If classes is |
|
a tuple or list, override the CLASSES defined by the dataset. |
|
|
|
Returns: |
|
tuple[str] or list[str]: Names of categories of the dataset. |
|
""" |
|
if classes is None: |
|
return cls.CLASSES |
|
|
|
if isinstance(classes, str): |
|
|
|
class_names = mmcv.list_from_file(classes) |
|
elif isinstance(classes, (tuple, list)): |
|
class_names = classes |
|
else: |
|
raise ValueError(f'Unsupported type {type(classes)} of classes.') |
|
|
|
return class_names |
|
|
|
def format_results(self, results, **kwargs): |
|
"""Place holder to format result to dataset specific output.""" |
|
|
|
def evaluate(self, |
|
results, |
|
metric='mAP', |
|
logger=None, |
|
proposal_nums=(100, 300, 1000), |
|
iou_thr=0.5, |
|
scale_ranges=None): |
|
"""Evaluate the dataset. |
|
|
|
Args: |
|
results (list): Testing results of the dataset. |
|
metric (str | list[str]): Metrics to be evaluated. |
|
logger (logging.Logger | None | str): Logger used for printing |
|
related information during evaluation. Default: None. |
|
proposal_nums (Sequence[int]): Proposal number used for evaluating |
|
recalls, such as recall@100, recall@1000. |
|
Default: (100, 300, 1000). |
|
iou_thr (float | list[float]): IoU threshold. Default: 0.5. |
|
scale_ranges (list[tuple] | None): Scale ranges for evaluating mAP. |
|
Default: None. |
|
""" |
|
|
|
if not isinstance(metric, str): |
|
assert len(metric) == 1 |
|
metric = metric[0] |
|
allowed_metrics = ['mAP', 'recall'] |
|
if metric not in allowed_metrics: |
|
raise KeyError(f'metric {metric} is not supported') |
|
annotations = [self.get_ann_info(i) for i in range(len(self))] |
|
eval_results = OrderedDict() |
|
iou_thrs = [iou_thr] if isinstance(iou_thr, float) else iou_thr |
|
if metric == 'mAP': |
|
assert isinstance(iou_thrs, list) |
|
mean_aps = [] |
|
for iou_thr in iou_thrs: |
|
print_log(f'\n{"-" * 15}iou_thr: {iou_thr}{"-" * 15}') |
|
mean_ap, _ = eval_map( |
|
results, |
|
annotations, |
|
scale_ranges=scale_ranges, |
|
iou_thr=iou_thr, |
|
dataset=self.CLASSES, |
|
logger=logger) |
|
mean_aps.append(mean_ap) |
|
eval_results[f'AP{int(iou_thr * 100):02d}'] = round(mean_ap, 3) |
|
eval_results['mAP'] = sum(mean_aps) / len(mean_aps) |
|
elif metric == 'recall': |
|
gt_bboxes = [ann['bboxes'] for ann in annotations] |
|
recalls = eval_recalls( |
|
gt_bboxes, results, proposal_nums, iou_thr, logger=logger) |
|
for i, num in enumerate(proposal_nums): |
|
for j, iou in enumerate(iou_thrs): |
|
eval_results[f'recall@{num}@{iou}'] = recalls[i, j] |
|
if recalls.shape[1] > 1: |
|
ar = recalls.mean(axis=1) |
|
for i, num in enumerate(proposal_nums): |
|
eval_results[f'AR@{num}'] = ar[i] |
|
return eval_results |
|
|
|
def __repr__(self): |
|
"""Print the number of instance number.""" |
|
dataset_type = 'Test' if self.test_mode else 'Train' |
|
result = (f'\n{self.__class__.__name__} {dataset_type} dataset ' |
|
f'with number of images {len(self)}, ' |
|
f'and instance counts: \n') |
|
if self.CLASSES is None: |
|
result += 'Category names are not provided. \n' |
|
return result |
|
instance_count = np.zeros(len(self.CLASSES) + 1).astype(int) |
|
|
|
for idx in range(len(self)): |
|
label = self.get_ann_info(idx)['labels'] |
|
unique, counts = np.unique(label, return_counts=True) |
|
if len(unique) > 0: |
|
|
|
instance_count[unique] += counts |
|
else: |
|
|
|
instance_count[-1] += 1 |
|
|
|
table_data = [['category', 'count'] * 5] |
|
row_data = [] |
|
for cls, count in enumerate(instance_count): |
|
if cls < len(self.CLASSES): |
|
row_data += [f'{cls} [{self.CLASSES[cls]}]', f'{count}'] |
|
else: |
|
|
|
row_data += ['-1 background', f'{count}'] |
|
if len(row_data) == 10: |
|
table_data.append(row_data) |
|
row_data = [] |
|
|
|
table = AsciiTable(table_data) |
|
result += table.table |
|
return result |
|
|