Spaces:
Running
Running
import gradio as gr | |
import numpy as np | |
import math | |
import torch | |
from PIL import Image, ImageDraw | |
from rect_main import docscanner_rec, load_docscanner_model | |
from data_utils.image_utils import unwarp, mask2point, get_corner, _rotate_90_degrees | |
from config import Config | |
config = Config() | |
cuda = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
docscanner = load_docscanner_model( | |
cuda, path_l=config.get_rec_model_path, path_m=config.get_seg_model_path | |
) | |
# ์ขํ๋ฅผ ์ด๊ธฐํํ๋ ํจ์ | |
def reset_points(image, state): | |
state = [] | |
return image, state | |
def cutting_image(image, state): | |
min_x = min(point[0] for point in state) | |
max_x = max(point[0] for point in state) | |
min_y = min(point[1] for point in state) | |
max_y = max(point[1] for point in state) | |
cutted_image = image[min_y:max_y, min_x:max_x] | |
state = [] | |
return cutted_image, cutted_image, state | |
def rotate_image(image): | |
rotated_image = _rotate_90_degrees(image) | |
state = [] | |
return rotated_image, state | |
def reset_image(image, state): | |
out_image, msk_np = docscanner_rec(image, docscanner) | |
state = list(get_corner(mask2point(mask=msk_np))) | |
img = Image.fromarray(image) | |
area = image.shape[0]*image.shape[1] | |
radius=max(5, round(area**0.5 / 120)) | |
# ์ขํ๊ฐ ์ต์ 3๊ฐ ์ด์์ผ ๋๋ง ํด๋ฆฌ๊ณค ๊ทธ๋ฆฌ๊ธฐ | |
draw = ImageDraw.Draw(img) | |
for pt in state: | |
left_up_point = (pt[0] - radius, pt[1] - radius) | |
right_down_point = (pt[0] + radius, pt[1] + radius) | |
draw.ellipse([left_up_point, right_down_point], outline="black", fill="red") | |
center = (sum(p[0] for p in state) / len(state), sum(p[1] for p in state) / len(state)) | |
# ๊ฐ๋์ ๋ฐ๋ผ ์ ๋ค์ ์ ๋ ฌ | |
sorted_points = sorted(state, key=lambda p: calculate_angle(p, center)) | |
draw.polygon(sorted_points, outline="red", fill=None, width=round(radius/2)) | |
return img, state | |
def auto_point_detect(image): | |
out_image, msk_np = docscanner_rec(image, docscanner, cuda) | |
state = list(get_corner(mask2point(mask=msk_np))) | |
img = Image.fromarray(image) | |
area = image.shape[0]*image.shape[1] | |
radius=max(5, round(area**0.5 / 120)) | |
# ์ขํ๊ฐ ์ต์ 3๊ฐ ์ด์์ผ ๋๋ง ํด๋ฆฌ๊ณค ๊ทธ๋ฆฌ๊ธฐ | |
draw = ImageDraw.Draw(img) | |
for pt in state: | |
left_up_point = (pt[0] - radius, pt[1] - radius) | |
right_down_point = (pt[0] + radius, pt[1] + radius) | |
draw.ellipse([left_up_point, right_down_point], outline="black", fill="red") | |
center = (sum(p[0] for p in state) / len(state), sum(p[1] for p in state) / len(state)) | |
# ๊ฐ๋์ ๋ฐ๋ผ ์ ๋ค์ ์ ๋ ฌ | |
sorted_points = sorted(state, key=lambda p: calculate_angle(p, center)) | |
draw.polygon(sorted_points, outline="red", fill=None, width=round(radius/2)) | |
return img, state | |
def calculate_angle(point, center): | |
return math.atan2(point[1] - center[1], point[0] - center[0]) | |
# ์ขํ๋ฅผ ๋ฐ์์ ํด๋ฆฌ๊ณค์ ๊ทธ๋ฆฌ๋ ํจ์ | |
def draw_polygon_on_image(image, evt: gr.SelectData, state): | |
img = Image.fromarray(image) | |
pt = (evt.index[0], evt.index[1]) | |
state.append(pt) | |
# ํด๋ฆญํ ์ขํ๋ฅผ ์ ์ฅ | |
area = image.shape[0]*image.shape[1] | |
radius=max(5, round(area**0.5 / 120)) | |
draw = ImageDraw.Draw(img) | |
for pt in state: | |
left_up_point = (pt[0] - radius, pt[1] - radius) | |
right_down_point = (pt[0] + radius, pt[1] + radius) | |
draw.ellipse([left_up_point, right_down_point], outline="black", fill="red") | |
if len(state) == 2: | |
draw.line([state[0], state[1]], fill="red", width=round(radius/2)) | |
if len(state) >= 3: # ์ขํ๊ฐ ์ต์ 3๊ฐ ์ด์์ผ ๋๋ง ํด๋ฆฌ๊ณค ๊ทธ๋ฆฌ๊ธฐ | |
center = (sum(p[0] for p in state) / len(state), sum(p[1] for p in state) / len(state)) | |
# ๊ฐ๋์ ๋ฐ๋ผ ์ ๋ค์ ์ ๋ ฌ | |
sorted_points = sorted(state, key=lambda p: calculate_angle(p, center)) | |
draw.polygon(sorted_points, outline="red", fill=None, width=round(radius/2)) | |
return img, state | |
def sort_corners(corners): | |
# ๊ฐ ์ขํ๋ฅผ (x, y) ํํ๋ก ๋ฐ๋๋ค๊ณ ๊ฐ์ ํฉ๋๋ค. | |
# corners = [(x1, y1), (x2, y2), (x3, y3), (x4, y4)] | |
if len(corners) != 4: | |
raise ValueError("Input should contain exactly four coordinates.") | |
# ์ขํ๋ฅผ y ๊ธฐ์ค์ผ๋ก ์ ๋ ฌํ์ฌ ๊ฐ ์ขํ๋ฅผ ๊ฒฐ์ ํฉ๋๋ค. | |
sorted_by_y = sorted(corners, key=lambda p: p[1]) # y ๊ธฐ์ค์ผ๋ก ์ ๋ ฌ | |
lt, rt = sorted(sorted_by_y[:2], key=lambda p: p[0]) | |
lb, rb = sorted(sorted_by_y[2:], key=lambda p: p[0]) | |
return lt, rt, rb, lb | |
def convert(image, state): | |
h,w = image.shape[:2] | |
if len(state) < 4: | |
out_image, msk_np = docscanner_rec(image, docscanner, cuda) | |
out_image = out_image[:,:,::-1] | |
elif len(state) ==4: | |
state = list(sort_corners(state)) | |
src = np.array(state).astype(np.float32) | |
dst = np.float32([ | |
(0, 0), | |
(w - 1, 0), | |
(w - 1, h - 1), | |
(0, h - 1) | |
]) | |
out_image, M = unwarp(image, src, dst) | |
return out_image | |
css = """ | |
.image-container { | |
padding: 20px; | |
background-color: #f0f0f0; | |
} | |
""" | |
# Gradio Blocks ์ปจํ ์คํธ์์ ์ธํฐํ์ด์ค ๊ตฌ์ฑ | |
with gr.Blocks(css=css) as demo: | |
state = gr.State([]) | |
with gr.Row(): | |
with gr.Column(): | |
text = gr.Textbox("์ ๋ ฅ ์ด๋ฏธ์ง(์ฝ๋๋ฅผ ํด๋ฆญํ์ธ์)", show_label=False) | |
image_input = gr.Image(show_label=False, interactive=True, elem_classes="image-container") | |
clear_button = gr.Button("Clear Points") | |
cutting_button = gr.Button("Cutting Image(need more than 2 points)") | |
rotating_button = gr.Button("Rotate Image(clock wise 90 degree)") | |
auto_button = gr.Button("Auto Points detection") | |
convert_button = gr.Button("Convert Image") | |
with gr.Column(): | |
text = gr.Textbox("๋ณํ๋ ์์ญ", show_label=False) | |
image_output = gr.Image(show_label=False) | |
# state_display = gr.Textbox(label="Current State") | |
# coordinates_text = gr.Textbox(label="Coordinates", placeholder="Enter coordinates (x, y) for each point") | |
# update_coords_button = gr.Button("Update Coordinates") | |
with gr.Column(): | |
text = gr.Textbox("๊ฒฐ๊ณผ ์ด๋ฏธ์ง", show_label=False) | |
result_image = gr.Image(show_label=False, format="png") | |
# # ์ด๋ฏธ์ง ์์์ ํด๋ฆญ ์ด๋ฒคํธ ์ฒ๋ฆฌ | |
image_input.select(draw_polygon_on_image, inputs=[image_input,state], outputs=[image_output,state]) | |
# ์ขํ ์ด๊ธฐํ ๋ฒํผ ํด๋ฆญ ์ ์ขํ ๋ฆฌ์ | |
clear_button.click(fn=reset_points, inputs=[image_input,state], outputs=[image_output,state]) | |
# ์ด๋ฏธ์ง ์๋ฅด๊ธฐ ํธ์ง | |
cutting_button.click(fn=cutting_image, inputs=[image_input,state], outputs=[image_input, image_output, state]) | |
# ์ด๋ฏธ์ง ํ์ | |
rotating_button.click(fn=rotate_image, inputs=[image_input], outputs=[image_input, state]) | |
# ์๋ ๊ฒ์ถ ๋ฒํผ | |
auto_button.click(fn=auto_point_detect, inputs=image_input, outputs=[image_output,state]) | |
# ๋ณํ ๋ฒํผ | |
convert_button.click(fn=convert, inputs=[image_input,state], outputs=result_image) | |
demo.launch(share=True) | |