MMOCR / tools /data /common /curvedsyntext_converter.py
tomofi's picture
Add application file
2366e36
raw
history blame
4.38 kB
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os.path as osp
from functools import partial
import mmcv
import numpy as np
from mmocr.utils import bezier_to_polygon, sort_points
# The default dictionary used by CurvedSynthText
dict95 = [
' ', '!', '"', '#', '$', '%', '&', '\'', '(', ')', '*', '+', ',', '-', '.',
'/', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', ':', ';', '<', '=',
'>', '?', '@', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L',
'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', '[',
'\\', ']', '^', '_', '`', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j',
'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y',
'z', '{', '|', '}', '~'
]
UNK = len(dict95)
EOS = UNK + 1
def digit2text(rec):
res = []
for d in rec:
assert d <= EOS
if d == EOS:
break
if d == UNK:
print('Warning: Has a UNK character')
res.append('口') # Or any special character not in the target dict
res.append(dict95[d])
return ''.join(res)
def modify_annotation(ann, num_sample, start_img_id=0, start_ann_id=0):
ann['text'] = digit2text(ann.pop('rec'))
# Get hide egmentation points
polygon_pts = bezier_to_polygon(ann['bezier_pts'], num_sample=num_sample)
ann['segmentation'] = np.asarray(sort_points(polygon_pts)).reshape(
1, -1).tolist()
ann['image_id'] += start_img_id
ann['id'] += start_ann_id
return ann
def modify_image_info(image_info, path_prefix, start_img_id=0):
image_info['file_name'] = osp.join(path_prefix, image_info['file_name'])
image_info['id'] += start_img_id
return image_info
def parse_args():
parser = argparse.ArgumentParser(
description='Convert CurvedSynText150k to COCO format')
parser.add_argument('root_path', help='CurvedSynText150k root path')
parser.add_argument('-o', '--out-dir', help='Output path')
parser.add_argument(
'-n',
'--num-sample',
type=int,
default=4,
help='Number of sample points at each Bezier curve.')
parser.add_argument(
'--nproc', default=1, type=int, help='Number of processes')
args = parser.parse_args()
return args
def convert_annotations(data,
path_prefix,
num_sample,
nproc,
start_img_id=0,
start_ann_id=0):
modify_image_info_with_params = partial(
modify_image_info, path_prefix=path_prefix, start_img_id=start_img_id)
modify_annotation_with_params = partial(
modify_annotation,
num_sample=num_sample,
start_img_id=start_img_id,
start_ann_id=start_ann_id)
if nproc > 1:
data['annotations'] = mmcv.track_parallel_progress(
modify_annotation_with_params, data['annotations'], nproc=nproc)
data['images'] = mmcv.track_parallel_progress(
modify_image_info_with_params, data['images'], nproc=nproc)
else:
data['annotations'] = mmcv.track_progress(
modify_annotation_with_params, data['annotations'])
data['images'] = mmcv.track_progress(
modify_image_info_with_params,
data['images'],
)
data['categories'] = [{'id': 1, 'name': 'text'}]
return data
def main():
args = parse_args()
root_path = args.root_path
out_dir = args.out_dir if args.out_dir else root_path
mmcv.mkdir_or_exist(out_dir)
anns = mmcv.load(osp.join(root_path, 'train1.json'))
data1 = convert_annotations(anns, 'syntext_word_eng', args.num_sample,
args.nproc)
# Get the maximum image id from data1
start_img_id = max(data1['images'], key=lambda x: x['id'])['id'] + 1
start_ann_id = max(data1['annotations'], key=lambda x: x['id'])['id'] + 1
anns = mmcv.load(osp.join(root_path, 'train2.json'))
data2 = convert_annotations(
anns,
'emcs_imgs',
args.num_sample,
args.nproc,
start_img_id=start_img_id,
start_ann_id=start_ann_id)
data1['images'] += data2['images']
data1['annotations'] += data2['annotations']
mmcv.dump(data1, osp.join(out_dir, 'instances_training.json'))
if __name__ == '__main__':
main()