Spaces:
ginipick
/
Running on Zero

ginipick commited on
Commit
7778e32
ยท
verified ยท
1 Parent(s): 39aa1d7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +178 -57
app.py CHANGED
@@ -7,7 +7,10 @@ from leffa.inference import LeffaInference
7
  from utils.garment_agnostic_mask_predictor import AutoMasker
8
  from utils.densepose_predictor import DensePosePredictor
9
  from utils.utils import resize_and_center
10
-
 
 
 
11
  import gradio as gr
12
 
13
  # Download checkpoints
@@ -35,7 +38,57 @@ pt_model = LeffaModel(
35
  )
36
  pt_inference = LeffaInference(model=pt_model)
37
 
38
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  def leffa_predict(src_image_path, ref_image_path, control_type):
40
  assert control_type in [
41
  "virtual_tryon", "pose_transfer"], "Invalid control type: {}".format(control_type)
@@ -92,25 +145,87 @@ def leffa_predict_pt(src_image_path, ref_image_path):
92
  return leffa_predict(src_image_path, ref_image_path, "pose_transfer")
93
 
94
 
95
- if __name__ == "__main__":
96
- # import sys
97
-
98
- # src_image_path = sys.argv[1]
99
- # ref_image_path = sys.argv[2]
100
- # control_type = sys.argv[3]
101
- # leffa_predict(src_image_path, ref_image_path, control_type)
102
-
103
- title = "## Leffa: Learning Flow Fields in Attention for Controllable Person Image Generation"
104
- link = "[๐Ÿ“š Paper](https://arxiv.org/abs/2412.08486) - [๐Ÿ”ฅ Demo](https://huggingface.co/spaces/franciszzj/Leffa) - [๐Ÿค— Model](https://huggingface.co/franciszzj/Leffa)"
105
- description = "Leffa is a unified framework for controllable person image generation that enables precise manipulation of both appearance (i.e., virtual try-on) and pose (i.e., pose transfer)."
106
- note = "Note: The models used in the demo are trained solely on academic datasets. Virtual try-on uses VITON-HD, and pose transfer uses DeepFashion."
107
 
108
- with gr.Blocks(theme=gr.themes.Default(primary_hue=gr.themes.colors.pink, secondary_hue=gr.themes.colors.red)).queue() as demo:
109
- gr.Markdown(title)
110
- gr.Markdown(link)
111
- gr.Markdown(description)
112
-
113
- with gr.Tab("Control Appearance (Virtual Try-on)"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
  with gr.Row():
115
  with gr.Column():
116
  gr.Markdown("#### Person Image")
@@ -121,15 +236,14 @@ if __name__ == "__main__":
121
  width=512,
122
  height=512,
123
  )
124
-
125
  gr.Examples(
126
  inputs=vt_src_image,
127
  examples_per_page=5,
128
  examples=["./ckpts/examples/person1/01350_00.jpg",
129
- "./ckpts/examples/person1/01376_00.jpg",
130
- "./ckpts/examples/person1/01416_00.jpg",
131
- "./ckpts/examples/person1/05976_00.jpg",
132
- "./ckpts/examples/person1/06094_00.jpg",],
133
  )
134
 
135
  with gr.Column():
@@ -141,15 +255,14 @@ if __name__ == "__main__":
141
  width=512,
142
  height=512,
143
  )
144
-
145
  gr.Examples(
146
  inputs=vt_ref_image,
147
  examples_per_page=5,
148
  examples=["./ckpts/examples/garment/01449_00.jpg",
149
- "./ckpts/examples/garment/01486_00.jpg",
150
- "./ckpts/examples/garment/01853_00.jpg",
151
- "./ckpts/examples/garment/02070_00.jpg",
152
- "./ckpts/examples/garment/03553_00.jpg",],
153
  )
154
 
155
  with gr.Column():
@@ -159,14 +272,10 @@ if __name__ == "__main__":
159
  width=512,
160
  height=512,
161
  )
 
162
 
