MMOCR / tools /data /textrecog /synthtext_converter.py
tomofi's picture
Add application file
2366e36
raw
history blame
4.75 kB
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os
from functools import partial
import mmcv
import numpy as np
from scipy.io import loadmat
def parse_args():
parser = argparse.ArgumentParser(
description='Crop images in Synthtext-style dataset in '
'prepration for MMOCR\'s use')
parser.add_argument(
'anno_path', help='Path to gold annotation data (gt.mat)')
parser.add_argument('img_path', help='Path to images')
parser.add_argument('out_dir', help='Path of output images and labels')
parser.add_argument(
'--n_proc',
default=1,
type=int,
help='Number of processes to run with')
args = parser.parse_args()
return args
def load_gt_datum(datum):
img_path, txt, wordBB, charBB = datum
words = []
word_bboxes = []
char_bboxes = []
# when there's only one word in txt
# scipy will load it as a string
if type(txt) is str:
words = txt.split()
else:
for line in txt:
words += line.split()
# From (2, 4, num_boxes) to (num_boxes, 4, 2)
if len(wordBB.shape) == 2:
wordBB = wordBB[:, :, np.newaxis]
cur_wordBB = wordBB.transpose(2, 1, 0)
for box in cur_wordBB:
word_bboxes.append(
[max(round(coord), 0) for pt in box for coord in pt])
# Validate word bboxes.
if len(words) != len(word_bboxes):
return
# From (2, 4, num_boxes) to (num_boxes, 4, 2)
cur_charBB = charBB.transpose(2, 1, 0)
for box in cur_charBB:
char_bboxes.append(
[max(round(coord), 0) for pt in box for coord in pt])
char_bbox_idx = 0
char_bbox_grps = []
for word in words:
temp_bbox = char_bboxes[char_bbox_idx:char_bbox_idx + len(word)]
char_bbox_idx += len(word)
char_bbox_grps.append(temp_bbox)
# Validate char bboxes.
# If the length of the last char bbox is correct, then
# all the previous bboxes are also valid
if len(char_bbox_grps[len(words) - 1]) != len(words[-1]):
return
return img_path, words, word_bboxes, char_bbox_grps
def load_gt_data(filename, n_proc):
mat_data = loadmat(filename, simplify_cells=True)
imnames = mat_data['imnames']
txt = mat_data['txt']
wordBB = mat_data['wordBB']
charBB = mat_data['charBB']
return mmcv.track_parallel_progress(
load_gt_datum, list(zip(imnames, txt, wordBB, charBB)), nproc=n_proc)
def process(data, img_path_prefix, out_dir):
if data is None:
return
# Dirty hack for multi-processing
img_path, words, word_bboxes, char_bbox_grps = data
img_dir, img_name = os.path.split(img_path)
img_name = os.path.splitext(img_name)[0]
input_img = mmcv.imread(os.path.join(img_path_prefix, img_path))
output_sub_dir = os.path.join(out_dir, img_dir)
if not os.path.exists(output_sub_dir):
try:
os.makedirs(output_sub_dir)
except FileExistsError:
pass # occurs when multi-proessing
for i, word in enumerate(words):
output_image_patch_name = f'{img_name}_{i}.png'
output_label_name = f'{img_name}_{i}.txt'
output_image_patch_path = os.path.join(output_sub_dir,
output_image_patch_name)
output_label_path = os.path.join(output_sub_dir, output_label_name)
if os.path.exists(output_image_patch_path) and os.path.exists(
output_label_path):
continue
word_bbox = word_bboxes[i]
min_x, max_x = int(min(word_bbox[::2])), int(max(word_bbox[::2]))
min_y, max_y = int(min(word_bbox[1::2])), int(max(word_bbox[1::2]))
cropped_img = input_img[min_y:max_y, min_x:max_x]
if cropped_img.shape[0] <= 0 or cropped_img.shape[1] <= 0:
continue
char_bbox_grp = np.array(char_bbox_grps[i])
char_bbox_grp[:, ::2] -= min_x
char_bbox_grp[:, 1::2] -= min_y
mmcv.imwrite(cropped_img, output_image_patch_path)
with open(output_label_path, 'w') as output_label_file:
output_label_file.write(word + '\n')
for cbox in char_bbox_grp:
output_label_file.write('%d %d %d %d %d %d %d %d\n' %
tuple(cbox.tolist()))
def main():
args = parse_args()
print('Loading annoataion data...')
data = load_gt_data(args.anno_path, args.n_proc)
process_with_outdir = partial(
process, img_path_prefix=args.img_path, out_dir=args.out_dir)
print('Creating cropped images and gold labels...')
mmcv.track_parallel_progress(process_with_outdir, data, nproc=args.n_proc)
print('Done')
if __name__ == '__main__':
main()