Spaces:
Sleeping
Sleeping
from recognize import recongize | |
from ner import ner | |
import os | |
import time | |
import argparse | |
from sr.sr import sr | |
import torch | |
from scipy.ndimage import gaussian_filter | |
from PIL import Image | |
import numpy as np | |
import torch.nn as nn | |
import torch.backends.cudnn as cudnn | |
from torch.autograd import Variable | |
from mosaik import mosaik | |
from PIL import Image | |
import cv2 | |
from skimage import io | |
import numpy as np | |
import craft_utils | |
import imgproc | |
import file_utils | |
from seg import mask_percentage | |
from seg2 import dino_seg | |
from craft import CRAFT | |
from collections import OrderedDict | |
import gradio as gr | |
from refinenet import RefineNet | |
# craft, refine ๋ชจ๋ธ ๋ถ๋ฌ์ค๋ ์ฝ๋ | |
def copyStateDict(state_dict): | |
if list(state_dict.keys())[0].startswith("module"): | |
start_idx = 1 | |
else: | |
start_idx = 0 | |
new_state_dict = OrderedDict() | |
for k, v in state_dict.items(): | |
name = ".".join(k.split(".")[start_idx:]) | |
new_state_dict[name] = v | |
return new_state_dict | |
def str2bool(v): | |
return v.lower() in ("yes", "y", "true", "t", "1") | |
parser = argparse.ArgumentParser(description='CRAFT Text Detection') | |
parser.add_argument('--trained_model', default='weights/craft_mlt_25k.pth', type=str, help='์ฌ์ ํ์ต craft ๋ชจ๋ธ') | |
parser.add_argument('--text_threshold', default=0.7, type=float, help='text confidence threshold') | |
parser.add_argument('--low_text', default=0.4, type=float, help='text low-bound score') | |
parser.add_argument('--link_threshold', default=0.4, type=float, help='link confidence threshold') | |
parser.add_argument('--cuda', default=False, type=str2bool, help='Use cuda for inference') | |
parser.add_argument('--canvas_size', default=1280, type=int, help='image size for inference') | |
parser.add_argument('--mag_ratio', default=1.5, type=float, help='image magnification ratio') | |
parser.add_argument('--poly', default=False, action='store_true', help='enable polygon type') | |
parser.add_argument('--refine', default=True, help='enable link refiner') | |
parser.add_argument('--image_path', default="input/2.png", help='input image') | |
parser.add_argument('--refiner_model', default='weights/craft_refiner_CTW1500.pth', type=str, help='pretrained refiner model') | |
args = parser.parse_args() | |
# ์๋๋ option | |
def full_img_masking(full_image,net,refine_net): | |
reference_image=sr(full_image) | |
reference_boxes=text_detect(reference_image,net=net,refine_net=refine_net) | |
boxes=get_box_from_refer(reference_boxes) | |
for index2,box in enumerate(boxes): | |
xmin,xmax,ymin,ymax=get_min_max(box) | |
text_area=full_image[int(ymin):int(ymax),int(xmin):int(xmax),:] | |
text=recongize(text_area) | |
label=ner(text) | |
if label==1: | |
A=full_image[int(ymin):int(ymax),int(xmin):int(xmax),:] | |
full_image[int(ymin):int(ymax),int(xmin):int(xmax),:] = gaussian_filter(A, sigma=16) | |
return full_image | |
def test_net(net, image, text_threshold, link_threshold, low_text, cuda, poly, refine_net=None): | |
t0 = time.time() | |
img_resized, target_ratio, size_heatmap = imgproc.resize_aspect_ratio(image, args.canvas_size, interpolation=cv2.INTER_LINEAR, mag_ratio=args.mag_ratio) | |
ratio_h = ratio_w = 1 / target_ratio | |
x = imgproc.normalizeMeanVariance(img_resized) | |
x = torch.from_numpy(x).permute(2, 0, 1) | |
x = Variable(x.unsqueeze(0)) | |
if cuda: | |
x = x.cuda() | |
with torch.no_grad(): | |
y, feature = net(x) | |
score_text = y[0,:,:,0].cpu().data.numpy() | |
score_link = y[0,:,:,1].cpu().data.numpy() | |
if refine_net is not None: | |
with torch.no_grad(): | |
y_refiner = refine_net(y, feature) | |
score_link = y_refiner[0,:,:,0].cpu().data.numpy() | |
t0 = time.time() - t0 | |
t1 = time.time() | |
boxes, polys = craft_utils.getDetBoxes(score_text, score_link, text_threshold, link_threshold, low_text, poly) | |
boxes = craft_utils.adjustResultCoordinates(boxes, ratio_w, ratio_h) | |
polys = craft_utils.adjustResultCoordinates(polys, ratio_w, ratio_h) | |
for k in range(len(polys)): | |
if polys[k] is None: polys[k] = boxes[k] | |
t1 = time.time() - t1 | |
# render results (optional) | |
render_img = score_text.copy() | |
render_img = np.hstack((render_img, score_link)) | |
ret_score_text = imgproc.cvt2HeatmapImg(render_img) | |
return boxes, polys, ret_score_text | |
def text_detect(image,net,refine_net): | |
bboxes, polys, score_text = test_net(net, image, args.text_threshold, args.link_threshold, args.low_text, args.cuda, args.poly, refine_net) | |
return bboxes | |
def get_box_from_refer(reference_boxes): | |
real_boxes=[] | |
for box in reference_boxes: | |
real_boxes.append(box//2) | |
return real_boxes | |
def get_min_max(box): | |
xlist=[] | |
ylist=[] | |
for coor in box: | |
xlist.append(coor[0]) | |
ylist.append(coor[1]) | |
return min(xlist),max(xlist),min(ylist),max(ylist) | |
def main(image_path0): | |
# 1๋จ๊ณ | |
# ==> craft ๋ชจ๋ธ๊ณผ refinenet ๋ชจ๋ธ์ ๋ถ๋ฌ์ค๊ณ cuda device ์ ์นํ๋๋ค. | |
net = CRAFT() | |
if args.cuda: | |
net.load_state_dict(copyStateDict(torch.load(args.trained_model))) | |
if args.cuda: | |
net = net.cuda() | |
cudnn.benchmark = False | |
net.eval() | |
refine_net = None | |
if args.refine: | |
refine_net = RefineNet() | |
if args.cuda: | |
refine_net.load_state_dict(copyStateDict(torch.load(args.refiner_model))) | |
refine_net = refine_net.cuda() | |
refine_net.eval() | |
args.poly = True | |
# 2๋จ๊ณ | |
# gradio ๋น์นธ์ ์ด๋ฏธ์ง๋ฅผ ๋ฃ๊ณ A ์ ์ ๋ ฅ๋ฉ๋๋ค. | |
A=image_path0 | |
image_list=[] | |
image_list.append(A) | |
for k, image_path in enumerate(image_list): | |
image = imgproc.loadImage(image_path) | |
if image.shape[2]>3: | |
image=image[:,:,0:3] | |
original_image=image | |
# ์ด๋ฏธ์ง์์ ์ก์ฅ๋ถ๋ถ๋ง dinov2 ๋ชจ๋ธ๋ก segmentation ์ ํฉ๋๋ค. | |
output=dino_seg(image) | |
image3=Image.fromarray(output) | |
image3.save("temporal_mask/mask.png") | |
# ๋ง์คํฌ์ด๋ฏธ์ง(white pixel, black background)๋ฅผ ๋ง๋ญ๋๋ค. | |
# ์ ๋ง์คํฌ ์ด๋ฏธ์ง์์ ๊ฐ ๋ฉ์ด๋ฆฌ๋ค(์ก์ฅ์ผ๋ก ์ถ์ )์ด ์ ์ฒด ์ด๋ฏธ์ง๋ด์์ ๋ชํ๋ก์ฐจ์งํ๋์ง ๊ณ์ฐํฉ๋๋ค. | |
contours_list,percentage_list=mask_percentage("temporal_mask/mask.png") | |
normal_image_list=[] | |
small_coordinate_list=[] | |
original_coordinate_list=[] | |
#3๋จ๊ณ | |
sorted_list = sorted(percentage_list, reverse=True) | |
top_5 = sorted_list[:5] | |
print("์์ 5๊ฐ ๊ฐ:", top_5) | |
# percentage list์ ๊ฒฝ์ฐ ์ก์ฅ์ผ๋ก ์ถ์ ๋๋ ๋ญ์น๋ค์ ํผ์ผํธ๋ฅผ ๋ชจ์๋์๊ฒ์ด๊ณ | |
# contours list๋ ์ด๋ฏธ์ง๋ด์์ ์ก์ฅ์ผ๋ก ์ถ์ ๋๋ ๋ญ์น๋ค์ด ํฌ๋กญ๋์ด์ ์ ๋ ฌ๋ ๋ฆฌ์คํธ์ ๋๋ค. | |
# ์ : percentatge list ์ ์ฒซ๋ฒ์งธ ์์๋ contours list์ ์ฒซ๋ฒ์งธ ์์์ percentage | |
for index,percentage in enumerate(percentage_list): | |
if 5<percentage: | |
# percentage ๊ฐ ์๋ฏธ์ง๋ด์์ 5ํ๋ก ๋๋ ๊ฒ๋ค์ normal list๋ก ํฌํจ๋ฉ๋๋ค. | |
# normal list์์๋ ์ด๋ฏธ์ง๋ด์์ ์ถฉ๋ถํ ํฐ ๋ญ์น๋ค(์ก์ฅ์ผ๋ก ์ถ์ ) ์ ๋ชจ์๋์์ต๋๋ค. | |
# 1-5ํ๋ก ์ธ๊ฒ๋ค์ small coordinate list์ ํฌํจ๋๊ณ ๋งค์ฐ ์์ ๋ญ์น๋ก ๊ฐ์ฃผํฉ๋๋ค. | |
# ๋งค์ฐ์์ ๋ญ์น์ ๊ฒฝ์ฐ zoom in์ ํ์๋ ๋ญ์น(์ก์ฅ์ผ๋ก ์ถ์ )๋ด ๊ธ์๊ฐ ๊ฑฐ์ ๋ณด์ด์ง์์์ ๋ฐ๋ผ์ ๋ญ์น ์ ์ฒด๋ฅผ mosaikํฉ๋๋ค. | |
# 1ํ๋ก๋ฏธ๋ง ๋ญ์น๋ค์ ์๋ฉธ์ง์ ์ผ์ ๋๋ก ์์ ์๋ตํฉ๋๋ค. | |
contour=contours_list[index] | |
x_list=[] | |
y_list=[] | |
contour2=list(contour) | |
for r in contour2: | |
r2=r[0] | |
x_list.append(r2[0]) | |
y_list.append(r2[1]) | |
x_min=min(x_list) | |
y_min=min(y_list) | |
x_max=max(x_list) | |
y_max=max(y_list) | |
original_coordinate_list.append([y_min,y_max,x_min,x_max]) | |
image2=original_image[y_min:y_max,x_min:x_max,:] | |
normal_image_list.append(image2) | |
# | |
elif 1<percentage<5: | |
contour=contours_list[index] | |
x_list=[] | |
y_list=[] | |
contour2=list(contour) | |
for r in contour2: | |
r2=r[0] | |
x_list.append(r2[0]) | |
y_list.append(r2[1]) | |
x_min=min(x_list) | |
y_min=min(y_list) | |
x_max=max(x_list) | |
y_max=max(y_list) | |
small_coordinate_list.append([y_min,y_max,x_min,x_max]) #์ก์ฅ 5ํ๋ก๋ฏธ๋ง์ ์ขํ | |
else: | |
continue | |
# 4๋จ๊ณ (๋งค์ฐ์์ ์ก์ฅ) | |
# small coordinate list์์ ๋งค์ฐ์์ ์ก์ฅ๋ค์ด ๋ชจ์ฌ์ ธ์์ง๋ง list์์ ์์๊ฐ ์์ผ๋ฉด 5๋จ๊ณ๋ก ๋ฐ๋ก๊ฐ๋๋ค. | |
# ๋ฐ๋ก ๊ฐ์ง์์๊ฒฝ์ฐ(list ์์์ ์ต์ํ๋) mosaik ๋ฅผ ํตํด์ ์ ์ฒด์ด๋ฏธ์ง์์ ์์ ๋ญ์น์ ํด๋นํ๋ ์ขํ๋ค์ ๋ชจ๋ ๋ชจ์์ดํฌํฉ๋๋ค. | |
if len(small_coordinate_list)>0: | |
original_image=mosaik(original_image,small_coordinate_list) | |
else: | |
pass | |
# 5๋จ๊ณ (์ด๋์ ๋ ์ฌ์ด์ฆ ์๋ ์ก์ฅ) ==> normal list | |
# normal image list์์ ์ ์ ํ ํฌ๊ธฐ์ ์ก์ฅ(์ค ํ๋ฉด ๊ธ์ ๋ณด์ด๋) ๋ค์ด ์์ต๋๋ค. | |
# craft ์ ์ฅ์์ text ์์น๋ฅผ return ํ ์ ์๊ฒ๋ ํฌ๋กญ๋ ์ก์ฅ์ esrgan ์ผ๋ก ํ์ง๊ฐ์ ํฉ๋๋ค. | |
# ํ์ง๊ฐ์ ๋ ์ก์ฅ์ craft์ ๋ฃ์ด์ ์ ํํ๊ฒ text ์ขํ๋ค์ ๋ชจ๋ ๊ตฌํฉ๋๋ค. | |
# ์ขํ๋ฅผ ๊ตฌํ ๋ ํ์ง ์ข์ ์ก์ฅ์ด๋ฏธ์ง์ ์ขํ๋ฅผ ๊ทธ๋๋ก return ํ์ง ์๊ณ ์๋ณธ ์ก์ฅ์ด๋ฏธ์ง์ ๋ง์ถ์ด์ scale(//2) ํ๊ณ ์ต์ข ์ขํ๋ฅผ ๊ตฌํฉ๋๋ค. | |
for index,normal_image in enumerate(normal_image_list): | |
reference_image=sr(normal_image) | |
reference_boxes=text_detect(reference_image,net=net,refine_net=refine_net) | |
boxes=get_box_from_refer(reference_boxes) | |
for index2,box in enumerate(boxes): | |
xmin,xmax,ymin,ymax=get_min_max(box) | |
text_area=normal_image[int(ymin):int(ymax),int(xmin):int(xmax),:] | |
text_area=Image.fromarray(text_area) | |
os.makedirs("text_area",exist_ok=True) | |
text_area.save(f"text_area/new_{index2+1}.png") | |
# 6๋จ๊ณ (text recognize, ner) | |
# ์ ์ขํ๋ค์ ํตํด์ ์ก์ฅ ๋ด์์ ๋ฐ์ค๋ค์ ํฌ๋กญํฉ๋๋ค. | |
# ํฌ๋กญ๋ ์ก์ฅ๋ด ๋ถ๋ถ(ํฌ๋กญ๋ ๋ฐ์ค , ์ฆ text ์๋๊ณณ์ผ๋ก ์ถ์ ๋๋๊ณณ) ์ trocr ์๋ฃ์ต๋๋ค. | |
# trocr์ ์์๋ด์ ์ถ์ ๋๋ text๋ฅผ ๋ณด์ฌ์ค๋๋ค. | |
# text๋ฅผ ko electra ์๋ฃ์ด์ ํด๋น ์์์์๋ text๊ฐ ๊ฐ์ธ์ ๋ณด์ธ์ง์๋์ง ํ๋ณํฉ๋๋ค. | |
# ์ก์ฅ๋ด ํด๋น ์์๊ฐ ๊ฐ์ธ์ ๋ณด๋ก(๋ ์ด๋ธ :1) ์ถ์ ๋ ๊ฒฝ์ฐ ๋ชจ์์ดํฌ๋ฅผํฉ๋๋ค. | |
# ๋ชจ์์ดํฌ๋ผ๊ณ ํ๋ณํ ๊ฒฝ์ฐ ํด๋น์์์ ์ขํ๋ฅผ ์ก์ฅ์ด๋ฏธ์ง์ ๋ง๋ ์ขํ๋ก ๋ณํํ๊ณ ๊ทธ ์ขํ์ ํด๋นํ๋ ๋ถ๋ถ์ ๋ชจ์์ดํฌํฉ๋๋ค. | |
# ๋ถ๋ถ์ ์ผ๋ก ๋ชจ์์ดํฌ๋ ์ก์ฅ์ด๋ฏธ์ง๋ฅผ ์ ์ฒด์ด๋ฏธ์ง(์ก์ฅ์ ํฌํจํ๋ ์ด๋ฏธ์ง)์ ๋ถ์ ๋๋ค. | |
text=recongize(text_area) | |
label=ner(text) | |
with open("output/text_recongnize.txt","a") as recognized: | |
recognized.writelines(str(index2+1)) | |
recognized.writelines(" ") | |
recognized.writelines(str(text)) | |
recognized.writelines(" ") | |
recognized.writelines(str(label)) | |
recognized.writelines("\n") | |
recognized.close() | |
print("done") | |
if label==1: | |
A=normal_image[int(ymin):int(ymax),int(xmin):int(xmax),:] | |
normal_image[int(ymin):int(ymax),int(xmin):int(xmax),:] = gaussian_filter(A, sigma=16) | |
else: | |
pass | |
a,b,c,d=original_coordinate_list[index] | |
original_image[a:b,c:d,:]=normal_image | |
# ๋ ์ ํ๋ ๋์ด๊ธฐ์ํด์ ์ด๋ฏธ์ง ์ ์ฒด(์ก์ฅ๊ณผ ๋ฐฐ๊ฒฝ ๋๋ค) craft์ ํต์งธ๋ก ๋ฃ๊ธฐ | |
# ๋จ optional (๋จ์ : infer speed ) | |
#print("full mask start") | |
#original_image=full_img_masking(original_image,net=net,refine_net=refine_net) | |
#print("full mask done") | |
original_image=Image.fromarray(original_image) | |
original_image.save("output/mosaiked.png") | |
print("masked complete") | |
return original_image | |
# if __name__ == '__main__': | |
# iface = gr.Interface( | |
# fn=main, | |
# inputs=gr.Image(type="filepath", label="Invoice Image"), | |
# outputs=gr.Image(type="pil", label="Masked Invoice Image"), | |
# live=True | |
# ) | |
# iface.launch() | |