163
- with gr.Row():
164
- vt_gen_button = gr.Button("Generate")
165
-
166
- vt_gen_button.click(fn=leffa_predict_vt, inputs=[
167
- vt_src_image, vt_ref_image], outputs=[vt_gen_image])
168
-
169
- with gr.Tab("Control Pose (Pose Transfer)"):
170
  with gr.Row():
171
  with gr.Column():
172
  gr.Markdown("#### Person Image")
@@ -177,15 +286,14 @@ if __name__ == "__main__":
177
  width=512,
178
  height=512,
179
  )
180
-
181
  gr.Examples(
182
  inputs=pt_ref_image,
183
  examples_per_page=5,
184
  examples=["./ckpts/examples/person1/01350_00.jpg",
185
- "./ckpts/examples/person1/01376_00.jpg",
186
- "./ckpts/examples/person1/01416_00.jpg",
187
- "./ckpts/examples/person1/05976_00.jpg",
188
- "./ckpts/examples/person1/06094_00.jpg",],
189
  )
190
 
191
  with gr.Column():
@@ -197,15 +305,14 @@ if __name__ == "__main__":
197
  width=512,
198
  height=512,
199
  )
200
-
201
  gr.Examples(
202
  inputs=pt_src_image,
203
  examples_per_page=5,
204
  examples=["./ckpts/examples/person2/01850_00.jpg",
205
- "./ckpts/examples/person2/01875_00.jpg",
206
- "./ckpts/examples/person2/02532_00.jpg",
207
- "./ckpts/examples/person2/02902_00.jpg",
208
- "./ckpts/examples/person2/05346_00.jpg",],
209
  )
210
 
211
  with gr.Column():
@@ -215,13 +322,27 @@ if __name__ == "__main__":
215
  width=512,
216
  height=512,
217
  )
218
-
219
- with gr.Row():
220
- pose_transfer_gen_button = gr.Button("Generate")
221
-
222
- pose_transfer_gen_button.click(fn=leffa_predict_pt, inputs=[
223
- pt_src_image, pt_ref_image], outputs=[pt_gen_image])
224
-
225
- gr.Markdown(note)
226
-
227
- demo.launch(share=True, server_port=7860)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  from utils.garment_agnostic_mask_predictor import AutoMasker
8
  from utils.densepose_predictor import DensePosePredictor
9
  from utils.utils import resize_and_center
10
+ import spaces
11
+ import torch
12
+ from diffusers import DiffusionPipeline
13
+ from transformers import pipeline
14
  import gradio as gr
15
 
16
  # Download checkpoints
 
38
  )
39
  pt_inference = LeffaInference(model=pt_model)
40
 
