Spaces:
ginipick
/
Running on Zero

FitGen / app.py
ginipick's picture
Update app.py
67d988e verified
raw
history blame
19.9 kB
import numpy as np
from PIL import Image
from huggingface_hub import snapshot_download, login
from leffa.transform import LeffaTransform
from leffa.model import LeffaModel
from leffa.inference import LeffaInference
from utils.garment_agnostic_mask_predictor import AutoMasker
from utils.densepose_predictor import DensePosePredictor
from utils.utils import resize_and_center
import spaces
import torch
from diffusers import DiffusionPipeline
from transformers import pipeline
import gradio as gr
import os
import random
import gc
from contextlib import contextmanager
# ์ƒ์ˆ˜ ์ •์˜
MAX_SEED = 2**32 - 1
BASE_MODEL = "black-forest-labs/FLUX.1-dev"
MODEL_LORA_REPO = "Motas/Flux_Fashion_Photography_Style"
CLOTHES_LORA_REPO = "prithivMLmods/Canopus-Clothing-Flux-LoRA"
# ๋ฉ”๋ชจ๋ฆฌ ๊ด€๋ฆฌ๋ฅผ ์œ„ํ•œ ๋ฐ์ฝ”๋ ˆ์ดํ„ฐ
def safe_model_call(func):
def wrapper(*args, **kwargs):
try:
clear_memory()
result = func(*args, **kwargs)
clear_memory()
return result
except Exception as e:
clear_memory()
print(f"Error in {func.__name__}: {str(e)}")
raise
return wrapper
# ๋ฉ”๋ชจ๋ฆฌ ๊ด€๋ฆฌ๋ฅผ ์œ„ํ•œ ์ปจํ…์ŠคํŠธ ๋งค๋‹ˆ์ €
@contextmanager
def torch_gc():
try:
yield
finally:
gc.collect()
if torch.cuda.is_available() and torch.cuda.current_device() >= 0:
with torch.cuda.device('cuda'):
torch.cuda.empty_cache()
def clear_memory():
gc.collect()
def setup_environment():
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128'
HF_TOKEN = os.getenv("HF_TOKEN")
if not HF_TOKEN:
raise ValueError("HF_TOKEN not found in environment variables")
login(token=HF_TOKEN)
return HF_TOKEN
def contains_korean(text):
return any(ord('๊ฐ€') <= ord(char) <= ord('ํžฃ') for char in text)
@spaces.GPU()
def get_translator():
return pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en", device="cuda")
# ํ™˜๊ฒฝ ์„ค์ • ์‹คํ–‰
setup_environment()
@spaces.GPU()
def initialize_fashion_pipe():
with torch_gc():
pipe = DiffusionPipeline.from_pretrained(
BASE_MODEL,
torch_dtype=torch.float16,
)
return pipe.to("cuda")
def setup():
# Leffa ์ฒดํฌํฌ์ธํŠธ ๋‹ค์šด๋กœ๋“œ๋งŒ ์ˆ˜ํ–‰
snapshot_download(repo_id="franciszzj/Leffa", local_dir="./ckpts")
@spaces.GPU()
def get_translator():
with torch_gc():
return pipeline("translation",
model="Helsinki-NLP/opus-mt-ko-en",
device="cuda")
@safe_model_call
def get_mask_predictor():
global mask_predictor
if mask_predictor is None:
mask_predictor = AutoMasker(
densepose_path="./ckpts/densepose",
schp_path="./ckpts/schp",
)
return mask_predictor
@safe_model_call
def get_densepose_predictor():
global densepose_predictor
if densepose_predictor is None:
densepose_predictor = DensePosePredictor(
config_path="./ckpts/densepose/densepose_rcnn_R_50_FPN_s1x.yaml",
weights_path="./ckpts/densepose/model_final_162be9.pkl",
)
return densepose_predictor
@spaces.GPU()
def get_vt_model():
with torch_gc():
model = LeffaModel(
pretrained_model_name_or_path="./ckpts/stable-diffusion-inpainting",
pretrained_model="./ckpts/virtual_tryon.pth"
)
model = model.half()
return model.to("cuda"), LeffaInference(model=model)
def load_lora(pipe, lora_path):
try:
pipe.unload_lora_weights()
except:
pass
try:
pipe.load_lora_weights(lora_path)
return pipe
except Exception as e:
print(f"Warning: Failed to load LoRA weights from {lora_path}: {e}")
return pipe
@spaces.GPU()
def get_mask_predictor():
global mask_predictor
if mask_predictor is None:
mask_predictor = AutoMasker(
densepose_path="./ckpts/densepose",
schp_path="./ckpts/schp",
)
return mask_predictor
# ๋ชจ๋ธ ์ดˆ๊ธฐํ™” ํ•จ์ˆ˜ ์ˆ˜์ •
@spaces.GPU()
def initialize_fashion_pipe():
try:
pipe = DiffusionPipeline.from_pretrained(
BASE_MODEL,
torch_dtype=torch.float16,
safety_checker=None,
requires_safety_checker=False
).to("cuda")
pipe.enable_model_cpu_offload()
return pipe
except Exception as e:
print(f"Error initializing fashion pipe: {e}")
raise
@spaces.GPU()
def generate_fashion(prompt, mode, cfg_scale, steps, randomize_seed, seed, width, height, lora_scale, progress=gr.Progress(track_tqdm=True)):
try:
# ํ•œ๊ธ€ ์ฒ˜๋ฆฌ
if contains_korean(prompt):
with torch.inference_mode():
translator = get_translator()
translated = translator(prompt)[0]['translation_text']
actual_prompt = translated
else:
actual_prompt = prompt
# ํŒŒ์ดํ”„๋ผ์ธ ์ดˆ๊ธฐํ™”
pipe = initialize_fashion_pipe()
# LoRA ์„ค์ •
if mode == "Generate Model":
pipe.load_lora_weights(MODEL_LORA_REPO)
trigger_word = "fashion photography, professional model"
else:
pipe.load_lora_weights(CLOTHES_LORA_REPO)
trigger_word = "upper clothing, fashion item"
# ํŒŒ๋ผ๋ฏธํ„ฐ ์ œํ•œ
width = min(width, 768)
height = min(height, 768)
steps = min(steps, 30)
# ์‹œ๋“œ ์„ค์ •
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator("cuda").manual_seed(seed)
# ์ด๋ฏธ์ง€ ์ƒ์„ฑ
with torch.inference_mode():
output = pipe(
prompt=f"{actual_prompt} {trigger_word}",
num_inference_steps=steps,
guidance_scale=cfg_scale,
width=width,
height=height,
generator=generator,
cross_attention_kwargs={"scale": lora_scale},
)
image = output.images[0]
# ๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ
del pipe
torch.cuda.empty_cache()
gc.collect()
return image, seed
except Exception as e:
print(f"Error in generate_fashion: {str(e)}")
raise gr.Error(f"Generation failed: {str(e)}")
class ModelManager:
def __init__(self):
self.mask_predictor = None
self.densepose_predictor = None
self.translator = None
@spaces.GPU()
def get_mask_predictor(self):
if self.mask_predictor is None:
self.mask_predictor = AutoMasker(
densepose_path="./ckpts/densepose",
schp_path="./ckpts/schp",
)
return self.mask_predictor
@spaces.GPU()
def get_densepose_predictor(self):
if self.densepose_predictor is None:
self.densepose_predictor = DensePosePredictor(
config_path="./ckpts/densepose/densepose_rcnn_R_50_FPN_s1x.yaml",
weights_path="./ckpts/densepose/model_final_162be9.pkl",
)
return self.densepose_predictor
@spaces.GPU()
def get_translator(self):
if self.translator is None:
self.translator = pipeline("translation",
model="Helsinki-NLP/opus-mt-ko-en",
device="cuda")
return self.translator
# ๋ชจ๋ธ ๋งค๋‹ˆ์ € ์ธ์Šคํ„ด์Šค ์ƒ์„ฑ
model_manager = ModelManager()
@spaces.GPU()
def leffa_predict(src_image_path, ref_image_path, control_type):
try:
with torch_gc():
# ๋ชจ๋ธ ์ดˆ๊ธฐํ™”
model, inference = get_vt_model()
# ์ด๋ฏธ์ง€ ์ฒ˜๋ฆฌ
src_image = Image.open(src_image_path)
ref_image = Image.open(ref_image_path)
src_image = resize_and_center(src_image, 768, 1024)
ref_image = resize_and_center(ref_image, 768, 1024)
src_image_array = np.array(src_image)
ref_image_array = np.array(ref_image)
# Mask ๋ฐ DensePose ์ฒ˜๋ฆฌ
with torch.inference_mode():
src_image = src_image.convert("RGB")
mask_pred = model_manager.get_mask_predictor()
mask = mask_pred(src_image, "upper")["mask"]
dense_pred = model_manager.get_densepose_predictor()
src_image_seg_array = dense_pred.predict_seg(src_image_array)
densepose = Image.fromarray(src_image_seg_array)
# Leffa ๋ณ€ํ™˜ ๋ฐ ์ถ”๋ก 
transform = LeffaTransform()
data = {
"src_image": [src_image],
"ref_image": [ref_image],
"mask": [mask],
"densepose": [densepose],
}
data = transform(data)
with torch.inference_mode():
output = inference(data)
# ๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ
del model
del inference
torch.cuda.empty_cache()
gc.collect()
return np.array(output["generated_image"][0])
except Exception as e:
print(f"Error in leffa_predict: {str(e)}")
raise
@spaces.GPU()
def leffa_predict_vt(src_image_path, ref_image_path):
try:
return leffa_predict(src_image_path, ref_image_path, "virtual_tryon")
except Exception as e:
print(f"Error in leffa_predict_vt: {str(e)}")
raise
@spaces.GPU()
def generate_image(prompt, mode, cfg_scale=7.0, steps=30, seed=None, width=512, height=768, lora_scale=0.85):
try:
with torch_gc():
# ํ•œ๊ธ€ ์ฒ˜๋ฆฌ
if contains_korean(prompt):
translator = model_manager.get_translator()
with torch.inference_mode():
translated = translator(prompt)[0]['translation_text']
actual_prompt = translated
else:
actual_prompt = prompt
# ํŒŒ์ดํ”„๋ผ์ธ ์ดˆ๊ธฐํ™”
pipe = DiffusionPipeline.from_pretrained(
BASE_MODEL,
torch_dtype=torch.float16,
)
pipe = pipe.to("cuda")
# LoRA ์„ค์ •
if mode == "Generate Model":
pipe.load_lora_weights(MODEL_LORA_REPO)
trigger_word = "fashion photography, professional model"
else:
pipe.load_lora_weights(CLOTHES_LORA_REPO)
trigger_word = "upper clothing, fashion item"
# ์ด๋ฏธ์ง€ ์ƒ์„ฑ
with torch.inference_mode():
result = pipe(
prompt=f"{actual_prompt} {trigger_word}",
num_inference_steps=steps,
guidance_scale=cfg_scale,
width=width,
height=height,
generator=torch.Generator("cuda").manual_seed(
seed if seed is not None else torch.randint(0, 2**32 - 1, (1,)).item()
),
joint_attention_kwargs={"scale": lora_scale},
).images[0]
# ๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ
del pipe
return result, seed
except Exception as e:
raise gr.Error(f"Generation failed: {str(e)}")
# ์ดˆ๊ธฐ ์„ค์ • ์‹คํ–‰
setup()
def create_interface():
with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange") as demo:
gr.Markdown("# ๐ŸŽญ FitGen:Fashion Studio & Virtual Try-on")
with gr.Tabs():
# ํŒจ์…˜ ์ƒ์„ฑ ํƒญ
with gr.Tab("Fashion Generation"):
with gr.Column():
mode = gr.Radio(
choices=["Generate Model", "Generate Clothes"],
label="Generation Mode",
value="Generate Model"
)
# ์˜ˆ์ œ ํ”„๋กฌํ”„ํŠธ ์„ค์ •
example_model_prompts = [
"professional fashion model, full body shot, standing pose, natural lighting, studio background, high fashion, elegant pose",
"fashion model portrait, upper body, confident pose, fashion photography, neutral background, professional lighting",
"stylish fashion model, three-quarter view, editorial pose, high-end fashion magazine style, minimal background"
]
example_clothes_prompts = [
"luxury designer sweater, cashmere material, cream color, cable knit pattern, high-end fashion, product photography",
"elegant business blazer, tailored fit, charcoal grey, premium wool fabric, professional wear",
"modern streetwear hoodie, oversized fit, minimalist design, premium cotton, urban style"
]
prompt = gr.TextArea(
label="Fashion Description (ํ•œ๊ธ€ ๋˜๋Š” ์˜์–ด)",
placeholder="ํŒจ์…˜ ๋ชจ๋ธ์ด๋‚˜ ์˜๋ฅ˜๋ฅผ ์„ค๋ช…ํ•˜์„ธ์š”..."
)
# ์˜ˆ์ œ ์„น์…˜ ์ถ”๊ฐ€
gr.Examples(
examples=example_model_prompts + example_clothes_prompts,
inputs=prompt,
label="Example Prompts"
)
with gr.Row():
with gr.Column():
result = gr.Image(label="Generated Result")
generate_button = gr.Button("Generate Fashion")
with gr.Accordion("Advanced Options", open=False):
with gr.Group():
with gr.Row():
with gr.Column():
cfg_scale = gr.Slider(
label="CFG Scale",
minimum=1,
maximum=20,
step=0.5,
value=7.0
)
steps = gr.Slider(
label="Steps",
minimum=1,
maximum=30,
step=1,
value=30
)
lora_scale = gr.Slider(
label="LoRA Scale",
minimum=0,
maximum=1,
step=0.01,
value=0.85
)
with gr.Row():
width = gr.Slider(
label="Width",
minimum=256,
maximum=768,
step=64,
value=512
)
height = gr.Slider(
label="Height",
minimum=256,
maximum=768,
step=64,
value=768
)
with gr.Row():
randomize_seed = gr.Checkbox(
True,
label="Randomize seed"
)
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=2**32-1,
step=1,
value=42
)
# ๊ฐ€์ƒ ํ”ผํŒ… ํƒญ
with gr.Tab("Virtual Try-on"):
with gr.Row():
with gr.Column():
gr.Markdown("#### Person Image")
vt_src_image = gr.Image(
sources=["upload"],
type="filepath",
label="Person Image",
width=512,
height=512,
)
gr.Examples(
inputs=vt_src_image,
examples_per_page=5,
examples=["a1.webp",
"a2.webp",
"a3.webp",
"a4.webp",
"a5.webp"]
)
with gr.Column():
gr.Markdown("#### Garment Image")
vt_ref_image = gr.Image(
sources=["upload"],
type="filepath",
label="Garment Image",
width=512,
height=512,
)
gr.Examples(
inputs=vt_ref_image,
examples_per_page=5,
examples=["b1.webp",
"b2.webp",
"b3.webp",
"b4.webp",
"c1.png",
"c2.png",
"c3.png",
"c4.png",
"c5.png",
"c6.png",
"c7.png",
"c8.png",
"c9.png",
"c10.png",
"c11.png",
"c12.png",
"c13.png",
"c14.png",
"c15.png",
"c16.png",
"b5.webp"]
)
with gr.Column():
gr.Markdown("#### Generated Image")
vt_gen_image = gr.Image(
label="Generated Image",
width=512,
height=512,
)
vt_gen_button = gr.Button("Try-on")
vt_gen_button.click(
fn=leffa_predict_vt,
inputs=[vt_src_image, vt_ref_image],
outputs=[vt_gen_image]
)
generate_button.click(
fn=generate_image,
inputs=[prompt, mode, cfg_scale, steps, seed, width, height, lora_scale],
outputs=[result, seed]
).success(
fn=lambda: gc.collect(), # ์„ฑ๊ณต ํ›„ ๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ
inputs=None,
outputs=None
)
return demo
if __name__ == "__main__":
setup_environment()
demo = create_interface()
demo.queue()
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False
)