TheEditor / app.py
SS3M's picture
Upload 7 files
e2cc14b verified
import cv2
import gradio as gr
import os
from edit_func import *
from TransUnet import Trans_UNet
import TransUnet_Config as config2
from huggingface_hub import hf_hub_download
from googletrans import Translator
import random
import torch.nn as nn
import spaces
@spaces.GPU
class DTM(nn.Module):
def __init__(self):
super().__init__()
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.detect_text_model = Trans_UNet(
config2.in_channels, config2.adp_channels, config2.out_channels,
config2.trans_num_layers, config2.trans_num_attn_heads, config2.trans_ffw_channels,
config2.dropout
).to(self.device)
self.repo_name = 'SS3M/detect-text-model'
files = ['detect-text-v3-0.pt', 'detect-text-v3-1.pt',
'detect-text-v3-2.pt', 'detect-text-v3-3.pt',
'detect-text-v3-4.pt', 'detect-text-v3-5.pt',
'detect-text-v3-6.pt', 'detect-text-v3-7.pt']
self.files = []
for file in files:
self.files.append(hf_hub_download(repo_id=self.repo_name, filename=file))
def forward(self, X):
X = X.to(self.device)
N, C, H, W = X.shape
result = torch.zeros((N, 1, H, W))
for file in self.files:
model_path = file
best_model_state = torch.load(
model_path,
weights_only=True,
map_location=self.device
)
self.detect_text_model.load_state_dict(best_model_state)
result += self.detect_text_model(X)
result /= len(self.files)
return result
@spaces.GPU
class DWBM(nn.Module):
def __init__(self):
super().__init__()
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.detect_wordball_model = Trans_UNet(
config2.in_channels, config2.adp_channels, config2.out_channels,
config2.trans_num_layers, config2.trans_num_attn_heads, config2.trans_ffw_channels,
config2.dropout
).to(self.device)
self.repo_name = 'SS3M/detect-wordball-model'
files = ['detect-text-v3-0.pt', 'detect-text-v3-1.pt',
'detect-text-v3-2.pt', 'detect-text-v3-3.pt',
'detect-text-v3-4.pt', 'detect-text-v3-5.pt',
'detect-text-v3-6.pt', 'detect-text-v3-7.pt']
self.files = []
for file in files:
self.files.append(hf_hub_download(repo_id=self.repo_name, filename=file))
def forward(self, X):
X = X.to(self.device)
N, C, H, W = X.shape
result = torch.zeros((N, 1, H, W))
for file in self.files:
model_path = file
best_model_state = torch.load(
model_path,
weights_only=True,
map_location=self.device
)
self.detect_wordball_model.load_state_dict(best_model_state)
result += self.detect_wordball_model(X)
result /= len(self.files)
return result
detect_text_model = DTM()
detect_wordball_model = DWBM()
translator = Translator()
def down1(src_img):
src_img = cv2.cvtColor(src_img, cv2.COLOR_RGB2BGR)
text_msk = create_text_mask(src_img, detect_text_model)
wordball_msk = create_wordball_mask(src_img, detect_wordball_model)
text_positions, areas = get_text_positions(text_msk, text_value=0)
rgbs = []
for _ in range(len(areas)):
rgbs.append((random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)))
idx = '; '.join(str(i) for i in range(len(areas)))
text_positions = '; '.join([', '.join(str(i) for i in pos) for pos in text_positions])
areas = '; '.join(str(i) for i in areas)
rgbs = '; '.join([', '.join(str(i) for i in rgb) for rgb in rgbs])
src_img = cv2.cvtColor(src_img, cv2.COLOR_BGR2RGB)
return text_msk*255, wordball_msk*255, idx, text_positions, areas, rgbs, 'Xong'
def idx_txt_change(src_img, idx_txt, pos_txt, rgb_txt):
try:
src_img2 = cv2.cvtColor(src_img, cv2.COLOR_RGB2BGR)
text_positions = pos_txt.split('; ')
for idx in range(len(text_positions)):
text_positions[idx] = (int(i) for i in text_positions[idx].split(', '))
rgbs = rgb_txt.split('; ')
for idx in range(len(rgbs)):
rgbs[idx] = (int(i) for i in rgbs[idx].split(', '))
idxes = [int(idx) for idx in idx_txt.split('; ')]
for idx, ((min_x, min_y, max_x, max_y), (r, g, b)) in enumerate(zip(text_positions, rgbs)):
if idx in idxes:
cv2.rectangle(src_img2, (min_x, min_y), (max_x, max_y), (b, g, r), thickness=4)
src_img2 = cv2.cvtColor(src_img2, cv2.COLOR_BGR2RGB)
return src_img2
except:
return src_img
def scale_area_change(min_area, max_area, area_txt):
areas = [int(area) for area in area_txt.split('; ')]
idxes = []
for idx, area in enumerate(areas):
if min_area <= area <= max_area:
idxes.append(idx)
idxes = '; '.join(str(i) for i in idxes)
return idxes
def position_block_change(X, Y, W, H, ID, pos_txt_value):
text_positions = pos_txt_value.split('; ')
for idx in range(len(text_positions)):
text_positions[idx] = (int(i) for i in text_positions[idx].split(', '))
text_positions2 = []
for idx, (min_x, min_y, max_x, max_y) in enumerate(text_positions):
if idx == ID:
text_positions2.append((X, Y, X+W, Y+H))
else:
text_positions2.append((min_x, min_y, max_x, max_y))
text_positions2 = '; '.join([', '.join(str(i) for i in pos) for pos in text_positions2])
return text_positions2
def ID_block_change(ID_value, checkbox_value, ID_txt_value):
ID_txt_value = [int(i) for i in ID_txt_value.split('; ')]
if checkbox_value and ID_value not in ID_txt_value:
ID_txt_value.append(ID_value)
if not checkbox_value and ID_value in ID_txt_value:
ID_txt_value.remove(ID_value)
ID_txt_value = sorted(ID_txt_value)
ID_txt_value = '; '.join([str(i) for i in ID_txt_value])
return ID_txt_value
def down2(src_img_value, txt_mask_value, wordball_mask_value, idx_txt_value, pos_txt_value):
src_img_value = cv2.cvtColor(src_img_value, cv2.COLOR_RGB2BGR)
text_positions = pos_txt_value.split('; ')
for idx in range(len(text_positions)):
text_positions[idx] = (int(i) for i in text_positions[idx].split(', '))
idxes = [int(i) for i in idx_txt_value.split('; ')]
for idx, (min_x, min_y, max_x, max_y) in enumerate(text_positions):
if idx not in idxes:
txt_mask_value[min_y:max_y+1, min_x:max_x+1] = 255
txt_mask_value = txt_mask_value[:, :, 0].astype(np.uint8)
non_text_src_img = clear_text(src_img_value, txt_mask_value, wordball_mask_value, text_value=0, non_text_value=255, r=5)
list_texts = get_list_texts(src_img_value, [tuple(map(int, pos.split(', '))) for idx, pos in enumerate(pos_txt_value.split('; ')) if idx in idxes])
list_translated_texts = translate(list_texts, translator)
list_fonts = '; '.join(['MTO Astro City.ttf' for _ in range(len(list_translated_texts))])
list_sizes = '; '.join(['20' for _ in range(len(list_translated_texts))])
list_strokes = '; '.join(['3' for _ in range(len(list_translated_texts))])
list_pads = '; '.join(['5' for _ in range(len(list_translated_texts))])
list_translated_texts = '; '.join(list_translated_texts)
switch = str(random.random())
return non_text_src_img, list_translated_texts, list_fonts, list_sizes, list_strokes, list_pads, switch, 'Xong'
def text_info_change(non_txt_img_value, translated_txt_value, pos_txt_value, idx_txt_value, font_txt_value, size_txt_value, stroke_txt_value, pad_txt_value):
non_txt_img_value = non_txt_img_value.copy()
idxes = [int(i) for i in idx_txt_value.split('; ')]
translated_text_src_img = insert_text(non_txt_img_value,
translated_txt_value.split('; '),
[tuple(map(int, pos.split(', '))) for idx, pos in enumerate(pos_txt_value.split('; ')) if idx in idxes],
font=font_txt_value.split('; '),
font_size=[int(i) for i in size_txt_value.split('; ')],
pad=[int(i) for i in pad_txt_value.split('; ')],
stroke=[int(i) for i in stroke_txt_value.split('; ')])
return translated_text_src_img
def value2_change(value, ID2_value, txt_value):
txt_value = txt_value.split('; ')
txt_value2 = []
for idx, text in enumerate(txt_value):
if idx == ID2_value:
txt_value2.append(str(value))
else:
txt_value2.append(str(text))
txt_value2 = '; '.join(txt_value2)
return txt_value2
# Tạo giao diện Gradio
with gr.Blocks() as demo:
# Cấu trúc
src_img = gr.Image(type="numpy", label="Upload Image")
down_bttn_1 = gr.Button("↓", elem_classes="arrow-button")
with gr.Row():
txt_mask = gr.Image(type="numpy", label="Upload Image", visible=True)
wordball_mask = gr.Image(type="numpy", label="Upload Image", visible=True)
complete = gr.Textbox()
with gr.Row():
idx_txt = gr.Textbox(label='ID', interactive=False, visible=False)
pos_txt = gr.Textbox(label='Pos', interactive=False, visible=False)
area_txt = gr.Textbox(label='Area', interactive=False, visible=False)
rgb_txt = gr.Textbox(label='rgb', interactive=False, visible=False)
with gr.Row():
boxed_txt_img = gr.Image(type="numpy", label="Upload Image")
with gr.Column() as down_1_column:
@gr.render(inputs=[pos_txt, rgb_txt], triggers=[rgb_txt.change])
def create_box(pos_txt_value, rgb_txt_value):
text_positions = pos_txt_value.split('; ')
for idx in range(len(text_positions)):
text_positions[idx] = (int(i) for i in text_positions[idx].split(', '))
rgbs = rgb_txt_value.split('; ')
for idx in range(len(rgbs)):
rgbs[idx] = (int(i) for i in rgbs[idx].split(', '))
elements = []
for idx, (min_x, min_y, max_x, max_y) in enumerate(text_positions):
with gr.Group() as box:
r, g, b = rgbs[idx]
with gr.Row():
gr.Markdown(
f"""
<div style="margin-left: 20px; display: flex; align-items: center;">
<div style="width: 10px; height: 10px; background-color: rgb({r}, {g}, {b}); margin-right: 5px;"></div>
<span style="font-size: 20px;">Textbox {idx+1}</span>
</div>
"""
)
checkbox = gr.Checkbox(value=True, label='', min_width=50, interactive=True)
with gr.Row():
X = gr.Number(label="X", value=min_x, interactive=True)
Y = gr.Number(label="Y", value=min_y, interactive=True)
W = gr.Number(label="W", value=max_x-min_x, interactive=True)
H = gr.Number(label="H", value=max_y-min_y, interactive=True)
ID = gr.Number(label="ID", value=idx, interactive=True, visible=False)
elements.append((X, Y, W, H, ID))
checkbox.change(
fn=ID_block_change,
inputs=[ID, checkbox, idx_txt],
outputs=idx_txt,
show_progress=False
).then(
fn=idx_txt_change,
inputs=[src_img, idx_txt, pos_txt, rgb_txt],
outputs=boxed_txt_img,
)
X.change(
fn=position_block_change,
inputs=[X, Y, W, H, ID, pos_txt],
outputs=pos_txt,
show_progress=False
).then(
fn=idx_txt_change,
inputs=[src_img, idx_txt, pos_txt, rgb_txt],
outputs=boxed_txt_img,
show_progress=False
)
Y.change(
fn=position_block_change,
inputs=[X, Y, W, H, ID, pos_txt],
outputs=pos_txt,
show_progress=False
).then(
fn=idx_txt_change,
inputs=[src_img, idx_txt, pos_txt, rgb_txt],
outputs=boxed_txt_img,
show_progress=False
)
W.change(
fn=position_block_change,
inputs=[X, Y, W, H, ID, pos_txt],
outputs=pos_txt,
show_progress=False
).then(
fn=idx_txt_change,
inputs=[src_img, idx_txt, pos_txt, rgb_txt],
outputs=boxed_txt_img,
show_progress=False
)
H.change(
fn=position_block_change,
inputs=[X, Y, W, H, ID, pos_txt],
outputs=pos_txt,
show_progress=False
).then(
fn=idx_txt_change,
inputs=[src_img, idx_txt, pos_txt, rgb_txt],
outputs=boxed_txt_img,
show_progress=False
)
down_bttn_2 = gr.Button("↓", elem_classes="arrow-button")
non_txt_img = gr.Image(type="numpy", label="Upload Image", visible=False)
complete2 = gr.Textbox()
with gr.Row():
translated_txt = gr.Textbox(label='translated', interactive=False, visible=False)
font_txt = gr.Textbox(label='font', interactive=False, visible=False)
size_txt = gr.Textbox(label='size', interactive=False, visible=False)
stroke_txt = gr.Textbox(label='stroke', interactive=False, visible=False)
pad_txt = gr.Textbox(label='pad', interactive=False, visible=False)
switch_txt = gr.Textbox(label='switch', value='1', interactive=False, visible=False)
with gr.Row():
boxed_inserted_non_txt_img = gr.Image(type="numpy", label="Upload Image")
with gr.Column():
@gr.render(inputs=[translated_txt, font_txt, size_txt, stroke_txt, pad_txt], triggers=[switch_txt.change])
def create_box2(translated_txt_value, font_txt_value, size_txt_value, stroke_txt_value, pad_txt_value):
translated_txt_value = translated_txt_value.split('; ')
font_txt_value = font_txt_value.split('; ')
size_txt_value = size_txt_value.split('; ')
stroke_txt_value = stroke_txt_value.split('; ')
pad_txt_value = pad_txt_value.split('; ')
elements = []
for idx in range(len(font_txt_value)):
with gr.Group():
gr.Markdown(
f"""
<div style="margin-left: 20px; display: flex; align-items: center;">
<div style="width: 10px; height: 10px; background-color: rgb(255, 255, 255); margin-right: 5px;"></div>
<span style="font-size: 20px;">Text box {idx}</span>
</div>
"""
)
translated_text_box = gr.Textbox(label="Translate", value=translated_txt_value[idx], interactive=True)
with gr.Row():
font = gr.Dropdown(choices=os.listdir('MTO Font'), label="Phông chữ", value=font_txt_value[idx], interactive=True, scale=7)
size = gr.Number(label="Size", value=int(size_txt_value[idx]), interactive=True, minimum=1)
stroke = gr.Number(label="Stroke", value=int(stroke_txt_value[idx]), interactive=True, minimum=0, maximum=5)
pad = gr.Number(label="Pad", value=int(pad_txt_value[idx]), interactive=True, minimum=1, maximum=10)
ID2 = gr.Number(label="ID", value=int(idx), interactive=True, visible=False)
translated_text_box.submit(
fn=value2_change,
inputs=[translated_text_box, ID2, translated_txt],
outputs=translated_txt,
show_progress=False
).then(
fn=text_info_change,
inputs=[non_txt_img, translated_txt, pos_txt, idx_txt, font_txt, size_txt, stroke_txt, pad_txt],
outputs=boxed_inserted_non_txt_img,
)
font.change(
fn=value2_change,
inputs=[font, ID2, font_txt],
outputs=font_txt,
show_progress=False
).then(
fn=text_info_change,
inputs=[non_txt_img, translated_txt, pos_txt, idx_txt, font_txt, size_txt, stroke_txt, pad_txt],
outputs=boxed_inserted_non_txt_img,
)
size.change(
fn=value2_change,
inputs=[size, ID2, size_txt],
outputs=size_txt,
show_progress=False
).then(
fn=text_info_change,
inputs=[non_txt_img, translated_txt, pos_txt, idx_txt, font_txt, size_txt, stroke_txt, pad_txt],
outputs=boxed_inserted_non_txt_img,
)
stroke.change(
fn=value2_change,
inputs=[stroke, ID2, stroke_txt],
outputs=stroke_txt,
show_progress=False
).then(
fn=text_info_change,
inputs=[non_txt_img, translated_txt, pos_txt, idx_txt, font_txt, size_txt, stroke_txt, pad_txt],
outputs=boxed_inserted_non_txt_img,
)
pad.change(
fn=value2_change,
inputs=[pad, ID2, pad_txt],
outputs=pad_txt,
show_progress=False
).then(
fn=text_info_change,
inputs=[non_txt_img, translated_txt, pos_txt, idx_txt, font_txt, size_txt, stroke_txt, pad_txt],
outputs=boxed_inserted_non_txt_img,
)
# Css
demo.css = """
.arrow-button {
font-size: 40px; /* Kích thước font */
}
.group-elem {
height: 70px;
}
"""
# Điều khiển
down_bttn_1.click(
fn=down1,
inputs=src_img,
outputs=[txt_mask, wordball_mask, idx_txt, pos_txt, area_txt, rgb_txt, complete],
)
down_bttn_2.click(
fn=down2,
inputs=[src_img, txt_mask, wordball_mask, idx_txt, pos_txt],
outputs=[non_txt_img, translated_txt, font_txt, size_txt, stroke_txt, pad_txt, switch_txt, complete2],
).then(
fn=text_info_change,
inputs=[non_txt_img, translated_txt, pos_txt, idx_txt, font_txt, size_txt, stroke_txt, pad_txt],
outputs=boxed_inserted_non_txt_img,
)
demo.launch()