from __future__ import absolute_import from __future__ import division from __future__ import print_function import threading import queue import os import sys import time __dir__ = os.path.dirname(os.path.abspath(__file__)) sys.path.append(__dir__) sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..'))) import numpy as np import cv2 import json from PIL import Image from tools.utils.utility import get_image_file_list, check_and_read from tools.infer_rec import OpenRecognizer from tools.infer_det import OpenDetector from tools.infer_e2e import check_and_download_font, sorted_boxes from tools.engine import Config from tools.infer.utility import get_rotate_crop_image, get_minarea_rect_crop, draw_ocr_box_txt class OpenOCRParallel: def __init__(self, drop_score=0.5, det_box_type='quad', max_rec_threads=1): cfg_det = Config( './configs/det/dbnet/repvit_db.yml').cfg # mobile model # cfg_rec = Config('./configs/rec/svtrv2/svtrv2_ch.yml').cfg # server model cfg_rec = Config( './configs/rec/svtrv2/repsvtr_ch.yml').cfg # mobile model self.text_detector = OpenDetector(cfg_det, numId=0) self.text_recognizer = OpenRecognizer(cfg_rec, numId=0) self.det_box_type = det_box_type self.drop_score = drop_score self.queue = queue.Queue( ) # Queue to hold detected boxes for recognition self.results = {} self.lock = threading.Lock() # Lock for thread-safe access to results self.max_rec_threads = max_rec_threads self.stop_signal = threading.Event() # Signal to stop threads def start_recognition_threads(self): """Start recognition threads.""" self.rec_threads = [] for _ in range(self.max_rec_threads): t = threading.Thread(target=self.recognize_text) t.start() self.rec_threads.append(t) def detect_text(self, image_list): """Single-threaded text detection for all images.""" for image_id, (img_numpy, ori_img) in enumerate(image_list): dt_boxes = self.text_detector(img_numpy=img_numpy)[0]['boxes'] if dt_boxes is None: self.results[image_id] = [] # If no boxes, set empty results continue dt_boxes = sorted_boxes(dt_boxes) img_crop_list = [] for box in dt_boxes: tmp_box = np.array(box).astype(np.float32) img_crop = (get_rotate_crop_image(ori_img, tmp_box) if self.det_box_type == 'quad' else get_minarea_rect_crop(ori_img, tmp_box)) img_crop_list.append(img_crop) self.queue.put( (image_id, dt_boxes, img_crop_list )) # Put image ID, detected box, and cropped image in queue # Signal that no more items will be added to the queue self.stop_signal.set() def recognize_text(self): """Recognize text in each cropped image.""" while not self.stop_signal.is_set() or not self.queue.empty(): try: image_id, boxs, img_crop_list = self.queue.get(timeout=0.5) rec_results = self.text_recognizer( img_numpy_list=img_crop_list, batch_num=6) for rec_result, box in zip(rec_results, boxs): text, score = rec_result['text'], rec_result['score'] if score >= self.drop_score: with self.lock: # Ensure results dictionary has a list for each image ID if image_id not in self.results: self.results[image_id] = [] self.results[image_id].append({ 'transcription': text, 'points': box.tolist(), 'score': score }) self.queue.task_done() except queue.Empty: continue def process_images(self, image_list): """Process a list of images.""" # Initialize results dictionary self.results = {i: [] for i in range(len(image_list))} # Start recognition threads t_start_1 = time.time() self.start_recognition_threads() # Start detection in the main thread t_start = time.time() self.detect_text(image_list) print('det time:', time.time() - t_start) # Wait for recognition threads to finish for t in self.rec_threads: t.join() self.stop_signal.clear() print('all time:', time.time() - t_start_1) return self.results def main(cfg_det, cfg_rec): img_path = './testA/' image_file_list = get_image_file_list(img_path) drop_score = 0.5 text_sys = OpenOCRParallel( drop_score=drop_score, det_box_type='quad') # det_box_type: 'quad' or 'poly' is_visualize = False if is_visualize: font_path = './simfang.ttf' check_and_download_font(font_path) draw_img_save_dir = img_path + 'e2e_results/' if img_path[ -1] != '/' else img_path[:-1] + 'e2e_results/' os.makedirs(draw_img_save_dir, exist_ok=True) save_results = [] # Prepare images images = [] t_start = time.time() for image_file in image_file_list: img, flag_gif, flag_pdf = check_and_read(image_file) if not flag_gif and not flag_pdf: img = cv2.imread(image_file) if img is not None: images.append((img, img.copy())) results = text_sys.process_images(images) print(f'time cost: {time.time() - t_start}') # Save results and visualize for image_id, res in results.items(): image_file = image_file_list[image_id] save_pred = f'{os.path.basename(image_file)}\t{json.dumps(res, ensure_ascii=False)}\n' # print(save_pred) save_results.append(save_pred) if is_visualize: dt_boxes = [result['points'] for result in res] rec_res = [result['transcription'] for result in res] rec_score = [result['score'] for result in res] image = Image.fromarray( cv2.cvtColor(images[image_id][0], cv2.COLOR_BGR2RGB)) draw_img = draw_ocr_box_txt(image, dt_boxes, rec_res, rec_score, drop_score=drop_score, font_path=font_path) save_file = os.path.join(draw_img_save_dir, os.path.basename(image_file)) cv2.imwrite(save_file, draw_img[:, :, ::-1]) with open(os.path.join(draw_img_save_dir, 'system_results.txt'), 'w', encoding='utf-8') as f: f.writelines(save_results) if __name__ == '__main__': main()