Spaces:
ginipick
/
Running on Zero

ginipick commited on
Commit
5999e9b
ยท
verified ยท
1 Parent(s): dcb1878

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -57
app.py CHANGED
@@ -54,10 +54,18 @@ def load_model_with_optimization(model_class, *args, **kwargs):
54
  model = model.half() # FP16์œผ๋กœ ๋ณ€ํ™˜
55
  return model.to(device)
56
 
57
- # LoRA ๋กœ๋“œ ํ•จ์ˆ˜
58
  def load_lora(pipe, lora_path):
59
- pipe.load_lora_weights(lora_path)
60
- return pipe
 
 
 
 
 
 
 
 
 
61
 
62
  # FLUX ๋ชจ๋ธ ์ดˆ๊ธฐํ™” (ํ•„์š”ํ•  ๋•Œ๋งŒ ๋กœ๋“œ)
63
  fashion_pipe = None
@@ -151,6 +159,14 @@ def generate_fashion(prompt, mode, cfg_scale, steps, randomize_seed, seed, width
151
 
152
  pipe = get_fashion_pipe()
153
 
 
 
 
 
 
 
 
 
154
  # ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰ ์ œํ•œ์„ ์œ„ํ•œ ํฌ๊ธฐ ์กฐ์ •
155
  width = min(width, 768) # ์ตœ๋Œ€ ํฌ๊ธฐ ์ œํ•œ
156
  height = min(height, 768) # ์ตœ๋Œ€ ํฌ๊ธฐ ์ œํ•œ
@@ -178,59 +194,6 @@ def generate_fashion(prompt, mode, cfg_scale, steps, randomize_seed, seed, width
178
  clear_memory() # ์˜ค๋ฅ˜ ๋ฐœ์ƒ ์‹œ์—๋„ ๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ
179
  raise e
180
 
181
- def leffa_predict(src_image_path, ref_image_path, control_type):
182
- torch.cuda.empty_cache()
183
-
184
- assert control_type in [
185
- "virtual_tryon", "pose_transfer"], "Invalid control type: {}".format(control_type)
186
-
187
- # ์ด๋ฏธ์ง€ ๋กœ๋“œ ๋ฐ ํฌ๊ธฐ ์กฐ์ •
188
- src_image = Image.open(src_image_path)
189
- ref_image = Image.open(ref_image_path)
190
- src_image = resize_and_center(src_image, 768, 1024)
191
- ref_image = resize_and_center(ref_image, 768, 1024)
192
-
193
- src_image_array = np.array(src_image)
194
- ref_image_array = np.array(ref_image)
195
-
196
- # Mask ์ƒ์„ฑ
197
- if control_type == "virtual_tryon":
198
- mask_pred = get_mask_predictor()
199
- src_image = src_image.convert("RGB")
200
- mask = mask_pred(src_image, "upper")["mask"]
201
- elif control_type == "pose_transfer":
202
- mask = Image.fromarray(np.ones_like(src_image_array) * 255)
203
-
204
- # DensePose ์˜ˆ์ธก
205
- dense_pred = get_densepose_predictor()
206
- src_image_iuv_array = dense_pred.predict_iuv(src_image_array)
207
- src_image_seg_array = dense_pred.predict_seg(src_image_array)
208
- src_image_iuv = Image.fromarray(src_image_iuv_array)
209
- src_image_seg = Image.fromarray(src_image_seg_array)
210
-
211
- if control_type == "virtual_tryon":
212
- densepose = src_image_seg
213
- model, inference = get_vt_model()
214
- elif control_type == "pose_transfer":
215
- densepose = src_image_iuv
216
- model, inference = get_pt_model()
217
-
218
- # Leffa ๋ณ€ํ™˜ ๋ฐ ์ถ”๋ก 
219
- transform = LeffaTransform()
220
- data = {
221
- "src_image": [src_image],
222
- "ref_image": [ref_image],
223
- "mask": [mask],
224
- "densepose": [densepose],
225
- }
226
- data = transform(data)
227
-
228
- output = inference(data)
229
- gen_image = output["generated_image"][0]
230
-
231
- torch.cuda.empty_cache()
232
- return np.array(gen_image)
233
-
234
  def leffa_predict_vt(src_image_path, ref_image_path):
235
  return leffa_predict(src_image_path, ref_image_path, "virtual_tryon")
236
 
@@ -239,7 +202,7 @@ def leffa_predict_pt(src_image_path, ref_image_path):
239
 
240
 
241
  # Gradio ์ธํ„ฐํŽ˜์ด์Šค
242
- with gr.Blocks(theme=gr.themes.Default(primary_hue=gr.themes.colors.pink, secondary_hue=gr.themes.colors.red)) as demo:
243
  gr.Markdown("# ๐ŸŽญ Fashion Studio & Virtual Try-on")
244
 
245
  with gr.Tabs():
 
54
  model = model.half() # FP16์œผ๋กœ ๋ณ€ํ™˜
55
  return model.to(device)
56
 
 
57
  def load_lora(pipe, lora_path):
58
+ try:
59
+ pipe.unload_lora_weights() # ๊ธฐ์กด LoRA ๊ฐ€์ค‘์น˜ ์ œ๊ฑฐ
60
+ except:
61
+ pass
62
+
63
+ try:
64
+ pipe.load_lora_weights(lora_path)
65
+ return pipe
66
+ except Exception as e:
67
+ print(f"Warning: Failed to load LoRA weights from {lora_path}: {e}")
68
+ return pipe
69
 
70
  # FLUX ๋ชจ๋ธ ์ดˆ๊ธฐํ™” (ํ•„์š”ํ•  ๋•Œ๋งŒ ๋กœ๋“œ)
71
  fashion_pipe = None
 
159
 
160
  pipe = get_fashion_pipe()
161
 
162
+ # ๋ชจ๋“œ์— ๋”ฐ๋ฅธ LoRA ๋กœ๋”ฉ ๋ฐ ํŠธ๋ฆฌ๊ฑฐ์›Œ๋“œ ์„ค์ •
163
+ if mode == "Generate Model":
164
+ pipe = load_lora(pipe, MODEL_LORA_REPO)
165
+ trigger_word = "fashion photography, professional model"
166
+ else:
167
+ pipe = load_lora(pipe, CLOTHES_LORA_REPO)
168
+ trigger_word = "upper clothing, fashion item"
169
+
170
  # ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰ ์ œํ•œ์„ ์œ„ํ•œ ํฌ๊ธฐ ์กฐ์ •
171
  width = min(width, 768) # ์ตœ๋Œ€ ํฌ๊ธฐ ์ œํ•œ
172
  height = min(height, 768) # ์ตœ๋Œ€ ํฌ๊ธฐ ์ œํ•œ
 
194
  clear_memory() # ์˜ค๋ฅ˜ ๋ฐœ์ƒ ์‹œ์—๋„ ๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ
195
  raise e
196
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
  def leffa_predict_vt(src_image_path, ref_image_path):
198
  return leffa_predict(src_image_path, ref_image_path, "virtual_tryon")
199
 
 
202
 
203
 
204
  # Gradio ์ธํ„ฐํŽ˜์ด์Šค
205
+ with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange") as demo:
206
  gr.Markdown("# ๐ŸŽญ Fashion Studio & Virtual Try-on")
207
 
208
  with gr.Tabs():