Spaces:
ginipick
/
Running on Zero

FitGen / app.py
ginipick's picture
Update app.py
fc293bb verified
raw
history blame
18.5 kB
import numpy as np
from PIL import Image
from huggingface_hub import snapshot_download, login
from leffa.transform import LeffaTransform
from leffa.model import LeffaModel
from leffa.inference import LeffaInference
from utils.garment_agnostic_mask_predictor import AutoMasker
from utils.densepose_predictor import DensePosePredictor
from utils.utils import resize_and_center
import spaces
import torch
from diffusers import DiffusionPipeline
from transformers import pipeline
import gradio as gr
import os
import random
import gc
# ์ƒ์ˆ˜ ์ •์˜
MAX_SEED = 2**32 - 1
BASE_MODEL = "black-forest-labs/FLUX.1-dev"
MODEL_LORA_REPO = "Motas/Flux_Fashion_Photography_Style"
CLOTHES_LORA_REPO = "prithivMLmods/Canopus-Clothing-Flux-LoRA"
# ๋ฉ”๋ชจ๋ฆฌ ๊ด€๋ฆฌ๋ฅผ ์œ„ํ•œ ๋ฐ์ฝ”๋ ˆ์ดํ„ฐ
def safe_model_call(func):
def wrapper(*args, **kwargs):
try:
clear_memory()
result = func(*args, **kwargs)
clear_memory()
return result
except Exception as e:
clear_memory()
print(f"Error in {func.__name__}: {str(e)}")
raise
return wrapper
# ๋ฉ”๋ชจ๋ฆฌ ๊ด€๋ฆฌ ํ•จ์ˆ˜
def clear_memory():
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
gc.collect()
def setup_environment():
# ๋ฉ”๋ชจ๋ฆฌ ๊ด€๋ฆฌ ์„ค์ •
torch.cuda.empty_cache()
gc.collect()
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128'
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cuda.max_split_size_mb = 128
# Hugging Face ํ† ํฐ ์„ค์ •
global HF_TOKEN
HF_TOKEN = os.getenv("HF_TOKEN")
if HF_TOKEN is None:
raise ValueError("Please set the HF_TOKEN environment variable")
login(token=HF_TOKEN)
# CUDA ์„ค์ •
global device
device = "cuda" if torch.cuda.is_available() else "cpu"
# ์ „์—ญ ๋ณ€์ˆ˜ ์ดˆ๊ธฐํ™”
fashion_pipe = None
translator = None
mask_predictor = None
densepose_predictor = None
vt_model = None
pt_model = None
vt_inference = None
pt_inference = None
device = None
HF_TOKEN = None
# ํ™˜๊ฒฝ ์„ค์ • ์‹คํ–‰
setup_environment()
# ๋ชจ๋ธ ๊ด€๋ฆฌ ํ•จ์ˆ˜๋“ค
def initialize_fashion_pipe():
global fashion_pipe
if fashion_pipe is None:
clear_memory()
fashion_pipe = DiffusionPipeline.from_pretrained(
BASE_MODEL,
torch_dtype=torch.float16,
use_auth_token=HF_TOKEN
)
try:
fashion_pipe.enable_xformers_memory_efficient_attention()
except Exception as e:
print(f"Warning: Could not enable memory efficient attention: {e}")
fashion_pipe.enable_sequential_cpu_offload()
return fashion_pipe
@safe_model_call
def get_translator():
global translator
if translator is None:
translator = pipeline("translation",
model="Helsinki-NLP/opus-mt-ko-en",
device=device if device == "cuda" else -1)
return translator
@safe_model_call
def get_mask_predictor():
global mask_predictor
if mask_predictor is None:
mask_predictor = AutoMasker(
densepose_path="./ckpts/densepose",
schp_path="./ckpts/schp",
)
return mask_predictor
@safe_model_call
def get_densepose_predictor():
global densepose_predictor
if densepose_predictor is None:
densepose_predictor = DensePosePredictor(
config_path="./ckpts/densepose/densepose_rcnn_R_50_FPN_s1x.yaml",
weights_path="./ckpts/densepose/model_final_162be9.pkl",
)
return densepose_predictor
@safe_model_call
def get_vt_model():
global vt_model, vt_inference
if vt_model is None:
vt_model = LeffaModel(
pretrained_model_name_or_path="./ckpts/stable-diffusion-inpainting",
pretrained_model="./ckpts/virtual_tryon.pth"
)
vt_model = vt_model.half().to(device)
vt_inference = LeffaInference(model=vt_model)
return vt_model, vt_inference
@safe_model_call
def get_pt_model():
global pt_model, pt_inference
if pt_model is None:
pt_model = LeffaModel(
pretrained_model_name_or_path="./ckpts/stable-diffusion-xl-1.0-inpainting-0.1",
pretrained_model="./ckpts/pose_transfer.pth"
)
pt_model = pt_model.half().to(device)
pt_inference = LeffaInference(model=pt_model)
return pt_model, pt_inference
def load_lora(pipe, lora_path):
try:
pipe.unload_lora_weights()
except:
pass
try:
pipe.load_lora_weights(lora_path)
return pipe
except Exception as e:
print(f"Warning: Failed to load LoRA weights from {lora_path}: {e}")
return pipe
# ์ดˆ๊ธฐ ์„ค์ • ํ•จ์ˆ˜
def setup():
# Leffa ์ฒดํฌํฌ์ธํŠธ ๋‹ค์šด๋กœ๋“œ
snapshot_download(repo_id="franciszzj/Leffa", local_dir="./ckpts")
# ๊ธฐ๋ณธ ๋ชจ๋ธ ์ดˆ๊ธฐํ™”
initialize_fashion_pipe()
# ์œ ํ‹ธ๋ฆฌํ‹ฐ ํ•จ์ˆ˜
def contains_korean(text):
return any(ord('๊ฐ€') <= ord(char) <= ord('ํžฃ') for char in text)
# ๋ฉ”์ธ ๊ธฐ๋Šฅ ํ•จ์ˆ˜๋“ค
@spaces.GPU()
@safe_model_call
def generate_fashion(prompt, mode, cfg_scale, steps, randomize_seed, seed, width, height, lora_scale, progress=gr.Progress(track_tqdm=True)):
try:
# ํ•œ๊ธ€ ์ฒ˜๋ฆฌ
if contains_korean(prompt):
translator = get_translator()
translated = translator(prompt)[0]['translation_text']
actual_prompt = translated
else:
actual_prompt = prompt
# ํŒŒ์ดํ”„๋ผ์ธ ๊ฐ€์ ธ์˜ค๊ธฐ
pipe = initialize_fashion_pipe()
# LoRA ์„ค์ •
if mode == "Generate Model":
pipe = load_lora(pipe, MODEL_LORA_REPO)
trigger_word = "fashion photography, professional model"
else:
pipe = load_lora(pipe, CLOTHES_LORA_REPO)
trigger_word = "upper clothing, fashion item"
# ํŒŒ๋ผ๋ฏธํ„ฐ ์ œํ•œ
width = min(width, 768)
height = min(height, 768)
steps = min(steps, 30)
# ์‹œ๋“œ ์„ค์ •
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator(device="cuda").manual_seed(seed)
# ์ง„ํ–‰๋ฅ  ํ‘œ์‹œ
progress(0, "Starting fashion generation...")
# ์ด๋ฏธ์ง€ ์ƒ์„ฑ
image = pipe(
prompt=f"{actual_prompt} {trigger_word}",
num_inference_steps=steps,
guidance_scale=cfg_scale,
width=width,
height=height,
generator=generator,
joint_attention_kwargs={"scale": lora_scale},
).images[0]
return image, seed
except Exception as e:
print(f"Error in generate_fashion: {str(e)}")
raise
@safe_model_call
def leffa_predict(src_image_path, ref_image_path, control_type):
try:
# ๋ชจ๋ธ ์ดˆ๊ธฐํ™”
if control_type == "virtual_tryon":
model, inference = get_vt_model()
else:
model, inference = get_pt_model()
mask_pred = get_mask_predictor()
dense_pred = get_densepose_predictor()
# ์ด๋ฏธ์ง€ ๋กœ๋“œ ๋ฐ ์ „์ฒ˜๋ฆฌ
src_image = Image.open(src_image_path)
ref_image = Image.open(ref_image_path)
src_image = resize_and_center(src_image, 768, 1024)
ref_image = resize_and_center(ref_image, 768, 1024)
src_image_array = np.array(src_image)
ref_image_array = np.array(ref_image)
# Mask ์ƒ์„ฑ
if control_type == "virtual_tryon":
src_image = src_image.convert("RGB")
mask = mask_pred(src_image, "upper")["mask"]
else:
mask = Image.fromarray(np.ones_like(src_image_array) * 255)
# DensePose ์˜ˆ์ธก
src_image_iuv_array = dense_pred.predict_iuv(src_image_array)
src_image_seg_array = dense_pred.predict_seg(src_image_array)
if control_type == "virtual_tryon":
densepose = Image.fromarray(src_image_seg_array)
else:
densepose = Image.fromarray(src_image_iuv_array)
# Leffa ๋ณ€ํ™˜ ๋ฐ ์ถ”๋ก 
transform = LeffaTransform()
data = {
"src_image": [src_image],
"ref_image": [ref_image],
"mask": [mask],
"densepose": [densepose],
}
data = transform(data)
output = inference(data)
return np.array(output["generated_image"][0])
except Exception as e:
print(f"Error in leffa_predict: {str(e)}")
raise
@safe_model_call
def leffa_predict_vt(src_image_path, ref_image_path):
return leffa_predict(src_image_path, ref_image_path, "virtual_tryon")
@safe_model_call
def leffa_predict_pt(src_image_path, ref_image_path):
return leffa_predict(src_image_path, ref_image_path, "pose_transfer")
# ์ดˆ๊ธฐ ์„ค์ • ์‹คํ–‰
setup()
# Gradio ์ธํ„ฐํŽ˜์ด์Šค
with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange") as demo:
gr.Markdown("# ๐ŸŽญ FitGen:Fashion Studio & Virtual Try-on")
with gr.Tabs():
# ํŒจ์…˜ ์ƒ์„ฑ ํƒญ
# ํŒจ์…˜ ์ƒ์„ฑ ํƒญ
with gr.Tab("Fashion Generation"):
with gr.Column():
mode = gr.Radio(
choices=["Generate Model", "Generate Clothes"],
label="Generation Mode",
value="Generate Model"
)
# ์˜ˆ์ œ ํ”„๋กฌํ”„ํŠธ ์„ค์ •
example_model_prompts = [
"professional fashion model, full body shot, standing pose, natural lighting, studio background, high fashion, elegant pose",
"fashion model portrait, upper body, confident pose, fashion photography, neutral background, professional lighting",
"stylish fashion model, three-quarter view, editorial pose, high-end fashion magazine style, minimal background"
]
example_clothes_prompts = [
"luxury designer sweater, cashmere material, cream color, cable knit pattern, high-end fashion, product photography",
"elegant business blazer, tailored fit, charcoal grey, premium wool fabric, professional wear",
"modern streetwear hoodie, oversized fit, minimalist design, premium cotton, urban style"
]
prompt = gr.TextArea(
label="Fashion Description (ํ•œ๊ธ€ ๋˜๋Š” ์˜์–ด)",
placeholder="ํŒจ์…˜ ๋ชจ๋ธ์ด๋‚˜ ์˜๋ฅ˜๋ฅผ ์„ค๋ช…ํ•˜์„ธ์š”..."
)
# ์˜ˆ์ œ ์„น์…˜ ์ถ”๊ฐ€
gr.Examples(
examples=example_model_prompts + example_clothes_prompts,
inputs=prompt,
label="Example Prompts"
)
with gr.Row():
with gr.Column():
result = gr.Image(label="Generated Result")
generate_button = gr.Button("Generate Fashion")
with gr.Accordion("Advanced Options", open=False):
with gr.Group():
with gr.Row():
with gr.Column():
cfg_scale = gr.Slider(
label="CFG Scale",
minimum=1,
maximum=20,
step=0.5,
value=7.0
)
steps = gr.Slider(
label="Steps",
minimum=1,
maximum=50, # ์ตœ๋Œ€๊ฐ’ ๊ฐ์†Œ
step=1,
value=30
)
lora_scale = gr.Slider(
label="LoRA Scale",
minimum=0,
maximum=1,
step=0.01,
value=0.85
)
with gr.Row():
width = gr.Slider(
label="Width",
minimum=256,
maximum=1024, # ์ตœ๋Œ€๊ฐ’ ๊ฐ์†Œ
step=64,
value=512
)
height = gr.Slider(
label="Height",
minimum=256,
maximum=1024, # ์ตœ๋Œ€๊ฐ’ ๊ฐ์†Œ
step=64,
value=768
)
with gr.Row():
randomize_seed = gr.Checkbox(
True,
label="Randomize seed"
)
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=42
)
# ๊ฐ€์ƒ ํ”ผํŒ… ํƒญ
with gr.Tab("Virtual Try-on"):
with gr.Row():
with gr.Column():
gr.Markdown("#### Person Image")
vt_src_image = gr.Image(
sources=["upload"],
type="filepath",
label="Person Image",
width=512,
height=512,
)
gr.Examples(
inputs=vt_src_image,
examples_per_page=5,
examples=["./ckpts/examples/person1/01350_00.jpg",
"./ckpts/examples/person1/01376_00.jpg",
"./ckpts/examples/person1/01416_00.jpg",
"./ckpts/examples/person1/05976_00.jpg",
"./ckpts/examples/person1/06094_00.jpg"]
)
with gr.Column():
gr.Markdown("#### Garment Image")
vt_ref_image = gr.Image(
sources=["upload"],
type="filepath",
label="Garment Image",
width=512,
height=512,
)
gr.Examples(
inputs=vt_ref_image,
examples_per_page=5,
examples=["./ckpts/examples/garment/01449_00.jpg",
"./ckpts/examples/garment/01486_00.jpg",
"./ckpts/examples/garment/01853_00.jpg",
"./ckpts/examples/garment/02070_00.jpg",
"./ckpts/examples/garment/03553_00.jpg"]
)
with gr.Column():
gr.Markdown("#### Generated Image")
vt_gen_image = gr.Image(
label="Generated Image",
width=512,
height=512,
)
vt_gen_button = gr.Button("Try-on")
# ํฌ์ฆˆ ์ „์†ก ํƒญ
with gr.Tab("Pose Transfer"):
with gr.Row():
with gr.Column():
gr.Markdown("#### Person Image")
pt_ref_image = gr.Image(
sources=["upload"],
type="filepath",
label="Person Image",
width=512,
height=512,
)
gr.Examples(
inputs=pt_ref_image,
examples_per_page=5,
examples=["./ckpts/examples/person1/01350_00.jpg",
"./ckpts/examples/person1/01376_00.jpg",
"./ckpts/examples/person1/01416_00.jpg",
"./ckpts/examples/person1/05976_00.jpg",
"./ckpts/examples/person1/06094_00.jpg"]
)
with gr.Column():
gr.Markdown("#### Target Pose Person Image")
pt_src_image = gr.Image(
sources=["upload"],
type="filepath",
label="Target Pose Person Image",
width=512,
height=512,
)
gr.Examples(
inputs=pt_src_image,
examples_per_page=5,
examples=["./ckpts/examples/person2/01850_00.jpg",
"./ckpts/examples/person2/01875_00.jpg",
"./ckpts/examples/person2/02532_00.jpg",
"./ckpts/examples/person2/02902_00.jpg",
"./ckpts/examples/person2/05346_00.jpg"]
)
with gr.Column():
gr.Markdown("#### Generated Image")
pt_gen_image = gr.Image(
label="Generated Image",
width=512,
height=512,
)
pose_transfer_gen_button = gr.Button("Generate")
# ์ด๋ฒคํŠธ ํ•ธ๋“ค๋Ÿฌ
generate_button.click(
generate_fashion,
inputs=[prompt, mode, cfg_scale, steps, randomize_seed, seed, width, height, lora_scale],
outputs=[result, seed]
)
vt_gen_button.click(
fn=leffa_predict_vt,
inputs=[vt_src_image, vt_ref_image],
outputs=[vt_gen_image]
)
pose_transfer_gen_button.click(
fn=leffa_predict_pt,
inputs=[pt_src_image, pt_ref_image],
outputs=[pt_gen_image]
)
# ์•ฑ ์‹คํ–‰
demo.launch(share=True, server_port=7860)