Spaces:
ginipick
/
Running on Zero

ginipick commited on
Commit
8fdc0c8
·
verified ·
1 Parent(s): 99686f2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -11
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import numpy as np
2
  from PIL import Image
3
- from huggingface_hub import snapshot_download
4
  from leffa.transform import LeffaTransform
5
  from leffa.model import LeffaModel
6
  from leffa.inference import LeffaInference
@@ -13,11 +13,19 @@ from diffusers import DiffusionPipeline
13
  from transformers import pipeline
14
  import gradio as gr
15
  import os
16
- from huggingface_hub import login
17
  import random
 
 
 
 
 
 
18
 
19
  # 상수 정의
20
  MAX_SEED = 2**32 - 1
 
 
 
21
 
22
  # Hugging Face 토큰 설정 및 로그인
23
  HF_TOKEN = os.getenv("HF_TOKEN")
@@ -25,20 +33,26 @@ if HF_TOKEN is None:
25
  raise ValueError("Please set the HF_TOKEN environment variable")
26
  login(token=HF_TOKEN)
27
 
28
- # 모델 설정 (한 번만 선언)
29
- base_model = "black-forest-labs/FLUX.1-dev"
30
- model_lora_repo = "Motas/Flux_Fashion_Photography_Style"
31
- clothes_lora_repo = "prithivMLmods/Canopus-Clothing-Flux-LoRA"
 
 
 
32
 
33
- # FLUX 모델 초기화 (한 번만 초기화)
 
 
 
34
  fashion_pipe = DiffusionPipeline.from_pretrained(
35
- base_model,
36
- torch_dtype=torch.bfloat16,
37
  use_auth_token=HF_TOKEN
38
  )
39
- fashion_pipe.to("cuda")
40
 
41
- # 번역기 초기화 (한 번만)
42
  translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en")
43
 
44
  # Leffa 체크포인트 다운로드
@@ -55,18 +69,22 @@ densepose_predictor = DensePosePredictor(
55
  weights_path="./ckpts/densepose/model_final_162be9.pkl",
56
  )
57
 
 
58
  vt_model = LeffaModel(
59
  pretrained_model_name_or_path="./ckpts/stable-diffusion-inpainting",
60
  pretrained_model="./ckpts/virtual_tryon.pth",
 
61
  )
62
  vt_inference = LeffaInference(model=vt_model)
63
 
64
  pt_model = LeffaModel(
65
  pretrained_model_name_or_path="./ckpts/stable-diffusion-xl-1.0-inpainting-0.1",
66
  pretrained_model="./ckpts/pose_transfer.pth",
 
67
  )
68
  pt_inference = LeffaInference(model=pt_model)
69
 
 
70
  def contains_korean(text):
71
  return any(ord('가') <= ord(char) <= ord('힣') for char in text)
72
 
 
1
  import numpy as np
2
  from PIL import Image
3
+ from huggingface_hub import snapshot_download, login
4
  from leffa.transform import LeffaTransform
5
  from leffa.model import LeffaModel
6
  from leffa.inference import LeffaInference
 
13
  from transformers import pipeline
14
  import gradio as gr
15
  import os
 
16
  import random
17
+ import gc
18
+
19
+ # 메모리 최적화 설정
20
+ torch.backends.cudnn.benchmark = True
21
+ torch.backends.cuda.matmul.allow_tf32 = True
22
+ os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:512'
23
 
24
  # 상수 정의
25
  MAX_SEED = 2**32 - 1
26
+ BASE_MODEL = "black-forest-labs/FLUX.1-dev"
27
+ MODEL_LORA_REPO = "Motas/Flux_Fashion_Photography_Style"
28
+ CLOTHES_LORA_REPO = "prithivMLmods/Canopus-Clothing-Flux-LoRA"
29
 
30
  # Hugging Face 토큰 설정 및 로그인
31
  HF_TOKEN = os.getenv("HF_TOKEN")
 
33
  raise ValueError("Please set the HF_TOKEN environment variable")
34
  login(token=HF_TOKEN)
35
 
36
+ # 메모리 정리 함수
37
+ def clear_memory():
38
+ torch.cuda.empty_cache()
39
+ gc.collect()
40
+
41
+ # 초기 메모리 정리
42
+ clear_memory()
43
 
44
+ # CUDA 사용 가능 여부 확인
45
+ device = "cuda" if torch.cuda.is_available() else "cpu"
46
+
47
+ # FLUX 모델 초기화
48
  fashion_pipe = DiffusionPipeline.from_pretrained(
49
+ BASE_MODEL,
50
+ torch_dtype=torch.float16,
51
  use_auth_token=HF_TOKEN
52
  )
53
+ fashion_pipe.enable_model_cpu_offload()
54
 
55
+ # 번역기 초기화
56
  translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en")
57
 
58
  # Leffa 체크포인트 다운로드
 
69
  weights_path="./ckpts/densepose/model_final_162be9.pkl",
70
  )
71
 
72
+ # Leffa 모델 초기화
73
  vt_model = LeffaModel(
74
  pretrained_model_name_or_path="./ckpts/stable-diffusion-inpainting",
75
  pretrained_model="./ckpts/virtual_tryon.pth",
76
+ device_map="auto"
77
  )
78
  vt_inference = LeffaInference(model=vt_model)
79
 
80
  pt_model = LeffaModel(
81
  pretrained_model_name_or_path="./ckpts/stable-diffusion-xl-1.0-inpainting-0.1",
82
  pretrained_model="./ckpts/pose_transfer.pth",
83
+ device_map="auto"
84
  )
85
  pt_inference = LeffaInference(model=pt_model)
86
 
87
+
88
  def contains_korean(text):
89
  return any(ord('가') <= ord(char) <= ord('힣') for char in text)
90