Update app.py
Browse files
app.py
CHANGED
@@ -7,7 +7,10 @@ from leffa.inference import LeffaInference
|
|
7 |
from utils.garment_agnostic_mask_predictor import AutoMasker
|
8 |
from utils.densepose_predictor import DensePosePredictor
|
9 |
from utils.utils import resize_and_center
|
10 |
-
|
|
|
|
|
|
|
11 |
import gradio as gr
|
12 |
|
13 |
# Download checkpoints
|
@@ -35,7 +38,57 @@ pt_model = LeffaModel(
|
|
35 |
)
|
36 |
pt_inference = LeffaInference(model=pt_model)
|
37 |
|
38 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
def leffa_predict(src_image_path, ref_image_path, control_type):
|
40 |
assert control_type in [
|
41 |
"virtual_tryon", "pose_transfer"], "Invalid control type: {}".format(control_type)
|
@@ -92,25 +145,87 @@ def leffa_predict_pt(src_image_path, ref_image_path):
|
|
92 |
return leffa_predict(src_image_path, ref_image_path, "pose_transfer")
|
93 |
|
94 |
|
95 |
-
if __name__ == "__main__":
|
96 |
-
# import sys
|
97 |
-
|
98 |
-
# src_image_path = sys.argv[1]
|
99 |
-
# ref_image_path = sys.argv[2]
|
100 |
-
# control_type = sys.argv[3]
|
101 |
-
# leffa_predict(src_image_path, ref_image_path, control_type)
|
102 |
-
|
103 |
-
title = "## Leffa: Learning Flow Fields in Attention for Controllable Person Image Generation"
|
104 |
-
link = "[๐ Paper](https://arxiv.org/abs/2412.08486) - [๐ฅ Demo](https://huggingface.co/spaces/franciszzj/Leffa) - [๐ค Model](https://huggingface.co/franciszzj/Leffa)"
|
105 |
-
description = "Leffa is a unified framework for controllable person image generation that enables precise manipulation of both appearance (i.e., virtual try-on) and pose (i.e., pose transfer)."
|
106 |
-
note = "Note: The models used in the demo are trained solely on academic datasets. Virtual try-on uses VITON-HD, and pose transfer uses DeepFashion."
|
107 |
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
with gr.Tab("
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
114 |
with gr.Row():
|
115 |
with gr.Column():
|
116 |
gr.Markdown("#### Person Image")
|
@@ -121,15 +236,14 @@ if __name__ == "__main__":
|
|
121 |
width=512,
|
122 |
height=512,
|
123 |
)
|
124 |
-
|
125 |
gr.Examples(
|
126 |
inputs=vt_src_image,
|
127 |
examples_per_page=5,
|
128 |
examples=["./ckpts/examples/person1/01350_00.jpg",
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
)
|
134 |
|
135 |
with gr.Column():
|
@@ -141,15 +255,14 @@ if __name__ == "__main__":
|
|
141 |
width=512,
|
142 |
height=512,
|
143 |
)
|
144 |
-
|
145 |
gr.Examples(
|
146 |
inputs=vt_ref_image,
|
147 |
examples_per_page=5,
|
148 |
examples=["./ckpts/examples/garment/01449_00.jpg",
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
)
|
154 |
|
155 |
with gr.Column():
|
@@ -159,14 +272,10 @@ if __name__ == "__main__":
|
|
159 |
width=512,
|
160 |
height=512,
|
161 |
)
|
|
|
162 |
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
vt_gen_button.click(fn=leffa_predict_vt, inputs=[
|
167 |
-
vt_src_image, vt_ref_image], outputs=[vt_gen_image])
|
168 |
-
|
169 |
-
with gr.Tab("Control Pose (Pose Transfer)"):
|
170 |
with gr.Row():
|
171 |
with gr.Column():
|
172 |
gr.Markdown("#### Person Image")
|
@@ -177,15 +286,14 @@ if __name__ == "__main__":
|
|
177 |
width=512,
|
178 |
height=512,
|
179 |
)
|
180 |
-
|
181 |
gr.Examples(
|
182 |
inputs=pt_ref_image,
|
183 |
examples_per_page=5,
|
184 |
examples=["./ckpts/examples/person1/01350_00.jpg",
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
)
|
190 |
|
191 |
with gr.Column():
|
@@ -197,15 +305,14 @@ if __name__ == "__main__":
|
|
197 |
width=512,
|
198 |
height=512,
|
199 |
)
|
200 |
-
|
201 |
gr.Examples(
|
202 |
inputs=pt_src_image,
|
203 |
examples_per_page=5,
|
204 |
examples=["./ckpts/examples/person2/01850_00.jpg",
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
)
|
210 |
|
211 |
with gr.Column():
|
@@ -215,13 +322,27 @@ if __name__ == "__main__":
|
|
215 |
width=512,
|
216 |
height=512,
|
217 |
)
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
from utils.garment_agnostic_mask_predictor import AutoMasker
|
8 |
from utils.densepose_predictor import DensePosePredictor
|
9 |
from utils.utils import resize_and_center
|
10 |
+
import spaces
|
11 |
+
import torch
|
12 |
+
from diffusers import DiffusionPipeline
|
13 |
+
from transformers import pipeline
|
14 |
import gradio as gr
|
15 |
|
16 |
# Download checkpoints
|
|
|
38 |
)
|
39 |
pt_inference = LeffaInference(model=pt_model)
|
40 |
|
41 |
+
translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en")
|
42 |
+
base_model = "black-forest-labs/FLUX.1-dev"
|
43 |
+
model_lora_repo = "Motas/Flux_Fashion_Photography_Style"
|
44 |
+
clothes_lora_repo = "prithivMLmods/Canopus-Clothing-Flux-LoRA"
|
45 |
+
|
46 |
+
fashion_pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=torch.bfloat16)
|
47 |
+
fashion_pipe.to("cuda")
|
48 |
+
|
49 |
+
@spaces.GPU()
|
50 |
+
def generate_fashion(prompt, mode, cfg_scale, steps, randomize_seed, seed, width, height, lora_scale, progress=gr.Progress(track_tqdm=True)):
|
51 |
+
# ํ๊ธ ๊ฐ์ง ๋ฐ ๋ฒ์ญ
|
52 |
+
def contains_korean(text):
|
53 |
+
return any(ord('๊ฐ') <= ord(char) <= ord('ํฃ') for char in text)
|
54 |
+
|
55 |
+
if contains_korean(prompt):
|
56 |
+
translated = translator(prompt)[0]['translation_text']
|
57 |
+
actual_prompt = translated
|
58 |
+
else:
|
59 |
+
actual_prompt = prompt
|
60 |
+
|
61 |
+
# ๋ชจ๋์ ๋ฐ๋ฅธ LoRA ๋ฐ ํธ๋ฆฌ๊ฑฐ์๋ ์ค์
|
62 |
+
if mode == "Generate Model":
|
63 |
+
pipe.load_lora_weights(model_lora_repo)
|
64 |
+
trigger_word = "fashion photography, professional model"
|
65 |
+
else:
|
66 |
+
pipe.load_lora_weights(clothes_lora_repo)
|
67 |
+
trigger_word = "upper clothing, fashion item"
|
68 |
+
|
69 |
+
if randomize_seed:
|
70 |
+
seed = random.randint(0, MAX_SEED)
|
71 |
+
generator = torch.Generator(device="cuda").manual_seed(seed)
|
72 |
+
|
73 |
+
progress(0, "Starting fashion generation...")
|
74 |
+
|
75 |
+
for i in range(1, steps + 1):
|
76 |
+
if i % (steps // 10) == 0:
|
77 |
+
progress(i / steps * 100, f"Processing step {i} of {steps}...")
|
78 |
+
|
79 |
+
image = pipe(
|
80 |
+
prompt=f"{actual_prompt} {trigger_word}",
|
81 |
+
num_inference_steps=steps,
|
82 |
+
guidance_scale=cfg_scale,
|
83 |
+
width=width,
|
84 |
+
height=height,
|
85 |
+
generator=generator,
|
86 |
+
joint_attention_kwargs={"scale": lora_scale},
|
87 |
+
).images[0]
|
88 |
+
|
89 |
+
progress(100, "Completed!")
|
90 |
+
return image, seed
|
91 |
+
|
92 |
def leffa_predict(src_image_path, ref_image_path, control_type):
|
93 |
assert control_type in [
|
94 |
"virtual_tryon", "pose_transfer"], "Invalid control type: {}".format(control_type)
|
|
|
145 |
return leffa_predict(src_image_path, ref_image_path, "pose_transfer")
|
146 |
|
147 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
148 |
|
149 |
+
with gr.Blocks(theme=gr.themes.Default(primary_hue=gr.themes.colors.pink, secondary_hue=gr.themes.colors.red)) as demo:
|
150 |
+
gr.Markdown("# ๐ญ Fashion Studio & Virtual Try-on")
|
151 |
+
|
152 |
+
with gr.Tabs():
|
153 |
+
# ํจ์
์์ฑ ํญ
|
154 |
+
with gr.Tab("Fashion Generation"):
|
155 |
+
with gr.Column():
|
156 |
+
mode = gr.Radio(
|
157 |
+
choices=["Generate Model", "Generate Clothes"],
|
158 |
+
label="Generation Mode",
|
159 |
+
value="Generate Model"
|
160 |
+
)
|
161 |
+
|
162 |
+
prompt = gr.TextArea(
|
163 |
+
label="Fashion Description (ํ๊ธ ๋๋ ์์ด)",
|
164 |
+
placeholder="ํจ์
๋ชจ๋ธ์ด๋ ์๋ฅ๋ฅผ ์ค๋ช
ํ์ธ์..."
|
165 |
+
)
|
166 |
+
|
167 |
+
with gr.Row():
|
168 |
+
with gr.Column():
|
169 |
+
result = gr.Image(label="Generated Result")
|
170 |
+
generate_button = gr.Button("Generate Fashion")
|
171 |
+
|
172 |
+
with gr.Accordion("Advanced Options", open=False):
|
173 |
+
with gr.Group():
|
174 |
+
with gr.Row():
|
175 |
+
with gr.Column():
|
176 |
+
cfg_scale = gr.Slider(
|
177 |
+
label="CFG Scale",
|
178 |
+
minimum=1,
|
179 |
+
maximum=20,
|
180 |
+
step=0.5,
|
181 |
+
value=7.0
|
182 |
+
)
|
183 |
+
steps = gr.Slider(
|
184 |
+
label="Steps",
|
185 |
+
minimum=1,
|
186 |
+
maximum=100,
|
187 |
+
step=1,
|
188 |
+
value=30
|
189 |
+
)
|
190 |
+
lora_scale = gr.Slider(
|
191 |
+
label="LoRA Scale",
|
192 |
+
minimum=0,
|
193 |
+
maximum=1,
|
194 |
+
step=0.01,
|
195 |
+
value=0.85
|
196 |
+
)
|
197 |
+
|
198 |
+
with gr.Row():
|
199 |
+
width = gr.Slider(
|
200 |
+
label="Width",
|
201 |
+
minimum=256,
|
202 |
+
maximum=1536,
|
203 |
+
step=64,
|
204 |
+
value=512
|
205 |
+
)
|
206 |
+
height = gr.Slider(
|
207 |
+
label="Height",
|
208 |
+
minimum=256,
|
209 |
+
maximum=1536,
|
210 |
+
step=64,
|
211 |
+
value=768
|
212 |
+
)
|
213 |
+
|
214 |
+
with gr.Row():
|
215 |
+
randomize_seed = gr.Checkbox(
|
216 |
+
True,
|
217 |
+
label="Randomize seed"
|
218 |
+
)
|
219 |
+
seed = gr.Slider(
|
220 |
+
label="Seed",
|
221 |
+
minimum=0,
|
222 |
+
maximum=MAX_SEED,
|
223 |
+
step=1,
|
224 |
+
value=42
|
225 |
+
)
|
226 |
+
|
227 |
+
# ๊ฐ์ ํผํ
ํญ
|
228 |
+
with gr.Tab("Virtual Try-on"):
|
229 |
with gr.Row():
|
230 |
with gr.Column():
|
231 |
gr.Markdown("#### Person Image")
|
|
|
236 |
width=512,
|
237 |
height=512,
|
238 |
)
|
|
|
239 |
gr.Examples(
|
240 |
inputs=vt_src_image,
|
241 |
examples_per_page=5,
|
242 |
examples=["./ckpts/examples/person1/01350_00.jpg",
|
243 |
+
"./ckpts/examples/person1/01376_00.jpg",
|
244 |
+
"./ckpts/examples/person1/01416_00.jpg",
|
245 |
+
"./ckpts/examples/person1/05976_00.jpg",
|
246 |
+
"./ckpts/examples/person1/06094_00.jpg"]
|
247 |
)
|
248 |
|
249 |
with gr.Column():
|
|
|
255 |
width=512,
|
256 |
height=512,
|
257 |
)
|
|
|
258 |
gr.Examples(
|
259 |
inputs=vt_ref_image,
|
260 |
examples_per_page=5,
|
261 |
examples=["./ckpts/examples/garment/01449_00.jpg",
|
262 |
+
"./ckpts/examples/garment/01486_00.jpg",
|
263 |
+
"./ckpts/examples/garment/01853_00.jpg",
|
264 |
+
"./ckpts/examples/garment/02070_00.jpg",
|
265 |
+
"./ckpts/examples/garment/03553_00.jpg"]
|
266 |
)
|
267 |
|
268 |
with gr.Column():
|
|
|
272 |
width=512,
|
273 |
height=512,
|
274 |
)
|
275 |
+
vt_gen_button = gr.Button("Try-on")
|
276 |
|
277 |
+
# ํฌ์ฆ ์ ์ก ํญ
|
278 |
+
with gr.Tab("Pose Transfer"):
|
|
|
|
|
|
|
|
|
|
|
279 |
with gr.Row():
|
280 |
with gr.Column():
|
281 |
gr.Markdown("#### Person Image")
|
|
|
286 |
width=512,
|
287 |
height=512,
|
288 |
)
|
|
|
289 |
gr.Examples(
|
290 |
inputs=pt_ref_image,
|
291 |
examples_per_page=5,
|
292 |
examples=["./ckpts/examples/person1/01350_00.jpg",
|
293 |
+
"./ckpts/examples/person1/01376_00.jpg",
|
294 |
+
"./ckpts/examples/person1/01416_00.jpg",
|
295 |
+
"./ckpts/examples/person1/05976_00.jpg",
|
296 |
+
"./ckpts/examples/person1/06094_00.jpg"]
|
297 |
)
|
298 |
|
299 |
with gr.Column():
|
|
|
305 |
width=512,
|
306 |
height=512,
|
307 |
)
|
|
|
308 |
gr.Examples(
|
309 |
inputs=pt_src_image,
|
310 |
examples_per_page=5,
|
311 |
examples=["./ckpts/examples/person2/01850_00.jpg",
|
312 |
+
"./ckpts/examples/person2/01875_00.jpg",
|
313 |
+
"./ckpts/examples/person2/02532_00.jpg",
|
314 |
+
"./ckpts/examples/person2/02902_00.jpg",
|
315 |
+
"./ckpts/examples/person2/05346_00.jpg"]
|
316 |
)
|
317 |
|
318 |
with gr.Column():
|
|
|
322 |
width=512,
|
323 |
height=512,
|
324 |
)
|
325 |
+
pose_transfer_gen_button = gr.Button("Generate")
|
326 |
+
|
327 |
+
gr.Markdown(note)
|
328 |
+
|
329 |
+
# ์ด๋ฒคํธ ํธ๋ค๋ฌ
|
330 |
+
generate_button.click(
|
331 |
+
generate_fashion,
|
332 |
+
inputs=[prompt, mode, cfg_scale, steps, randomize_seed, seed, width, height, lora_scale],
|
333 |
+
outputs=[result, seed]
|
334 |
+
)
|
335 |
+
|
336 |
+
vt_gen_button.click(
|
337 |
+
fn=leffa_predict_vt,
|
338 |
+
inputs=[vt_src_image, vt_ref_image],
|
339 |
+
outputs=[vt_gen_image]
|
340 |
+
)
|
341 |
+
|
342 |
+
pose_transfer_gen_button.click(
|
343 |
+
fn=leffa_predict_pt,
|
344 |
+
inputs=[pt_src_image, pt_ref_image],
|
345 |
+
outputs=[pt_gen_image]
|
346 |
+
)
|
347 |
+
|
348 |
+
demo.launch(share=True, server_port=7860)
|