Spaces:
ginipick
/
Running on Zero

ginipick commited on
Commit
f7aa706
·
verified ·
1 Parent(s): a20aa8e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +125 -75
app.py CHANGED
@@ -16,10 +16,12 @@ import os
16
  import random
17
  import gc
18
 
19
- # 메모리 최적화 설정
 
 
 
20
  torch.backends.cudnn.benchmark = True
21
  torch.backends.cuda.matmul.allow_tf32 = True
22
- os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:512'
23
 
24
  # 상수 정의
25
  MAX_SEED = 2**32 - 1
@@ -27,110 +29,159 @@ BASE_MODEL = "black-forest-labs/FLUX.1-dev"
27
  MODEL_LORA_REPO = "Motas/Flux_Fashion_Photography_Style"
28
  CLOTHES_LORA_REPO = "prithivMLmods/Canopus-Clothing-Flux-LoRA"
29
 
30
- # Hugging Face 토큰 설정 및 로그인
31
  HF_TOKEN = os.getenv("HF_TOKEN")
32
  if HF_TOKEN is None:
33
  raise ValueError("Please set the HF_TOKEN environment variable")
34
  login(token=HF_TOKEN)
35
 
36
- # 메모리 정리 함수
37
- def clear_memory():
38
- torch.cuda.empty_cache()
39
- gc.collect()
40
-
41
- # 초기 메모리 정리
42
- clear_memory()
43
-
44
- # CUDA 사용 가능 여부 확인
45
  device = "cuda" if torch.cuda.is_available() else "cpu"
46
 
47
- # FLUX 모델 초기화
48
- fashion_pipe = DiffusionPipeline.from_pretrained(
49
- BASE_MODEL,
50
- torch_dtype=torch.float16,
51
- use_auth_token=HF_TOKEN
52
- )
53
- fashion_pipe.enable_model_cpu_offload()
54
-
55
- # 번역기 초기화
56
- translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
  # Leffa 체크포인트 다운로드
59
  snapshot_download(repo_id="franciszzj/Leffa", local_dir="./ckpts")
60
 
61
- # Leffa 관련 모델 초기화
62
- mask_predictor = AutoMasker(
63
- densepose_path="./ckpts/densepose",
64
- schp_path="./ckpts/schp",
65
- )
66
-
67
- densepose_predictor = DensePosePredictor(
68
- config_path="./ckpts/densepose/densepose_rcnn_R_50_FPN_s1x.yaml",
69
- weights_path="./ckpts/densepose/model_final_162be9.pkl",
70
- )
71
- # Leffa 모델 초기화 수정
72
- vt_model = LeffaModel(
73
- pretrained_model_name_or_path="./ckpts/stable-diffusion-inpainting",
74
- pretrained_model="./ckpts/virtual_tryon.pth"
75
- )
76
- vt_model.to(device) # 모델을 GPU로 이동
77
- vt_inference = LeffaInference(model=vt_model)
78
-
79
- pt_model = LeffaModel(
80
- pretrained_model_name_or_path="./ckpts/stable-diffusion-xl-1.0-inpainting-0.1",
81
- pretrained_model="./ckpts/pose_transfer.pth"
82
- )
83
- pt_model.to(device) # 모델을 GPU로 이동
84
- pt_inference = LeffaInference(model=pt_model)
85
-
86
-
87
  def contains_korean(text):
88
  return any(ord('가') <= ord(char) <= ord('힣') for char in text)
89
 
90
  @spaces.GPU()
91
  def generate_fashion(prompt, mode, cfg_scale, steps, randomize_seed, seed, width, height, lora_scale, progress=gr.Progress(track_tqdm=True)):
 
 
92
  if contains_korean(prompt):
 
93
  translated = translator(prompt)[0]['translation_text']
94
  actual_prompt = translated
95
  else:
96
  actual_prompt = prompt
97
-
98
 
99
  # 모드에 따른 LoRA 로딩 및 트리거워드 설정
 
