Spaces:
ginipick
/
Running on Zero

fantaxy commited on
Commit
501340d
·
verified ·
1 Parent(s): 2a0877a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -128
app.py CHANGED
@@ -15,9 +15,7 @@ import gradio as gr
15
  import os
16
  import random
17
  import gc
18
- from contextlib import contextmanager # 이 줄 추가
19
-
20
-
21
 
22
  # 상수 정의
23
  MAX_SEED = 2**32 - 1
@@ -39,7 +37,6 @@ def safe_model_call(func):
39
  raise
40
  return wrapper
41
 
42
-
43
  # 메모리 관리를 위한 컨텍스트 매니저
44
  @contextmanager
45
  def torch_gc():
@@ -54,7 +51,6 @@ def torch_gc():
54
  def clear_memory():
55
  gc.collect()
56
 
57
-
58
  def setup_environment():
59
  os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128'
60
  HF_TOKEN = os.getenv("HF_TOKEN")
@@ -63,17 +59,6 @@ def setup_environment():
63
  login(token=HF_TOKEN)
64
  return HF_TOKEN
65
 
66
- @contextmanager
67
- def torch_gc():
68
- try:
69
- yield
70
- finally:
71
- gc.collect()
72
- if torch.cuda.is_available() and torch.cuda.current_device() >= 0:
73
- with torch.cuda.device('cuda'):
74
- torch.cuda.empty_cache()
75
-
76
-
77
  def contains_korean(text):
78
  return any(ord('가') <= ord(char) <= ord('힣') for char in text)
79
 
@@ -81,7 +66,6 @@ def contains_korean(text):
81
  def get_translator():
82
  return pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en", device="cuda")
83
 
84
-
85
  # 환경 설정 실행
86
  setup_environment()
87
 
@@ -97,7 +81,7 @@ def initialize_fashion_pipe():
97
  def setup():
98
  # Leffa 체크포인트 다운로드만 수행
99
  snapshot_download(repo_id="franciszzj/Leffa", local_dir="./ckpts")
100
-
101
  @spaces.GPU()
102
  def get_translator():
103
  with torch_gc():
@@ -135,20 +119,6 @@ def get_vt_model():
135
  model = model.half()
136
  return model.to("cuda"), LeffaInference(model=model)
137
 
138
- @spaces.GPU()
139
- def get_pt_model():
140
- try:
141
- model = LeffaModel(
142
- pretrained_model_name_or_path="./ckpts/stable-diffusion-xl-1.0-inpainting-0.1",
143
- pretrained_model="./ckpts/pose_transfer.pth"
144
- )
145
- model = model.half().to("cuda")
146
- inference = LeffaInference(model=model)
147
- return model, inference
148
- except Exception as e:
149
- print(f"Error in get_pt_model: {str(e)}")
150
- raise
151
-
152
  def load_lora(pipe, lora_path):
153
  try:
154
  pipe.unload_lora_weights()
@@ -170,11 +140,6 @@ def get_mask_predictor():
170
  schp_path="./ckpts/schp",
171
  )
172
  return mask_predictor
173
-
174
- # 유틸리티 함수
175
- def contains_korean(text):
176
- return any(ord('가') <= ord(char) <= ord('힣') for char in text)
177
-
178
 
179
  # 모델 초기화 함수 수정
180
  @spaces.GPU()
@@ -192,7 +157,6 @@ def initialize_fashion_pipe():
192
  print(f"Error initializing fashion pipe: {e}")
193
  raise
194
 
195
- # 생성 함수 수정
196
  @spaces.GPU()
197
  def generate_fashion(prompt, mode, cfg_scale, steps, randomize_seed, seed, width, height, lora_scale, progress=gr.Progress(track_tqdm=True)):
198
  try:
@@ -291,10 +255,7 @@ def leffa_predict(src_image_path, ref_image_path, control_type):
291
  try:
292
  with torch_gc():
293
  # 모델 초기화
294
- if control_type == "virtual_tryon":
295
- model, inference = get_vt_model()
296
- else:
297
- model, inference = get_pt_model()
298
 
299
  # 이미지 처리
300
  src_image = Image.open(src_image_path)
@@ -307,21 +268,13 @@ def leffa_predict(src_image_path, ref_image_path, control_type):
307
 
308
  # Mask 및 DensePose 처리
309
  with torch.inference_mode():
310
- if control_type == "virtual_tryon":
311
- src_image = src_image.convert("RGB")
312
- mask_pred = model_manager.get_mask_predictor()
313
- mask = mask_pred(src_image, "upper")["mask"]
314
- else:
315
- mask = Image.fromarray(np.ones_like(src_image_array) * 255)
316
 
