import tempfile import time from typing import Any from collections.abc import Sequence import gradio as gr import numpy as np import pillow_heif import spaces import torch from gradio_image_annotation import image_annotator from gradio_imageslider import ImageSlider from PIL import Image from pymatting.foreground.estimate_foreground_ml import estimate_foreground_ml from refiners.fluxion.utils import no_grad from refiners.solutions import BoxSegmenter BoundingBox = tuple[int, int, int, int] pillow_heif.register_heif_opener() pillow_heif.register_avif_opener() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Initialize segmenter segmenter = BoxSegmenter(device="cpu") segmenter.device = device segmenter.model = segmenter.model.to(device=segmenter.device) def bbox_union(bboxes: Sequence[list[int]]) -> BoundingBox | None: if not bboxes: return None for bbox in bboxes: assert len(bbox) == 4 assert all(isinstance(x, int) for x in bbox) return ( min(bbox[0] for bbox in bboxes), min(bbox[1] for bbox in bboxes), max(bbox[2] for bbox in bboxes), max(bbox[3] for bbox in bboxes), ) def apply_mask( img: Image.Image, mask_img: Image.Image, defringe: bool = True, ) -> Image.Image: assert img.size == mask_img.size img = img.convert("RGB") mask_img = mask_img.convert("L") if defringe: # Mitigate edge halo effects via color decontamination rgb, alpha = np.asarray(img) / 255.0, np.asarray(mask_img) / 255.0 foreground = estimate_foreground_ml(rgb, alpha) img = Image.fromarray((foreground * 255).astype("uint8")) result = Image.new("RGBA", img.size) result.paste(img, (0, 0), mask_img) return result @spaces.GPU def _gpu_process( img: Image.Image, bbox: BoundingBox | None, ) -> tuple[Image.Image, BoundingBox | None, list[str]]: time_log: list[str] = [] t0 = time.time() mask = segmenter(img, bbox) time_log.append(f"segment: {time.time() - t0}") return mask, bbox, time_log def _process( img: Image.Image, bbox: BoundingBox | None, ) -> tuple[tuple[Image.Image, Image.Image], gr.DownloadButton]: if img.width > 2048 or img.height > 2048: orig_res = max(img.width, img.height) img.thumbnail((2048, 2048)) if isinstance(bbox, tuple): x0, y0, x1, y1 = (int(x * 2048 / orig_res) for x in bbox) bbox = (x0, y0, x1, y1) mask, bbox, time_log = _gpu_process(img, bbox) t0 = time.time() masked_alpha = apply_mask(img, mask, defringe=True) time_log.append(f"crop: {time.time() - t0}") print(", ".join(time_log)) masked_rgb = Image.alpha_composite(Image.new("RGBA", masked_alpha.size, "white"), masked_alpha) thresholded = mask.point(lambda p: 255 if p > 10 else 0) bbox = thresholded.getbbox() to_dl = masked_alpha.crop(bbox) temp = tempfile.NamedTemporaryFile(delete=False, suffix=".png") to_dl.save(temp, format="PNG") temp.close() return (img, masked_rgb), gr.DownloadButton(value=temp.name, interactive=True) def process_bbox(prompts: dict[str, Any]) -> tuple[tuple[Image.Image, Image.Image], gr.DownloadButton]: assert isinstance(img := prompts["image"], Image.Image) assert isinstance(boxes := prompts["boxes"], list) if len(boxes) == 1: assert isinstance(box := boxes[0], dict) bbox = tuple(box[k] for k in ["xmin", "ymin", "xmax", "ymax"]) else: assert len(boxes) == 0 bbox = None return _process(img, bbox) def on_change_bbox(prompts: dict[str, Any] | None): return gr.update(interactive=prompts is not None) css = ''' .gradio-container { max-width: 1400px !important; margin: auto; } /* 이미지 크기 조정 */ .image-container img { max-height: 600px !important; } /* 이미지 슬라이더 크기 조정 */ .image-slider { height: 600px !important; max-height: 600px !important; } h1 { text-align: center; font-family: 'Pretendard', sans-serif; color: #EA580C; font-size: 2.5rem; font-weight: 700; margin-bottom: 1.5rem; text-shadow: 0 2px 4px rgba(0,0,0,0.1); } .subtitle { text-align: center; color: #4B5563; font-size: 1.1rem; margin-bottom: 2rem; font-family: 'Pretendard', sans-serif; } .gr-button-primary { background-color: #F97316 !important; border: none !important; box-shadow: 0 2px 4px rgba(234, 88, 12, 0.2) !important; } .gr-button-primary:hover { background-color: #EA580C !important; transform: translateY(-1px); box-shadow: 0 4px 6px rgba(234, 88, 12, 0.25) !important; } .footer-content { text-align: center; margin-top: 3rem; padding: 2rem; background: linear-gradient(to bottom, #FFF7ED, white); border-radius: 12px; font-family: 'Pretendard', sans-serif; } .footer-content a { color: #EA580C; text-decoration: none; font-weight: 500; transition: all 0.2s; } .footer-content a:hover { color: #C2410C; } .visit-button { background-color: #EA580C; color: white !important; /* 강제 적용 */ padding: 12px 24px; border-radius: 8px; font-weight: 600; text-decoration: none; display: inline-block; transition: all 0.3s; margin-top: 1rem; box-shadow: 0 2px 4px rgba(234, 88, 12, 0.2); font-size: 1.1rem; } .visit-button:hover { background-color: #C2410C; transform: translateY(-2px); box-shadow: 0 4px 6px rgba(234, 88, 12, 0.25); color: white !important; /* 호버 상태에서도 강제 적용 */ } .container-wrapper { background: white; border-radius: 16px; padding: 2rem; box-shadow: 0 4px 6px rgba(0, 0, 0, 0.05); } .image-container { border-radius: 12px; overflow: hidden; border: 2px solid #F3F4F6; } ''' with gr.Blocks( theme=gr.themes.Soft( primary_hue=gr.themes.Color( c50="#FFF7ED", c100="#FFEDD5", c200="#FED7AA", c300="#FDBA74", c400="#FB923C", c500="#F97316", c600="#EA580C", c700="#C2410C", c800="#9A3412", c900="#7C2D12", c950="#431407", ), secondary_hue="zinc", neutral_hue="zinc", font=("Pretendard", "sans-serif") ), css=css ) as demo: gr.HTML( """