41
+ translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en")
42
+ base_model = "black-forest-labs/FLUX.1-dev"
43
+ model_lora_repo = "Motas/Flux_Fashion_Photography_Style"
44
+ clothes_lora_repo = "prithivMLmods/Canopus-Clothing-Flux-LoRA"
45
+
46
+ fashion_pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=torch.bfloat16)
47
+ fashion_pipe.to("cuda")
48
+
49
+ @spaces.GPU()
50
+ def generate_fashion(prompt, mode, cfg_scale, steps, randomize_seed, seed, width, height, lora_scale, progress=gr.Progress(track_tqdm=True)):
51
+ # ํ•œ๊ธ€ ๊ฐ์ง€ ๋ฐ ๋ฒˆ์—ญ
52
+ def contains_korean(text):
53
+ return any(ord('๊ฐ€') <= ord(char) <= ord('ํžฃ') for char in text)
54
+
55
+ if contains_korean(prompt):
56
+ translated = translator(prompt)[0]['translation_text']
57
+ actual_prompt = translated
58
+ else:
59
+ actual_prompt = prompt
60
+
61
+ # ๋ชจ๋“œ์— ๋”ฐ๋ฅธ LoRA ๋ฐ ํŠธ๋ฆฌ๊ฑฐ์›Œ๋“œ ์„ค์ •
62
+ if mode == "Generate Model":
63
+ pipe.load_lora_weights(model_lora_repo)
64
+ trigger_word = "fashion photography, professional model"
65
+ else:
66
+ pipe.load_lora_weights(clothes_lora_repo)
67
+ trigger_word = "upper clothing, fashion item"
68
+
69
+ if randomize_seed:
70
+ seed = random.randint(0, MAX_SEED)
71
+ generator = torch.Generator(device="cuda").manual_seed(seed)
72
+
73
+ progress(0, "Starting fashion generation...")
74
+
75
+ for i in range(1, steps + 1):
76
+ if i % (steps // 10) == 0:
77
+ progress(i / steps * 100, f"Processing step {i} of {steps}...")
78
+
79
+ image = pipe(
80
+ prompt=f"{actual_prompt} {trigger_word}",
81
+ num_inference_steps=steps,
82
+ guidance_scale=cfg_scale,
83
+ width=width,
84
+ height=height,
85
+ generator=generator,
86
+ joint_attention_kwargs={"scale": lora_scale},
87
+ ).images[0]
88
+
89
+ progress(100, "Completed!")
90
+ return image, seed
91
+
92
  def leffa_predict(src_image_path, ref_image_path, control_type):
93
  assert control_type in [
94
  "virtual_tryon", "pose_transfer"], "Invalid control type: {}".format(control_type)
 
145
  return leffa_predict(src_image_path, ref_image_path, "pose_transfer")
146
 
147
 
 
 
 
 
 
 
 
 
 
 
 
 
148
 
149
+ with gr.Blocks(theme=gr.themes.Default(primary_hue=gr.themes.colors.pink, secondary_hue=gr.themes.colors.red)) as demo:
150
+ gr.Markdown("# ๐ŸŽญ Fashion Studio & Virtual Try-on")
151
+
152
+ with gr.Tabs():
153
+ # ํŒจ์…˜ ์ƒ์„ฑ ํƒญ
154
+ with gr.Tab("Fashion Generation"):
155
+ with gr.Column():
156
+ mode = gr.Radio(
157
+ choices=["Generate Model", "Generate Clothes"],
158
+ label="Generation Mode",
159
+ value="Generate Model"
160
+ )
161
+
162
+ prompt = gr.TextArea(
163
+ label="Fashion Description (ํ•œ๊ธ€ ๋˜๋Š” ์˜์–ด)",
164
+ placeholder="ํŒจ์…˜ ๋ชจ๋ธ์ด๋‚˜ ์˜๋ฅ˜๋ฅผ ์„ค๋ช…ํ•˜์„ธ์š”..."
165
+ )
166
+
167
+ with gr.Row():
168
+ with gr.Column():
169
+ result = gr.Image(label="Generated Result")
170
+ generate_button = gr.Button("Generate Fashion")
171
+
172
+ with gr.Accordion("Advanced Options", open=False):
173
+ with gr.Group():
174
+ with gr.Row():
175
+ with gr.Column():
176
+ cfg_scale = gr.Slider(
177
+ label="CFG Scale",
178
+ minimum=1,
179
+ maximum=20,
180
+ step=0.5,
181
+ value=7.0
182
+ )
183
+ steps = gr.Slider(
184
+ label="Steps",
185
+ minimum=1,
186
+ maximum=100,
187
+ step=1,
188
+ value=30
189
+ )
190
+ lora_scale = gr.Slider(
191
+ label="LoRA Scale",
192
+ minimum=0,
193
+ maximum=1,
194
+ step=0.01,
195
+ value=0.85
196
+ )
197
+
198
+ with gr.Row():
199
+ width = gr.Slider(
200
+ label="Width",
201
+ minimum=256,
202
+ maximum=1536,
203
+ step=64,
204
+ value=512
205
+ )
206
+ height = gr.Slider(
207
+ label="Height",
208
+ minimum=256,
209
+ maximum=1536,
210
+ step=64,
211
+ value=768
212
+ )
213
+
214
+ with gr.Row():
215
+ randomize_seed = gr.Checkbox(
216
+ True,
217
+ label="Randomize seed"
218
+ )
219
+ seed = gr.Slider(
220
+ label="Seed",
221
+ minimum=0,
222
+ maximum=MAX_SEED,
223
+ step=1,
224
+ value=42
225
+ )
226
+
227
+ # ๊ฐ€์ƒ ํ”ผํŒ… ํƒญ
228
+ with gr.Tab("Virtual Try-on"):
229
  with gr.Row():
230
  with gr.Column():
231
  gr.Markdown("#### Person Image")
 
236
  width=512,
237
  height=512,
238
  )
 
