image_cut_rect / app.py
HERIUN
.
2bb6556
import gradio as gr
import numpy as np
import math
import torch
from PIL import Image, ImageDraw
from rect_main import docscanner_rec, load_docscanner_model
from data_utils.image_utils import unwarp, mask2point, get_corner, _rotate_90_degrees
from config import Config
config = Config()
cuda = torch.device("cuda" if torch.cuda.is_available() else "cpu")
docscanner = load_docscanner_model(
cuda, path_l=config.get_rec_model_path, path_m=config.get_seg_model_path
)
# ์ขŒํ‘œ๋ฅผ ์ดˆ๊ธฐํ™”ํ•˜๋Š” ํ•จ์ˆ˜
def reset_points(image, state):
state = []
return image, state
def cutting_image(image, state):
min_x = min(point[0] for point in state)
max_x = max(point[0] for point in state)
min_y = min(point[1] for point in state)
max_y = max(point[1] for point in state)
cutted_image = image[min_y:max_y, min_x:max_x]
state = []
return cutted_image, cutted_image, state
def rotate_image(image):
rotated_image = _rotate_90_degrees(image)
state = []
return rotated_image, state
def reset_image(image, state):
out_image, msk_np = docscanner_rec(image, docscanner)
state = list(get_corner(mask2point(mask=msk_np)))
img = Image.fromarray(image)
area = image.shape[0]*image.shape[1]
radius=max(5, round(area**0.5 / 120))
# ์ขŒํ‘œ๊ฐ€ ์ตœ์†Œ 3๊ฐœ ์ด์ƒ์ผ ๋•Œ๋งŒ ํด๋ฆฌ๊ณค ๊ทธ๋ฆฌ๊ธฐ
draw = ImageDraw.Draw(img)
for pt in state:
left_up_point = (pt[0] - radius, pt[1] - radius)
right_down_point = (pt[0] + radius, pt[1] + radius)
draw.ellipse([left_up_point, right_down_point], outline="black", fill="red")
center = (sum(p[0] for p in state) / len(state), sum(p[1] for p in state) / len(state))
# ๊ฐ๋„์— ๋”ฐ๋ผ ์ ๋“ค์„ ์ •๋ ฌ
sorted_points = sorted(state, key=lambda p: calculate_angle(p, center))
draw.polygon(sorted_points, outline="red", fill=None, width=round(radius/2))
return img, state
def auto_point_detect(image):
out_image, msk_np = docscanner_rec(image, docscanner, cuda)
state = list(get_corner(mask2point(mask=msk_np)))
img = Image.fromarray(image)
area = image.shape[0]*image.shape[1]
radius=max(5, round(area**0.5 / 120))
# ์ขŒํ‘œ๊ฐ€ ์ตœ์†Œ 3๊ฐœ ์ด์ƒ์ผ ๋•Œ๋งŒ ํด๋ฆฌ๊ณค ๊ทธ๋ฆฌ๊ธฐ
draw = ImageDraw.Draw(img)
for pt in state:
left_up_point = (pt[0] - radius, pt[1] - radius)
right_down_point = (pt[0] + radius, pt[1] + radius)
draw.ellipse([left_up_point, right_down_point], outline="black", fill="red")
center = (sum(p[0] for p in state) / len(state), sum(p[1] for p in state) / len(state))
# ๊ฐ๋„์— ๋”ฐ๋ผ ์ ๋“ค์„ ์ •๋ ฌ
sorted_points = sorted(state, key=lambda p: calculate_angle(p, center))
draw.polygon(sorted_points, outline="red", fill=None, width=round(radius/2))
return img, state
def calculate_angle(point, center):
return math.atan2(point[1] - center[1], point[0] - center[0])
# ์ขŒํ‘œ๋ฅผ ๋ฐ›์•„์„œ ํด๋ฆฌ๊ณค์„ ๊ทธ๋ฆฌ๋Š” ํ•จ์ˆ˜
def draw_polygon_on_image(image, evt: gr.SelectData, state):
img = Image.fromarray(image)
pt = (evt.index[0], evt.index[1])
state.append(pt)
# ํด๋ฆญํ•œ ์ขŒํ‘œ๋ฅผ ์ €์žฅ
area = image.shape[0]*image.shape[1]
radius=max(5, round(area**0.5 / 120))
draw = ImageDraw.Draw(img)
for pt in state:
left_up_point = (pt[0] - radius, pt[1] - radius)
right_down_point = (pt[0] + radius, pt[1] + radius)
draw.ellipse([left_up_point, right_down_point], outline="black", fill="red")
if len(state) == 2:
draw.line([state[0], state[1]], fill="red", width=round(radius/2))
if len(state) >= 3: # ์ขŒํ‘œ๊ฐ€ ์ตœ์†Œ 3๊ฐœ ์ด์ƒ์ผ ๋•Œ๋งŒ ํด๋ฆฌ๊ณค ๊ทธ๋ฆฌ๊ธฐ
center = (sum(p[0] for p in state) / len(state), sum(p[1] for p in state) / len(state))
# ๊ฐ๋„์— ๋”ฐ๋ผ ์ ๋“ค์„ ์ •๋ ฌ
sorted_points = sorted(state, key=lambda p: calculate_angle(p, center))
draw.polygon(sorted_points, outline="red", fill=None, width=round(radius/2))
return img, state
def sort_corners(corners):
# ๊ฐ ์ขŒํ‘œ๋ฅผ (x, y) ํ˜•ํƒœ๋กœ ๋ฐ›๋Š”๋‹ค๊ณ  ๊ฐ€์ •ํ•ฉ๋‹ˆ๋‹ค.
# corners = [(x1, y1), (x2, y2), (x3, y3), (x4, y4)]
if len(corners) != 4:
raise ValueError("Input should contain exactly four coordinates.")
# ์ขŒํ‘œ๋ฅผ y ๊ธฐ์ค€์œผ๋กœ ์ •๋ ฌํ•˜์—ฌ ๊ฐ ์ขŒํ‘œ๋ฅผ ๊ฒฐ์ •ํ•ฉ๋‹ˆ๋‹ค.
sorted_by_y = sorted(corners, key=lambda p: p[1]) # y ๊ธฐ์ค€์œผ๋กœ ์ •๋ ฌ
lt, rt = sorted(sorted_by_y[:2], key=lambda p: p[0])
lb, rb = sorted(sorted_by_y[2:], key=lambda p: p[0])
return lt, rt, rb, lb
def convert(image, state):
h,w = image.shape[:2]
if len(state) < 4:
out_image, msk_np = docscanner_rec(image, docscanner, cuda)
out_image = out_image[:,:,::-1]
elif len(state) ==4:
state = list(sort_corners(state))
src = np.array(state).astype(np.float32)
dst = np.float32([
(0, 0),
(w - 1, 0),
(w - 1, h - 1),
(0, h - 1)
])
out_image, M = unwarp(image, src, dst)
return out_image
css = """
.image-container {
padding: 20px;
background-color: #f0f0f0;
}
"""
# Gradio Blocks ์ปจํ…์ŠคํŠธ์—์„œ ์ธํ„ฐํŽ˜์ด์Šค ๊ตฌ์„ฑ
with gr.Blocks(css=css) as demo:
state = gr.State([])
with gr.Row():
with gr.Column():
text = gr.Textbox("์ž…๋ ฅ ์ด๋ฏธ์ง€(์ฝ”๋„ˆ๋ฅผ ํด๋ฆญํ•˜์„ธ์š”)", show_label=False)
image_input = gr.Image(show_label=False, interactive=True, elem_classes="image-container")
clear_button = gr.Button("Clear Points")
cutting_button = gr.Button("Cutting Image(need more than 2 points)")
rotating_button = gr.Button("Rotate Image(clock wise 90 degree)")
auto_button = gr.Button("Auto Points detection")
convert_button = gr.Button("Convert Image")
with gr.Column():
text = gr.Textbox("๋ณ€ํ™˜๋  ์˜์—ญ", show_label=False)
image_output = gr.Image(show_label=False)
# state_display = gr.Textbox(label="Current State")
# coordinates_text = gr.Textbox(label="Coordinates", placeholder="Enter coordinates (x, y) for each point")
# update_coords_button = gr.Button("Update Coordinates")
with gr.Column():
text = gr.Textbox("๊ฒฐ๊ณผ ์ด๋ฏธ์ง€", show_label=False)
result_image = gr.Image(show_label=False, format="png")
# # ์ด๋ฏธ์ง€ ์œ„์—์„œ ํด๋ฆญ ์ด๋ฒคํŠธ ์ฒ˜๋ฆฌ
image_input.select(draw_polygon_on_image, inputs=[image_input,state], outputs=[image_output,state])
# ์ขŒํ‘œ ์ดˆ๊ธฐํ™” ๋ฒ„ํŠผ ํด๋ฆญ ์‹œ ์ขŒํ‘œ ๋ฆฌ์…‹
clear_button.click(fn=reset_points, inputs=[image_input,state], outputs=[image_output,state])
# ์ด๋ฏธ์ง€ ์ž๋ฅด๊ธฐ ํŽธ์ง‘
cutting_button.click(fn=cutting_image, inputs=[image_input,state], outputs=[image_input, image_output, state])
# ์ด๋ฏธ์ง€ ํšŒ์ „
rotating_button.click(fn=rotate_image, inputs=[image_input], outputs=[image_input, state])
# ์ž๋™ ๊ฒ€์ถœ ๋ฒ„ํŠผ
auto_button.click(fn=auto_point_detect, inputs=image_input, outputs=[image_output,state])
# ๋ณ€ํ™˜ ๋ฒ„ํŠผ
convert_button.click(fn=convert, inputs=[image_input,state], outputs=result_image)
demo.launch(share=True)