import numpy as np from PIL import Image from huggingface_hub import snapshot_download, login from leffa.transform import LeffaTransform from leffa.model import LeffaModel from leffa.inference import LeffaInference from utils.garment_agnostic_mask_predictor import AutoMasker from utils.densepose_predictor import DensePosePredictor from utils.utils import resize_and_center import spaces import torch from diffusers import DiffusionPipeline from transformers import pipeline import gradio as gr import os import random import gc # 상수 정의 MAX_SEED = 2**32 - 1 BASE_MODEL = "black-forest-labs/FLUX.1-dev" MODEL_LORA_REPO = "Motas/Flux_Fashion_Photography_Style" CLOTHES_LORA_REPO = "prithivMLmods/Canopus-Clothing-Flux-LoRA" # 메모리 관리를 위한 데코레이터 def safe_model_call(func): def wrapper(*args, **kwargs): try: clear_memory() result = func(*args, **kwargs) clear_memory() return result except Exception as e: clear_memory() print(f"Error in {func.__name__}: {str(e)}") raise return wrapper # 메모리 관리 함수 def clear_memory(): if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.synchronize() gc.collect() def setup_environment(): # 메모리 관리 설정 torch.cuda.empty_cache() gc.collect() os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128' torch.backends.cudnn.benchmark = True torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cuda.max_split_size_mb = 128 # Hugging Face 토큰 설정 global HF_TOKEN HF_TOKEN = os.getenv("HF_TOKEN") if HF_TOKEN is None: raise ValueError("Please set the HF_TOKEN environment variable") login(token=HF_TOKEN) # CUDA 설정 global device device = "cuda" if torch.cuda.is_available() else "cpu" # 전역 변수 초기화 fashion_pipe = None translator = None mask_predictor = None densepose_predictor = None vt_model = None pt_model = None vt_inference = None pt_inference = None device = None HF_TOKEN = None # 환경 설정 실행 setup_environment() # 모델 관리 함수들 def initialize_fashion_pipe(): global fashion_pipe if fashion_pipe is None: clear_memory() fashion_pipe = DiffusionPipeline.from_pretrained( BASE_MODEL, torch_dtype=torch.float16, use_auth_token=HF_TOKEN ) try: fashion_pipe.enable_xformers_memory_efficient_attention() except Exception as e: print(f"Warning: Could not enable memory efficient attention: {e}") fashion_pipe.enable_sequential_cpu_offload() return fashion_pipe @safe_model_call def get_translator(): global translator if translator is None: translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en", device=device if device == "cuda" else -1) return translator @safe_model_call def get_mask_predictor(): global mask_predictor if mask_predictor is None: mask_predictor = AutoMasker( densepose_path="./ckpts/densepose", schp_path="./ckpts/schp", ) return mask_predictor @safe_model_call def get_densepose_predictor(): global densepose_predictor if densepose_predictor is None: densepose_predictor = DensePosePredictor( config_path="./ckpts/densepose/densepose_rcnn_R_50_FPN_s1x.yaml", weights_path="./ckpts/densepose/model_final_162be9.pkl", ) return densepose_predictor @safe_model_call def get_vt_model(): global vt_model, vt_inference if vt_model is None: vt_model = LeffaModel( pretrained_model_name_or_path="./ckpts/stable-diffusion-inpainting", pretrained_model="./ckpts/virtual_tryon.pth" ) vt_model = vt_model.half().to(device) vt_inference = LeffaInference(model=vt_model) return vt_model, vt_inference @safe_model_call def get_pt_model(): global pt_model, pt_inference if pt_model is None: pt_model = LeffaModel( pretrained_model_name_or_path="./ckpts/stable-diffusion-xl-1.0-inpainting-0.1", pretrained_model="./ckpts/pose_transfer.pth" ) pt_model = pt_model.half().to(device) pt_inference = LeffaInference(model=pt_model) return pt_model, pt_inference def load_lora(pipe, lora_path): try: pipe.unload_lora_weights() except: pass try: pipe.load_lora_weights(lora_path) return pipe except Exception as e: print(f"Warning: Failed to load LoRA weights from {lora_path}: {e}") return pipe # 초기 설정 함수 def setup(): # Leffa 체크포인트 다운로드 snapshot_download(repo_id="franciszzj/Leffa", local_dir="./ckpts") # 기본 모델 초기화 initialize_fashion_pipe() # 유틸리티 함수 def contains_korean(text): return any(ord('가') <= ord(char) <= ord('힣') for char in text) # 메인 기능 함수들 @spaces.GPU() @safe_model_call def generate_fashion(prompt, mode, cfg_scale, steps, randomize_seed, seed, width, height, lora_scale, progress=gr.Progress(track_tqdm=True)): try: # 한글 처리 if contains_korean(prompt): translator = get_translator() translated = translator(prompt)[0]['translation_text'] actual_prompt = translated else: actual_prompt = prompt # 파이프라인 가져오기 pipe = initialize_fashion_pipe() # LoRA 설정 if mode == "Generate Model": pipe = load_lora(pipe, MODEL_LORA_REPO) trigger_word = "fashion photography, professional model" else: pipe = load_lora(pipe, CLOTHES_LORA_REPO) trigger_word = "upper clothing, fashion item" # 파라미터 제한 width = min(width, 768) height = min(height, 768) steps = min(steps, 30) # 시드 설정 if randomize_seed: seed = random.randint(0, MAX_SEED) generator = torch.Generator(device="cuda").manual_seed(seed) # 진행률 표시 progress(0, "Starting fashion generation...") # 이미지 생성 image = pipe( prompt=f"{actual_prompt} {trigger_word}", num_inference_steps=steps, guidance_scale=cfg_scale, width=width, height=height, generator=generator, joint_attention_kwargs={"scale": lora_scale}, ).images[0] return image, seed except Exception as e: print(f"Error in generate_fashion: {str(e)}") raise @safe_model_call def leffa_predict(src_image_path, ref_image_path, control_type): try: # 모델 초기화 if control_type == "virtual_tryon": model, inference = get_vt_model() else: model, inference = get_pt_model() mask_pred = get_mask_predictor() dense_pred = get_densepose_predictor() # 이미지 로드 및 전처리 src_image = Image.open(src_image_path) ref_image = Image.open(ref_image_path) src_image = resize_and_center(src_image, 768, 1024) ref_image = resize_and_center(ref_image, 768, 1024) src_image_array = np.array(src_image) ref_image_array = np.array(ref_image) # Mask 생성 if control_type == "virtual_tryon": src_image = src_image.convert("RGB") mask = mask_pred(src_image, "upper")["mask"] else: mask = Image.fromarray(np.ones_like(src_image_array) * 255) # DensePose 예측 src_image_iuv_array = dense_pred.predict_iuv(src_image_array) src_image_seg_array = dense_pred.predict_seg(src_image_array) if control_type == "virtual_tryon": densepose = Image.fromarray(src_image_seg_array) else: densepose = Image.fromarray(src_image_iuv_array) # Leffa 변환 및 추론 transform = LeffaTransform() data = { "src_image": [src_image], "ref_image": [ref_image], "mask": [mask], "densepose": [densepose], } data = transform(data) output = inference(data) return np.array(output["generated_image"][0]) except Exception as e: print(f"Error in leffa_predict: {str(e)}") raise @safe_model_call def leffa_predict_vt(src_image_path, ref_image_path): return leffa_predict(src_image_path, ref_image_path, "virtual_tryon") @safe_model_call def leffa_predict_pt(src_image_path, ref_image_path): return leffa_predict(src_image_path, ref_image_path, "pose_transfer") # 초기 설정 실행 setup() # Gradio 인터페이스 with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange") as demo: gr.Markdown("# 🎭 FitGen:Fashion Studio & Virtual Try-on") with gr.Tabs(): # 패션 생성 탭 # 패션 생성 탭 with gr.Tab("Fashion Generation"): with gr.Column(): mode = gr.Radio( choices=["Generate Model", "Generate Clothes"], label="Generation Mode", value="Generate Model" ) # 예제 프롬프트 설정 example_model_prompts = [ "professional fashion model, full body shot, standing pose, natural lighting, studio background, high fashion, elegant pose", "fashion model portrait, upper body, confident pose, fashion photography, neutral background, professional lighting", "stylish fashion model, three-quarter view, editorial pose, high-end fashion magazine style, minimal background" ] example_clothes_prompts = [ "luxury designer sweater, cashmere material, cream color, cable knit pattern, high-end fashion, product photography", "elegant business blazer, tailored fit, charcoal grey, premium wool fabric, professional wear", "modern streetwear hoodie, oversized fit, minimalist design, premium cotton, urban style" ] prompt = gr.TextArea( label="Fashion Description (한글 또는 영어)", placeholder="패션 모델이나 의류를 설명하세요..." ) # 예제 섹션 추가 gr.Examples( examples=example_model_prompts + example_clothes_prompts, inputs=prompt, label="Example Prompts" ) with gr.Row(): with gr.Column(): result = gr.Image(label="Generated Result") generate_button = gr.Button("Generate Fashion") with gr.Accordion("Advanced Options", open=False): with gr.Group(): with gr.Row(): with gr.Column(): cfg_scale = gr.Slider( label="CFG Scale", minimum=1, maximum=20, step=0.5, value=7.0 ) steps = gr.Slider( label="Steps", minimum=1, maximum=50, # 최대값 감소 step=1, value=30 ) lora_scale = gr.Slider( label="LoRA Scale", minimum=0, maximum=1, step=0.01, value=0.85 ) with gr.Row(): width = gr.Slider( label="Width", minimum=256, maximum=1024, # 최대값 감소 step=64, value=512 ) height = gr.Slider( label="Height", minimum=256, maximum=1024, # 최대값 감소 step=64, value=768 ) with gr.Row(): randomize_seed = gr.Checkbox( True, label="Randomize seed" ) seed = gr.Slider( label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42 ) # 가상 피팅 탭 with gr.Tab("Virtual Try-on"): with gr.Row(): with gr.Column(): gr.Markdown("#### Person Image") vt_src_image = gr.Image( sources=["upload"], type="filepath", label="Person Image", width=512, height=512, ) gr.Examples( inputs=vt_src_image, examples_per_page=5, examples=["./ckpts/examples/person1/01350_00.jpg", "./ckpts/examples/person1/01376_00.jpg", "./ckpts/examples/person1/01416_00.jpg", "./ckpts/examples/person1/05976_00.jpg", "./ckpts/examples/person1/06094_00.jpg"] ) with gr.Column(): gr.Markdown("#### Garment Image") vt_ref_image = gr.Image( sources=["upload"], type="filepath", label="Garment Image", width=512, height=512, ) gr.Examples( inputs=vt_ref_image, examples_per_page=5, examples=["./ckpts/examples/garment/01449_00.jpg", "./ckpts/examples/garment/01486_00.jpg", "./ckpts/examples/garment/01853_00.jpg", "./ckpts/examples/garment/02070_00.jpg", "./ckpts/examples/garment/03553_00.jpg"] ) with gr.Column(): gr.Markdown("#### Generated Image") vt_gen_image = gr.Image( label="Generated Image", width=512, height=512, ) vt_gen_button = gr.Button("Try-on") # 포즈 전송 탭 with gr.Tab("Pose Transfer"): with gr.Row(): with gr.Column(): gr.Markdown("#### Person Image") pt_ref_image = gr.Image( sources=["upload"], type="filepath", label="Person Image", width=512, height=512, ) gr.Examples( inputs=pt_ref_image, examples_per_page=5, examples=["./ckpts/examples/person1/01350_00.jpg", "./ckpts/examples/person1/01376_00.jpg", "./ckpts/examples/person1/01416_00.jpg", "./ckpts/examples/person1/05976_00.jpg", "./ckpts/examples/person1/06094_00.jpg"] ) with gr.Column(): gr.Markdown("#### Target Pose Person Image") pt_src_image = gr.Image( sources=["upload"], type="filepath", label="Target Pose Person Image", width=512, height=512, ) gr.Examples( inputs=pt_src_image, examples_per_page=5, examples=["./ckpts/examples/person2/01850_00.jpg", "./ckpts/examples/person2/01875_00.jpg", "./ckpts/examples/person2/02532_00.jpg", "./ckpts/examples/person2/02902_00.jpg", "./ckpts/examples/person2/05346_00.jpg"] ) with gr.Column(): gr.Markdown("#### Generated Image") pt_gen_image = gr.Image( label="Generated Image", width=512, height=512, ) pose_transfer_gen_button = gr.Button("Generate") # 이벤트 핸들러 generate_button.click( generate_fashion, inputs=[prompt, mode, cfg_scale, steps, randomize_seed, seed, width, height, lora_scale], outputs=[result, seed] ) vt_gen_button.click( fn=leffa_predict_vt, inputs=[vt_src_image, vt_ref_image], outputs=[vt_gen_image] ) pose_transfer_gen_button.click( fn=leffa_predict_pt, inputs=[pt_src_image, pt_ref_image], outputs=[pt_gen_image] ) # 앱 실행 demo.launch(share=True, server_port=7860)