import cv2 import gradio as gr import os from edit_func import * from TransUnet import Trans_UNet import TransUnet_Config as config2 from huggingface_hub import hf_hub_download from googletrans import Translator import random import torch.nn as nn import spaces @spaces.GPU class DTM(nn.Module): def __init__(self): super().__init__() self.device = 'cuda' if torch.cuda.is_available() else 'cpu' self.detect_text_model = Trans_UNet( config2.in_channels, config2.adp_channels, config2.out_channels, config2.trans_num_layers, config2.trans_num_attn_heads, config2.trans_ffw_channels, config2.dropout ).to(self.device) self.repo_name = 'SS3M/detect-text-model' files = ['detect-text-v3-0.pt', 'detect-text-v3-1.pt', 'detect-text-v3-2.pt', 'detect-text-v3-3.pt', 'detect-text-v3-4.pt', 'detect-text-v3-5.pt', 'detect-text-v3-6.pt', 'detect-text-v3-7.pt'] self.files = [] for file in files: self.files.append(hf_hub_download(repo_id=self.repo_name, filename=file)) def forward(self, X): X = X.to(self.device) N, C, H, W = X.shape result = torch.zeros((N, 1, H, W)) for file in self.files: model_path = file best_model_state = torch.load( model_path, weights_only=True, map_location=self.device ) self.detect_text_model.load_state_dict(best_model_state) result += self.detect_text_model(X) result /= len(self.files) return result @spaces.GPU class DWBM(nn.Module): def __init__(self): super().__init__() self.device = 'cuda' if torch.cuda.is_available() else 'cpu' self.detect_wordball_model = Trans_UNet( config2.in_channels, config2.adp_channels, config2.out_channels, config2.trans_num_layers, config2.trans_num_attn_heads, config2.trans_ffw_channels, config2.dropout ).to(self.device) self.repo_name = 'SS3M/detect-wordball-model' files = ['detect-text-v3-0.pt', 'detect-text-v3-1.pt', 'detect-text-v3-2.pt', 'detect-text-v3-3.pt', 'detect-text-v3-4.pt', 'detect-text-v3-5.pt', 'detect-text-v3-6.pt', 'detect-text-v3-7.pt'] self.files = [] for file in files: self.files.append(hf_hub_download(repo_id=self.repo_name, filename=file)) def forward(self, X): X = X.to(self.device) N, C, H, W = X.shape result = torch.zeros((N, 1, H, W)) for file in self.files: model_path = file best_model_state = torch.load( model_path, weights_only=True, map_location=self.device ) self.detect_wordball_model.load_state_dict(best_model_state) result += self.detect_wordball_model(X) result /= len(self.files) return result detect_text_model = DTM() detect_wordball_model = DWBM() translator = Translator() def down1(src_img): src_img = cv2.cvtColor(src_img, cv2.COLOR_RGB2BGR) text_msk = create_text_mask(src_img, detect_text_model) wordball_msk = create_wordball_mask(src_img, detect_wordball_model) text_positions, areas = get_text_positions(text_msk, text_value=0) rgbs = [] for _ in range(len(areas)): rgbs.append((random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))) idx = '; '.join(str(i) for i in range(len(areas))) text_positions = '; '.join([', '.join(str(i) for i in pos) for pos in text_positions]) areas = '; '.join(str(i) for i in areas) rgbs = '; '.join([', '.join(str(i) for i in rgb) for rgb in rgbs]) src_img = cv2.cvtColor(src_img, cv2.COLOR_BGR2RGB) return text_msk*255, wordball_msk*255, idx, text_positions, areas, rgbs, 'Xong' def idx_txt_change(src_img, idx_txt, pos_txt, rgb_txt): try: src_img2 = cv2.cvtColor(src_img, cv2.COLOR_RGB2BGR) text_positions = pos_txt.split('; ') for idx in range(len(text_positions)): text_positions[idx] = (int(i) for i in text_positions[idx].split(', ')) rgbs = rgb_txt.split('; ') for idx in range(len(rgbs)): rgbs[idx] = (int(i) for i in rgbs[idx].split(', ')) idxes = [int(idx) for idx in idx_txt.split('; ')] for idx, ((min_x, min_y, max_x, max_y), (r, g, b)) in enumerate(zip(text_positions, rgbs)): if idx in idxes: cv2.rectangle(src_img2, (min_x, min_y), (max_x, max_y), (b, g, r), thickness=4) src_img2 = cv2.cvtColor(src_img2, cv2.COLOR_BGR2RGB) return src_img2 except: return src_img def scale_area_change(min_area, max_area, area_txt): areas = [int(area) for area in area_txt.split('; ')] idxes = [] for idx, area in enumerate(areas): if min_area <= area <= max_area: idxes.append(idx) idxes = '; '.join(str(i) for i in idxes) return idxes def position_block_change(X, Y, W, H, ID, pos_txt_value): text_positions = pos_txt_value.split('; ') for idx in range(len(text_positions)): text_positions[idx] = (int(i) for i in text_positions[idx].split(', ')) text_positions2 = [] for idx, (min_x, min_y, max_x, max_y) in enumerate(text_positions): if idx == ID: text_positions2.append((X, Y, X+W, Y+H)) else: text_positions2.append((min_x, min_y, max_x, max_y)) text_positions2 = '; '.join([', '.join(str(i) for i in pos) for pos in text_positions2]) return text_positions2 def ID_block_change(ID_value, checkbox_value, ID_txt_value): ID_txt_value = [int(i) for i in ID_txt_value.split('; ')] if checkbox_value and ID_value not in ID_txt_value: ID_txt_value.append(ID_value) if not checkbox_value and ID_value in ID_txt_value: ID_txt_value.remove(ID_value) ID_txt_value = sorted(ID_txt_value) ID_txt_value = '; '.join([str(i) for i in ID_txt_value]) return ID_txt_value def down2(src_img_value, txt_mask_value, wordball_mask_value, idx_txt_value, pos_txt_value): src_img_value = cv2.cvtColor(src_img_value, cv2.COLOR_RGB2BGR) text_positions = pos_txt_value.split('; ') for idx in range(len(text_positions)): text_positions[idx] = (int(i) for i in text_positions[idx].split(', ')) idxes = [int(i) for i in idx_txt_value.split('; ')] for idx, (min_x, min_y, max_x, max_y) in enumerate(text_positions): if idx not in idxes: txt_mask_value[min_y:max_y+1, min_x:max_x+1] = 255 txt_mask_value = txt_mask_value[:, :, 0].astype(np.uint8) non_text_src_img = clear_text(src_img_value, txt_mask_value, wordball_mask_value, text_value=0, non_text_value=255, r=5) list_texts = get_list_texts(src_img_value, [tuple(map(int, pos.split(', '))) for idx, pos in enumerate(pos_txt_value.split('; ')) if idx in idxes]) list_translated_texts = translate(list_texts, translator) list_fonts = '; '.join(['MTO Astro City.ttf' for _ in range(len(list_translated_texts))]) list_sizes = '; '.join(['20' for _ in range(len(list_translated_texts))]) list_strokes = '; '.join(['3' for _ in range(len(list_translated_texts))]) list_pads = '; '.join(['5' for _ in range(len(list_translated_texts))]) list_translated_texts = '; '.join(list_translated_texts) switch = str(random.random()) return non_text_src_img, list_translated_texts, list_fonts, list_sizes, list_strokes, list_pads, switch, 'Xong' def text_info_change(non_txt_img_value, translated_txt_value, pos_txt_value, idx_txt_value, font_txt_value, size_txt_value, stroke_txt_value, pad_txt_value): non_txt_img_value = non_txt_img_value.copy() idxes = [int(i) for i in idx_txt_value.split('; ')] translated_text_src_img = insert_text(non_txt_img_value, translated_txt_value.split('; '), [tuple(map(int, pos.split(', '))) for idx, pos in enumerate(pos_txt_value.split('; ')) if idx in idxes], font=font_txt_value.split('; '), font_size=[int(i) for i in size_txt_value.split('; ')], pad=[int(i) for i in pad_txt_value.split('; ')], stroke=[int(i) for i in stroke_txt_value.split('; ')]) return translated_text_src_img def value2_change(value, ID2_value, txt_value): txt_value = txt_value.split('; ') txt_value2 = [] for idx, text in enumerate(txt_value): if idx == ID2_value: txt_value2.append(str(value)) else: txt_value2.append(str(text)) txt_value2 = '; '.join(txt_value2) return txt_value2 # Tạo giao diện Gradio with gr.Blocks() as demo: # Cấu trúc src_img = gr.Image(type="numpy", label="Upload Image") down_bttn_1 = gr.Button("↓", elem_classes="arrow-button") with gr.Row(): txt_mask = gr.Image(type="numpy", label="Upload Image", visible=True) wordball_mask = gr.Image(type="numpy", label="Upload Image", visible=True) complete = gr.Textbox() with gr.Row(): idx_txt = gr.Textbox(label='ID', interactive=False, visible=False) pos_txt = gr.Textbox(label='Pos', interactive=False, visible=False) area_txt = gr.Textbox(label='Area', interactive=False, visible=False) rgb_txt = gr.Textbox(label='rgb', interactive=False, visible=False) with gr.Row(): boxed_txt_img = gr.Image(type="numpy", label="Upload Image") with gr.Column() as down_1_column: @gr.render(inputs=[pos_txt, rgb_txt], triggers=[rgb_txt.change]) def create_box(pos_txt_value, rgb_txt_value): text_positions = pos_txt_value.split('; ') for idx in range(len(text_positions)): text_positions[idx] = (int(i) for i in text_positions[idx].split(', ')) rgbs = rgb_txt_value.split('; ') for idx in range(len(rgbs)): rgbs[idx] = (int(i) for i in rgbs[idx].split(', ')) elements = [] for idx, (min_x, min_y, max_x, max_y) in enumerate(text_positions): with gr.Group() as box: r, g, b = rgbs[idx] with gr.Row(): gr.Markdown( f"""
Textbox {idx+1}
""" ) checkbox = gr.Checkbox(value=True, label='', min_width=50, interactive=True) with gr.Row(): X = gr.Number(label="X", value=min_x, interactive=True) Y = gr.Number(label="Y", value=min_y, interactive=True) W = gr.Number(label="W", value=max_x-min_x, interactive=True) H = gr.Number(label="H", value=max_y-min_y, interactive=True) ID = gr.Number(label="ID", value=idx, interactive=True, visible=False) elements.append((X, Y, W, H, ID)) checkbox.change( fn=ID_block_change, inputs=[ID, checkbox, idx_txt], outputs=idx_txt, show_progress=False ).then( fn=idx_txt_change, inputs=[src_img, idx_txt, pos_txt, rgb_txt], outputs=boxed_txt_img, ) X.change( fn=position_block_change, inputs=[X, Y, W, H, ID, pos_txt], outputs=pos_txt, show_progress=False ).then( fn=idx_txt_change, inputs=[src_img, idx_txt, pos_txt, rgb_txt], outputs=boxed_txt_img, show_progress=False ) Y.change( fn=position_block_change, inputs=[X, Y, W, H, ID, pos_txt], outputs=pos_txt, show_progress=False ).then( fn=idx_txt_change, inputs=[src_img, idx_txt, pos_txt, rgb_txt], outputs=boxed_txt_img, show_progress=False ) W.change( fn=position_block_change, inputs=[X, Y, W, H, ID, pos_txt], outputs=pos_txt, show_progress=False ).then( fn=idx_txt_change, inputs=[src_img, idx_txt, pos_txt, rgb_txt], outputs=boxed_txt_img, show_progress=False ) H.change( fn=position_block_change, inputs=[X, Y, W, H, ID, pos_txt], outputs=pos_txt, show_progress=False ).then( fn=idx_txt_change, inputs=[src_img, idx_txt, pos_txt, rgb_txt], outputs=boxed_txt_img, show_progress=False ) down_bttn_2 = gr.Button("↓", elem_classes="arrow-button") non_txt_img = gr.Image(type="numpy", label="Upload Image", visible=False) complete2 = gr.Textbox() with gr.Row(): translated_txt = gr.Textbox(label='translated', interactive=False, visible=False) font_txt = gr.Textbox(label='font', interactive=False, visible=False) size_txt = gr.Textbox(label='size', interactive=False, visible=False) stroke_txt = gr.Textbox(label='stroke', interactive=False, visible=False) pad_txt = gr.Textbox(label='pad', interactive=False, visible=False) switch_txt = gr.Textbox(label='switch', value='1', interactive=False, visible=False) with gr.Row(): boxed_inserted_non_txt_img = gr.Image(type="numpy", label="Upload Image") with gr.Column(): @gr.render(inputs=[translated_txt, font_txt, size_txt, stroke_txt, pad_txt], triggers=[switch_txt.change]) def create_box2(translated_txt_value, font_txt_value, size_txt_value, stroke_txt_value, pad_txt_value): translated_txt_value = translated_txt_value.split('; ') font_txt_value = font_txt_value.split('; ') size_txt_value = size_txt_value.split('; ') stroke_txt_value = stroke_txt_value.split('; ') pad_txt_value = pad_txt_value.split('; ') elements = [] for idx in range(len(font_txt_value)): with gr.Group(): gr.Markdown( f"""
Text box {idx}
""" ) translated_text_box = gr.Textbox(label="Translate", value=translated_txt_value[idx], interactive=True) with gr.Row(): font = gr.Dropdown(choices=os.listdir('MTO Font'), label="Phông chữ", value=font_txt_value[idx], interactive=True, scale=7) size = gr.Number(label="Size", value=int(size_txt_value[idx]), interactive=True, minimum=1) stroke = gr.Number(label="Stroke", value=int(stroke_txt_value[idx]), interactive=True, minimum=0, maximum=5) pad = gr.Number(label="Pad", value=int(pad_txt_value[idx]), interactive=True, minimum=1, maximum=10) ID2 = gr.Number(label="ID", value=int(idx), interactive=True, visible=False) translated_text_box.submit( fn=value2_change, inputs=[translated_text_box, ID2, translated_txt], outputs=translated_txt, show_progress=False ).then( fn=text_info_change, inputs=[non_txt_img, translated_txt, pos_txt, idx_txt, font_txt, size_txt, stroke_txt, pad_txt], outputs=boxed_inserted_non_txt_img, ) font.change( fn=value2_change, inputs=[font, ID2, font_txt], outputs=font_txt, show_progress=False ).then( fn=text_info_change, inputs=[non_txt_img, translated_txt, pos_txt, idx_txt, font_txt, size_txt, stroke_txt, pad_txt], outputs=boxed_inserted_non_txt_img, ) size.change( fn=value2_change, inputs=[size, ID2, size_txt], outputs=size_txt, show_progress=False ).then( fn=text_info_change, inputs=[non_txt_img, translated_txt, pos_txt, idx_txt, font_txt, size_txt, stroke_txt, pad_txt], outputs=boxed_inserted_non_txt_img, ) stroke.change( fn=value2_change, inputs=[stroke, ID2, stroke_txt], outputs=stroke_txt, show_progress=False ).then( fn=text_info_change, inputs=[non_txt_img, translated_txt, pos_txt, idx_txt, font_txt, size_txt, stroke_txt, pad_txt], outputs=boxed_inserted_non_txt_img, ) pad.change( fn=value2_change, inputs=[pad, ID2, pad_txt], outputs=pad_txt, show_progress=False ).then( fn=text_info_change, inputs=[non_txt_img, translated_txt, pos_txt, idx_txt, font_txt, size_txt, stroke_txt, pad_txt], outputs=boxed_inserted_non_txt_img, ) # Css demo.css = """ .arrow-button { font-size: 40px; /* Kích thước font */ } .group-elem { height: 70px; } """ # Điều khiển down_bttn_1.click( fn=down1, inputs=src_img, outputs=[txt_mask, wordball_mask, idx_txt, pos_txt, area_txt, rgb_txt, complete], ) down_bttn_2.click( fn=down2, inputs=[src_img, txt_mask, wordball_mask, idx_txt, pos_txt], outputs=[non_txt_img, translated_txt, font_txt, size_txt, stroke_txt, pad_txt, switch_txt, complete2], ).then( fn=text_info_change, inputs=[non_txt_img, translated_txt, pos_txt, idx_txt, font_txt, size_txt, stroke_txt, pad_txt], outputs=boxed_inserted_non_txt_img, ) demo.launch()