|
|
|
"""Extracting subsets from coco2017 dataset. |
|
|
|
This script is mainly used to debug and verify the correctness of the |
|
program quickly. |
|
The root folder format must be in the following format: |
|
|
|
βββ root |
|
β βββ annotations |
|
β βββ train2017 |
|
β βββ val2017 |
|
β βββ test2017 |
|
|
|
Currently, only support COCO2017. In the future will support user-defined |
|
datasets of standard coco JSON format. |
|
|
|
Example: |
|
python tools/misc/extract_subcoco.py ${ROOT} ${OUT_DIR} --num-img ${NUM_IMG} |
|
""" |
|
|
|
import argparse |
|
import os.path as osp |
|
import shutil |
|
|
|
import mmengine |
|
import numpy as np |
|
from pycocotools.coco import COCO |
|
|
|
|
|
|
|
def _process_data(args, |
|
in_dataset_type: str, |
|
out_dataset_type: str, |
|
year: str = '2017'): |
|
assert in_dataset_type in ('train', 'val') |
|
assert out_dataset_type in ('train', 'val') |
|
|
|
int_ann_file_name = f'annotations/instances_{in_dataset_type}{year}.json' |
|
out_ann_file_name = f'annotations/instances_{out_dataset_type}{year}.json' |
|
|
|
ann_path = osp.join(args.root, int_ann_file_name) |
|
json_data = mmengine.load(ann_path) |
|
|
|
new_json_data = { |
|
'info': json_data['info'], |
|
'licenses': json_data['licenses'], |
|
'categories': json_data['categories'], |
|
'images': [], |
|
'annotations': [] |
|
} |
|
|
|
area_dict = { |
|
'small': [0., 32 * 32], |
|
'medium': [32 * 32, 96 * 96], |
|
'large': [96 * 96, float('inf')] |
|
} |
|
|
|
coco = COCO(ann_path) |
|
|
|
|
|
areaRng = area_dict[args.area_size] if args.area_size else [] |
|
catIds = coco.getCatIds(args.classes) if args.classes else [] |
|
ann_ids = coco.getAnnIds(catIds=catIds, areaRng=areaRng) |
|
ann_info = coco.loadAnns(ann_ids) |
|
|
|
|
|
filter_img_ids = {ann['image_id'] for ann in ann_info} |
|
filter_img = coco.loadImgs(filter_img_ids) |
|
|
|
|
|
np.random.shuffle(filter_img) |
|
|
|
num_img = args.num_img if args.num_img > 0 else len(filter_img) |
|
if num_img > len(filter_img): |
|
print( |
|
f'num_img is too big, will be set to {len(filter_img)}, ' |
|
'because of not enough image after filter by classes and area_size' |
|
) |
|
num_img = len(filter_img) |
|
|
|
progress_bar = mmengine.ProgressBar(num_img) |
|
|
|
for i in range(num_img): |
|
file_name = filter_img[i]['file_name'] |
|
image_path = osp.join(args.root, in_dataset_type + year, file_name) |
|
|
|
ann_ids = coco.getAnnIds( |
|
imgIds=[filter_img[i]['id']], catIds=catIds, areaRng=areaRng) |
|
img_ann_info = coco.loadAnns(ann_ids) |
|
|
|
new_json_data['images'].append(filter_img[i]) |
|
new_json_data['annotations'].extend(img_ann_info) |
|
|
|
shutil.copy(image_path, osp.join(args.out_dir, |
|
out_dataset_type + year)) |
|
|
|
progress_bar.update() |
|
|
|
mmengine.dump(new_json_data, osp.join(args.out_dir, out_ann_file_name)) |
|
|
|
|
|
def _make_dirs(out_dir): |
|
mmengine.mkdir_or_exist(out_dir) |
|
mmengine.mkdir_or_exist(osp.join(out_dir, 'annotations')) |
|
mmengine.mkdir_or_exist(osp.join(out_dir, 'train2017')) |
|
mmengine.mkdir_or_exist(osp.join(out_dir, 'val2017')) |
|
|
|
|
|
def parse_args(): |
|
parser = argparse.ArgumentParser(description='Extract coco subset') |
|
parser.add_argument('root', help='root path') |
|
parser.add_argument( |
|
'out_dir', type=str, help='directory where subset coco will be saved.') |
|
parser.add_argument( |
|
'--num-img', |
|
default=50, |
|
type=int, |
|
help='num of extract image, -1 means all images') |
|
parser.add_argument( |
|
'--area-size', |
|
choices=['small', 'medium', 'large'], |
|
help='filter ground-truth info by area size') |
|
parser.add_argument( |
|
'--classes', nargs='+', help='filter ground-truth by class name') |
|
parser.add_argument( |
|
'--use-training-set', |
|
action='store_true', |
|
help='Whether to use the training set when extract the training set. ' |
|
'The training subset is extracted from the validation set by ' |
|
'default which can speed up.') |
|
parser.add_argument('--seed', default=-1, type=int, help='seed') |
|
args = parser.parse_args() |
|
return args |
|
|
|
|
|
def main(): |
|
args = parse_args() |
|
assert args.out_dir != args.root, \ |
|
'The file will be overwritten in place, ' \ |
|
'so the same folder is not allowed !' |
|
|
|
seed = int(args.seed) |
|
if seed != -1: |
|
print(f'Set the global seed: {seed}') |
|
np.random.seed(int(args.seed)) |
|
|
|
_make_dirs(args.out_dir) |
|
|
|
print('====Start processing train dataset====') |
|
if args.use_training_set: |
|
_process_data(args, 'train', 'train') |
|
else: |
|
_process_data(args, 'val', 'train') |
|
print('\n====Start processing val dataset====') |
|
_process_data(args, 'val', 'val') |
|
print(f'\n Result save to {args.out_dir}') |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|