Spaces:
ginipick
/
Running on Zero

FitGen / app.py
ginipick's picture
Update app.py
dcb1878 verified
raw
history blame
16.3 kB
import numpy as np
from PIL import Image
from huggingface_hub import snapshot_download, login
from leffa.transform import LeffaTransform
from leffa.model import LeffaModel
from leffa.inference import LeffaInference
from utils.garment_agnostic_mask_predictor import AutoMasker
from utils.densepose_predictor import DensePosePredictor
from utils.utils import resize_and_center
import spaces
import torch
from diffusers import DiffusionPipeline
from transformers import pipeline
import gradio as gr
import os
import random
import gc
# ๋ฉ”๋ชจ๋ฆฌ ๊ด€๋ฆฌ ์„ค์ • ์ถ”๊ฐ€
import torch.backends.cuda
torch.backends.cuda.max_split_size_mb = 128 # ๋ฉ”๋ชจ๋ฆฌ ๋ถ„ํ•  ํฌ๊ธฐ ์ œํ•œ
# ๋ฉ”๋ชจ๋ฆฌ ๊ด€๋ฆฌ ์„ค์ •
torch.cuda.empty_cache()
gc.collect()
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128'
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
def clear_memory():
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
gc.collect()
# ์ƒ์ˆ˜ ์ •์˜
MAX_SEED = 2**32 - 1
BASE_MODEL = "black-forest-labs/FLUX.1-dev"
MODEL_LORA_REPO = "Motas/Flux_Fashion_Photography_Style"
CLOTHES_LORA_REPO = "prithivMLmods/Canopus-Clothing-Flux-LoRA"
# Hugging Face ํ† ํฐ ์„ค์ •
HF_TOKEN = os.getenv("HF_TOKEN")
if HF_TOKEN is None:
raise ValueError("Please set the HF_TOKEN environment variable")
login(token=HF_TOKEN)
# CUDA ์„ค์ •
device = "cuda" if torch.cuda.is_available() else "cpu"
# ๋ชจ๋ธ ๋กœ๋“œ ํ•จ์ˆ˜
def load_model_with_optimization(model_class, *args, **kwargs):
torch.cuda.empty_cache()
gc.collect()
model = model_class(*args, **kwargs)
if device == "cuda":
model = model.half() # FP16์œผ๋กœ ๋ณ€ํ™˜
return model.to(device)
# LoRA ๋กœ๋“œ ํ•จ์ˆ˜
def load_lora(pipe, lora_path):
pipe.load_lora_weights(lora_path)
return pipe
# FLUX ๋ชจ๋ธ ์ดˆ๊ธฐํ™” (ํ•„์š”ํ•  ๋•Œ๋งŒ ๋กœ๋“œ)
fashion_pipe = None
def get_fashion_pipe():
global fashion_pipe
if fashion_pipe is None:
torch.cuda.empty_cache()
fashion_pipe = DiffusionPipeline.from_pretrained(
BASE_MODEL,
torch_dtype=torch.float16,
use_auth_token=HF_TOKEN
)
try:
fashion_pipe.enable_xformers_memory_efficient_attention() # ์ˆ˜์ •๋œ ๋ถ€๋ถ„
except Exception as e:
print(f"Warning: Could not enable memory efficient attention: {e}")
fashion_pipe.enable_sequential_cpu_offload()
return fashion_pipe
# ๋ฒˆ์—ญ๊ธฐ ์ดˆ๊ธฐํ™” (ํ•„์š”ํ•  ๋•Œ๋งŒ ๋กœ๋“œ)
translator = None
def get_translator():
global translator
if translator is None:
translator = pipeline("translation",
model="Helsinki-NLP/opus-mt-ko-en",
device=device if device == "cuda" else -1)
return translator
# Leffa ๋ชจ๋ธ ๊ด€๋ จ ํ•จ์ˆ˜๋“ค
def get_mask_predictor():
global mask_predictor
if mask_predictor is None:
mask_predictor = AutoMasker(
densepose_path="./ckpts/densepose",
schp_path="./ckpts/schp",
)
return mask_predictor
def get_densepose_predictor():
global densepose_predictor
if densepose_predictor is None:
densepose_predictor = DensePosePredictor(
config_path="./ckpts/densepose/densepose_rcnn_R_50_FPN_s1x.yaml",
weights_path="./ckpts/densepose/model_final_162be9.pkl",
)
return densepose_predictor
def get_vt_model():
global vt_model, vt_inference
if vt_model is None:
torch.cuda.empty_cache()
vt_model = load_model_with_optimization(
LeffaModel,
pretrained_model_name_or_path="./ckpts/stable-diffusion-inpainting",
pretrained_model="./ckpts/virtual_tryon.pth"
)
vt_inference = LeffaInference(model=vt_model)
return vt_model, vt_inference
def get_pt_model():
global pt_model, pt_inference
if pt_model is None:
torch.cuda.empty_cache()
pt_model = load_model_with_optimization(
LeffaModel,
pretrained_model_name_or_path="./ckpts/stable-diffusion-xl-1.0-inpainting-0.1",
pretrained_model="./ckpts/pose_transfer.pth"
)
pt_inference = LeffaInference(model=pt_model)
return pt_model, pt_inference
# Leffa ์ฒดํฌํฌ์ธํŠธ ๋‹ค์šด๋กœ๋“œ
snapshot_download(repo_id="franciszzj/Leffa", local_dir="./ckpts")
def contains_korean(text):
return any(ord('๊ฐ€') <= ord(char) <= ord('ํžฃ') for char in text)
@spaces.GPU()
def generate_fashion(prompt, mode, cfg_scale, steps, randomize_seed, seed, width, height, lora_scale, progress=gr.Progress(track_tqdm=True)):
clear_memory() # ๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ
try:
if contains_korean(prompt):
translator = get_translator()
translated = translator(prompt)[0]['translation_text']
actual_prompt = translated
else:
actual_prompt = prompt
pipe = get_fashion_pipe()
# ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰ ์ œํ•œ์„ ์œ„ํ•œ ํฌ๊ธฐ ์กฐ์ •
width = min(width, 768) # ์ตœ๋Œ€ ํฌ๊ธฐ ์ œํ•œ
height = min(height, 768) # ์ตœ๋Œ€ ํฌ๊ธฐ ์ œํ•œ
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator(device="cuda").manual_seed(seed)
progress(0, "Starting fashion generation...")
image = pipe(
prompt=f"{actual_prompt} {trigger_word}",
num_inference_steps=min(steps, 30), # ์Šคํ… ์ˆ˜ ์ œํ•œ
guidance_scale=cfg_scale,
width=width,
height=height,
generator=generator,
joint_attention_kwargs={"scale": lora_scale},
).images[0]
clear_memory() # ๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ
return image, seed
except Exception as e:
clear_memory() # ์˜ค๋ฅ˜ ๋ฐœ์ƒ ์‹œ์—๋„ ๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ
raise e
def leffa_predict(src_image_path, ref_image_path, control_type):
torch.cuda.empty_cache()
assert control_type in [
"virtual_tryon", "pose_transfer"], "Invalid control type: {}".format(control_type)
# ์ด๋ฏธ์ง€ ๋กœ๋“œ ๋ฐ ํฌ๊ธฐ ์กฐ์ •
src_image = Image.open(src_image_path)
ref_image = Image.open(ref_image_path)
src_image = resize_and_center(src_image, 768, 1024)
ref_image = resize_and_center(ref_image, 768, 1024)
src_image_array = np.array(src_image)
ref_image_array = np.array(ref_image)
# Mask ์ƒ์„ฑ
if control_type == "virtual_tryon":
mask_pred = get_mask_predictor()
src_image = src_image.convert("RGB")
mask = mask_pred(src_image, "upper")["mask"]
elif control_type == "pose_transfer":
mask = Image.fromarray(np.ones_like(src_image_array) * 255)
# DensePose ์˜ˆ์ธก
dense_pred = get_densepose_predictor()
src_image_iuv_array = dense_pred.predict_iuv(src_image_array)
src_image_seg_array = dense_pred.predict_seg(src_image_array)
src_image_iuv = Image.fromarray(src_image_iuv_array)
src_image_seg = Image.fromarray(src_image_seg_array)
if control_type == "virtual_tryon":
densepose = src_image_seg
model, inference = get_vt_model()
elif control_type == "pose_transfer":
densepose = src_image_iuv
model, inference = get_pt_model()
# Leffa ๋ณ€ํ™˜ ๋ฐ ์ถ”๋ก 
transform = LeffaTransform()
data = {
"src_image": [src_image],
"ref_image": [ref_image],
"mask": [mask],
"densepose": [densepose],
}
data = transform(data)
output = inference(data)
gen_image = output["generated_image"][0]
torch.cuda.empty_cache()
return np.array(gen_image)
def leffa_predict_vt(src_image_path, ref_image_path):
return leffa_predict(src_image_path, ref_image_path, "virtual_tryon")
def leffa_predict_pt(src_image_path, ref_image_path):
return leffa_predict(src_image_path, ref_image_path, "pose_transfer")
# Gradio ์ธํ„ฐํŽ˜์ด์Šค
with gr.Blocks(theme=gr.themes.Default(primary_hue=gr.themes.colors.pink, secondary_hue=gr.themes.colors.red)) as demo:
gr.Markdown("# ๐ŸŽญ Fashion Studio & Virtual Try-on")
with gr.Tabs():
# ํŒจ์…˜ ์ƒ์„ฑ ํƒญ
with gr.Tab("Fashion Generation"):
with gr.Column():
mode = gr.Radio(
choices=["Generate Model", "Generate Clothes"],
label="Generation Mode",
value="Generate Model"
)
prompt = gr.TextArea(
label="Fashion Description (ํ•œ๊ธ€ ๋˜๋Š” ์˜์–ด)",
placeholder="ํŒจ์…˜ ๋ชจ๋ธ์ด๋‚˜ ์˜๋ฅ˜๋ฅผ ์„ค๋ช…ํ•˜์„ธ์š”..."
)
with gr.Row():
with gr.Column():
result = gr.Image(label="Generated Result")
generate_button = gr.Button("Generate Fashion")
with gr.Accordion("Advanced Options", open=False):
with gr.Group():
with gr.Row():
with gr.Column():
cfg_scale = gr.Slider(
label="CFG Scale",
minimum=1,
maximum=20,
step=0.5,
value=7.0
)
steps = gr.Slider(
label="Steps",
minimum=1,
maximum=50, # ์ตœ๋Œ€๊ฐ’ ๊ฐ์†Œ
step=1,
value=30
)
lora_scale = gr.Slider(
label="LoRA Scale",
minimum=0,
maximum=1,
step=0.01,
value=0.85
)
with gr.Row():
width = gr.Slider(
label="Width",
minimum=256,
maximum=1024, # ์ตœ๋Œ€๊ฐ’ ๊ฐ์†Œ
step=64,
value=512
)
height = gr.Slider(
label="Height",
minimum=256,
maximum=1024, # ์ตœ๋Œ€๊ฐ’ ๊ฐ์†Œ
step=64,
value=768
)
with gr.Row():
randomize_seed = gr.Checkbox(
True,
label="Randomize seed"
)
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=42
)
# ๊ฐ€์ƒ ํ”ผํŒ… ํƒญ
with gr.Tab("Virtual Try-on"):
with gr.Row():
with gr.Column():
gr.Markdown("#### Person Image")
vt_src_image = gr.Image(
sources=["upload"],
type="filepath",
label="Person Image",
width=512,
height=512,
)
gr.Examples(
inputs=vt_src_image,
examples_per_page=5,
examples=["./ckpts/examples/person1/01350_00.jpg",
"./ckpts/examples/person1/01376_00.jpg",
"./ckpts/examples/person1/01416_00.jpg",
"./ckpts/examples/person1/05976_00.jpg",
"./ckpts/examples/person1/06094_00.jpg"]
)
with gr.Column():
gr.Markdown("#### Garment Image")
vt_ref_image = gr.Image(
sources=["upload"],
type="filepath",
label="Garment Image",
width=512,
height=512,
)
gr.Examples(
inputs=vt_ref_image,
examples_per_page=5,
examples=["./ckpts/examples/garment/01449_00.jpg",
"./ckpts/examples/garment/01486_00.jpg",
"./ckpts/examples/garment/01853_00.jpg",
"./ckpts/examples/garment/02070_00.jpg",
"./ckpts/examples/garment/03553_00.jpg"]
)
with gr.Column():
gr.Markdown("#### Generated Image")
vt_gen_image = gr.Image(
label="Generated Image",
width=512,
height=512,
)
vt_gen_button = gr.Button("Try-on")
# ํฌ์ฆˆ ์ „์†ก ํƒญ
with gr.Tab("Pose Transfer"):
with gr.Row():
with gr.Column():
gr.Markdown("#### Person Image")
pt_ref_image = gr.Image(
sources=["upload"],
type="filepath",
label="Person Image",
width=512,
height=512,
)
gr.Examples(
inputs=pt_ref_image,
examples_per_page=5,
examples=["./ckpts/examples/person1/01350_00.jpg",
"./ckpts/examples/person1/01376_00.jpg",
"./ckpts/examples/person1/01416_00.jpg",
"./ckpts/examples/person1/05976_00.jpg",
"./ckpts/examples/person1/06094_00.jpg"]
)
with gr.Column():
gr.Markdown("#### Target Pose Person Image")
pt_src_image = gr.Image(
sources=["upload"],
type="filepath",
label="Target Pose Person Image",
width=512,
height=512,
)
gr.Examples(
inputs=pt_src_image,
examples_per_page=5,
examples=["./ckpts/examples/person2/01850_00.jpg",
"./ckpts/examples/person2/01875_00.jpg",
"./ckpts/examples/person2/02532_00.jpg",
"./ckpts/examples/person2/02902_00.jpg",
"./ckpts/examples/person2/05346_00.jpg"]
)
with gr.Column():
gr.Markdown("#### Generated Image")
pt_gen_image = gr.Image(
label="Generated Image",
width=512,
height=512,
)
pose_transfer_gen_button = gr.Button("Generate")
# ์ด๋ฒคํŠธ ํ•ธ๋“ค๋Ÿฌ
generate_button.click(
generate_fashion,
inputs=[prompt, mode, cfg_scale, steps, randomize_seed, seed, width, height, lora_scale],
outputs=[result, seed]
)
vt_gen_button.click(
fn=leffa_predict_vt,
inputs=[vt_src_image, vt_ref_image],
outputs=[vt_gen_image]
)
pose_transfer_gen_button.click(
fn=leffa_predict_pt,
inputs=[pt_src_image, pt_ref_image],
outputs=[pt_gen_image]
)
# ์•ฑ ์‹คํ–‰
demo.launch(share=True, server_port=7860)