Spaces:
Running
Running
File size: 7,163 Bytes
29f689c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 |
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import threading
import queue
import os
import sys
import time
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..')))
import numpy as np
import cv2
import json
from PIL import Image
from tools.utils.utility import get_image_file_list, check_and_read
from tools.infer_rec import OpenRecognizer
from tools.infer_det import OpenDetector
from tools.infer_e2e import check_and_download_font, sorted_boxes
from tools.engine import Config
from tools.infer.utility import get_rotate_crop_image, get_minarea_rect_crop, draw_ocr_box_txt
class OpenOCRParallel:
def __init__(self, drop_score=0.5, det_box_type='quad', max_rec_threads=1):
cfg_det = Config(
'./configs/det/dbnet/repvit_db.yml').cfg # mobile model
# cfg_rec = Config('./configs/rec/svtrv2/svtrv2_ch.yml').cfg # server model
cfg_rec = Config(
'./configs/rec/svtrv2/repsvtr_ch.yml').cfg # mobile model
self.text_detector = OpenDetector(cfg_det, numId=0)
self.text_recognizer = OpenRecognizer(cfg_rec, numId=0)
self.det_box_type = det_box_type
self.drop_score = drop_score
self.queue = queue.Queue(
) # Queue to hold detected boxes for recognition
self.results = {}
self.lock = threading.Lock() # Lock for thread-safe access to results
self.max_rec_threads = max_rec_threads
self.stop_signal = threading.Event() # Signal to stop threads
def start_recognition_threads(self):
"""Start recognition threads."""
self.rec_threads = []
for _ in range(self.max_rec_threads):
t = threading.Thread(target=self.recognize_text)
t.start()
self.rec_threads.append(t)
def detect_text(self, image_list):
"""Single-threaded text detection for all images."""
for image_id, (img_numpy, ori_img) in enumerate(image_list):
dt_boxes = self.text_detector(img_numpy=img_numpy)[0]['boxes']
if dt_boxes is None:
self.results[image_id] = [] # If no boxes, set empty results
continue
dt_boxes = sorted_boxes(dt_boxes)
img_crop_list = []
for box in dt_boxes:
tmp_box = np.array(box).astype(np.float32)
img_crop = (get_rotate_crop_image(ori_img, tmp_box)
if self.det_box_type == 'quad' else
get_minarea_rect_crop(ori_img, tmp_box))
img_crop_list.append(img_crop)
self.queue.put(
(image_id, dt_boxes, img_crop_list
)) # Put image ID, detected box, and cropped image in queue
# Signal that no more items will be added to the queue
self.stop_signal.set()
def recognize_text(self):
"""Recognize text in each cropped image."""
while not self.stop_signal.is_set() or not self.queue.empty():
try:
image_id, boxs, img_crop_list = self.queue.get(timeout=0.5)
rec_results = self.text_recognizer(
img_numpy_list=img_crop_list, batch_num=6)
for rec_result, box in zip(rec_results, boxs):
text, score = rec_result['text'], rec_result['score']
if score >= self.drop_score:
with self.lock:
# Ensure results dictionary has a list for each image ID
if image_id not in self.results:
self.results[image_id] = []
self.results[image_id].append({
'transcription':
text,
'points':
box.tolist(),
'score':
score
})
self.queue.task_done()
except queue.Empty:
continue
def process_images(self, image_list):
"""Process a list of images."""
# Initialize results dictionary
self.results = {i: [] for i in range(len(image_list))}
# Start recognition threads
t_start_1 = time.time()
self.start_recognition_threads()
# Start detection in the main thread
t_start = time.time()
self.detect_text(image_list)
print('det time:', time.time() - t_start)
# Wait for recognition threads to finish
for t in self.rec_threads:
t.join()
self.stop_signal.clear()
print('all time:', time.time() - t_start_1)
return self.results
def main(cfg_det, cfg_rec):
img_path = './testA/'
image_file_list = get_image_file_list(img_path)
drop_score = 0.5
text_sys = OpenOCRParallel(
drop_score=drop_score,
det_box_type='quad') # det_box_type: 'quad' or 'poly'
is_visualize = False
if is_visualize:
font_path = './simfang.ttf'
check_and_download_font(font_path)
draw_img_save_dir = img_path + 'e2e_results/' if img_path[
-1] != '/' else img_path[:-1] + 'e2e_results/'
os.makedirs(draw_img_save_dir, exist_ok=True)
save_results = []
# Prepare images
images = []
t_start = time.time()
for image_file in image_file_list:
img, flag_gif, flag_pdf = check_and_read(image_file)
if not flag_gif and not flag_pdf:
img = cv2.imread(image_file)
if img is not None:
images.append((img, img.copy()))
results = text_sys.process_images(images)
print(f'time cost: {time.time() - t_start}')
# Save results and visualize
for image_id, res in results.items():
image_file = image_file_list[image_id]
save_pred = f'{os.path.basename(image_file)}\t{json.dumps(res, ensure_ascii=False)}\n'
# print(save_pred)
save_results.append(save_pred)
if is_visualize:
dt_boxes = [result['points'] for result in res]
rec_res = [result['transcription'] for result in res]
rec_score = [result['score'] for result in res]
image = Image.fromarray(
cv2.cvtColor(images[image_id][0], cv2.COLOR_BGR2RGB))
draw_img = draw_ocr_box_txt(image,
dt_boxes,
rec_res,
rec_score,
drop_score=drop_score,
font_path=font_path)
save_file = os.path.join(draw_img_save_dir,
os.path.basename(image_file))
cv2.imwrite(save_file, draw_img[:, :, ::-1])
with open(os.path.join(draw_img_save_dir, 'system_results.txt'),
'w',
encoding='utf-8') as f:
f.writelines(save_results)
if __name__ == '__main__':
main()
|