Spaces:
ginipick
/
Running on Zero

FitGen / app.py
ginipick's picture
Update app.py
12cd271 verified
raw
history blame
20.6 kB
import numpy as np
from PIL import Image
from huggingface_hub import snapshot_download, login
from leffa.transform import LeffaTransform
from leffa.model import LeffaModel
from leffa.inference import LeffaInference
from utils.garment_agnostic_mask_predictor import AutoMasker
from utils.densepose_predictor import DensePosePredictor
from utils.utils import resize_and_center
import spaces
import torch
from diffusers import DiffusionPipeline
from transformers import pipeline
import gradio as gr
import os
import random
import gc
# 상수 정의
MAX_SEED = 2**32 - 1
BASE_MODEL = "black-forest-labs/FLUX.1-dev"
MODEL_LORA_REPO = "Motas/Flux_Fashion_Photography_Style"
CLOTHES_LORA_REPO = "prithivMLmods/Canopus-Clothing-Flux-LoRA"
# 메모리 관리를 위한 데코레이터
def safe_model_call(func):
def wrapper(*args, **kwargs):
try:
clear_memory()
result = func(*args, **kwargs)
clear_memory()
return result
except Exception as e:
clear_memory()
print(f"Error in {func.__name__}: {str(e)}")
raise
return wrapper
# 메모리 관리 함수 수정
def clear_memory():
gc.collect()
if torch.cuda.is_available() and torch.cuda.current_device() >= 0:
torch.cuda.empty_cache()
def setup_environment():
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128'
HF_TOKEN = os.getenv("HF_TOKEN")
if not HF_TOKEN:
raise ValueError("HF_TOKEN not found in environment variables")
login(token=HF_TOKEN)
return HF_TOKEN
@spaces.GPU()
def generate_image(prompt, mode, cfg_scale=7.0, steps=30, seed=None, width=512, height=768, lora_scale=0.85):
try:
# 파이프라인 초기화
pipe = DiffusionPipeline.from_pretrained(
BASE_MODEL,
torch_dtype=torch.float16,
)
pipe.to("cuda")
# LoRA 설정
if mode == "Generate Model":
pipe.load_lora_weights(MODEL_LORA_REPO)
trigger_word = "fashion photography, professional model"
else:
pipe.load_lora_weights(CLOTHES_LORA_REPO)
trigger_word = "upper clothing, fashion item"
# 생성 설정
generator = torch.Generator("cuda").manual_seed(seed if seed is not None else torch.randint(0, 2**32 - 1, (1,)).item())
# 이미지 생성
with torch.inference_mode():
result = pipe(
prompt=f"{prompt} {trigger_word}",
num_inference_steps=steps,
guidance_scale=cfg_scale,
width=width,
height=height,
generator=generator,
cross_attention_kwargs={"scale": lora_scale},
).images[0]
# 메모리 정리
del pipe
clear_memory()
return result, seed
except Exception as e:
clear_memory()
raise gr.Error(f"Generation failed: {str(e)}")
# 전역 변수 초기화
fashion_pipe = None
translator = None
mask_predictor = None
densepose_predictor = None
vt_model = None
pt_model = None
vt_inference = None
pt_inference = None
device = None
HF_TOKEN = None
# 환경 설정 실행
setup_environment()
@spaces.GPU()
def initialize_fashion_pipe():
global fashion_pipe
if fashion_pipe is None:
fashion_pipe = DiffusionPipeline.from_pretrained(
BASE_MODEL,
torch_dtype=torch.float16,
use_auth_token=HF_TOKEN
).to("cuda")
try:
fashion_pipe.enable_xformers_memory_efficient_attention()
except Exception as e:
print(f"Warning: Could not enable memory efficient attention: {e}")
return fashion_pipe
def setup():
# Leffa 체크포인트 다운로드만 수행
snapshot_download(repo_id="franciszzj/Leffa", local_dir="./ckpts")
@spaces.GPU()
def get_translator():
global translator
if translator is None:
translator = pipeline("translation",
model="Helsinki-NLP/opus-mt-ko-en",
device="cuda")
return translator
@safe_model_call
def get_mask_predictor():
global mask_predictor
if mask_predictor is None:
mask_predictor = AutoMasker(
densepose_path="./ckpts/densepose",
schp_path="./ckpts/schp",
)
return mask_predictor
@safe_model_call
def get_densepose_predictor():
global densepose_predictor
if densepose_predictor is None:
densepose_predictor = DensePosePredictor(
config_path="./ckpts/densepose/densepose_rcnn_R_50_FPN_s1x.yaml",
weights_path="./ckpts/densepose/model_final_162be9.pkl",
)
return densepose_predictor
@safe_model_call
def get_vt_model():
global vt_model, vt_inference
if vt_model is None:
vt_model = LeffaModel(
pretrained_model_name_or_path="./ckpts/stable-diffusion-inpainting",
pretrained_model="./ckpts/virtual_tryon.pth"
)
vt_model = vt_model.half().to(device)
vt_inference = LeffaInference(model=vt_model)
return vt_model, vt_inference
@safe_model_call
def get_pt_model():
global pt_model, pt_inference
if pt_model is None:
pt_model = LeffaModel(
pretrained_model_name_or_path="./ckpts/stable-diffusion-xl-1.0-inpainting-0.1",
pretrained_model="./ckpts/pose_transfer.pth"
)
pt_model = pt_model.half().to(device)
pt_inference = LeffaInference(model=pt_model)
return pt_model, pt_inference
def load_lora(pipe, lora_path):
try:
pipe.unload_lora_weights()
except:
pass
try:
pipe.load_lora_weights(lora_path)
return pipe
except Exception as e:
print(f"Warning: Failed to load LoRA weights from {lora_path}: {e}")
return pipe
@spaces.GPU()
def get_mask_predictor():
global mask_predictor
if mask_predictor is None:
mask_predictor = AutoMasker(
densepose_path="./ckpts/densepose",
schp_path="./ckpts/schp",
)
return mask_predictor
# 유틸리티 함수
def contains_korean(text):
return any(ord('가') <= ord(char) <= ord('힣') for char in text)
# 모델 초기화 함수 수정
@spaces.GPU()
def initialize_fashion_pipe():
try:
pipe = DiffusionPipeline.from_pretrained(
BASE_MODEL,
torch_dtype=torch.float16,
safety_checker=None,
requires_safety_checker=False
).to("cuda")
pipe.enable_model_cpu_offload()
return pipe
except Exception as e:
print(f"Error initializing fashion pipe: {e}")
raise
# 생성 함수 수정
@spaces.GPU()
def generate_fashion(prompt, mode, cfg_scale, steps, randomize_seed, seed, width, height, lora_scale, progress=gr.Progress(track_tqdm=True)):
try:
# 한글 처리
if contains_korean(prompt):
with torch.inference_mode():
translator = get_translator()
translated = translator(prompt)[0]['translation_text']
actual_prompt = translated
else:
actual_prompt = prompt
# 파이프라인 초기화
pipe = initialize_fashion_pipe()
# LoRA 설정
if mode == "Generate Model":
pipe.load_lora_weights(MODEL_LORA_REPO)
trigger_word = "fashion photography, professional model"
else:
pipe.load_lora_weights(CLOTHES_LORA_REPO)
trigger_word = "upper clothing, fashion item"
# 파라미터 제한
width = min(width, 768)
height = min(height, 768)
steps = min(steps, 30)
# 시드 설정
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator("cuda").manual_seed(seed)
# 이미지 생성
with torch.inference_mode():
output = pipe(
prompt=f"{actual_prompt} {trigger_word}",
num_inference_steps=steps,
guidance_scale=cfg_scale,
width=width,
height=height,
generator=generator,
cross_attention_kwargs={"scale": lora_scale},
)
image = output.images[0]
# 메모리 정리
del pipe
torch.cuda.empty_cache()
gc.collect()
return image, seed
except Exception as e:
print(f"Error in generate_fashion: {str(e)}")
raise gr.Error(f"Generation failed: {str(e)}")
@safe_model_call
def leffa_predict(src_image_path, ref_image_path, control_type):
try:
# 모델 초기화
if control_type == "virtual_tryon":
model, inference = get_vt_model()
else:
model, inference = get_pt_model()
mask_pred = get_mask_predictor()
dense_pred = get_densepose_predictor()
# 이미지 로드 및 전처리
src_image = Image.open(src_image_path)
ref_image = Image.open(ref_image_path)
src_image = resize_and_center(src_image, 768, 1024)
ref_image = resize_and_center(ref_image, 768, 1024)
src_image_array = np.array(src_image)
ref_image_array = np.array(ref_image)
# Mask 생성
if control_type == "virtual_tryon":
src_image = src_image.convert("RGB")
mask = mask_pred(src_image, "upper")["mask"]
else:
mask = Image.fromarray(np.ones_like(src_image_array) * 255)
# DensePose 예측
src_image_iuv_array = dense_pred.predict_iuv(src_image_array)
src_image_seg_array = dense_pred.predict_seg(src_image_array)
if control_type == "virtual_tryon":
densepose = Image.fromarray(src_image_seg_array)
else:
densepose = Image.fromarray(src_image_iuv_array)
# Leffa 변환 및 추론
transform = LeffaTransform()
data = {
"src_image": [src_image],
"ref_image": [ref_image],
"mask": [mask],
"densepose": [densepose],
}
data = transform(data)
output = inference(data)
return np.array(output["generated_image"][0])
except Exception as e:
print(f"Error in leffa_predict: {str(e)}")
raise
@safe_model_call
def leffa_predict_vt(src_image_path, ref_image_path):
return leffa_predict(src_image_path, ref_image_path, "virtual_tryon")
@safe_model_call
def leffa_predict_pt(src_image_path, ref_image_path):
return leffa_predict(src_image_path, ref_image_path, "pose_transfer")
# 초기 설정 실행
setup()
def create_interface():
with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange") as demo:
gr.Markdown("# 🎭 FitGen:Fashion Studio & Virtual Try-on")
with gr.Tabs():
# 패션 생성 탭
with gr.Tab("Fashion Generation"):
with gr.Column():
mode = gr.Radio(
choices=["Generate Model", "Generate Clothes"],
label="Generation Mode",
value="Generate Model"
)
# 예제 프롬프트 설정
example_model_prompts = [
"professional fashion model, full body shot, standing pose, natural lighting, studio background, high fashion, elegant pose",
"fashion model portrait, upper body, confident pose, fashion photography, neutral background, professional lighting",
"stylish fashion model, three-quarter view, editorial pose, high-end fashion magazine style, minimal background"
]
example_clothes_prompts = [
"luxury designer sweater, cashmere material, cream color, cable knit pattern, high-end fashion, product photography",
"elegant business blazer, tailored fit, charcoal grey, premium wool fabric, professional wear",
"modern streetwear hoodie, oversized fit, minimalist design, premium cotton, urban style"
]
prompt = gr.TextArea(
label="Fashion Description (한글 또는 영어)",
placeholder="패션 모델이나 의류를 설명하세요..."
)
# 예제 섹션 추가
gr.Examples(
examples=example_model_prompts + example_clothes_prompts,
inputs=prompt,
label="Example Prompts"
)
with gr.Row():
with gr.Column():
result = gr.Image(label="Generated Result")
generate_button = gr.Button("Generate Fashion")
with gr.Accordion("Advanced Options", open=False):
with gr.Group():
with gr.Row():
with gr.Column():
cfg_scale = gr.Slider(
label="CFG Scale",
minimum=1,
maximum=20,
step=0.5,
value=7.0
)
steps = gr.Slider(
label="Steps",
minimum=1,
maximum=30,
step=1,
value=30
)
lora_scale = gr.Slider(
label="LoRA Scale",
minimum=0,
maximum=1,
step=0.01,
value=0.85
)
with gr.Row():
width = gr.Slider(
label="Width",
minimum=256,
maximum=768,
step=64,
value=512
)
height = gr.Slider(
label="Height",
minimum=256,
maximum=768,
step=64,
value=768
)
with gr.Row():
randomize_seed = gr.Checkbox(
True,
label="Randomize seed"
)
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=2**32-1,
step=1,
value=42
)
# 가상 피팅 탭
with gr.Tab("Virtual Try-on"):
with gr.Row():
with gr.Column():
gr.Markdown("#### Person Image")
vt_src_image = gr.Image(
sources=["upload"],
type="filepath",
label="Person Image",
width=512,
height=512,
)
gr.Examples(
inputs=vt_src_image,
examples_per_page=5,
examples=["a1.webp",
"a2.webp",
"a3.webp",
"a4.webp",
"a5.webp"]
)
with gr.Column():
gr.Markdown("#### Garment Image")
vt_ref_image = gr.Image(
sources=["upload"],
type="filepath",
label="Garment Image",
width=512,
height=512,
)
gr.Examples(
inputs=vt_ref_image,
examples_per_page=5,
examples=["b1.webp",
"b2.webp",
"b3.webp",
"b4.webp",
"b5.webp"]
)
with gr.Column():
gr.Markdown("#### Generated Image")
vt_gen_image = gr.Image(
label="Generated Image",
width=512,
height=512,
)
vt_gen_button = gr.Button("Try-on")
# 포즈 전송 탭
with gr.Tab("Pose Transfer"):
with gr.Row():
with gr.Column():
gr.Markdown("#### Person Image")
pt_ref_image = gr.Image(
sources=["upload"],
type="filepath",
label="Person Image",
width=512,
height=512,
)
gr.Examples(
inputs=pt_ref_image,
examples_per_page=5,
examples=["a1.webp",
"a2.webp",
"a3.webp",
"a4.webp",
"a5.webp"]
)
with gr.Column():
gr.Markdown("#### Target Pose Person Image")
pt_src_image = gr.Image(
sources=["upload"],
type="filepath",
label="Target Pose Person Image",
width=512,
height=512,
)
gr.Examples(
inputs=pt_src_image,
examples_per_page=5,
examples=["d1.webp",
"d2.webp",
"d3.webp",
"d4.webp",
"d5.webp"]
)
with gr.Column():
gr.Markdown("#### Generated Image")
pt_gen_image = gr.Image(
label="Generated Image",
width=512,
height=512,
)
pose_transfer_gen_button = gr.Button("Generate")
# 이벤트 핸들러
generate_button.click(
fn=generate_image,
inputs=[prompt, mode, cfg_scale, steps, seed, width, height, lora_scale],
outputs=[result, seed]
)
vt_gen_button.click(
fn=leffa_predict_vt,
inputs=[vt_src_image, vt_ref_image],
outputs=[vt_gen_image]
)
pose_transfer_gen_button.click(
fn=leffa_predict_pt,
inputs=[pt_src_image, pt_ref_image],
outputs=[pt_gen_image]
)
return demo
if __name__ == "__main__":
# 환경 설정
setup_environment()
# 인터페이스 생성 및 실행
demo = create_interface()
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False
)