|
|
|
import unittest |
|
|
|
from mmengine.dataset import ConcatDataset |
|
|
|
from mmyolo.datasets import YOLOv5VOCDataset |
|
from mmyolo.utils import register_all_modules |
|
|
|
register_all_modules() |
|
|
|
|
|
class TestYOLOv5VocDataset(unittest.TestCase): |
|
|
|
def test_batch_shapes_cfg(self): |
|
batch_shapes_cfg = dict( |
|
type='BatchShapePolicy', |
|
batch_size=2, |
|
img_size=640, |
|
size_divisor=32, |
|
extra_pad_ratio=0.5) |
|
|
|
|
|
dataset = YOLOv5VOCDataset( |
|
data_root='tests/data/VOCdevkit/', |
|
ann_file='VOC2007/ImageSets/Main/trainval.txt', |
|
data_prefix=dict(sub_data_root='VOC2007/'), |
|
test_mode=True, |
|
pipeline=[], |
|
batch_shapes_cfg=batch_shapes_cfg, |
|
) |
|
|
|
expected_img_ids = ['000001'] |
|
expected_batch_shapes = [[672, 480]] |
|
for i, data in enumerate(dataset): |
|
assert data['img_id'] == expected_img_ids[i] |
|
assert data['batch_shape'].tolist() == expected_batch_shapes[i] |
|
|
|
def test_prepare_data(self): |
|
dataset = YOLOv5VOCDataset( |
|
data_root='tests/data/VOCdevkit/', |
|
ann_file='VOC2007/ImageSets/Main/trainval.txt', |
|
data_prefix=dict(sub_data_root='VOC2007/'), |
|
filter_cfg=dict(filter_empty_gt=False, min_size=0), |
|
pipeline=[], |
|
serialize_data=True, |
|
batch_shapes_cfg=None, |
|
) |
|
for data in dataset: |
|
assert 'dataset' in data |
|
|
|
|
|
dataset = YOLOv5VOCDataset( |
|
data_root='tests/data/VOCdevkit/', |
|
ann_file='VOC2007/ImageSets/Main/trainval.txt', |
|
data_prefix=dict(sub_data_root='VOC2007/'), |
|
filter_cfg=dict( |
|
filter_empty_gt=True, min_size=32, bbox_min_size=None), |
|
pipeline=[], |
|
test_mode=True, |
|
batch_shapes_cfg=None) |
|
|
|
for data in dataset: |
|
assert 'dataset' not in data |
|
|
|
def test_concat_dataset(self): |
|
dataset = ConcatDataset( |
|
datasets=[ |
|
dict( |
|
type='YOLOv5VOCDataset', |
|
data_root='tests/data/VOCdevkit/', |
|
ann_file='VOC2007/ImageSets/Main/trainval.txt', |
|
data_prefix=dict(sub_data_root='VOC2007/'), |
|
filter_cfg=dict(filter_empty_gt=False, min_size=32), |
|
pipeline=[]), |
|
dict( |
|
type='YOLOv5VOCDataset', |
|
data_root='tests/data/VOCdevkit/', |
|
ann_file='VOC2012/ImageSets/Main/trainval.txt', |
|
data_prefix=dict(sub_data_root='VOC2012/'), |
|
filter_cfg=dict(filter_empty_gt=False, min_size=32), |
|
pipeline=[]) |
|
], |
|
ignore_keys='dataset_type') |
|
|
|
dataset.full_init() |
|
self.assertEqual(len(dataset), 2) |
|
|