esun-choi's picture
Update main.py
a9ccd5e
raw
history blame
13.6 kB
from recognize import recongize
from ner import ner
import os
import time
import argparse
from sr.sr import sr
import torch
from scipy.ndimage import gaussian_filter
from PIL import Image
import numpy as np
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.autograd import Variable
from mosaik import mosaik
from PIL import Image
import cv2
from skimage import io
import numpy as np
import craft_utils
import imgproc
import file_utils
from seg import mask_percentage
from seg2 import dino_seg
from craft import CRAFT
from collections import OrderedDict
import gradio as gr
from refinenet import RefineNet
# craft, refine ๋ชจ๋ธ ๋ถˆ๋Ÿฌ์˜ค๋Š” ์ฝ”๋“œ
def copyStateDict(state_dict):
if list(state_dict.keys())[0].startswith("module"):
start_idx = 1
else:
start_idx = 0
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = ".".join(k.split(".")[start_idx:])
new_state_dict[name] = v
return new_state_dict
def str2bool(v):
return v.lower() in ("yes", "y", "true", "t", "1")
parser = argparse.ArgumentParser(description='CRAFT Text Detection')
parser.add_argument('--trained_model', default='weights/craft_mlt_25k.pth', type=str, help='์‚ฌ์ „ํ•™์Šต craft ๋ชจ๋ธ')
parser.add_argument('--text_threshold', default=0.7, type=float, help='text confidence threshold')
parser.add_argument('--low_text', default=0.4, type=float, help='text low-bound score')
parser.add_argument('--link_threshold', default=0.4, type=float, help='link confidence threshold')
parser.add_argument('--cuda', default=False, type=str2bool, help='Use cuda for inference')
parser.add_argument('--canvas_size', default=1280, type=int, help='image size for inference')
parser.add_argument('--mag_ratio', default=1.5, type=float, help='image magnification ratio')
parser.add_argument('--poly', default=False, action='store_true', help='enable polygon type')
parser.add_argument('--refine', default=True, help='enable link refiner')
parser.add_argument('--image_path', default="input/2.png", help='input image')
parser.add_argument('--refiner_model', default='weights/craft_refiner_CTW1500.pth', type=str, help='pretrained refiner model')
args = parser.parse_args()
# ์•„๋ž˜๋Š” option
def full_img_masking(full_image,net,refine_net):
reference_image=sr(full_image)
reference_boxes=text_detect(reference_image,net=net,refine_net=refine_net)
boxes=get_box_from_refer(reference_boxes)
for index2,box in enumerate(boxes):
xmin,xmax,ymin,ymax=get_min_max(box)
text_area=full_image[int(ymin):int(ymax),int(xmin):int(xmax),:]
text=recongize(text_area)
label=ner(text)
if label==1:
A=full_image[int(ymin):int(ymax),int(xmin):int(xmax),:]
full_image[int(ymin):int(ymax),int(xmin):int(xmax),:] = gaussian_filter(A, sigma=16)
return full_image
def test_net(net, image, text_threshold, link_threshold, low_text, cuda, poly, refine_net=None):
t0 = time.time()
img_resized, target_ratio, size_heatmap = imgproc.resize_aspect_ratio(image, args.canvas_size, interpolation=cv2.INTER_LINEAR, mag_ratio=args.mag_ratio)
ratio_h = ratio_w = 1 / target_ratio
x = imgproc.normalizeMeanVariance(img_resized)
x = torch.from_numpy(x).permute(2, 0, 1)
x = Variable(x.unsqueeze(0))
if cuda:
x = x.cuda()
with torch.no_grad():
y, feature = net(x)
score_text = y[0,:,:,0].cpu().data.numpy()
score_link = y[0,:,:,1].cpu().data.numpy()
if refine_net is not None:
with torch.no_grad():
y_refiner = refine_net(y, feature)
score_link = y_refiner[0,:,:,0].cpu().data.numpy()
t0 = time.time() - t0
t1 = time.time()
boxes, polys = craft_utils.getDetBoxes(score_text, score_link, text_threshold, link_threshold, low_text, poly)
boxes = craft_utils.adjustResultCoordinates(boxes, ratio_w, ratio_h)
polys = craft_utils.adjustResultCoordinates(polys, ratio_w, ratio_h)
for k in range(len(polys)):
if polys[k] is None: polys[k] = boxes[k]
t1 = time.time() - t1
# render results (optional)
render_img = score_text.copy()
render_img = np.hstack((render_img, score_link))
ret_score_text = imgproc.cvt2HeatmapImg(render_img)
return boxes, polys, ret_score_text
def text_detect(image,net,refine_net):
bboxes, polys, score_text = test_net(net, image, args.text_threshold, args.link_threshold, args.low_text, args.cuda, args.poly, refine_net)
return bboxes
def get_box_from_refer(reference_boxes):
real_boxes=[]
for box in reference_boxes:
real_boxes.append(box//2)
return real_boxes
def get_min_max(box):
xlist=[]
ylist=[]
for coor in box:
xlist.append(coor[0])
ylist.append(coor[1])
return min(xlist),max(xlist),min(ylist),max(ylist)
def main(image_path0):
# 1๋‹จ๊ณ„
# ==> craft ๋ชจ๋ธ๊ณผ refinenet ๋ชจ๋ธ์„ ๋ถˆ๋Ÿฌ์˜ค๊ณ  cuda device ์— ์–นํž™๋‹ˆ๋‹ค.
net = CRAFT()
if args.cuda:
net.load_state_dict(copyStateDict(torch.load(args.trained_model)))
if args.cuda:
net = net.cuda()
cudnn.benchmark = False
net.eval()
refine_net = None
if args.refine:
refine_net = RefineNet()
if args.cuda:
refine_net.load_state_dict(copyStateDict(torch.load(args.refiner_model)))
refine_net = refine_net.cuda()
refine_net.eval()
args.poly = True
# 2๋‹จ๊ณ„
# gradio ๋นˆ์นธ์— ์ด๋ฏธ์ง€๋ฅผ ๋„ฃ๊ณ  A ์— ์ž…๋ ฅ๋ฉ๋‹ˆ๋‹ค.
A=image_path0
image_list=[]
image_list.append(A)
for k, image_path in enumerate(image_list):
image = imgproc.loadImage(image_path)
if image.shape[2]>3:
image=image[:,:,0:3]
original_image=image
# ์ด๋ฏธ์ง€์—์„œ ์†ก์žฅ๋ถ€๋ถ„๋งŒ dinov2 ๋ชจ๋ธ๋กœ segmentation ์„ ํ•ฉ๋‹ˆ๋‹ค.
output=dino_seg(image)
image3=Image.fromarray(output)
image3.save("temporal_mask/mask.png")
# ๋งˆ์Šคํฌ์ด๋ฏธ์ง€(white pixel, black background)๋ฅผ ๋งŒ๋“ญ๋‹ˆ๋‹ค.
# ์œ„ ๋งˆ์Šคํฌ ์ด๋ฏธ์ง€์—์„œ ๊ฐ ๋ฉ์–ด๋ฆฌ๋“ค(์†ก์žฅ์œผ๋กœ ์ถ”์ •)์ด ์ „์ฒด ์ด๋ฏธ์ง€๋‚ด์—์„œ ๋ช‡ํ”„๋กœ์ฐจ์ง€ํ•˜๋Š”์ง€ ๊ณ„์‚ฐํ•ฉ๋‹ˆ๋‹ค.
contours_list,percentage_list=mask_percentage("temporal_mask/mask.png")
normal_image_list=[]
small_coordinate_list=[]
original_coordinate_list=[]
#3๋‹จ๊ณ„
sorted_list = sorted(percentage_list, reverse=True)
top_5 = sorted_list[:5]
print("์ƒ์œ„ 5๊ฐœ ๊ฐ’:", top_5)
# percentage list์˜ ๊ฒฝ์šฐ ์†ก์žฅ์œผ๋กœ ์ถ”์ •๋˜๋Š” ๋ญ‰์น˜๋“ค์˜ ํผ์„ผํŠธ๋ฅผ ๋ชจ์•„๋†“์€๊ฒƒ์ด๊ณ 
# contours list๋Š” ์ด๋ฏธ์ง€๋‚ด์—์„œ ์†ก์žฅ์œผ๋กœ ์ถ”์ •๋˜๋Š” ๋ญ‰์น˜๋“ค์ด ํฌ๋กญ๋˜์–ด์„œ ์ •๋ ฌ๋œ ๋ฆฌ์ŠคํŠธ์ž…๋‹ˆ๋‹ค.
# ์˜ˆ : percentatge list ์˜ ์ฒซ๋ฒˆ์งธ ์š”์†Œ๋Š” contours list์˜ ์ฒซ๋ฒˆ์งธ ์š”์†Œ์˜ percentage
for index,percentage in enumerate(percentage_list):
if 5<percentage:
# percentage ๊ฐ€ ์•„๋ฏธ์ง€๋‚ด์—์„œ 5ํ”„๋กœ ๋„˜๋Š” ๊ฒƒ๋“ค์€ normal list๋กœ ํฌํ•จ๋ฉ๋‹ˆ๋‹ค.
# normal list์•ˆ์—๋Š” ์ด๋ฏธ์ง€๋‚ด์—์„œ ์ถฉ๋ถ„ํžˆ ํฐ ๋ญ‰์น˜๋“ค(์†ก์žฅ์œผ๋กœ ์ถ”์ •) ์„ ๋ชจ์•„๋†“์•˜์Šต๋‹ˆ๋‹ค.
# 1-5ํ”„๋กœ ์ธ๊ฒƒ๋“ค์€ small coordinate list์— ํฌํ•จ๋˜๊ณ  ๋งค์šฐ ์ž‘์€ ๋ญ‰์น˜๋กœ ๊ฐ„์ฃผํ•ฉ๋‹ˆ๋‹ค.
# ๋งค์šฐ์ž‘์€ ๋ญ‰์น˜์˜ ๊ฒฝ์šฐ zoom in์„ ํ–ˆ์„๋•Œ ๋ญ‰์น˜(์†ก์žฅ์œผ๋กœ ์ถ”์ •)๋‚ด ๊ธ€์ž๊ฐ€ ๊ฑฐ์˜ ๋ณด์ด์ง€์•Š์•„์„œ ๋”ฐ๋ผ์„œ ๋ญ‰์น˜ ์ „์ฒด๋ฅผ mosaikํ•ฉ๋‹ˆ๋‹ค.
# 1ํ”„๋กœ๋ฏธ๋งŒ ๋ญ‰์น˜๋“ค์€ ์†Œ๋ฉธ์ง์ „์ผ์ •๋„๋กœ ์ž‘์•„ ์ƒ๋žตํ•ฉ๋‹ˆ๋‹ค.
contour=contours_list[index]
x_list=[]
y_list=[]
contour2=list(contour)
for r in contour2:
r2=r[0]
x_list.append(r2[0])
y_list.append(r2[1])
x_min=min(x_list)
y_min=min(y_list)
x_max=max(x_list)
y_max=max(y_list)
original_coordinate_list.append([y_min,y_max,x_min,x_max])
image2=original_image[y_min:y_max,x_min:x_max,:]
normal_image_list.append(image2)
#
elif 1<percentage<5:
contour=contours_list[index]
x_list=[]
y_list=[]
contour2=list(contour)
for r in contour2:
r2=r[0]
x_list.append(r2[0])
y_list.append(r2[1])
x_min=min(x_list)
y_min=min(y_list)
x_max=max(x_list)
y_max=max(y_list)
small_coordinate_list.append([y_min,y_max,x_min,x_max]) #์†ก์žฅ 5ํ”„๋กœ๋ฏธ๋งŒ์˜ ์ขŒํ‘œ
else:
continue
# 4๋‹จ๊ณ„ (๋งค์šฐ์ž‘์€ ์†ก์žฅ)
# small coordinate list์•ˆ์— ๋งค์šฐ์ž‘์€ ์†ก์žฅ๋“ค์ด ๋ชจ์—ฌ์ ธ์žˆ์ง€๋งŒ list์•ˆ์— ์š”์†Œ๊ฐ€ ์—†์œผ๋ฉด 5๋‹จ๊ณ„๋กœ ๋ฐ”๋กœ๊ฐ‘๋‹ˆ๋‹ค.
# ๋ฐ”๋กœ ๊ฐ€์ง€์•Š์„๊ฒฝ์šฐ(list ์•ˆ์š”์†Œ ์ตœ์†Œํ•˜๋‚˜) mosaik ๋ฅผ ํ†ตํ•ด์„œ ์ „์ฒด์ด๋ฏธ์ง€์—์„œ ์ž‘์€ ๋ญ‰์น˜์— ํ•ด๋‹นํ•˜๋Š” ์ขŒํ‘œ๋“ค์„ ๋ชจ๋‘ ๋ชจ์ž์ดํฌํ•ฉ๋‹ˆ๋‹ค.
if len(small_coordinate_list)>0:
original_image=mosaik(original_image,small_coordinate_list)
else:
pass
# 5๋‹จ๊ณ„ (์–ด๋Š์ •๋„ ์‚ฌ์ด์ฆˆ ์žˆ๋Š” ์†ก์žฅ) ==> normal list
# normal image list์•ˆ์— ์ ์ ˆํ•œ ํฌ๊ธฐ์˜ ์†ก์žฅ(์คŒ ํ•˜๋ฉด ๊ธ€์ž ๋ณด์ด๋Š”) ๋“ค์ด ์žˆ์Šต๋‹ˆ๋‹ค.
# craft ์ž…์žฅ์—์„œ text ์œ„์น˜๋ฅผ return ํ• ์ˆ˜ ์žˆ๊ฒŒ๋” ํฌ๋กญ๋œ ์†ก์žฅ์„ esrgan ์œผ๋กœ ํ™”์งˆ๊ฐœ์„ ํ•ฉ๋‹ˆ๋‹ค.
# ํ™”์งˆ๊ฐœ์„ ๋œ ์†ก์žฅ์„ craft์— ๋„ฃ์–ด์„œ ์ •ํ™•ํ•˜๊ฒŒ text ์ขŒํ‘œ๋“ค์„ ๋ชจ๋‘ ๊ตฌํ•ฉ๋‹ˆ๋‹ค.
# ์ขŒํ‘œ๋ฅผ ๊ตฌํ• ๋•Œ ํ™”์งˆ ์ข‹์€ ์†ก์žฅ์ด๋ฏธ์ง€์˜ ์ขŒํ‘œ๋ฅผ ๊ทธ๋Œ€๋กœ return ํ•˜์ง€ ์•Š๊ณ  ์›๋ณธ ์†ก์žฅ์ด๋ฏธ์ง€์— ๋งž์ถ”์–ด์„œ scale(//2) ํ•˜๊ณ  ์ตœ์ข…์ขŒํ‘œ๋ฅผ ๊ตฌํ•ฉ๋‹ˆ๋‹ค.
for index,normal_image in enumerate(normal_image_list):
reference_image=sr(normal_image)
reference_boxes=text_detect(reference_image,net=net,refine_net=refine_net)
boxes=get_box_from_refer(reference_boxes)
for index2,box in enumerate(boxes):
xmin,xmax,ymin,ymax=get_min_max(box)
text_area=normal_image[int(ymin):int(ymax),int(xmin):int(xmax),:]
text_area=Image.fromarray(text_area)
os.makedirs("text_area",exist_ok=True)
text_area.save(f"text_area/new_{index2+1}.png")
# 6๋‹จ๊ณ„ (text recognize, ner)
# ์œ„ ์ขŒํ‘œ๋“ค์„ ํ†ตํ•ด์„œ ์†ก์žฅ ๋‚ด์—์„œ ๋ฐ•์Šค๋“ค์„ ํฌ๋กญํ•ฉ๋‹ˆ๋‹ค.
# ํฌ๋กญ๋œ ์†ก์žฅ๋‚ด ๋ถ€๋ถ„(ํฌ๋กญ๋œ ๋ฐ•์Šค , ์ฆ‰ text ์žˆ๋Š”๊ณณ์œผ๋กœ ์ถ”์ •๋˜๋Š”๊ณณ) ์„ trocr ์—๋„ฃ์Šต๋‹ˆ๋‹ค.
# trocr์€ ์ƒ์ž๋‚ด์— ์ถ”์ •๋˜๋Š” text๋ฅผ ๋ณด์—ฌ์ค๋‹ˆ๋‹ค.
# text๋ฅผ ko electra ์—๋„ฃ์–ด์„œ ํ•ด๋‹น ์ƒ์ž์—์žˆ๋Š” text๊ฐ€ ๊ฐœ์ธ์ •๋ณด์ธ์ง€์•„๋‹Œ์ง€ ํŒ๋ณ„ํ•ฉ๋‹ˆ๋‹ค.
# ์†ก์žฅ๋‚ด ํ•ด๋‹น ์ƒ์ž๊ฐ€ ๊ฐœ์ธ์ •๋ณด๋กœ(๋ ˆ์ด๋ธ” :1) ์ถ”์ •๋ ๊ฒฝ์šฐ ๋ชจ์ž์ดํฌ๋ฅผํ•ฉ๋‹ˆ๋‹ค.
# ๋ชจ์ž์ดํฌ๋ผ๊ณ  ํŒ๋ณ„ํ• ๊ฒฝ์šฐ ํ•ด๋‹น์ƒ์ž์˜ ์ขŒํ‘œ๋ฅผ ์†ก์žฅ์ด๋ฏธ์ง€์— ๋งž๋Š” ์ขŒํ‘œ๋กœ ๋ณ€ํ™˜ํ•˜๊ณ  ๊ทธ ์ขŒํ‘œ์— ํ•ด๋‹นํ•˜๋Š” ๋ถ€๋ถ„์„ ๋ชจ์ž์ดํฌํ•ฉ๋‹ˆ๋‹ค.
# ๋ถ€๋ถ„์ ์œผ๋กœ ๋ชจ์ž์ดํฌ๋œ ์†ก์žฅ์ด๋ฏธ์ง€๋ฅผ ์ „์ฒด์ด๋ฏธ์ง€(์†ก์žฅ์„ ํฌํ•จํ•˜๋Š” ์ด๋ฏธ์ง€)์— ๋ถ™์ž…๋‹ˆ๋‹ค.
text=recongize(text_area)
label=ner(text)
with open("output/text_recongnize.txt","a") as recognized:
recognized.writelines(str(index2+1))
recognized.writelines(" ")
recognized.writelines(str(text))
recognized.writelines(" ")
recognized.writelines(str(label))
recognized.writelines("\n")
recognized.close()
print("done")
if label==1:
A=normal_image[int(ymin):int(ymax),int(xmin):int(xmax),:]
normal_image[int(ymin):int(ymax),int(xmin):int(xmax),:] = gaussian_filter(A, sigma=16)
else:
pass
a,b,c,d=original_coordinate_list[index]
original_image[a:b,c:d,:]=normal_image
# ๋” ์ •ํ™•๋„ ๋†’์ด๊ธฐ์œ„ํ•ด์„œ ์ด๋ฏธ์ง€ ์ „์ฒด(์†ก์žฅ๊ณผ ๋ฐฐ๊ฒฝ ๋‘˜๋‹ค) craft์— ํ†ต์งธ๋กœ ๋„ฃ๊ธฐ
# ๋‹จ optional (๋‹จ์  : infer speed )
#print("full mask start")
#original_image=full_img_masking(original_image,net=net,refine_net=refine_net)
#print("full mask done")
original_image=Image.fromarray(original_image)
original_image.save("output/mosaiked.png")
print("masked complete")
return original_image
# if __name__ == '__main__':
# iface = gr.Interface(
# fn=main,
# inputs=gr.Image(type="filepath", label="Invoice Image"),
# outputs=gr.Image(type="pil", label="Masked Invoice Image"),
# live=True
# )
# iface.launch()