Spaces:
ginipick
/
Running on Zero

ginipick commited on
Commit
26d2e48
·
verified ·
1 Parent(s): 9fbd4b4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -39
app.py CHANGED
@@ -172,29 +172,33 @@ def get_densepose_predictor():
172
  )
173
  return densepose_predictor
174
 
175
- @safe_model_call
176
  def get_vt_model():
177
- global vt_model, vt_inference
178
- if vt_model is None:
179
- vt_model = LeffaModel(
180
  pretrained_model_name_or_path="./ckpts/stable-diffusion-inpainting",
181
  pretrained_model="./ckpts/virtual_tryon.pth"
182
  )
183
- vt_model = vt_model.half().to(device)
184
- vt_inference = LeffaInference(model=vt_model)
185
- return vt_model, vt_inference
 
 
 
186
 
187
- @safe_model_call
188
  def get_pt_model():
189
- global pt_model, pt_inference
190
- if pt_model is None:
191
- pt_model = LeffaModel(
192
  pretrained_model_name_or_path="./ckpts/stable-diffusion-xl-1.0-inpainting-0.1",
193
  pretrained_model="./ckpts/pose_transfer.pth"
194
  )
195
- pt_model = pt_model.half().to(device)
196
- pt_inference = LeffaInference(model=pt_model)
197
- return pt_model, pt_inference
 
 
 
198
 
199
  def load_lora(pipe, lora_path):
200
  try:
@@ -298,7 +302,7 @@ def generate_fashion(prompt, mode, cfg_scale, steps, randomize_seed, seed, width
298
  print(f"Error in generate_fashion: {str(e)}")
299
  raise gr.Error(f"Generation failed: {str(e)}")
300
 
301
- @safe_model_call
302
  def leffa_predict(src_image_path, ref_image_path, control_type):
303
  try:
304
  # 모델 초기화
@@ -307,10 +311,7 @@ def leffa_predict(src_image_path, ref_image_path, control_type):
307
  else:
308
  model, inference = get_pt_model()
309
 
310
- mask_pred = get_mask_predictor()
311
- dense_pred = get_densepose_predictor()
312
-
313
- # 이미지 로드 및 전처리
314
  src_image = Image.open(src_image_path)
315
  ref_image = Image.open(ref_image_path)
316
  src_image = resize_and_center(src_image, 768, 1024)
@@ -319,21 +320,23 @@ def leffa_predict(src_image_path, ref_image_path, control_type):
319
  src_image_array = np.array(src_image)
320
  ref_image_array = np.array(ref_image)
321
 
322
- # Mask 생성
323
- if control_type == "virtual_tryon":
324
- src_image = src_image.convert("RGB")
325
- mask = mask_pred(src_image, "upper")["mask"]
326
- else:
327
- mask = Image.fromarray(np.ones_like(src_image_array) * 255)
328
-
329
- # DensePose 예측
330
- src_image_iuv_array = dense_pred.predict_iuv(src_image_array)
331
- src_image_seg_array = dense_pred.predict_seg(src_image_array)
332
-
333
- if control_type == "virtual_tryon":
334
- densepose = Image.fromarray(src_image_seg_array)
335
- else:
336
- densepose = Image.fromarray(src_image_iuv_array)
 
 
337
 
338
  # Leffa 변환 및 추론
339
  transform = LeffaTransform()
@@ -345,20 +348,36 @@ def leffa_predict(src_image_path, ref_image_path, control_type):
345
  }
346
  data = transform(data)
347
 
348
- output = inference(data)
 
 
 
 
 
 
 
 
349
  return np.array(output["generated_image"][0])
350
 
351
  except Exception as e:
352
  print(f"Error in leffa_predict: {str(e)}")
353
  raise
354
 
355
- @safe_model_call
356
  def leffa_predict_vt(src_image_path, ref_image_path):
357
- return leffa_predict(src_image_path, ref_image_path, "virtual_tryon")
 
 
 
 
358
 
359
- @safe_model_call
360
  def leffa_predict_pt(src_image_path, ref_image_path):
361
- return leffa_predict(src_image_path, ref_image_path, "pose_transfer")
 
 
 
 
362
 
