Spaces:
ginipick
/
Running on Zero

ginipick commited on
Commit
83b8d5b
·
verified ·
1 Parent(s): 26d2e48

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +79 -91
app.py CHANGED
@@ -36,11 +36,21 @@ def safe_model_call(func):
36
  raise
37
  return wrapper
38
 
39
- # 메모리 관리 함수 수정
 
 
 
 
 
 
 
 
 
 
 
40
  def clear_memory():
41
  gc.collect()
42
- if torch.cuda.is_available() and torch.cuda.current_device() >= 0:
43
- torch.cuda.empty_cache()
44
 
45
  def setup_environment():
46
  os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128'
@@ -53,52 +63,48 @@ def setup_environment():
53
  @spaces.GPU()
54
  def generate_image(prompt, mode, cfg_scale=7.0, steps=30, seed=None, width=512, height=768, lora_scale=0.85):
55
  try:
56
- # 한글 처리
57
- if contains_korean(prompt):
58
- translator = get_translator()
59
- translated = translator(prompt)[0]['translation_text']
60
- actual_prompt = translated
61
- else:
62
- actual_prompt = prompt
63
-
64
- # 파이프라인 초기화
65
- pipe = DiffusionPipeline.from_pretrained(
66
- BASE_MODEL,
67
- torch_dtype=torch.float16,
68
- )
69
- pipe.to("cuda")
70
-
71
- # LoRA 설정
72
- if mode == "Generate Model":
73
- pipe.load_lora_weights(MODEL_LORA_REPO)
74
- trigger_word = "fashion photography, professional model"
75
- else:
76
- pipe.load_lora_weights(CLOTHES_LORA_REPO)
77
- trigger_word = "upper clothing, fashion item"
78
-
79
- # 생성 설정
80
- generator = torch.Generator("cuda").manual_seed(seed if seed is not None else torch.randint(0, 2**32 - 1, (1,)).item())
81
 
82
- # 이미지 생성
83
- with torch.inference_mode():
84
- result = pipe(
85
- prompt=f"{actual_prompt} {trigger_word}",
86
- num_inference_steps=steps,
87
- guidance_scale=cfg_scale,
88
- width=width,
89
- height=height,
90
- generator=generator,
91
- joint_attention_kwargs={"scale": lora_scale},
92
- ).images[0]
93
 
94
- # 메모리 정리
95
- del pipe
96
- clear_memory()
 
 
 
 
97
 
98
- return result, seed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
  except Exception as e:
101
- clear_memory()
102
  raise gr.Error(f"Generation failed: {str(e)}")
103
 
104
  def contains_korean(text):
@@ -109,35 +115,17 @@ def get_translator():
109
  return pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en", device="cuda")
110
 
111
 
112
- # 전역 변수 초기화
113
- fashion_pipe = None
114
- translator = None
115
- mask_predictor = None
116
- densepose_predictor = None
117
- vt_model = None
118
- pt_model = None
119
- vt_inference = None
120
- pt_inference = None
121
- device = None
122
- HF_TOKEN = None
123
-
124
  # 환경 설정 실행
125
  setup_environment()
126
 
127
  @spaces.GPU()
128
  def initialize_fashion_pipe():
129
- global fashion_pipe
130
- if fashion_pipe is None:
131
- fashion_pipe = DiffusionPipeline.from_pretrained(
132
  BASE_MODEL,
133
  torch_dtype=torch.float16,
134
- use_auth_token=HF_TOKEN
135
- ).to("cuda")
136
- try:
137
- fashion_pipe.enable_xformers_memory_efficient_attention()
138
- except Exception as e:
139
- print(f"Warning: Could not enable memory efficient attention: {e}")
140
- return fashion_pipe
141
 
142
  def setup():
143
  # Leffa 체크포���트 다운로드만 수행
@@ -145,12 +133,10 @@ def setup():
145
 
146
  @spaces.GPU()
147
  def get_translator():
148
- global translator
149
- if translator is None:
150
- translator = pipeline("translation",
151
- model="Helsinki-NLP/opus-mt-ko-en",
152
- device="cuda")
153
- return translator
154
 
155
  @safe_model_call
156
  def get_mask_predictor():
@@ -174,17 +160,13 @@ def get_densepose_predictor():
174
 
175
  @spaces.GPU()
176
  def get_vt_model():
177
- try:
178
  model = LeffaModel(
179
  pretrained_model_name_or_path="./ckpts/stable-diffusion-inpainting",
180
  pretrained_model="./ckpts/virtual_tryon.pth"
181
  )
182
- model = model.half().to("cuda")
183
- inference = LeffaInference(model=model)
184
- return model, inference
185
- except Exception as e:
186
- print(f"Error in get_vt_model: {str(e)}")
187
- raise
188
 
189
  @spaces.GPU()
190
  def get_pt_model():
@@ -381,9 +363,9 @@ def leffa_predict_pt(src_image_path, ref_image_path):
381
 
382
  # 초기 설정 실행
383
  setup()
384
-
385
  def create_interface():
386
- with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange") as demo:
 
387
  gr.Markdown("# 🎭 FitGen:Fashion Studio & Virtual Try-on")
388
 
389
  with gr.Tabs():
@@ -581,12 +563,7 @@ def create_interface():
581
  )
582
  pose_transfer_gen_button = gr.Button("Generate")
583
 
584
- # 이벤트 핸들러
585
- generate_button.click(
586
- fn=generate_image,
587
- inputs=[prompt, mode, cfg_scale, steps, seed, width, height, lora_scale],
588
- outputs=[result, seed]
589
- )
590
 
591
  vt_gen_button.click(
592
  fn=leffa_predict_vt,
@@ -600,16 +577,27 @@ def create_interface():
600
  outputs=[pt_gen_image]
601
  )
