#!/usr/bin/env python # Copyright (c) OpenMMLab. All rights reserved. import copy import os import warnings from argparse import ArgumentParser, Namespace from pathlib import Path import mmcv import numpy as np import torch from mmcv.image.misc import tensor2imgs from mmcv.runner import load_checkpoint from mmcv.utils.config import Config from mmocr.apis import init_detector from mmocr.apis.inference import model_inference from mmocr.core.visualize import det_recog_show_result from mmocr.datasets.kie_dataset import KIEDataset from mmocr.datasets.pipelines.crop import crop_img from mmocr.models import build_detector from mmocr.utils.box_util import stitch_boxes_into_lines from mmocr.utils.fileio import list_from_file from mmocr.utils.model import revert_sync_batchnorm # Parse CLI arguments def parse_args(): parser = ArgumentParser() parser.add_argument( 'img', type=str, help='Input image file or folder path.') parser.add_argument( '--output', type=str, default='', help='Output file/folder name for visualization') parser.add_argument( '--det', type=str, default='PANet_IC15', help='Pretrained text detection algorithm') parser.add_argument( '--det-config', type=str, default='', help='Path to the custom config file of the selected det model. It ' 'overrides the settings in det') parser.add_argument( '--det-ckpt', type=str, default='', help='Path to the custom checkpoint file of the selected det model. ' 'It overrides the settings in det') parser.add_argument( '--recog', type=str, default='SEG', help='Pretrained text recognition algorithm') parser.add_argument( '--recog-config', type=str, default='', help='Path to the custom config file of the selected recog model. It' 'overrides the settings in recog') parser.add_argument( '--recog-ckpt', type=str, default='', help='Path to the custom checkpoint file of the selected recog model. ' 'It overrides the settings in recog') parser.add_argument( '--kie', type=str, default='', help='Pretrained key information extraction algorithm') parser.add_argument( '--kie-config', type=str, default='', help='Path to the custom config file of the selected kie model. It' 'overrides the settings in kie') parser.add_argument( '--kie-ckpt', type=str, default='', help='Path to the custom checkpoint file of the selected kie model. ' 'It overrides the settings in kie') parser.add_argument( '--config-dir', type=str, default=os.path.join(str(Path.cwd()), 'configs/'), help='Path to the config directory where all the config files ' 'are located. Defaults to "configs/"') parser.add_argument( '--batch-mode', action='store_true', help='Whether use batch mode for inference') parser.add_argument( '--recog-batch-size', type=int, default=0, help='Batch size for text recognition') parser.add_argument( '--det-batch-size', type=int, default=0, help='Batch size for text detection') parser.add_argument( '--single-batch-size', type=int, default=0, help='Batch size for separate det/recog inference') parser.add_argument( '--device', default=None, help='Device used for inference.') parser.add_argument( '--export', type=str, default='', help='Folder where the results of each image are exported') parser.add_argument( '--export-format', type=str, default='json', help='Format of the exported result file(s)') parser.add_argument( '--details', action='store_true', help='Whether include the text boxes coordinates and confidence values' ) parser.add_argument( '--imshow', action='store_true', help='Whether show image with OpenCV.') parser.add_argument( '--print-result', action='store_true', help='Prints the recognised text') parser.add_argument( '--merge', action='store_true', help='Merge neighboring boxes') parser.add_argument( '--merge-xdist', type=float, default=20, help='The maximum x-axis distance to merge boxes') args = parser.parse_args() if args.det == 'None': args.det = None if args.recog == 'None': args.recog = None # Warnings if args.merge and not (args.det and args.recog): warnings.warn( 'Box merging will not work if the script is not' ' running in detection + recognition mode.', UserWarning) if not os.path.samefile(args.config_dir, os.path.join(str( Path.cwd()))) and (args.det_config != '' or args.recog_config != ''): warnings.warn( 'config_dir will be overridden by det-config or recog-config.', UserWarning) return args class MMOCR: def __init__(self, det='PANet_IC15', det_config='', det_ckpt='', recog='SEG', recog_config='', recog_ckpt='', kie='', kie_config='', kie_ckpt='', config_dir=os.path.join(str(Path.cwd()), 'configs/'), device=None, **kwargs): textdet_models = { 'DB_r18': { 'config': 'dbnet/dbnet_r18_fpnc_1200e_icdar2015.py', 'ckpt': 'dbnet/' 'dbnet_r18_fpnc_sbn_1200e_icdar2015_20210329-ba3ab597.pth' }, 'DB_r50': { 'config': 'dbnet/dbnet_r50dcnv2_fpnc_1200e_icdar2015.py', 'ckpt': 'dbnet/' 'dbnet_r50dcnv2_fpnc_sbn_1200e_icdar2015_20211025-9fe3b590.pth' }, 'DRRG': { 'config': 'drrg/drrg_r50_fpn_unet_1200e_ctw1500.py', 'ckpt': 'drrg/drrg_r50_fpn_unet_1200e_ctw1500_20211022-fb30b001.pth' }, 'FCE_IC15': { 'config': 'fcenet/fcenet_r50_fpn_1500e_icdar2015.py', 'ckpt': 'fcenet/fcenet_r50_fpn_1500e_icdar2015_20211022-daefb6ed.pth' }, 'FCE_CTW_DCNv2': { 'config': 'fcenet/fcenet_r50dcnv2_fpn_1500e_ctw1500.py', 'ckpt': 'fcenet/' + 'fcenet_r50dcnv2_fpn_1500e_ctw1500_20211022-e326d7ec.pth' }, 'MaskRCNN_CTW': { 'config': 'maskrcnn/mask_rcnn_r50_fpn_160e_ctw1500.py', 'ckpt': 'maskrcnn/' 'mask_rcnn_r50_fpn_160e_ctw1500_20210219-96497a76.pth' }, 'MaskRCNN_IC15': { 'config': 'maskrcnn/mask_rcnn_r50_fpn_160e_icdar2015.py', 'ckpt': 'maskrcnn/' 'mask_rcnn_r50_fpn_160e_icdar2015_20210219-8eb340a3.pth' }, 'MaskRCNN_IC17': { 'config': 'maskrcnn/mask_rcnn_r50_fpn_160e_icdar2017.py', 'ckpt': 'maskrcnn/' 'mask_rcnn_r50_fpn_160e_icdar2017_20210218-c6ec3ebb.pth' }, 'PANet_CTW': { 'config': 'panet/panet_r18_fpem_ffm_600e_ctw1500.py', 'ckpt': 'panet/' 'panet_r18_fpem_ffm_sbn_600e_ctw1500_20210219-3b3a9aa3.pth' }, 'PANet_IC15': { 'config': 'panet/panet_r18_fpem_ffm_600e_icdar2015.py', 'ckpt': 'panet/' 'panet_r18_fpem_ffm_sbn_600e_icdar2015_20210219-42dbe46a.pth' }, 'PS_CTW': { 'config': 'psenet/psenet_r50_fpnf_600e_ctw1500.py', 'ckpt': 'psenet/psenet_r50_fpnf_600e_ctw1500_20210401-216fed50.pth' }, 'PS_IC15': { 'config': 'psenet/psenet_r50_fpnf_600e_icdar2015.py', 'ckpt': 'psenet/psenet_r50_fpnf_600e_icdar2015_pretrain-eefd8fe6.pth' }, 'TextSnake': { 'config': 'textsnake/textsnake_r50_fpn_unet_1200e_ctw1500.py', 'ckpt': 'textsnake/textsnake_r50_fpn_unet_1200e_ctw1500-27f65b64.pth' } } textrecog_models = { 'CRNN': { 'config': 'crnn/crnn_academic_dataset.py', 'ckpt': 'crnn/crnn_academic-a723a1c5.pth' }, 'SAR': { 'config': 'sar/sar_r31_parallel_decoder_academic.py', 'ckpt': 'sar/sar_r31_parallel_decoder_academic-dba3a4a3.pth' }, 'SAR_CN': { 'config': 'sar/sar_r31_parallel_decoder_chinese.py', 'ckpt': 'sar/sar_r31_parallel_decoder_chineseocr_20210507-b4be8214.pth' }, 'NRTR_1/16-1/8': { 'config': 'nrtr/nrtr_r31_1by16_1by8_academic.py', 'ckpt': 'nrtr/nrtr_r31_1by16_1by8_academic_20211124-f60cebf4.pth' }, 'NRTR_1/8-1/4': { 'config': 'nrtr/nrtr_r31_1by8_1by4_academic.py', 'ckpt': 'nrtr/nrtr_r31_1by8_1by4_academic_20211123-e1fdb322.pth' }, 'RobustScanner': { 'config': 'robust_scanner/robustscanner_r31_academic.py', 'ckpt': 'robustscanner/robustscanner_r31_academic-5f05874f.pth' }, 'SATRN': { 'config': 'satrn/satrn_academic.py', 'ckpt': 'satrn/satrn_academic_20211009-cb8b1580.pth' }, 'SATRN_sm': { 'config': 'satrn/satrn_small.py', 'ckpt': 'satrn/satrn_small_20211009-2cf13355.pth' }, 'ABINet': { 'config': 'abinet/abinet_academic.py', 'ckpt': 'abinet/abinet_academic-f718abf6.pth' }, 'SEG': { 'config': 'seg/seg_r31_1by16_fpnocr_academic.py', 'ckpt': 'seg/seg_r31_1by16_fpnocr_academic-72235b11.pth' }, 'CRNN_TPS': { 'config': 'tps/crnn_tps_academic_dataset.py', 'ckpt': 'tps/crnn_tps_academic_dataset_20210510-d221a905.pth' } } kie_models = { 'SDMGR': { 'config': 'sdmgr/sdmgr_unet16_60e_wildreceipt.py', 'ckpt': 'sdmgr/sdmgr_unet16_60e_wildreceipt_20210520-7489e6de.pth' } } self.td = det self.tr = recog self.kie = kie self.device = device if self.device is None: self.device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu') # Check if the det/recog model choice is valid if self.td and self.td not in textdet_models: raise ValueError(self.td, 'is not a supported text detection algorthm') elif self.tr and self.tr not in textrecog_models: raise ValueError(self.tr, 'is not a supported text recognition algorithm') elif self.kie: if self.kie not in kie_models: raise ValueError( self.kie, 'is not a supported key information extraction' ' algorithm') elif not (self.td and self.tr): raise NotImplementedError( self.kie, 'has to run together' ' with text detection and recognition algorithms.') self.detect_model = None if self.td: # Build detection model if not det_config: det_config = os.path.join(config_dir, 'textdet/', textdet_models[self.td]['config']) if not det_ckpt: det_ckpt = 'https://download.openmmlab.com/mmocr/textdet/' + \ textdet_models[self.td]['ckpt'] self.detect_model = init_detector( det_config, det_ckpt, device=self.device) self.detect_model = revert_sync_batchnorm(self.detect_model) self.recog_model = None if self.tr: # Build recognition model if not recog_config: recog_config = os.path.join( config_dir, 'textrecog/', textrecog_models[self.tr]['config']) if not recog_ckpt: recog_ckpt = 'https://download.openmmlab.com/mmocr/' + \ 'textrecog/' + textrecog_models[self.tr]['ckpt'] self.recog_model = init_detector( recog_config, recog_ckpt, device=self.device) self.recog_model = revert_sync_batchnorm(self.recog_model) self.kie_model = None if self.kie: # Build key information extraction model if not kie_config: kie_config = os.path.join(config_dir, 'kie/', kie_models[self.kie]['config']) if not kie_ckpt: kie_ckpt = 'https://download.openmmlab.com/mmocr/' + \ 'kie/' + kie_models[self.kie]['ckpt'] kie_cfg = Config.fromfile(kie_config) self.kie_model = build_detector( kie_cfg.model, test_cfg=kie_cfg.get('test_cfg')) self.kie_model = revert_sync_batchnorm(self.kie_model) self.kie_model.cfg = kie_cfg load_checkpoint(self.kie_model, kie_ckpt, map_location=self.device) # Attribute check for model in list(filter(None, [self.recog_model, self.detect_model])): if hasattr(model, 'module'): model = model.module def readtext(self, img, output=None, details=False, export=None, export_format='json', batch_mode=False, recog_batch_size=0, det_batch_size=0, single_batch_size=0, imshow=False, print_result=False, merge=False, merge_xdist=20, **kwargs): args = locals().copy() [args.pop(x, None) for x in ['kwargs', 'self']] args = Namespace(**args) # Input and output arguments processing self._args_processing(args) self.args = args pp_result = None # Send args and models to the MMOCR model inference API # and call post-processing functions for the output if self.detect_model and self.recog_model: det_recog_result = self.det_recog_kie_inference( self.detect_model, self.recog_model, kie_model=self.kie_model) pp_result = self.det_recog_pp(det_recog_result) else: for model in list( filter(None, [self.recog_model, self.detect_model])): result = self.single_inference(model, args.arrays, args.batch_mode, args.single_batch_size) pp_result = self.single_pp(result, model) return pp_result # Post processing function for end2end ocr def det_recog_pp(self, result): final_results = [] args = self.args for arr, output, export, det_recog_result in zip( args.arrays, args.output, args.export, result): if output or args.imshow: if self.kie_model: res_img = det_recog_show_result(arr, det_recog_result) else: res_img = det_recog_show_result( arr, det_recog_result, out_file=output) if args.imshow and not self.kie_model: mmcv.imshow(res_img, 'inference results') if not args.details: simple_res = {} simple_res['filename'] = det_recog_result['filename'] simple_res['text'] = [ x['text'] for x in det_recog_result['result'] ] final_result = simple_res else: final_result = det_recog_result if export: mmcv.dump(final_result, export, indent=4) if args.print_result: print(final_result, end='\n\n') final_results.append(final_result) return final_results # Post processing function for separate det/recog inference def single_pp(self, result, model): for arr, output, export, res in zip(self.args.arrays, self.args.output, self.args.export, result): if export: mmcv.dump(res, export, indent=4) if output or self.args.imshow: res_img = model.show_result(arr, res, out_file=output) if self.args.imshow: mmcv.imshow(res_img, 'inference results') if self.args.print_result: print(res, end='\n\n') return result def generate_kie_labels(self, result, boxes, class_list): idx_to_cls = {} if class_list is not None: for line in list_from_file(class_list): class_idx, class_label = line.strip().split() idx_to_cls[class_idx] = class_label max_value, max_idx = torch.max(result['nodes'].detach().cpu(), -1) node_pred_label = max_idx.numpy().tolist() node_pred_score = max_value.numpy().tolist() labels = [] for i in range(len(boxes)): pred_label = str(node_pred_label[i]) if pred_label in idx_to_cls: pred_label = idx_to_cls[pred_label] pred_score = node_pred_score[i] labels.append((pred_label, pred_score)) return labels def visualize_kie_output(self, model, data, result, out_file=None, show=False): """Visualizes KIE output.""" img_tensor = data['img'].data img_meta = data['img_metas'].data gt_bboxes = data['gt_bboxes'].data.numpy().tolist() if img_tensor.dtype == torch.uint8: # The img tensor is the raw input not being normalized # (For SDMGR non-visual) img = img_tensor.cpu().numpy().transpose(1, 2, 0) else: img = tensor2imgs( img_tensor.unsqueeze(0), **img_meta.get('img_norm_cfg', {}))[0] h, w, _ = img_meta.get('img_shape', img.shape) img_show = img[:h, :w, :] model.show_result( img_show, result, gt_bboxes, show=show, out_file=out_file) # End2end ocr inference pipeline def det_recog_kie_inference(self, det_model, recog_model, kie_model=None): end2end_res = [] # Find bounding boxes in the images (text detection) det_result = self.single_inference(det_model, self.args.arrays, self.args.batch_mode, self.args.det_batch_size) bboxes_list = [res['boundary_result'] for res in det_result] if kie_model: kie_dataset = KIEDataset( dict_file=kie_model.cfg.data.test.dict_file) # For each bounding box, the image is cropped and # sent to the recognition model either one by one # or all together depending on the batch_mode for filename, arr, bboxes, out_file in zip(self.args.filenames, self.args.arrays, bboxes_list, self.args.output): img_e2e_res = {} img_e2e_res['filename'] = filename img_e2e_res['result'] = [] box_imgs = [] for bbox in bboxes: box_res = {} box_res['box'] = [round(x) for x in bbox[:-1]] box_res['box_score'] = float(bbox[-1]) box = bbox[:8] if len(bbox) > 9: min_x = min(bbox[0:-1:2]) min_y = min(bbox[1:-1:2]) max_x = max(bbox[0:-1:2]) max_y = max(bbox[1:-1:2]) box = [ min_x, min_y, max_x, min_y, max_x, max_y, min_x, max_y ] box_img = crop_img(arr, box) if self.args.batch_mode: box_imgs.append(box_img) else: recog_result = model_inference(recog_model, box_img) text = recog_result['text'] text_score = recog_result['score'] if isinstance(text_score, list): text_score = sum(text_score) / max(1, len(text)) box_res['text'] = text box_res['text_score'] = text_score img_e2e_res['result'].append(box_res) if self.args.batch_mode: recog_results = self.single_inference( recog_model, box_imgs, True, self.args.recog_batch_size) for i, recog_result in enumerate(recog_results): text = recog_result['text'] text_score = recog_result['score'] if isinstance(text_score, (list, tuple)): text_score = sum(text_score) / max(1, len(text)) img_e2e_res['result'][i]['text'] = text img_e2e_res['result'][i]['text_score'] = text_score if self.args.merge: img_e2e_res['result'] = stitch_boxes_into_lines( img_e2e_res['result'], self.args.merge_xdist, 0.5) if kie_model: annotations = copy.deepcopy(img_e2e_res['result']) # Customized for kie_dataset, which # assumes that boxes are represented by only 4 points for i, ann in enumerate(annotations): min_x = min(ann['box'][::2]) min_y = min(ann['box'][1::2]) max_x = max(ann['box'][::2]) max_y = max(ann['box'][1::2]) annotations[i]['box'] = [ min_x, min_y, max_x, min_y, max_x, max_y, min_x, max_y ] ann_info = kie_dataset._parse_anno_info(annotations) ann_info['ori_bboxes'] = ann_info.get('ori_bboxes', ann_info['bboxes']) ann_info['gt_bboxes'] = ann_info.get('gt_bboxes', ann_info['bboxes']) kie_result, data = model_inference( kie_model, arr, ann=ann_info, return_data=True, batch_mode=self.args.batch_mode) # visualize KIE results self.visualize_kie_output( kie_model, data, kie_result, out_file=out_file, show=self.args.imshow) gt_bboxes = data['gt_bboxes'].data.numpy().tolist() labels = self.generate_kie_labels(kie_result, gt_bboxes, kie_model.class_list) for i in range(len(gt_bboxes)): img_e2e_res['result'][i]['label'] = labels[i][0] img_e2e_res['result'][i]['label_score'] = labels[i][1] end2end_res.append(img_e2e_res) return end2end_res # Separate det/recog inference pipeline def single_inference(self, model, arrays, batch_mode, batch_size=0): result = [] if batch_mode: if batch_size == 0: result = model_inference(model, arrays, batch_mode=True) else: n = batch_size arr_chunks = [ arrays[i:i + n] for i in range(0, len(arrays), n) ] for chunk in arr_chunks: result.extend( model_inference(model, chunk, batch_mode=True)) else: for arr in arrays: result.append(model_inference(model, arr, batch_mode=False)) return result # Arguments pre-processing function def _args_processing(self, args): # Check if the input is a list/tuple that # contains only np arrays or strings if isinstance(args.img, (list, tuple)): img_list = args.img if not all([isinstance(x, (np.ndarray, str)) for x in args.img]): raise AssertionError('Images must be strings or numpy arrays') # Create a list of the images if isinstance(args.img, str): img_path = Path(args.img) if img_path.is_dir(): img_list = [str(x) for x in img_path.glob('*')] else: img_list = [str(img_path)] elif isinstance(args.img, np.ndarray): img_list = [args.img] # Read all image(s) in advance to reduce wasted time # re-reading the images for visualization output args.arrays = [mmcv.imread(x) for x in img_list] # Create a list of filenames (used for output images and result files) if isinstance(img_list[0], str): args.filenames = [str(Path(x).stem) for x in img_list] else: args.filenames = [str(x) for x in range(len(img_list))] # If given an output argument, create a list of output image filenames num_res = len(img_list) if args.output: output_path = Path(args.output) if output_path.is_dir(): args.output = [ str(output_path / f'out_{x}.png') for x in args.filenames ] else: args.output = [str(args.output)] if args.batch_mode: raise AssertionError('Output of multiple images inference' ' must be a directory') else: args.output = [None] * num_res # If given an export argument, create a list of # result filenames for each image if args.export: export_path = Path(args.export) args.export = [ str(export_path / f'out_{x}.{args.export_format}') for x in args.filenames ] else: args.export = [None] * num_res return args # Create an inference pipeline with parsed arguments def main(): args = parse_args() ocr = MMOCR(**vars(args)) ocr.readtext(**vars(args)) if __name__ == '__main__': main()