Spaces:
ginipick
/
Running on Zero

ginipick commited on
Commit
af38e9b
·
verified ·
1 Parent(s): 5999e9b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +123 -2
app.py CHANGED
@@ -18,6 +18,32 @@ import gc
18
  # 메모리 관리 설정 추가
19
  import torch.backends.cuda
20
  torch.backends.cuda.max_split_size_mb = 128 # 메모리 분할 크기 제한
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  # 메모리 관리 설정
22
  torch.cuda.empty_cache()
23
  gc.collect()
@@ -30,6 +56,16 @@ def clear_memory():
30
  torch.cuda.synchronize()
31
  gc.collect()
32
 
 
 
 
 
 
 
 
 
 
 
33
  # 상수 정의
34
  MAX_SEED = 2**32 - 1
35
  BASE_MODEL = "black-forest-labs/FLUX.1-dev"
@@ -194,13 +230,98 @@ def generate_fashion(prompt, mode, cfg_scale, steps, randomize_seed, seed, width
194
  clear_memory() # 오류 발생 시에도 메모리 정리
195
  raise e
196
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
  def leffa_predict_vt(src_image_path, ref_image_path):
198
  return leffa_predict(src_image_path, ref_image_path, "virtual_tryon")
199
 
200
  def leffa_predict_pt(src_image_path, ref_image_path):
201
  return leffa_predict(src_image_path, ref_image_path, "pose_transfer")
202
-
203
-
204
  # Gradio 인터페이스
205
  with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange") as demo:
206
  gr.Markdown("# 🎭 Fashion Studio & Virtual Try-on")
 
18
  # 메모리 관리 설정 추가
19
  import torch.backends.cuda
20
  torch.backends.cuda.max_split_size_mb = 128 # 메모리 분할 크기 제한
21
+
22
+
23
+ # 전역 변수로 모델들을 선언
24
+ fashion_pipe = None
25
+ translator = None
26
+ mask_predictor = None
27
+ densepose_predictor = None
28
+ vt_model = None
29
+ pt_model = None
30
+ vt_inference = None
31
+ pt_inference = None
32
+
33
+ # 초기화 함수
34
+ def initialize_models():
35
+ global fashion_pipe
36
+ if fashion_pipe is None:
37
+ fashion_pipe = DiffusionPipeline.from_pretrained(
38
+ BASE_MODEL,
39
+ torch_dtype=torch.float16,
40
+ use_auth_token=HF_TOKEN
41
+ )
42
+ fashion_pipe.to(device)
43
+
44
+ # 앱 시작 시 모델 초기화
45
+ initialize_models()
46
+
47
  # 메모리 관리 설정
48
  torch.cuda.empty_cache()
49
  gc.collect()
 
56
  torch.cuda.synchronize()
57
  gc.collect()
58
 
59
+ # 모델 사용 후 메모리 해제
60
+ def unload_models():
61
+ global fashion_pipe, translator, mask_predictor, densepose_predictor, vt_model, pt_model
62
+ fashion_pipe = None
63
+ translator = None
64
+ mask_predictor = None
65
+ densepose_predictor = None
66
+ vt_model = None
67
+ pt_model = None
68
+ clear_memory()
69
  # 상수 정의
70
  MAX_SEED = 2**32 - 1
71
  BASE_MODEL = "black-forest-labs/FLUX.1-dev"
 
230
  clear_memory() # 오류 발생 시에도 메모리 정리
231
  raise e
232
 
233
+
234
+
235
+ def leffa_predict(src_image_path, ref_image_path, control_type):
236
+ global mask_predictor, densepose_predictor, vt_model, pt_model, vt_inference, pt_inference
237
+
238
+ clear_memory()
239
+
240
+ try:
241
+ # 필요한 모델 초기화
242
+ if control_type == "virtual_tryon" and vt_model is None:
243
+ vt_model = LeffaModel(
244
+ pretrained_model_name_or_path="./ckpts/stable-diffusion-inpainting",
245
+ pretrained_model="./ckpts/virtual_tryon.pth"
246
+ )
247
+ vt_model.to(device)
248
+ vt_inference = LeffaInference(model=vt_model)
249
+
250
+ elif control_type == "pose_transfer" and pt_model is None:
251
+ pt_model = LeffaModel(
252
+ pretrained_model_name_or_path="./ckpts/stable-diffusion-xl-1.0-inpainting-0.1",
253
+ pretrained_model="./ckpts/pose_transfer.pth"
254
+ )
255
+ pt_model.to(device)
256
+ pt_inference = LeffaInference(model=pt_model)
257
+
258
+ if mask_predictor is None:
259
+ mask_predictor = AutoMasker(
260
+ densepose_path="./ckpts/densepose",
261
+ schp_path="./ckpts/schp",
262
+ )
263
+
264
+ if densepose_predictor is None:
265
+ densepose_predictor = DensePosePredictor(
266
+ config_path="./ckpts/densepose/densepose_rcnn_R_50_FPN_s1x.yaml",
267
+ weights_path="./ckpts/densepose/model_final_162be9.pkl",
268
+ )
269
+
270
+ # 이미지 처리
271
+ src_image = Image.open(src_image_path)
272
+ ref_image = Image.open(ref_image_path)
273
+ src_image = resize_and_center(src_image, 768, 1024)
274
+ ref_image = resize_and_center(ref_image, 768, 1024)
275
+
276
+ src_image_array = np.array(src_image)
277
+ ref_image_array = np.array(ref_image)
278
+
279
+ # Mask 생성
280
+ if control_type == "virtual_tryon":
281
+ src_image = src_image.convert("RGB")
282
+ mask = mask_predictor(src_image, "upper")["mask"]
283
+ else:
284
+ mask = Image.fromarray(np.ones_like(src_image_array) * 255)
285
+
286
+ # DensePose 예측
287
+ src_image_iuv_array = densepose_predictor.predict_iuv(src_image_array)
288
+ src_image_seg_array = densepose_predictor.predict_seg(src_image_array)
289
+ src_image_iuv = Image.fromarray(src_image_iuv_array)
290
+ src_image_seg = Image.fromarray(src_image_seg_array)
291
+
292
+ if control_type == "virtual_tryon":
293
+ densepose = src_image_seg
294
+ inference = vt_inference
295
+ else:
296
+ densepose = src_image_iuv
297
+ inference = pt_inference
298
+
299
+ # Leffa 변환 및 추론
300
+ transform = LeffaTransform()
301
+ data = {
302
+ "src_image": [src_image],
303
+ "ref_image": [ref_image],
304
+ "mask": [mask],
305
+ "densepose": [densepose],
306
+ }
307
+ data = transform(data)
308
+
309
+ output = inference(data)
310
+ gen_image = output["generated_image"][0]
311
+
312
+ clear_memory()
313
+ return np.array(gen_image)
314
+
315
+ except Exception as e:
316
+ clear_memory()
317
+ raise e
318
+
319
  def leffa_predict_vt(src_image_path, ref_image_path):
320
  return leffa_predict(src_image_path, ref_image_path, "virtual_tryon")
321
 
322
  def leffa_predict_pt(src_image_path, ref_image_path):
323
  return leffa_predict(src_image_path, ref_image_path, "pose_transfer")
324
+
 
325
  # Gradio 인터페이스
326
  with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange") as demo:
327
  gr.Markdown("# 🎭 Fashion Studio & Virtual Try-on")