602
 
 
 
 
 
 
 
 
 
 
 
 
 
603
  return demo
604
 
605
  if __name__ == "__main__":
606
- # 환경 설정
607
  setup_environment()
608
-
609
- # 인터페이스 생성 및 실행
610
  demo = create_interface()
 
611
  demo.launch(
612
  server_name="0.0.0.0",
613
  server_port=7860,
614
- share=False
 
615
  )
 
36
  raise
37
  return wrapper
38
 
39
+
40
+ # 메모리 관리를 위한 컨텍스트 매니저
41
+ @contextmanager
42
+ def torch_gc():
43
+ try:
44
+ yield
45
+ finally:
46
+ gc.collect()
47
+ if torch.cuda.is_available() and torch.cuda.current_device() >= 0:
48
+ with torch.cuda.device('cuda'):
49
+ torch.cuda.empty_cache()
50
+
51
  def clear_memory():
52
  gc.collect()
53
+
 
54
 
55
  def setup_environment():
56
  os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128'
 
63
  @spaces.GPU()
64
  def generate_image(prompt, mode, cfg_scale=7.0, steps=30, seed=None, width=512, height=768, lora_scale=0.85):
65
  try:
66
+ with torch_gc():
67
+ # 한글 처리
68
+ if contains_korean(prompt):
69
+ translator = get_translator()
70
+ with torch.inference_mode():
71
+ translated = translator(prompt)[0]['translation_text']
72
+ actual_prompt = translated
73
+ else:
74
+ actual_prompt = prompt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
+ # 파이프라인 초기화
77
+ pipe = DiffusionPipeline.from_pretrained(
78
+ BASE_MODEL,
79
+ torch_dtype=torch.float16,
80
+ )
81
+ pipe = pipe.to("cuda")
 
 
 
 
 
82
 
83
+ # LoRA 설정
84
+ if mode == "Generate Model":
85
+ pipe.load_lora_weights(MODEL_LORA_REPO)
86
+ trigger_word = "fashion photography, professional model"
87
+ else:
88
+ pipe.load_lora_weights(CLOTHES_LORA_REPO)
89
+ trigger_word = "upper clothing, fashion item"
90
 
91
+ # 이미지 생성
92
+ with torch.inference_mode():
93
+ result = pipe(
94
+ prompt=f"{actual_prompt} {trigger_word}",
95
+ num_inference_steps=steps,
96
+ guidance_scale=cfg_scale,
97
+ width=width,
98
+ height=height,
99
+ generator=torch.Generator("cuda").manual_seed(
100
+ seed if seed is not None else torch.randint(0, 2**32 - 1, (1,)).item()
101
+ ),
102
+ joint_attention_kwargs={"scale": lora_scale},
103
+ ).images[0]
104
+
105
+ return result, seed
106
 
107
  except Exception as e:
 
108
  raise gr.Error(f"Generation failed: {str(e)}")
109
 
110
  def contains_korean(text):
 
115
  return pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en", device="cuda")
116
 
117
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  # 환경 설정 실행
119
  setup_environment()
120
 
121
  @spaces.GPU()
122
  def initialize_fashion_pipe():
123
+ with torch_gc():
124
+ pipe = DiffusionPipeline.from_pretrained(
 
125
  BASE_MODEL,
126
  torch_dtype=torch.float16,
127
+ )
128
+ return pipe.to("cuda")
 
 
 
 
 
129
 
130
  def setup():
131
  # Leffa 체크포���트 다운로드만 수행
 
133
 
134
  @spaces.GPU()
135
  def get_translator():
136
+ with torch_gc():
137
+ return pipeline("translation",
138
+ model="Helsinki-NLP/opus-mt-ko-en",
139
+ device="cuda")
 
 
140
 
141
  @safe_model_call
142
  def get_mask_predictor():
 
160
 
161
  @spaces.GPU()
162
  def get_vt_model():
163
+ with torch_gc():
164
  model = LeffaModel(
165
  pretrained_model_name_or_path="./ckpts/stable-diffusion-inpainting",
166
  pretrained_model="./ckpts/virtual_tryon.pth"
167
  )
168
+ model = model.half()
169
+ return model.to("cuda"), LeffaInference(model=model)
 
 
 
 
170
 
171
  @spaces.GPU()
172
  def get_pt_model():
 
363
 
364
  # 초기 설정 실행
365
  setup()
 
366
  def create_interface():
367
+ with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange") as demo:
368
+
369
  gr.Markdown("# 🎭 FitGen:Fashion Studio & Virtual Try-on")
370
 
371
  with gr.Tabs():
 
563
  )
564
  pose_transfer_gen_button = gr.Button("Generate")
565
 
566
+
 
 
 
 
 
567
 
568
  vt_gen_button.click(
569
  fn=leffa_predict_vt,
 
577
  outputs=[pt_gen_image]
578
  )
579
 
580
+
581
+
582
+ generate_button.click(
583
+ fn=generate_image,
584
+ inputs=[prompt, mode, cfg_scale, steps, seed, width, height, lora_scale],
585
+ outputs=[result, seed]
586
+ ).success(
587
+ fn=lambda: gc.collect(), # 성공 후 메모리 정리
588
+ inputs=None,
589
+ outputs=None
590
+ )
591
+
592
  return demo
593
 
594
  if __name__ == "__main__":
 
595
  setup_environment()
 
 
596
  demo = create_interface()
597
+ demo.queue() # 큐 활성화
598
  demo.launch(
599
  server_name="0.0.0.0",
600
  server_port=7860,
601
+ share=False,
602
+ memory_target_gb=0.5 # 메모리 제한 설정
603
  )