100
  if mode == "Generate Model":
101
- fashion_pipe = load_lora(fashion_pipe, model_lora_repo)
102
  trigger_word = "fashion photography, professional model"
103
  else:
104
- fashion_pipe = load_lora(fashion_pipe, clothes_lora_repo)
105
  trigger_word = "upper clothing, fashion item"
106
 
107
  if randomize_seed:
108
  seed = random.randint(0, MAX_SEED)
109
  generator = torch.Generator(device="cuda").manual_seed(seed)
110
 
 
 
 
 
111
  progress(0, "Starting fashion generation...")
112
 
113
  for i in range(1, steps + 1):
114
  if i % (steps // 10) == 0:
115
  progress(i / steps * 100, f"Processing step {i} of {steps}...")
116
 
117
- image = fashion_pipe(
118
  prompt=f"{actual_prompt} {trigger_word}",
119
  num_inference_steps=steps,
120
  guidance_scale=cfg_scale,
121
  width=width,
122
  height=height,
123
  generator=generator,
124
- use_auth_token=HF_TOKEN, # 인증 토큰 추가
125
  joint_attention_kwargs={"scale": lora_scale},
126
  ).images[0]
127
 
128
  progress(100, "Completed!")
129
  return image, seed
130
-
 
131
  def leffa_predict(src_image_path, ref_image_path, control_type):
 
 
132
  assert control_type in [
133
  "virtual_tryon", "pose_transfer"], "Invalid control type: {}".format(control_type)
 
 
134
  src_image = Image.open(src_image_path)
135
  ref_image = Image.open(ref_image_path)
136
  src_image = resize_and_center(src_image, 768, 1024)
@@ -139,26 +190,30 @@ def leffa_predict(src_image_path, ref_image_path, control_type):
139
  src_image_array = np.array(src_image)
140
  ref_image_array = np.array(ref_image)
141
 
142
- # Mask
143
  if control_type == "virtual_tryon":
 
144
  src_image = src_image.convert("RGB")
145
- mask = mask_predictor(src_image, "upper")["mask"]
146
  elif control_type == "pose_transfer":
147
  mask = Image.fromarray(np.ones_like(src_image_array) * 255)
148
 
149
- # DensePose
150
- src_image_iuv_array = densepose_predictor.predict_iuv(src_image_array)
151
- src_image_seg_array = densepose_predictor.predict_seg(src_image_array)
 
152
  src_image_iuv = Image.fromarray(src_image_iuv_array)
153
  src_image_seg = Image.fromarray(src_image_seg_array)
 
154
  if control_type == "virtual_tryon":
155
  densepose = src_image_seg
 
156
  elif control_type == "pose_transfer":
157
  densepose = src_image_iuv
 
158
 
159
- # Leffa
160
  transform = LeffaTransform()
161
-
162
  data = {
163
  "src_image": [src_image],
164
  "ref_image": [ref_image],
@@ -166,25 +221,21 @@ def leffa_predict(src_image_path, ref_image_path, control_type):
166
  "densepose": [densepose],
167
  }
168
  data = transform(data)
169
- if control_type == "virtual_tryon":
170
- inference = vt_inference
171
- elif control_type == "pose_transfer":
172
- inference = pt_inference
173
  output = inference(data)
174
  gen_image = output["generated_image"][0]
175
- # gen_image.save("gen_image.png")
 
176
  return np.array(gen_image)
177
 
178
-
179
  def leffa_predict_vt(src_image_path, ref_image_path):
180
  return leffa_predict(src_image_path, ref_image_path, "virtual_tryon")
181
 
182
-
183
  def leffa_predict_pt(src_image_path, ref_image_path):
184
  return leffa_predict(src_image_path, ref_image_path, "pose_transfer")
185
 
186
 
187
-
188
  with gr.Blocks(theme=gr.themes.Default(primary_hue=gr.themes.colors.pink, secondary_hue=gr.themes.colors.red)) as demo:
189
  gr.Markdown("# 🎭 Fashion Studio & Virtual Try-on")
190
 
@@ -222,7 +273,7 @@ with gr.Blocks(theme=gr.themes.Default(primary_hue=gr.themes.colors.pink, second
222
  steps = gr.Slider(
223
  label="Steps",
224
  minimum=1,
225
- maximum=100,
226
  step=1,
227
  value=30
228
  )
@@ -238,14 +289,14 @@ with gr.Blocks(theme=gr.themes.Default(primary_hue=gr.themes.colors.pink, second
238
  width = gr.Slider(
239
  label="Width",
240
  minimum=256,
241
- maximum=1536,
242
  step=64,
243
  value=512
244
  )
245
  height = gr.Slider(
246
  label="Height",
247
  minimum=256,
248
- maximum=1536,
249
  step=64,
250
  value=768
251
  )
@@ -363,8 +414,6 @@ with gr.Blocks(theme=gr.themes.Default(primary_hue=gr.themes.colors.pink, second
363
  )
364
  pose_transfer_gen_button = gr.Button("Generate")
365
 
366
- gr.Markdown(note)
367
-
368
  # 이벤트 핸들러
369
  generate_button.click(
370
  generate_fashion,
@@ -384,4 +433,5 @@ with gr.Blocks(theme=gr.themes.Default(primary_hue=gr.themes.colors.pink, second
384
  outputs=[pt_gen_image]
385
  )
386
 
387
- demo.launch(share=True, server_port=7860)
 
 
16
  import random
17
  import gc
18
 
19
+ # 메모리 관리 설정
20
+ torch.cuda.empty_cache()
21
+ gc.collect()
22
+ os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128'
23
  torch.backends.cudnn.benchmark = True
24
  torch.backends.cuda.matmul.allow_tf32 = True
 
25
 
26
  # 상수 정의
27
  MAX_SEED = 2**32 - 1
 
29
  MODEL_LORA_REPO = "Motas/Flux_Fashion_Photography_Style"
30
  CLOTHES_LORA_REPO = "prithivMLmods/Canopus-Clothing-Flux-LoRA"
31
 
32
+ # Hugging Face 토큰 설정
33
  HF_TOKEN = os.getenv("HF_TOKEN")
34
  if HF_TOKEN is None:
35
  raise ValueError("Please set the HF_TOKEN environment variable")
36
  login(token=HF_TOKEN)
37
 
38
+ # CUDA 설정
 
 
 
 
 
 
 
 
39
  device = "cuda" if torch.cuda.is_available() else "cpu"
40
 
41
+ # 모델 로드 함수
42
+ def load_model_with_optimization(model_class, *args, **kwargs):
43
+ torch.cuda.empty_cache()
44
+ gc.collect()
45
+ model = model_class(*args, **kwargs)
46
+ if device == "cuda":
47
+ model = model.half() # FP16으로 변환
48
+ return model.to(device)
49
+
50
+ # LoRA 로드 함수
51
+ def load_lora(pipe, lora_path):
52
+ pipe.load_lora_weights(lora_path)
53
+ return pipe
54
+
55
+ # FLUX 모델 초기화 (필요할 때만 로드)
56
+ fashion_pipe = None
57
+ def get_fashion_pipe():
58
+ global fashion_pipe
59
+ if fashion_pipe is None:
60
+ torch.cuda.empty_cache()
61
+ fashion_pipe = DiffusionPipeline.from_pretrained(
62
+ BASE_MODEL,
63
+ torch_dtype=torch.float16,
64
+ use_auth_token=HF_TOKEN
65
+ )
66
+ fashion_pipe.enable_memory_efficient_attention()
67
+ fashion_pipe.enable_sequential_cpu_offload()
68
+ return fashion_pipe
69
+
70
+ # 번역기 초기화 (필요할 때만 로드)
71
+ translator = None
72
+ def get_translator():
73
+ global translator
74
+ if translator is None:
75
+ translator = pipeline("translation",
76
+ model="Helsinki-NLP/opus-mt-ko-en",
77
+ device=device if device == "cuda" else -1)
78
+ return translator
79
+
80
+
81
+ # Leffa 모델 관련 함수들
82
+ def get_mask_predictor():
83
+ global mask_predictor
84
+ if mask_predictor is None:
85
+ mask_predictor = AutoMasker(
86
+ densepose_path="./ckpts/densepose",
87
+ schp_path="./ckpts/schp",
88
+ )
89
+ return mask_predictor
90
+
91
+ def get_densepose_predictor():
92
+ global densepose_predictor
93
+ if densepose_predictor is None:
94
+ densepose_predictor = DensePosePredictor(
95
+ config_path="./ckpts/densepose/densepose_rcnn_R_50_FPN_s1x.yaml",
96
+ weights_path="./ckpts/densepose/model_final_162be9.pkl",
97
+ )
98
+ return densepose_predictor
99
+
100
+ def get_vt_model():
101
+ global vt_model, vt_inference
102
+ if vt_model is None:
103
+ torch.cuda.empty_cache()
104
+ vt_model = load_model_with_optimization(
105
+ LeffaModel,
106
+ pretrained_model_name_or_path="./ckpts/stable-diffusion-inpainting",
107
+ pretrained_model="./ckpts/virtual_tryon.pth"
108
+ )
109
+ vt_inference = LeffaInference(model=vt_model)
110
+ return vt_model, vt_inference
111
+
112
+ def get_pt_model():
113
+ global pt_model, pt_inference
114
+ if pt_model is None:
115
+ torch.cuda.empty_cache()
116
+ pt_model = load_model_with_optimization(
117
+ LeffaModel,
118
+ pretrained_model_name_or_path="./ckpts/stable-diffusion-xl-1.0-inpainting-0.1",
119
+ pretrained_model="./ckpts/pose_transfer.pth"
120
+ )
121
+ pt_inference = LeffaInference(model=pt_model)
122
+ return pt_model, pt_inference
123
 
124
  # Leffa 체크포인트 다운로드
125
  snapshot_download(repo_id="franciszzj/Leffa", local_dir="./ckpts")
126
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
  def contains_korean(text):
128
  return any(ord('가') <= ord(char) <= ord('힣') for char in text)
129
 
130
  @spaces.GPU()
131
  def generate_fashion(prompt, mode, cfg_scale, steps, randomize_seed, seed, width, height, lora_scale, progress=gr.Progress(track_tqdm=True)):
132
+ torch.cuda.empty_cache()
133
+
134
  if contains_korean(prompt):
135
+ translator = get_translator()
136
  translated = translator(prompt)[0]['translation_text']
137
  actual_prompt = translated
138
  else:
139
  actual_prompt = prompt
 
140
 
141
  # 모드에 따른 LoRA 로딩 및 트리거워드 설정
142
+ pipe = get_fashion_pipe()
143
  if mode == "Generate Model":
144
+ pipe = load_lora(pipe, MODEL_LORA_REPO)
145
  trigger_word = "fashion photography, professional model"
146
  else:
147
+ pipe = load_lora(pipe, CLOTHES_LORA_REPO)
148
  trigger_word = "upper clothing, fashion item"
149
 
150
  if randomize_seed:
151
  seed = random.randint(0, MAX_SEED)
152
  generator = torch.Generator(device="cuda").manual_seed(seed)
153
 
154
+ # 이미지 크기 제한
155
+ width = min(width, 1024)
156
+ height = min(height, 1024)
157
+
158
  progress(0, "Starting fashion generation...")
159
 
160
  for i in range(1, steps + 1):
161
  if i % (steps // 10) == 0:
162
  progress(i / steps * 100, f"Processing step {i} of {steps}...")
163
 
164
+ image = pipe(
165
  prompt=f"{actual_prompt} {trigger_word}",
166
  num_inference_steps=steps,
167
  guidance_scale=cfg_scale,
168
  width=width,
169
  height=height,
170
  generator=generator,
 
171
  joint_attention_kwargs={"scale": lora_scale},
172
  ).images[0]
173
 
174
  progress(100, "Completed!")
175
  return image, seed
176
+
177
+
178
  def leffa_predict(src_image_path, ref_image_path, control_type):
179
+ torch.cuda.empty_cache()
180
+
181
  assert control_type in [
182
  "virtual_tryon", "pose_transfer"], "Invalid control type: {}".format(control_type)
183
+
184
+ # 이미지 로드 및 크기 조정
185
  src_image = Image.open(src_image_path)
186
  ref_image = Image.open(ref_image_path)
187
  src_image = resize_and_center(src_image, 768, 1024)
 
190
  src_image_array = np.array(src_image)
191
  ref_image_array = np.array(ref_image)
192
 
193
+ # Mask 생성
194
  if control_type == "virtual_tryon":
195
+ mask_pred = get_mask_predictor()
196
  src_image = src_image.convert("RGB")
197
+ mask = mask_pred(src_image, "upper")["mask"]
198
  elif control_type == "pose_transfer":
199
  mask = Image.fromarray(np.ones_like(src_image_array) * 255)
200
 
201
+ # DensePose 예측
202
+ dense_pred = get_densepose_predictor()
203
+ src_image_iuv_array = dense_pred.predict_iuv(src_image_array)
204
+ src_image_seg_array = dense_pred.predict_seg(src_image_array)
205
  src_image_iuv = Image.fromarray(src_image_iuv_array)
206
  src_image_seg = Image.fromarray(src_image_seg_array)
207
+
208
  if control_type == "virtual_tryon":
209
  densepose = src_image_seg
210
+ model, inference = get_vt_model()
211
  elif control_type == "pose_transfer":
212
  densepose = src_image_iuv
213
+ model, inference = get_pt_model()
214
 
215
+ # Leffa 변환 및 추론
216
  transform = LeffaTransform()
 
217
  data = {
218
  "src_image": [src_image],
219
  "ref_image": [ref_image],
 
221
  "densepose": [densepose],
222
  }
223
  data = transform(data)
224
+
 
 
 
225
  output = inference(data)
226
  gen_image = output["generated_image"][0]
227
+
228
+ torch.cuda.empty_cache()
229
  return np.array(gen_image)
230
 
 
231
  def leffa_predict_vt(src_image_path, ref_image_path):
232
  return leffa_predict(src_image_path, ref_image_path, "virtual_tryon")
233
 
 
234
  def leffa_predict_pt(src_image_path, ref_image_path):
235
  return leffa_predict(src_image_path, ref_image_path, "pose_transfer")
236
 
237
 
238
+ # Gradio 인터페이스
239
  with gr.Blocks(theme=gr.themes.Default(primary_hue=gr.themes.colors.pink, secondary_hue=gr.themes.colors.red)) as demo:
240
  gr.Markdown("# 🎭 Fashion Studio & Virtual Try-on")
241
 
 
273
  steps = gr.Slider(
274
  label="Steps",
275
  minimum=1,
276
+ maximum=50, # 최대값 감소
277
  step=1,
278
  value=30
279
  )
 
289
  width = gr.Slider(
290
  label="Width",
291
  minimum=256,
292
+ maximum=1024, # 최대값 감소
293
  step=64,
294
  value=512
295
  )
296
  height = gr.Slider(
297
  label="Height",
298
  minimum=256,
299
+ maximum=1024, # 최대값 감소
300
  step=64,
301
  value=768
302
  )
 
414
  )
415
  pose_transfer_gen_button = gr.Button("Generate")
416
 
 
 
417
  # 이벤트 핸들러
418
  generate_button.click(
419
  generate_fashion,
 
433
  outputs=[pt_gen_image]
434
  )
435
 
436
+ # 앱 실행
437
+ demo.launch(share=True, server_port=7860)