317
  dense_pred = model_manager.get_densepose_predictor()
318
- src_image_iuv_array = dense_pred.predict_iuv(src_image_array)
319
  src_image_seg_array = dense_pred.predict_seg(src_image_array)
320
-
321
- if control_type == "virtual_tryon":
322
- densepose = Image.fromarray(src_image_seg_array)
323
- else:
324
- densepose = Image.fromarray(src_image_iuv_array)
325
 
326
  # Leffa 변환 및 추론
327
  transform = LeffaTransform()
@@ -348,8 +301,6 @@ def leffa_predict(src_image_path, ref_image_path, control_type):
348
  print(f"Error in leffa_predict: {str(e)}")
349
  raise
350
 
351
-
352
-
353
  @spaces.GPU()
354
  def leffa_predict_vt(src_image_path, ref_image_path):
355
  try:
@@ -358,14 +309,6 @@ def leffa_predict_vt(src_image_path, ref_image_path):
358
  print(f"Error in leffa_predict_vt: {str(e)}")
359
  raise
360
 
361
- @spaces.GPU()
362
- def leffa_predict_pt(src_image_path, ref_image_path):
363
- try:
364
- return leffa_predict(src_image_path, ref_image_path, "pose_transfer")
365
- except Exception as e:
366
- print(f"Error in leffa_predict_pt: {str(e)}")
367
- raise
368
-
369
  @spaces.GPU()
370
  def generate_image(prompt, mode, cfg_scale=7.0, steps=30, seed=None, width=512, height=768, lora_scale=0.85):
371
  try:
