Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import argparse | |
import os | |
from functools import partial | |
import mmcv | |
import numpy as np | |
from scipy.io import loadmat | |
def parse_args(): | |
parser = argparse.ArgumentParser( | |
description='Crop images in Synthtext-style dataset in ' | |
'prepration for MMOCR\'s use') | |
parser.add_argument( | |
'anno_path', help='Path to gold annotation data (gt.mat)') | |
parser.add_argument('img_path', help='Path to images') | |
parser.add_argument('out_dir', help='Path of output images and labels') | |
parser.add_argument( | |
'--n_proc', | |
default=1, | |
type=int, | |
help='Number of processes to run with') | |
args = parser.parse_args() | |
return args | |
def load_gt_datum(datum): | |
img_path, txt, wordBB, charBB = datum | |
words = [] | |
word_bboxes = [] | |
char_bboxes = [] | |
# when there's only one word in txt | |
# scipy will load it as a string | |
if type(txt) is str: | |
words = txt.split() | |
else: | |
for line in txt: | |
words += line.split() | |
# From (2, 4, num_boxes) to (num_boxes, 4, 2) | |
if len(wordBB.shape) == 2: | |
wordBB = wordBB[:, :, np.newaxis] | |
cur_wordBB = wordBB.transpose(2, 1, 0) | |
for box in cur_wordBB: | |
word_bboxes.append( | |
[max(round(coord), 0) for pt in box for coord in pt]) | |
# Validate word bboxes. | |
if len(words) != len(word_bboxes): | |
return | |
# From (2, 4, num_boxes) to (num_boxes, 4, 2) | |
cur_charBB = charBB.transpose(2, 1, 0) | |
for box in cur_charBB: | |
char_bboxes.append( | |
[max(round(coord), 0) for pt in box for coord in pt]) | |
char_bbox_idx = 0 | |
char_bbox_grps = [] | |
for word in words: | |
temp_bbox = char_bboxes[char_bbox_idx:char_bbox_idx + len(word)] | |
char_bbox_idx += len(word) | |
char_bbox_grps.append(temp_bbox) | |
# Validate char bboxes. | |
# If the length of the last char bbox is correct, then | |
# all the previous bboxes are also valid | |
if len(char_bbox_grps[len(words) - 1]) != len(words[-1]): | |
return | |
return img_path, words, word_bboxes, char_bbox_grps | |
def load_gt_data(filename, n_proc): | |
mat_data = loadmat(filename, simplify_cells=True) | |
imnames = mat_data['imnames'] | |
txt = mat_data['txt'] | |
wordBB = mat_data['wordBB'] | |
charBB = mat_data['charBB'] | |
return mmcv.track_parallel_progress( | |
load_gt_datum, list(zip(imnames, txt, wordBB, charBB)), nproc=n_proc) | |
def process(data, img_path_prefix, out_dir): | |
if data is None: | |
return | |
# Dirty hack for multi-processing | |
img_path, words, word_bboxes, char_bbox_grps = data | |
img_dir, img_name = os.path.split(img_path) | |
img_name = os.path.splitext(img_name)[0] | |
input_img = mmcv.imread(os.path.join(img_path_prefix, img_path)) | |
output_sub_dir = os.path.join(out_dir, img_dir) | |
if not os.path.exists(output_sub_dir): | |
try: | |
os.makedirs(output_sub_dir) | |
except FileExistsError: | |
pass # occurs when multi-proessing | |
for i, word in enumerate(words): | |
output_image_patch_name = f'{img_name}_{i}.png' | |
output_label_name = f'{img_name}_{i}.txt' | |
output_image_patch_path = os.path.join(output_sub_dir, | |
output_image_patch_name) | |
output_label_path = os.path.join(output_sub_dir, output_label_name) | |
if os.path.exists(output_image_patch_path) and os.path.exists( | |
output_label_path): | |
continue | |
word_bbox = word_bboxes[i] | |
min_x, max_x = int(min(word_bbox[::2])), int(max(word_bbox[::2])) | |
min_y, max_y = int(min(word_bbox[1::2])), int(max(word_bbox[1::2])) | |
cropped_img = input_img[min_y:max_y, min_x:max_x] | |
if cropped_img.shape[0] <= 0 or cropped_img.shape[1] <= 0: | |
continue | |
char_bbox_grp = np.array(char_bbox_grps[i]) | |
char_bbox_grp[:, ::2] -= min_x | |
char_bbox_grp[:, 1::2] -= min_y | |
mmcv.imwrite(cropped_img, output_image_patch_path) | |
with open(output_label_path, 'w') as output_label_file: | |
output_label_file.write(word + '\n') | |
for cbox in char_bbox_grp: | |
output_label_file.write('%d %d %d %d %d %d %d %d\n' % | |
tuple(cbox.tolist())) | |
def main(): | |
args = parse_args() | |
print('Loading annoataion data...') | |
data = load_gt_data(args.anno_path, args.n_proc) | |
process_with_outdir = partial( | |
process, img_path_prefix=args.img_path, out_dir=args.out_dir) | |
print('Creating cropped images and gold labels...') | |
mmcv.track_parallel_progress(process_with_outdir, data, nproc=args.n_proc) | |
print('Done') | |
if __name__ == '__main__': | |
main() | |