OpenOCR-Demo / tools /infer_e2e_parallel.py
topdu's picture
openocr demo
29f689c
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import threading
import queue
import os
import sys
import time
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..')))
import numpy as np
import cv2
import json
from PIL import Image
from tools.utils.utility import get_image_file_list, check_and_read
from tools.infer_rec import OpenRecognizer
from tools.infer_det import OpenDetector
from tools.infer_e2e import check_and_download_font, sorted_boxes
from tools.engine import Config
from tools.infer.utility import get_rotate_crop_image, get_minarea_rect_crop, draw_ocr_box_txt
class OpenOCRParallel:
def __init__(self, drop_score=0.5, det_box_type='quad', max_rec_threads=1):
cfg_det = Config(
'./configs/det/dbnet/repvit_db.yml').cfg # mobile model
# cfg_rec = Config('./configs/rec/svtrv2/svtrv2_ch.yml').cfg # server model
cfg_rec = Config(
'./configs/rec/svtrv2/repsvtr_ch.yml').cfg # mobile model
self.text_detector = OpenDetector(cfg_det, numId=0)
self.text_recognizer = OpenRecognizer(cfg_rec, numId=0)
self.det_box_type = det_box_type
self.drop_score = drop_score
self.queue = queue.Queue(
) # Queue to hold detected boxes for recognition
self.results = {}
self.lock = threading.Lock() # Lock for thread-safe access to results
self.max_rec_threads = max_rec_threads
self.stop_signal = threading.Event() # Signal to stop threads
def start_recognition_threads(self):
"""Start recognition threads."""
self.rec_threads = []
for _ in range(self.max_rec_threads):
t = threading.Thread(target=self.recognize_text)
t.start()
self.rec_threads.append(t)
def detect_text(self, image_list):
"""Single-threaded text detection for all images."""
for image_id, (img_numpy, ori_img) in enumerate(image_list):
dt_boxes = self.text_detector(img_numpy=img_numpy)[0]['boxes']
if dt_boxes is None:
self.results[image_id] = [] # If no boxes, set empty results
continue
dt_boxes = sorted_boxes(dt_boxes)
img_crop_list = []
for box in dt_boxes:
tmp_box = np.array(box).astype(np.float32)
img_crop = (get_rotate_crop_image(ori_img, tmp_box)
if self.det_box_type == 'quad' else
get_minarea_rect_crop(ori_img, tmp_box))
img_crop_list.append(img_crop)
self.queue.put(
(image_id, dt_boxes, img_crop_list
)) # Put image ID, detected box, and cropped image in queue
# Signal that no more items will be added to the queue
self.stop_signal.set()
def recognize_text(self):
"""Recognize text in each cropped image."""
while not self.stop_signal.is_set() or not self.queue.empty():
try:
image_id, boxs, img_crop_list = self.queue.get(timeout=0.5)
rec_results = self.text_recognizer(
img_numpy_list=img_crop_list, batch_num=6)
for rec_result, box in zip(rec_results, boxs):
text, score = rec_result['text'], rec_result['score']
if score >= self.drop_score:
with self.lock:
# Ensure results dictionary has a list for each image ID
if image_id not in self.results:
self.results[image_id] = []
self.results[image_id].append({
'transcription':
text,
'points':
box.tolist(),
'score':
score
})
self.queue.task_done()
except queue.Empty:
continue
def process_images(self, image_list):
"""Process a list of images."""
# Initialize results dictionary
self.results = {i: [] for i in range(len(image_list))}
# Start recognition threads
t_start_1 = time.time()
self.start_recognition_threads()
# Start detection in the main thread
t_start = time.time()
self.detect_text(image_list)
print('det time:', time.time() - t_start)
# Wait for recognition threads to finish
for t in self.rec_threads:
t.join()
self.stop_signal.clear()
print('all time:', time.time() - t_start_1)
return self.results
def main(cfg_det, cfg_rec):
img_path = './testA/'
image_file_list = get_image_file_list(img_path)
drop_score = 0.5
text_sys = OpenOCRParallel(
drop_score=drop_score,
det_box_type='quad') # det_box_type: 'quad' or 'poly'
is_visualize = False
if is_visualize:
font_path = './simfang.ttf'
check_and_download_font(font_path)
draw_img_save_dir = img_path + 'e2e_results/' if img_path[
-1] != '/' else img_path[:-1] + 'e2e_results/'
os.makedirs(draw_img_save_dir, exist_ok=True)
save_results = []
# Prepare images
images = []
t_start = time.time()
for image_file in image_file_list:
img, flag_gif, flag_pdf = check_and_read(image_file)
if not flag_gif and not flag_pdf:
img = cv2.imread(image_file)
if img is not None:
images.append((img, img.copy()))
results = text_sys.process_images(images)
print(f'time cost: {time.time() - t_start}')
# Save results and visualize
for image_id, res in results.items():
image_file = image_file_list[image_id]
save_pred = f'{os.path.basename(image_file)}\t{json.dumps(res, ensure_ascii=False)}\n'
# print(save_pred)
save_results.append(save_pred)
if is_visualize:
dt_boxes = [result['points'] for result in res]
rec_res = [result['transcription'] for result in res]
rec_score = [result['score'] for result in res]
image = Image.fromarray(
cv2.cvtColor(images[image_id][0], cv2.COLOR_BGR2RGB))
draw_img = draw_ocr_box_txt(image,
dt_boxes,
rec_res,
rec_score,
drop_score=drop_score,
font_path=font_path)
save_file = os.path.join(draw_img_save_dir,
os.path.basename(image_file))
cv2.imwrite(save_file, draw_img[:, :, ::-1])
with open(os.path.join(draw_img_save_dir, 'system_results.txt'),
'w',
encoding='utf-8') as f:
f.writelines(save_results)
if __name__ == '__main__':
main()