import numpy as np from PIL import Image from huggingface_hub import snapshot_download, login from leffa.transform import LeffaTransform from leffa.model import LeffaModel from leffa.inference import LeffaInference from utils.garment_agnostic_mask_predictor import AutoMasker from utils.densepose_predictor import DensePosePredictor from utils.utils import resize_and_center import spaces import torch from diffusers import DiffusionPipeline from transformers import pipeline import gradio as gr import os import random import gc # 상수 정의 MAX_SEED = 2**32 - 1 BASE_MODEL = "black-forest-labs/FLUX.1-dev" MODEL_LORA_REPO = "Motas/Flux_Fashion_Photography_Style" CLOTHES_LORA_REPO = "prithivMLmods/Canopus-Clothing-Flux-LoRA" # 메모리 관리를 위한 데코레이터 def safe_model_call(func): def wrapper(*args, **kwargs): try: clear_memory() result = func(*args, **kwargs) clear_memory() return result except Exception as e: clear_memory() print(f"Error in {func.__name__}: {str(e)}") raise return wrapper # 메모리 관리 함수 수정 def clear_memory(): gc.collect() if torch.cuda.is_available() and torch.cuda.current_device() >= 0: torch.cuda.empty_cache() def setup_environment(): os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128' HF_TOKEN = os.getenv("HF_TOKEN") if not HF_TOKEN: raise ValueError("HF_TOKEN not found in environment variables") login(token=HF_TOKEN) return HF_TOKEN @spaces.GPU() def generate_image(prompt, mode, cfg_scale=7.0, steps=30, seed=None, width=512, height=768, lora_scale=0.85): try: # 파이프라인 초기화 pipe = DiffusionPipeline.from_pretrained( BASE_MODEL, torch_dtype=torch.float16, ) pipe.to("cuda") # LoRA 설정 if mode == "Generate Model": pipe.load_lora_weights(MODEL_LORA_REPO) trigger_word = "fashion photography, professional model" else: pipe.load_lora_weights(CLOTHES_LORA_REPO) trigger_word = "upper clothing, fashion item" # 생성 설정 generator = torch.Generator("cuda").manual_seed(seed if seed is not None else torch.randint(0, 2**32 - 1, (1,)).item()) # 이미지 생성 with torch.inference_mode(): result = pipe( prompt=f"{prompt} {trigger_word}", num_inference_steps=steps, guidance_scale=cfg_scale, width=width, height=height, generator=generator, cross_attention_kwargs={"scale": lora_scale}, ).images[0] # 메모리 정리 del pipe clear_memory() return result, seed except Exception as e: clear_memory() raise gr.Error(f"Generation failed: {str(e)}") # 전역 변수 초기화 fashion_pipe = None translator = None mask_predictor = None densepose_predictor = None vt_model = None pt_model = None vt_inference = None pt_inference = None device = None HF_TOKEN = None # 환경 설정 실행 setup_environment() @spaces.GPU() def initialize_fashion_pipe(): global fashion_pipe if fashion_pipe is None: fashion_pipe = DiffusionPipeline.from_pretrained( BASE_MODEL, torch_dtype=torch.float16, use_auth_token=HF_TOKEN ).to("cuda") try: fashion_pipe.enable_xformers_memory_efficient_attention() except Exception as e: print(f"Warning: Could not enable memory efficient attention: {e}") return fashion_pipe def setup(): # Leffa 체크포인트 다운로드만 수행 snapshot_download(repo_id="franciszzj/Leffa", local_dir="./ckpts") @spaces.GPU() def get_translator(): global translator if translator is None: translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en", device="cuda") return translator @safe_model_call def get_mask_predictor(): global mask_predictor if mask_predictor is None: mask_predictor = AutoMasker( densepose_path="./ckpts/densepose", schp_path="./ckpts/schp", ) return mask_predictor @safe_model_call def get_densepose_predictor(): global densepose_predictor if densepose_predictor is None: densepose_predictor = DensePosePredictor( config_path="./ckpts/densepose/densepose_rcnn_R_50_FPN_s1x.yaml", weights_path="./ckpts/densepose/model_final_162be9.pkl", ) return densepose_predictor @safe_model_call def get_vt_model(): global vt_model, vt_inference if vt_model is None: vt_model = LeffaModel( pretrained_model_name_or_path="./ckpts/stable-diffusion-inpainting", pretrained_model="./ckpts/virtual_tryon.pth" ) vt_model = vt_model.half().to(device) vt_inference = LeffaInference(model=vt_model) return vt_model, vt_inference @safe_model_call def get_pt_model(): global pt_model, pt_inference if pt_model is None: pt_model = LeffaModel( pretrained_model_name_or_path="./ckpts/stable-diffusion-xl-1.0-inpainting-0.1", pretrained_model="./ckpts/pose_transfer.pth" ) pt_model = pt_model.half().to(device) pt_inference = LeffaInference(model=pt_model) return pt_model, pt_inference def load_lora(pipe, lora_path): try: pipe.unload_lora_weights() except: pass try: pipe.load_lora_weights(lora_path) return pipe except Exception as e: print(f"Warning: Failed to load LoRA weights from {lora_path}: {e}") return pipe @spaces.GPU() def get_mask_predictor(): global mask_predictor if mask_predictor is None: mask_predictor = AutoMasker( densepose_path="./ckpts/densepose", schp_path="./ckpts/schp", ) return mask_predictor # 유틸리티 함수 def contains_korean(text): return any(ord('가') <= ord(char) <= ord('힣') for char in text) # 모델 초기화 함수 수정 @spaces.GPU() def initialize_fashion_pipe(): try: pipe = DiffusionPipeline.from_pretrained( BASE_MODEL, torch_dtype=torch.float16, safety_checker=None, requires_safety_checker=False ).to("cuda") pipe.enable_model_cpu_offload() return pipe except Exception as e: print(f"Error initializing fashion pipe: {e}") raise # 생성 함수 수정 @spaces.GPU() def generate_fashion(prompt, mode, cfg_scale, steps, randomize_seed, seed, width, height, lora_scale, progress=gr.Progress(track_tqdm=True)): try: # 한글 처리 if contains_korean(prompt): with torch.inference_mode(): translator = get_translator() translated = translator(prompt)[0]['translation_text'] actual_prompt = translated else: actual_prompt = prompt # 파이프라인 초기화 pipe = initialize_fashion_pipe() # LoRA 설정 if mode == "Generate Model": pipe.load_lora_weights(MODEL_LORA_REPO) trigger_word = "fashion photography, professional model" else: pipe.load_lora_weights(CLOTHES_LORA_REPO) trigger_word = "upper clothing, fashion item" # 파라미터 제한 width = min(width, 768) height = min(height, 768) steps = min(steps, 30) # 시드 설정 if randomize_seed: seed = random.randint(0, MAX_SEED) generator = torch.Generator("cuda").manual_seed(seed) # 이미지 생성 with torch.inference_mode(): output = pipe( prompt=f"{actual_prompt} {trigger_word}", num_inference_steps=steps, guidance_scale=cfg_scale, width=width, height=height, generator=generator, cross_attention_kwargs={"scale": lora_scale}, ) image = output.images[0] # 메모리 정리 del pipe torch.cuda.empty_cache() gc.collect() return image, seed except Exception as e: print(f"Error in generate_fashion: {str(e)}") raise gr.Error(f"Generation failed: {str(e)}") @safe_model_call def leffa_predict(src_image_path, ref_image_path, control_type): try: # 모델 초기화 if control_type == "virtual_tryon": model, inference = get_vt_model() else: model, inference = get_pt_model() mask_pred = get_mask_predictor() dense_pred = get_densepose_predictor() # 이미지 로드 및 전처리 src_image = Image.open(src_image_path) ref_image = Image.open(ref_image_path) src_image = resize_and_center(src_image, 768, 1024) ref_image = resize_and_center(ref_image, 768, 1024) src_image_array = np.array(src_image) ref_image_array = np.array(ref_image) # Mask 생성 if control_type == "virtual_tryon": src_image = src_image.convert("RGB") mask = mask_pred(src_image, "upper")["mask"] else: mask = Image.fromarray(np.ones_like(src_image_array) * 255) # DensePose 예측 src_image_iuv_array = dense_pred.predict_iuv(src_image_array) src_image_seg_array = dense_pred.predict_seg(src_image_array) if control_type == "virtual_tryon": densepose = Image.fromarray(src_image_seg_array) else: densepose = Image.fromarray(src_image_iuv_array) # Leffa 변환 및 추론 transform = LeffaTransform() data = { "src_image": [src_image], "ref_image": [ref_image], "mask": [mask], "densepose": [densepose], } data = transform(data) output = inference(data) return np.array(output["generated_image"][0]) except Exception as e: print(f"Error in leffa_predict: {str(e)}") raise @safe_model_call def leffa_predict_vt(src_image_path, ref_image_path): return leffa_predict(src_image_path, ref_image_path, "virtual_tryon") @safe_model_call def leffa_predict_pt(src_image_path, ref_image_path): return leffa_predict(src_image_path, ref_image_path, "pose_transfer") # 초기 설정 실행 setup() def create_interface(): with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange") as demo: gr.Markdown("# 🎭 FitGen:Fashion Studio & Virtual Try-on") with gr.Tabs(): # 패션 생성 탭 with gr.Tab("Fashion Generation"): with gr.Column(): mode = gr.Radio( choices=["Generate Model", "Generate Clothes"], label="Generation Mode", value="Generate Model" ) # 예제 프롬프트 설정 example_model_prompts = [ "professional fashion model, full body shot, standing pose, natural lighting, studio background, high fashion, elegant pose", "fashion model portrait, upper body, confident pose, fashion photography, neutral background, professional lighting", "stylish fashion model, three-quarter view, editorial pose, high-end fashion magazine style, minimal background" ] example_clothes_prompts = [ "luxury designer sweater, cashmere material, cream color, cable knit pattern, high-end fashion, product photography", "elegant business blazer, tailored fit, charcoal grey, premium wool fabric, professional wear", "modern streetwear hoodie, oversized fit, minimalist design, premium cotton, urban style" ] prompt = gr.TextArea( label="Fashion Description (한글 또는 영어)", placeholder="패션 모델이나 의류를 설명하세요..." ) # 예제 섹션 추가 gr.Examples( examples=example_model_prompts + example_clothes_prompts, inputs=prompt, label="Example Prompts" ) with gr.Row(): with gr.Column(): result = gr.Image(label="Generated Result") generate_button = gr.Button("Generate Fashion") with gr.Accordion("Advanced Options", open=False): with gr.Group(): with gr.Row(): with gr.Column(): cfg_scale = gr.Slider( label="CFG Scale", minimum=1, maximum=20, step=0.5, value=7.0 ) steps = gr.Slider( label="Steps", minimum=1, maximum=30, step=1, value=30 ) lora_scale = gr.Slider( label="LoRA Scale", minimum=0, maximum=1, step=0.01, value=0.85 ) with gr.Row(): width = gr.Slider( label="Width", minimum=256, maximum=768, step=64, value=512 ) height = gr.Slider( label="Height", minimum=256, maximum=768, step=64, value=768 ) with gr.Row(): randomize_seed = gr.Checkbox( True, label="Randomize seed" ) seed = gr.Slider( label="Seed", minimum=0, maximum=2**32-1, step=1, value=42 ) # 가상 피팅 탭 with gr.Tab("Virtual Try-on"): with gr.Row(): with gr.Column(): gr.Markdown("#### Person Image") vt_src_image = gr.Image( sources=["upload"], type="filepath", label="Person Image", width=512, height=512, ) gr.Examples( inputs=vt_src_image, examples_per_page=5, examples=["a1.webp", "a2.webp", "a3.webp", "a4.webp", "a5.webp"] ) with gr.Column(): gr.Markdown("#### Garment Image") vt_ref_image = gr.Image( sources=["upload"], type="filepath", label="Garment Image", width=512, height=512, ) gr.Examples( inputs=vt_ref_image, examples_per_page=5, examples=["b1.webp", "b2.webp", "b3.webp", "b4.webp", "b5.webp"] ) with gr.Column(): gr.Markdown("#### Generated Image") vt_gen_image = gr.Image( label="Generated Image", width=512, height=512, ) vt_gen_button = gr.Button("Try-on") # 포즈 전송 탭 with gr.Tab("Pose Transfer"): with gr.Row(): with gr.Column(): gr.Markdown("#### Person Image") pt_ref_image = gr.Image( sources=["upload"], type="filepath", label="Person Image", width=512, height=512, ) gr.Examples( inputs=pt_ref_image, examples_per_page=5, examples=["a1.webp", "a2.webp", "a3.webp", "a4.webp", "a5.webp"] ) with gr.Column(): gr.Markdown("#### Target Pose Person Image") pt_src_image = gr.Image( sources=["upload"], type="filepath", label="Target Pose Person Image", width=512, height=512, ) gr.Examples( inputs=pt_src_image, examples_per_page=5, examples=["d1.webp", "d2.webp", "d3.webp", "d4.webp", "d5.webp"] ) with gr.Column(): gr.Markdown("#### Generated Image") pt_gen_image = gr.Image( label="Generated Image", width=512, height=512, ) pose_transfer_gen_button = gr.Button("Generate") # 이벤트 핸들러 generate_button.click( fn=generate_image, inputs=[prompt, mode, cfg_scale, steps, seed, width, height, lora_scale], outputs=[result, seed] ) vt_gen_button.click( fn=leffa_predict_vt, inputs=[vt_src_image, vt_ref_image], outputs=[vt_gen_image] ) pose_transfer_gen_button.click( fn=leffa_predict_pt, inputs=[pt_src_image, pt_ref_image], outputs=[pt_gen_image] ) return demo if __name__ == "__main__": # 환경 설정 setup_environment() # 인터페이스 생성 및 실행 demo = create_interface() demo.launch( server_name="0.0.0.0", server_port=7860, share=False )