@@ -379,7 +322,6 @@ def generate_image(prompt, mode, cfg_scale=7.0, steps=30, seed=None, width=512,
379
  else:
380
  actual_prompt = prompt
381
 
382
-
383
  # 파이프라인 초기화
384
  pipe = DiffusionPipeline.from_pretrained(
385
  BASE_MODEL,
@@ -413,16 +355,14 @@ def generate_image(prompt, mode, cfg_scale=7.0, steps=30, seed=None, width=512,
413
  del pipe
414
  return result, seed
415
 
416
-
417
  except Exception as e:
418
  raise gr.Error(f"Generation failed: {str(e)}")
419
 
420
-
421
  # 초기 설정 실행
422
  setup()
 
423
  def create_interface():
424
  with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange") as demo:
425
-
426
  gr.Markdown("# 🎭 FitGen:Fashion Studio & Virtual Try-on")
427
 
428
  with gr.Tabs():
@@ -570,71 +510,11 @@ def create_interface():
570
  )
571
  vt_gen_button = gr.Button("Try-on")
572
 
573
- # 포즈 전송 탭
574
- with gr.Tab("Pose Transfer"):
575
- with gr.Row():
576
- with gr.Column():
577
- gr.Markdown("#### Person Image")
578
- pt_ref_image = gr.Image(
579
- sources=["upload"],
580
- type="filepath",
581
- label="Person Image",
582
- width=512,
583
- height=512,
584
- )
585
- gr.Examples(
586
- inputs=pt_ref_image,
587
- examples_per_page=5,
588
- examples=["a1.webp",
589
- "a2.webp",
590
- "a3.webp",
591
- "a4.webp",
592
- "a5.webp"]
593
- )
594
-
595
- with gr.Column():
596
- gr.Markdown("#### Target Pose Person Image")
597
- pt_src_image = gr.Image(
598
- sources=["upload"],
599
- type="filepath",
600
- label="Target Pose Person Image",
601
- width=512,
602
- height=512,
603
- )
604
- gr.Examples(
605
- inputs=pt_src_image,
606
- examples_per_page=5,
607
- examples=["d1.webp",
608
- "d2.webp",
609
- "d3.webp",
610
- "d4.webp",
611
- "d5.webp"]
612
- )
613
-
614
- with gr.Column():
615
- gr.Markdown("#### Generated Image")
616
- pt_gen_image = gr.Image(
617
- label="Generated Image",
618
- width=512,
619
- height=512,
620
- )
621
- pose_transfer_gen_button = gr.Button("Generate")
622
-
623
-
624
-
625
  vt_gen_button.click(
626
  fn=leffa_predict_vt,
627
  inputs=[vt_src_image, vt_ref_image],
628
  outputs=[vt_gen_image]
629
  )
630
-
631
- pose_transfer_gen_button.click(
632
- fn=leffa_predict_pt,
633
- inputs=[pt_src_image, pt_ref_image],
634
- outputs=[pt_gen_image]
635
- )
636
-
637
-
638
 
639
  generate_button.click(
640
  fn=generate_image,
 
15
  import os
16
  import random
17
  import gc
18
+ from contextlib import contextmanager
 
 
19
 
20
  # 상수 정의
21
  MAX_SEED = 2**32 - 1
 
37
  raise
38
  return wrapper
39
 
 
40
  # 메모리 관리를 위한 컨텍스트 매니저
41
  @contextmanager
42
  def torch_gc():
 
51
  def clear_memory():
52
  gc.collect()
53
 
 
54
  def setup_environment():
55
  os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128'
56
  HF_TOKEN = os.getenv("HF_TOKEN")
 
59
  login(token=HF_TOKEN)
60
  return HF_TOKEN
61
 
 
 
 
 
 
 
 
 
 
 
 
62
  def contains_korean(text):
63
  return any(ord('가') <= ord(char) <= ord('힣') for char in text)
64
 
 
66
  def get_translator():
67
  return pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en", device="cuda")
68
 
 
69
  # 환경 설정 실행
70
  setup_environment()
71
 
 
81
  def setup():
82
  # Leffa 체크포인트 다운로드만 수행
83
  snapshot_download(repo_id="franciszzj/Leffa", local_dir="./ckpts")
84
+
85
  @spaces.GPU()
86
  def get_translator():
87
  with torch_gc():
 
119
  model = model.half()
120
  return model.to("cuda"), LeffaInference(model=model)
121
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  def load_lora(pipe, lora_path):
123
  try:
124
  pipe.unload_lora_weights()
 
140
  schp_path="./ckpts/schp",
141
  )
142
  return mask_predictor
 
 
 
 
 
143
 
144
  # 모델 초기화 함수 수정
145
  @spaces.GPU()
 
157
  print(f"Error initializing fashion pipe: {e}")
158
  raise
159
 
 
160
  @spaces.GPU()
161
  def generate_fashion(prompt, mode, cfg_scale, steps, randomize_seed, seed, width, height, lora_scale, progress=gr.Progress(track_tqdm=True)):
162
  try:
 
255
  try:
256
  with torch_gc():
257
  # 모델 초기화
258
+ model, inference = get_vt_model()
 
 
 
259
 
260
  # 이미지 처리
261
  src_image = Image.open(src_image_path)
 
268
 
269
  # Mask 및 DensePose 처리
270
  with torch.inference_mode():
271
+ src_image = src_image.convert("RGB")
272
+ mask_pred = model_manager.get_mask_predictor()
273
+ mask = mask_pred(src_image, "upper")["mask"]
 
 
 
274
 
275
  dense_pred = model_manager.get_densepose_predictor()
 
276
  src_image_seg_array = dense_pred.predict_seg(src_image_array)
277
+ densepose = Image.fromarray(src_image_seg_array)
 
 
 
 
278
 
279
  # Leffa 변환 및 추론
280
  transform = LeffaTransform()
 
301
  print(f"Error in leffa_predict: {str(e)}")
302
  raise
303
 
 
 
304
  @spaces.GPU()
305
  def leffa_predict_vt(src_image_path, ref_image_path):
306
  try:
 
309
  print(f"Error in leffa_predict_vt: {str(e)}")
310
  raise
311
 
 
 
 
 
 
 
 
 
312
  @spaces.GPU()
313
  def generate_image(prompt, mode, cfg_scale=7.0, steps=30, seed=None, width=512, height=768, lora_scale=0.85):
314
  try:
 
322
  else:
323
  actual_prompt = prompt
324
 
 
325
  # 파이프라인 초기화
326
  pipe = DiffusionPipeline.from_pretrained(
327
  BASE_MODEL,
 
355
  del pipe
356
  return result, seed
357
 
 
358
  except Exception as e:
359
  raise gr.Error(f"Generation failed: {str(e)}")
360
 
 
361
  # 초기 설정 실행
362
  setup()
363
+
364
  def create_interface():
365
  with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange") as demo:
 
366
  gr.Markdown("# 🎭 FitGen:Fashion Studio & Virtual Try-on")
367
 
368
  with gr.Tabs():
 
510
  )
511
  vt_gen_button = gr.Button("Try-on")
512
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
513
  vt_gen_button.click(
514
  fn=leffa_predict_vt,
515
  inputs=[vt_src_image, vt_ref_image],
516
  outputs=[vt_gen_image]
517
  )
 
 
 
 
 
 
 
 
518
 
519
  generate_button.click(
520
  fn=generate_image,