363
  # 초기 설정 실행
364
  setup()
 
172
  )
173
  return densepose_predictor
174
 
175
+ @spaces.GPU()
176
  def get_vt_model():
177
+ try:
178
+ model = LeffaModel(
 
179
  pretrained_model_name_or_path="./ckpts/stable-diffusion-inpainting",
180
  pretrained_model="./ckpts/virtual_tryon.pth"
181
  )
182
+ model = model.half().to("cuda")
183
+ inference = LeffaInference(model=model)
184
+ return model, inference
185
+ except Exception as e:
186
+ print(f"Error in get_vt_model: {str(e)}")
187
+ raise
188
 
189
+ @spaces.GPU()
190
  def get_pt_model():
191
+ try:
192
+ model = LeffaModel(
 
193
  pretrained_model_name_or_path="./ckpts/stable-diffusion-xl-1.0-inpainting-0.1",
194
  pretrained_model="./ckpts/pose_transfer.pth"
195
  )
196
+ model = model.half().to("cuda")
197
+ inference = LeffaInference(model=model)
198
+ return model, inference
199
+ except Exception as e:
200
+ print(f"Error in get_pt_model: {str(e)}")
201
+ raise
202
 
203
  def load_lora(pipe, lora_path):
204
  try:
 
302
  print(f"Error in generate_fashion: {str(e)}")
303
  raise gr.Error(f"Generation failed: {str(e)}")
304
 
305
+ @spaces.GPU()
306
  def leffa_predict(src_image_path, ref_image_path, control_type):
307
  try:
308
  # 모델 초기화
 
311
  else:
312
  model, inference = get_pt_model()
313
 
314
+ # 이미지 처리
 
 
 
315
  src_image = Image.open(src_image_path)
316
  ref_image = Image.open(ref_image_path)
317
  src_image = resize_and_center(src_image, 768, 1024)
 
320
  src_image_array = np.array(src_image)
321
  ref_image_array = np.array(ref_image)
322
 
323
+ # Mask 및 DensePose 처리
324
+ with torch.inference_mode():
325
+ if control_type == "virtual_tryon":
326
+ src_image = src_image.convert("RGB")
327
+ mask_pred = get_mask_predictor()
328
+ mask = mask_pred(src_image, "upper")["mask"]
329
+ else:
330
+ mask = Image.fromarray(np.ones_like(src_image_array) * 255)
331
+
332
+ dense_pred = get_densepose_predictor()
333
+ src_image_iuv_array = dense_pred.predict_iuv(src_image_array)
334
+ src_image_seg_array = dense_pred.predict_seg(src_image_array)
335
+
336
+ if control_type == "virtual_tryon":
337
+ densepose = Image.fromarray(src_image_seg_array)
338
+ else:
339
+ densepose = Image.fromarray(src_image_iuv_array)
340
 
341
  # Leffa 변환 및 추론
342
  transform = LeffaTransform()
 
348
  }
349
  data = transform(data)
350
 
351
+ with torch.inference_mode():
352
+ output = inference(data)
353
+
354
+ # 메모리 정리
355
+ del model
356
+ del inference
357
+ torch.cuda.empty_cache()
358
+ gc.collect()
359
+
360
  return np.array(output["generated_image"][0])
361
 
362
  except Exception as e:
363
  print(f"Error in leffa_predict: {str(e)}")
364
  raise
365
 
366
+ @spaces.GPU()
367
  def leffa_predict_vt(src_image_path, ref_image_path):
368
+ try:
369
+ return leffa_predict(src_image_path, ref_image_path, "virtual_tryon")
370
+ except Exception as e:
371
+ print(f"Error in leffa_predict_vt: {str(e)}")
372
+ raise
373
 
374
+ @spaces.GPU()
375
  def leffa_predict_pt(src_image_path, ref_image_path):
376
+ try:
377
+ return leffa_predict(src_image_path, ref_image_path, "pose_transfer")
378
+ except Exception as e:
379
+ print(f"Error in leffa_predict_pt: {str(e)}")
380
+ raise
381
 
382
  # 초기 설정 실행
383
  setup()