import cv2 import gradio as gr import os from edit_func import * from TransUnet import Trans_UNet import TransUnet_Config as config2 from huggingface_hub import hf_hub_download from googletrans import Translator import random import torch.nn as nn import spaces @spaces.GPU class DTM(nn.Module): def __init__(self): super().__init__() self.device = 'cuda' if torch.cuda.is_available() else 'cpu' self.detect_text_model = Trans_UNet( config2.in_channels, config2.adp_channels, config2.out_channels, config2.trans_num_layers, config2.trans_num_attn_heads, config2.trans_ffw_channels, config2.dropout ).to(self.device) self.repo_name = 'SS3M/detect-text-model' files = ['detect-text-v3-0.pt', 'detect-text-v3-1.pt', 'detect-text-v3-2.pt', 'detect-text-v3-3.pt', 'detect-text-v3-4.pt', 'detect-text-v3-5.pt', 'detect-text-v3-6.pt', 'detect-text-v3-7.pt'] self.files = [] for file in files: self.files.append(hf_hub_download(repo_id=self.repo_name, filename=file)) def forward(self, X): X = X.to(self.device) N, C, H, W = X.shape result = torch.zeros((N, 1, H, W)) for file in self.files: model_path = file best_model_state = torch.load( model_path, weights_only=True, map_location=self.device ) self.detect_text_model.load_state_dict(best_model_state) result += self.detect_text_model(X) result /= len(self.files) return result @spaces.GPU class DWBM(nn.Module): def __init__(self): super().__init__() self.device = 'cuda' if torch.cuda.is_available() else 'cpu' self.detect_wordball_model = Trans_UNet( config2.in_channels, config2.adp_channels, config2.out_channels, config2.trans_num_layers, config2.trans_num_attn_heads, config2.trans_ffw_channels, config2.dropout ).to(self.device) self.repo_name = 'SS3M/detect-wordball-model' files = ['detect-text-v3-0.pt', 'detect-text-v3-1.pt', 'detect-text-v3-2.pt', 'detect-text-v3-3.pt', 'detect-text-v3-4.pt', 'detect-text-v3-5.pt', 'detect-text-v3-6.pt', 'detect-text-v3-7.pt'] self.files = [] for file in files: self.files.append(hf_hub_download(repo_id=self.repo_name, filename=file)) def forward(self, X): X = X.to(self.device) N, C, H, W = X.shape result = torch.zeros((N, 1, H, W)) for file in self.files: model_path = file best_model_state = torch.load( model_path, weights_only=True, map_location=self.device ) self.detect_wordball_model.load_state_dict(best_model_state) result += self.detect_wordball_model(X) result /= len(self.files) return result detect_text_model = DTM() detect_wordball_model = DWBM() translator = Translator() def down1(src_img): src_img = cv2.cvtColor(src_img, cv2.COLOR_RGB2BGR) text_msk = create_text_mask(src_img, detect_text_model) wordball_msk = create_wordball_mask(src_img, detect_wordball_model) text_positions, areas = get_text_positions(text_msk, text_value=0) rgbs = [] for _ in range(len(areas)): rgbs.append((random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))) idx = '; '.join(str(i) for i in range(len(areas))) text_positions = '; '.join([', '.join(str(i) for i in pos) for pos in text_positions]) areas = '; '.join(str(i) for i in areas) rgbs = '; '.join([', '.join(str(i) for i in rgb) for rgb in rgbs]) src_img = cv2.cvtColor(src_img, cv2.COLOR_BGR2RGB) return text_msk*255, wordball_msk*255, idx, text_positions, areas, rgbs, 'Xong' def idx_txt_change(src_img, idx_txt, pos_txt, rgb_txt): try: src_img2 = cv2.cvtColor(src_img, cv2.COLOR_RGB2BGR) text_positions = pos_txt.split('; ') for idx in range(len(text_positions)): text_positions[idx] = (int(i) for i in text_positions[idx].split(', ')) rgbs = rgb_txt.split('; ') for idx in range(len(rgbs)): rgbs[idx] = (int(i) for i in rgbs[idx].split(', ')) idxes = [int(idx) for idx in idx_txt.split('; ')] for idx, ((min_x, min_y, max_x, max_y), (r, g, b)) in enumerate(zip(text_positions, rgbs)): if idx in idxes: cv2.rectangle(src_img2, (min_x, min_y), (max_x, max_y), (b, g, r), thickness=4) src_img2 = cv2.cvtColor(src_img2, cv2.COLOR_BGR2RGB) return src_img2 except: return src_img def scale_area_change(min_area, max_area, area_txt): areas = [int(area) for area in area_txt.split('; ')] idxes = [] for idx, area in enumerate(areas): if min_area <= area <= max_area: idxes.append(idx) idxes = '; '.join(str(i) for i in idxes) return idxes def position_block_change(X, Y, W, H, ID, pos_txt_value): text_positions = pos_txt_value.split('; ') for idx in range(len(text_positions)): text_positions[idx] = (int(i) for i in text_positions[idx].split(', ')) text_positions2 = [] for idx, (min_x, min_y, max_x, max_y) in enumerate(text_positions): if idx == ID: text_positions2.append((X, Y, X+W, Y+H)) else: text_positions2.append((min_x, min_y, max_x, max_y)) text_positions2 = '; '.join([', '.join(str(i) for i in pos) for pos in text_positions2]) return text_positions2 def ID_block_change(ID_value, checkbox_value, ID_txt_value): ID_txt_value = [int(i) for i in ID_txt_value.split('; ')] if checkbox_value and ID_value not in ID_txt_value: ID_txt_value.append(ID_value) if not checkbox_value and ID_value in ID_txt_value: ID_txt_value.remove(ID_value) ID_txt_value = sorted(ID_txt_value) ID_txt_value = '; '.join([str(i) for i in ID_txt_value]) return ID_txt_value def down2(src_img_value, txt_mask_value, wordball_mask_value, idx_txt_value, pos_txt_value): src_img_value = cv2.cvtColor(src_img_value, cv2.COLOR_RGB2BGR) text_positions = pos_txt_value.split('; ') for idx in range(len(text_positions)): text_positions[idx] = (int(i) for i in text_positions[idx].split(', ')) idxes = [int(i) for i in idx_txt_value.split('; ')] for idx, (min_x, min_y, max_x, max_y) in enumerate(text_positions): if idx not in idxes: txt_mask_value[min_y:max_y+1, min_x:max_x+1] = 255 txt_mask_value = txt_mask_value[:, :, 0].astype(np.uint8) non_text_src_img = clear_text(src_img_value, txt_mask_value, wordball_mask_value, text_value=0, non_text_value=255, r=5) list_texts = get_list_texts(src_img_value, [tuple(map(int, pos.split(', '))) for idx, pos in enumerate(pos_txt_value.split('; ')) if idx in idxes]) list_translated_texts = translate(list_texts, translator) list_fonts = '; '.join(['MTO Astro City.ttf' for _ in range(len(list_translated_texts))]) list_sizes = '; '.join(['20' for _ in range(len(list_translated_texts))]) list_strokes = '; '.join(['3' for _ in range(len(list_translated_texts))]) list_pads = '; '.join(['5' for _ in range(len(list_translated_texts))]) list_translated_texts = '; '.join(list_translated_texts) switch = str(random.random()) return non_text_src_img, list_translated_texts, list_fonts, list_sizes, list_strokes, list_pads, switch, 'Xong' def text_info_change(non_txt_img_value, translated_txt_value, pos_txt_value, idx_txt_value, font_txt_value, size_txt_value, stroke_txt_value, pad_txt_value): non_txt_img_value = non_txt_img_value.copy() idxes = [int(i) for i in idx_txt_value.split('; ')] translated_text_src_img = insert_text(non_txt_img_value, translated_txt_value.split('; '), [tuple(map(int, pos.split(', '))) for idx, pos in enumerate(pos_txt_value.split('; ')) if idx in idxes], font=font_txt_value.split('; '), font_size=[int(i) for i in size_txt_value.split('; ')], pad=[int(i) for i in pad_txt_value.split('; ')], stroke=[int(i) for i in stroke_txt_value.split('; ')]) return translated_text_src_img def value2_change(value, ID2_value, txt_value): txt_value = txt_value.split('; ') txt_value2 = [] for idx, text in enumerate(txt_value): if idx == ID2_value: txt_value2.append(str(value)) else: txt_value2.append(str(text)) txt_value2 = '; '.join(txt_value2) return txt_value2 # Tạo giao diện Gradio with gr.Blocks() as demo: # Cấu trúc src_img = gr.Image(type="numpy", label="Upload Image") down_bttn_1 = gr.Button("↓", elem_classes="arrow-button") with gr.Row(): txt_mask = gr.Image(type="numpy", label="Upload Image", visible=True) wordball_mask = gr.Image(type="numpy", label="Upload Image", visible=True) complete = gr.Textbox() with gr.Row(): idx_txt = gr.Textbox(label='ID', interactive=False, visible=False) pos_txt = gr.Textbox(label='Pos', interactive=False, visible=False) area_txt = gr.Textbox(label='Area', interactive=False, visible=False) rgb_txt = gr.Textbox(label='rgb', interactive=False, visible=False) with gr.Row(): boxed_txt_img = gr.Image(type="numpy", label="Upload Image") with gr.Column() as down_1_column: @gr.render(inputs=[pos_txt, rgb_txt], triggers=[rgb_txt.change]) def create_box(pos_txt_value, rgb_txt_value): text_positions = pos_txt_value.split('; ') for idx in range(len(text_positions)): text_positions[idx] = (int(i) for i in text_positions[idx].split(', ')) rgbs = rgb_txt_value.split('; ') for idx in range(len(rgbs)): rgbs[idx] = (int(i) for i in rgbs[idx].split(', ')) elements = [] for idx, (min_x, min_y, max_x, max_y) in enumerate(text_positions): with gr.Group() as box: r, g, b = rgbs[idx] with gr.Row(): gr.Markdown( f"""