239
  gr.Examples(
240
  inputs=vt_src_image,
241
  examples_per_page=5,
242
  examples=["./ckpts/examples/person1/01350_00.jpg",
243
+ "./ckpts/examples/person1/01376_00.jpg",
244
+ "./ckpts/examples/person1/01416_00.jpg",
245
+ "./ckpts/examples/person1/05976_00.jpg",
246
+ "./ckpts/examples/person1/06094_00.jpg"]
247
  )
248
 
249
  with gr.Column():
 
255
  width=512,
256
  height=512,
257
  )
 
258
  gr.Examples(
259
  inputs=vt_ref_image,
260
  examples_per_page=5,
261
  examples=["./ckpts/examples/garment/01449_00.jpg",
262
+ "./ckpts/examples/garment/01486_00.jpg",
263
+ "./ckpts/examples/garment/01853_00.jpg",
264
+ "./ckpts/examples/garment/02070_00.jpg",
265
+ "./ckpts/examples/garment/03553_00.jpg"]
266
  )
267
 
268
  with gr.Column():
 
272
  width=512,
273
  height=512,
274
  )
275
+ vt_gen_button = gr.Button("Try-on")
276
 
277
+ # ํฌ์ฆˆ ์ „์†ก ํƒญ
278
+ with gr.Tab("Pose Transfer"):
 
 
 
 
 
279
  with gr.Row():
280
  with gr.Column():
281
  gr.Markdown("#### Person Image")
 
286
  width=512,
287
  height=512,
288
  )
 
289
  gr.Examples(
290
  inputs=pt_ref_image,
291
  examples_per_page=5,
292
  examples=["./ckpts/examples/person1/01350_00.jpg",
293
+ "./ckpts/examples/person1/01376_00.jpg",
294
+ "./ckpts/examples/person1/01416_00.jpg",
295
+ "./ckpts/examples/person1/05976_00.jpg",
296
+ "./ckpts/examples/person1/06094_00.jpg"]
297
  )
298
 
299
  with gr.Column():
 
305
  width=512,
306
  height=512,
307
  )
 
308
  gr.Examples(
309
  inputs=pt_src_image,
310
  examples_per_page=5,
311
  examples=["./ckpts/examples/person2/01850_00.jpg",
312
+ "./ckpts/examples/person2/01875_00.jpg",
313
+ "./ckpts/examples/person2/02532_00.jpg",
314
+ "./ckpts/examples/person2/02902_00.jpg",
315
+ "./ckpts/examples/person2/05346_00.jpg"]
316
  )
317
 
318
  with gr.Column():
 
322
  width=512,
323
  height=512,
324
  )
325
+ pose_transfer_gen_button = gr.Button("Generate")
326
+
327
+ gr.Markdown(note)
328
+
329
+ # ์ด๋ฒคํŠธ ํ•ธ๋“ค๋Ÿฌ
330
+ generate_button.click(
331
+ generate_fashion,
332
+ inputs=[prompt, mode, cfg_scale, steps, randomize_seed, seed, width, height, lora_scale],
333
+ outputs=[result, seed]
334
+ )
335
+
336
+ vt_gen_button.click(
337
+ fn=leffa_predict_vt,
338
+ inputs=[vt_src_image, vt_ref_image],
339
+ outputs=[vt_gen_image]
340
+ )
341
+
342
+ pose_transfer_gen_button.click(
343
+ fn=leffa_predict_pt,
344
+ inputs=[pt_src_image, pt_ref_image],
345
+ outputs=[pt_gen_image]
346
+ )
347
+
348
+ demo.launch(share=True, server_port=7860)