|
import cv2
|
|
import gradio as gr
|
|
import os
|
|
from edit_func import *
|
|
from TransUnet import Trans_UNet
|
|
import TransUnet_Config as config2
|
|
from huggingface_hub import hf_hub_download
|
|
from googletrans import Translator
|
|
import random
|
|
import torch.nn as nn
|
|
import spaces
|
|
|
|
@spaces.GPU
|
|
class DTM(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
self.detect_text_model = Trans_UNet(
|
|
config2.in_channels, config2.adp_channels, config2.out_channels,
|
|
config2.trans_num_layers, config2.trans_num_attn_heads, config2.trans_ffw_channels,
|
|
config2.dropout
|
|
).to(self.device)
|
|
self.repo_name = 'SS3M/detect-text-model'
|
|
files = ['detect-text-v3-0.pt', 'detect-text-v3-1.pt',
|
|
'detect-text-v3-2.pt', 'detect-text-v3-3.pt',
|
|
'detect-text-v3-4.pt', 'detect-text-v3-5.pt',
|
|
'detect-text-v3-6.pt', 'detect-text-v3-7.pt']
|
|
self.files = []
|
|
for file in files:
|
|
self.files.append(hf_hub_download(repo_id=self.repo_name, filename=file))
|
|
|
|
def forward(self, X):
|
|
X = X.to(self.device)
|
|
N, C, H, W = X.shape
|
|
result = torch.zeros((N, 1, H, W))
|
|
for file in self.files:
|
|
model_path = file
|
|
best_model_state = torch.load(
|
|
model_path,
|
|
weights_only=True,
|
|
map_location=self.device
|
|
)
|
|
self.detect_text_model.load_state_dict(best_model_state)
|
|
result += self.detect_text_model(X)
|
|
result /= len(self.files)
|
|
return result
|
|
|
|
@spaces.GPU
|
|
class DWBM(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
self.detect_wordball_model = Trans_UNet(
|
|
config2.in_channels, config2.adp_channels, config2.out_channels,
|
|
config2.trans_num_layers, config2.trans_num_attn_heads, config2.trans_ffw_channels,
|
|
config2.dropout
|
|
).to(self.device)
|
|
self.repo_name = 'SS3M/detect-wordball-model'
|
|
files = ['detect-text-v3-0.pt', 'detect-text-v3-1.pt',
|
|
'detect-text-v3-2.pt', 'detect-text-v3-3.pt',
|
|
'detect-text-v3-4.pt', 'detect-text-v3-5.pt',
|
|
'detect-text-v3-6.pt', 'detect-text-v3-7.pt']
|
|
self.files = []
|
|
for file in files:
|
|
self.files.append(hf_hub_download(repo_id=self.repo_name, filename=file))
|
|
|
|
def forward(self, X):
|
|
X = X.to(self.device)
|
|
N, C, H, W = X.shape
|
|
result = torch.zeros((N, 1, H, W))
|
|
for file in self.files:
|
|
model_path = file
|
|
best_model_state = torch.load(
|
|
model_path,
|
|
weights_only=True,
|
|
map_location=self.device
|
|
)
|
|
self.detect_wordball_model.load_state_dict(best_model_state)
|
|
result += self.detect_wordball_model(X)
|
|
result /= len(self.files)
|
|
return result
|
|
|
|
detect_text_model = DTM()
|
|
detect_wordball_model = DWBM()
|
|
|
|
translator = Translator()
|
|
|
|
def down1(src_img):
|
|
src_img = cv2.cvtColor(src_img, cv2.COLOR_RGB2BGR)
|
|
text_msk = create_text_mask(src_img, detect_text_model)
|
|
wordball_msk = create_wordball_mask(src_img, detect_wordball_model)
|
|
|
|
text_positions, areas = get_text_positions(text_msk, text_value=0)
|
|
rgbs = []
|
|
for _ in range(len(areas)):
|
|
rgbs.append((random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)))
|
|
|
|
idx = '; '.join(str(i) for i in range(len(areas)))
|
|
text_positions = '; '.join([', '.join(str(i) for i in pos) for pos in text_positions])
|
|
areas = '; '.join(str(i) for i in areas)
|
|
rgbs = '; '.join([', '.join(str(i) for i in rgb) for rgb in rgbs])
|
|
src_img = cv2.cvtColor(src_img, cv2.COLOR_BGR2RGB)
|
|
return text_msk*255, wordball_msk*255, idx, text_positions, areas, rgbs, 'Xong'
|
|
|
|
def idx_txt_change(src_img, idx_txt, pos_txt, rgb_txt):
|
|
try:
|
|
src_img2 = cv2.cvtColor(src_img, cv2.COLOR_RGB2BGR)
|
|
text_positions = pos_txt.split('; ')
|
|
for idx in range(len(text_positions)):
|
|
text_positions[idx] = (int(i) for i in text_positions[idx].split(', '))
|
|
rgbs = rgb_txt.split('; ')
|
|
for idx in range(len(rgbs)):
|
|
rgbs[idx] = (int(i) for i in rgbs[idx].split(', '))
|
|
idxes = [int(idx) for idx in idx_txt.split('; ')]
|
|
|
|
for idx, ((min_x, min_y, max_x, max_y), (r, g, b)) in enumerate(zip(text_positions, rgbs)):
|
|
if idx in idxes:
|
|
cv2.rectangle(src_img2, (min_x, min_y), (max_x, max_y), (b, g, r), thickness=4)
|
|
src_img2 = cv2.cvtColor(src_img2, cv2.COLOR_BGR2RGB)
|
|
return src_img2
|
|
except:
|
|
return src_img
|
|
|
|
def scale_area_change(min_area, max_area, area_txt):
|
|
areas = [int(area) for area in area_txt.split('; ')]
|
|
idxes = []
|
|
for idx, area in enumerate(areas):
|
|
if min_area <= area <= max_area:
|
|
idxes.append(idx)
|
|
idxes = '; '.join(str(i) for i in idxes)
|
|
return idxes
|
|
|
|
def position_block_change(X, Y, W, H, ID, pos_txt_value):
|
|
text_positions = pos_txt_value.split('; ')
|
|
for idx in range(len(text_positions)):
|
|
text_positions[idx] = (int(i) for i in text_positions[idx].split(', '))
|
|
|
|
text_positions2 = []
|
|
for idx, (min_x, min_y, max_x, max_y) in enumerate(text_positions):
|
|
if idx == ID:
|
|
text_positions2.append((X, Y, X+W, Y+H))
|
|
else:
|
|
text_positions2.append((min_x, min_y, max_x, max_y))
|
|
text_positions2 = '; '.join([', '.join(str(i) for i in pos) for pos in text_positions2])
|
|
return text_positions2
|
|
|
|
def ID_block_change(ID_value, checkbox_value, ID_txt_value):
|
|
ID_txt_value = [int(i) for i in ID_txt_value.split('; ')]
|
|
if checkbox_value and ID_value not in ID_txt_value:
|
|
ID_txt_value.append(ID_value)
|
|
if not checkbox_value and ID_value in ID_txt_value:
|
|
ID_txt_value.remove(ID_value)
|
|
ID_txt_value = sorted(ID_txt_value)
|
|
ID_txt_value = '; '.join([str(i) for i in ID_txt_value])
|
|
return ID_txt_value
|
|
|
|
def down2(src_img_value, txt_mask_value, wordball_mask_value, idx_txt_value, pos_txt_value):
|
|
src_img_value = cv2.cvtColor(src_img_value, cv2.COLOR_RGB2BGR)
|
|
text_positions = pos_txt_value.split('; ')
|
|
for idx in range(len(text_positions)):
|
|
text_positions[idx] = (int(i) for i in text_positions[idx].split(', '))
|
|
idxes = [int(i) for i in idx_txt_value.split('; ')]
|
|
|
|
for idx, (min_x, min_y, max_x, max_y) in enumerate(text_positions):
|
|
if idx not in idxes:
|
|
txt_mask_value[min_y:max_y+1, min_x:max_x+1] = 255
|
|
txt_mask_value = txt_mask_value[:, :, 0].astype(np.uint8)
|
|
non_text_src_img = clear_text(src_img_value, txt_mask_value, wordball_mask_value, text_value=0, non_text_value=255, r=5)
|
|
|
|
list_texts = get_list_texts(src_img_value, [tuple(map(int, pos.split(', '))) for idx, pos in enumerate(pos_txt_value.split('; ')) if idx in idxes])
|
|
list_translated_texts = translate(list_texts, translator)
|
|
list_fonts = '; '.join(['MTO Astro City.ttf' for _ in range(len(list_translated_texts))])
|
|
list_sizes = '; '.join(['20' for _ in range(len(list_translated_texts))])
|
|
list_strokes = '; '.join(['3' for _ in range(len(list_translated_texts))])
|
|
list_pads = '; '.join(['5' for _ in range(len(list_translated_texts))])
|
|
list_translated_texts = '; '.join(list_translated_texts)
|
|
switch = str(random.random())
|
|
|
|
return non_text_src_img, list_translated_texts, list_fonts, list_sizes, list_strokes, list_pads, switch, 'Xong'
|
|
|
|
def text_info_change(non_txt_img_value, translated_txt_value, pos_txt_value, idx_txt_value, font_txt_value, size_txt_value, stroke_txt_value, pad_txt_value):
|
|
non_txt_img_value = non_txt_img_value.copy()
|
|
idxes = [int(i) for i in idx_txt_value.split('; ')]
|
|
|
|
translated_text_src_img = insert_text(non_txt_img_value,
|
|
translated_txt_value.split('; '),
|
|
[tuple(map(int, pos.split(', '))) for idx, pos in enumerate(pos_txt_value.split('; ')) if idx in idxes],
|
|
font=font_txt_value.split('; '),
|
|
font_size=[int(i) for i in size_txt_value.split('; ')],
|
|
pad=[int(i) for i in pad_txt_value.split('; ')],
|
|
stroke=[int(i) for i in stroke_txt_value.split('; ')])
|
|
return translated_text_src_img
|
|
|
|
def value2_change(value, ID2_value, txt_value):
|
|
txt_value = txt_value.split('; ')
|
|
|
|
txt_value2 = []
|
|
for idx, text in enumerate(txt_value):
|
|
if idx == ID2_value:
|
|
txt_value2.append(str(value))
|
|
else:
|
|
txt_value2.append(str(text))
|
|
txt_value2 = '; '.join(txt_value2)
|
|
return txt_value2
|
|
|
|
|
|
with gr.Blocks() as demo:
|
|
|
|
src_img = gr.Image(type="numpy", label="Upload Image")
|
|
|
|
down_bttn_1 = gr.Button("↓", elem_classes="arrow-button")
|
|
|
|
with gr.Row():
|
|
txt_mask = gr.Image(type="numpy", label="Upload Image", visible=True)
|
|
wordball_mask = gr.Image(type="numpy", label="Upload Image", visible=True)
|
|
complete = gr.Textbox()
|
|
with gr.Row():
|
|
idx_txt = gr.Textbox(label='ID', interactive=False, visible=False)
|
|
pos_txt = gr.Textbox(label='Pos', interactive=False, visible=False)
|
|
area_txt = gr.Textbox(label='Area', interactive=False, visible=False)
|
|
rgb_txt = gr.Textbox(label='rgb', interactive=False, visible=False)
|
|
with gr.Row():
|
|
boxed_txt_img = gr.Image(type="numpy", label="Upload Image")
|
|
with gr.Column() as down_1_column:
|
|
@gr.render(inputs=[pos_txt, rgb_txt], triggers=[rgb_txt.change])
|
|
def create_box(pos_txt_value, rgb_txt_value):
|
|
text_positions = pos_txt_value.split('; ')
|
|
for idx in range(len(text_positions)):
|
|
text_positions[idx] = (int(i) for i in text_positions[idx].split(', '))
|
|
rgbs = rgb_txt_value.split('; ')
|
|
for idx in range(len(rgbs)):
|
|
rgbs[idx] = (int(i) for i in rgbs[idx].split(', '))
|
|
|
|
elements = []
|
|
for idx, (min_x, min_y, max_x, max_y) in enumerate(text_positions):
|
|
with gr.Group() as box:
|
|
r, g, b = rgbs[idx]
|
|
with gr.Row():
|
|
gr.Markdown(
|
|
f"""
|
|
<div style="margin-left: 20px; display: flex; align-items: center;">
|
|
<div style="width: 10px; height: 10px; background-color: rgb({r}, {g}, {b}); margin-right: 5px;"></div>
|
|
<span style="font-size: 20px;">Textbox {idx+1}</span>
|
|
</div>
|
|
"""
|
|
)
|
|
checkbox = gr.Checkbox(value=True, label='', min_width=50, interactive=True)
|
|
with gr.Row():
|
|
X = gr.Number(label="X", value=min_x, interactive=True)
|
|
Y = gr.Number(label="Y", value=min_y, interactive=True)
|
|
W = gr.Number(label="W", value=max_x-min_x, interactive=True)
|
|
H = gr.Number(label="H", value=max_y-min_y, interactive=True)
|
|
ID = gr.Number(label="ID", value=idx, interactive=True, visible=False)
|
|
elements.append((X, Y, W, H, ID))
|
|
|
|
checkbox.change(
|
|
fn=ID_block_change,
|
|
inputs=[ID, checkbox, idx_txt],
|
|
outputs=idx_txt,
|
|
show_progress=False
|
|
).then(
|
|
fn=idx_txt_change,
|
|
inputs=[src_img, idx_txt, pos_txt, rgb_txt],
|
|
outputs=boxed_txt_img,
|
|
)
|
|
X.change(
|
|
fn=position_block_change,
|
|
inputs=[X, Y, W, H, ID, pos_txt],
|
|
outputs=pos_txt,
|
|
show_progress=False
|
|
).then(
|
|
fn=idx_txt_change,
|
|
inputs=[src_img, idx_txt, pos_txt, rgb_txt],
|
|
outputs=boxed_txt_img,
|
|
show_progress=False
|
|
)
|
|
Y.change(
|
|
fn=position_block_change,
|
|
inputs=[X, Y, W, H, ID, pos_txt],
|
|
outputs=pos_txt,
|
|
show_progress=False
|
|
).then(
|
|
fn=idx_txt_change,
|
|
inputs=[src_img, idx_txt, pos_txt, rgb_txt],
|
|
outputs=boxed_txt_img,
|
|
show_progress=False
|
|
)
|
|
W.change(
|
|
fn=position_block_change,
|
|
inputs=[X, Y, W, H, ID, pos_txt],
|
|
outputs=pos_txt,
|
|
show_progress=False
|
|
).then(
|
|
fn=idx_txt_change,
|
|
inputs=[src_img, idx_txt, pos_txt, rgb_txt],
|
|
outputs=boxed_txt_img,
|
|
show_progress=False
|
|
)
|
|
H.change(
|
|
fn=position_block_change,
|
|
inputs=[X, Y, W, H, ID, pos_txt],
|
|
outputs=pos_txt,
|
|
show_progress=False
|
|
).then(
|
|
fn=idx_txt_change,
|
|
inputs=[src_img, idx_txt, pos_txt, rgb_txt],
|
|
outputs=boxed_txt_img,
|
|
show_progress=False
|
|
)
|
|
down_bttn_2 = gr.Button("↓", elem_classes="arrow-button")
|
|
|
|
non_txt_img = gr.Image(type="numpy", label="Upload Image", visible=False)
|
|
complete2 = gr.Textbox()
|
|
with gr.Row():
|
|
translated_txt = gr.Textbox(label='translated', interactive=False, visible=False)
|
|
font_txt = gr.Textbox(label='font', interactive=False, visible=False)
|
|
size_txt = gr.Textbox(label='size', interactive=False, visible=False)
|
|
stroke_txt = gr.Textbox(label='stroke', interactive=False, visible=False)
|
|
pad_txt = gr.Textbox(label='pad', interactive=False, visible=False)
|
|
switch_txt = gr.Textbox(label='switch', value='1', interactive=False, visible=False)
|
|
with gr.Row():
|
|
boxed_inserted_non_txt_img = gr.Image(type="numpy", label="Upload Image")
|
|
with gr.Column():
|
|
@gr.render(inputs=[translated_txt, font_txt, size_txt, stroke_txt, pad_txt], triggers=[switch_txt.change])
|
|
def create_box2(translated_txt_value, font_txt_value, size_txt_value, stroke_txt_value, pad_txt_value):
|
|
translated_txt_value = translated_txt_value.split('; ')
|
|
font_txt_value = font_txt_value.split('; ')
|
|
size_txt_value = size_txt_value.split('; ')
|
|
stroke_txt_value = stroke_txt_value.split('; ')
|
|
pad_txt_value = pad_txt_value.split('; ')
|
|
|
|
elements = []
|
|
for idx in range(len(font_txt_value)):
|
|
with gr.Group():
|
|
gr.Markdown(
|
|
f"""
|
|
<div style="margin-left: 20px; display: flex; align-items: center;">
|
|
<div style="width: 10px; height: 10px; background-color: rgb(255, 255, 255); margin-right: 5px;"></div>
|
|
<span style="font-size: 20px;">Text box {idx}</span>
|
|
</div>
|
|
"""
|
|
)
|
|
translated_text_box = gr.Textbox(label="Translate", value=translated_txt_value[idx], interactive=True)
|
|
with gr.Row():
|
|
font = gr.Dropdown(choices=os.listdir('MTO Font'), label="Phông chữ", value=font_txt_value[idx], interactive=True, scale=7)
|
|
size = gr.Number(label="Size", value=int(size_txt_value[idx]), interactive=True, minimum=1)
|
|
stroke = gr.Number(label="Stroke", value=int(stroke_txt_value[idx]), interactive=True, minimum=0, maximum=5)
|
|
pad = gr.Number(label="Pad", value=int(pad_txt_value[idx]), interactive=True, minimum=1, maximum=10)
|
|
ID2 = gr.Number(label="ID", value=int(idx), interactive=True, visible=False)
|
|
|
|
translated_text_box.submit(
|
|
fn=value2_change,
|
|
inputs=[translated_text_box, ID2, translated_txt],
|
|
outputs=translated_txt,
|
|
show_progress=False
|
|
).then(
|
|
fn=text_info_change,
|
|
inputs=[non_txt_img, translated_txt, pos_txt, idx_txt, font_txt, size_txt, stroke_txt, pad_txt],
|
|
outputs=boxed_inserted_non_txt_img,
|
|
)
|
|
font.change(
|
|
fn=value2_change,
|
|
inputs=[font, ID2, font_txt],
|
|
outputs=font_txt,
|
|
show_progress=False
|
|
).then(
|
|
fn=text_info_change,
|
|
inputs=[non_txt_img, translated_txt, pos_txt, idx_txt, font_txt, size_txt, stroke_txt, pad_txt],
|
|
outputs=boxed_inserted_non_txt_img,
|
|
)
|
|
size.change(
|
|
fn=value2_change,
|
|
inputs=[size, ID2, size_txt],
|
|
outputs=size_txt,
|
|
show_progress=False
|
|
).then(
|
|
fn=text_info_change,
|
|
inputs=[non_txt_img, translated_txt, pos_txt, idx_txt, font_txt, size_txt, stroke_txt, pad_txt],
|
|
outputs=boxed_inserted_non_txt_img,
|
|
)
|
|
stroke.change(
|
|
fn=value2_change,
|
|
inputs=[stroke, ID2, stroke_txt],
|
|
outputs=stroke_txt,
|
|
show_progress=False
|
|
).then(
|
|
fn=text_info_change,
|
|
inputs=[non_txt_img, translated_txt, pos_txt, idx_txt, font_txt, size_txt, stroke_txt, pad_txt],
|
|
outputs=boxed_inserted_non_txt_img,
|
|
)
|
|
pad.change(
|
|
fn=value2_change,
|
|
inputs=[pad, ID2, pad_txt],
|
|
outputs=pad_txt,
|
|
show_progress=False
|
|
).then(
|
|
fn=text_info_change,
|
|
inputs=[non_txt_img, translated_txt, pos_txt, idx_txt, font_txt, size_txt, stroke_txt, pad_txt],
|
|
outputs=boxed_inserted_non_txt_img,
|
|
)
|
|
|
|
|
|
demo.css = """
|
|
.arrow-button {
|
|
font-size: 40px; /* Kích thước font */
|
|
}
|
|
.group-elem {
|
|
height: 70px;
|
|
}
|
|
"""
|
|
|
|
|
|
down_bttn_1.click(
|
|
fn=down1,
|
|
inputs=src_img,
|
|
outputs=[txt_mask, wordball_mask, idx_txt, pos_txt, area_txt, rgb_txt, complete],
|
|
)
|
|
down_bttn_2.click(
|
|
fn=down2,
|
|
inputs=[src_img, txt_mask, wordball_mask, idx_txt, pos_txt],
|
|
outputs=[non_txt_img, translated_txt, font_txt, size_txt, stroke_txt, pad_txt, switch_txt, complete2],
|
|
).then(
|
|
fn=text_info_change,
|
|
inputs=[non_txt_img, translated_txt, pos_txt, idx_txt, font_txt, size_txt, stroke_txt, pad_txt],
|
|
outputs=boxed_inserted_non_txt_img,
|
|
)
|
|
|
|
demo.launch()
|
|
|