Spaces:
ginipick
/
Running on Zero

ginipick commited on
Commit
1f17448
·
verified ·
1 Parent(s): 5aecaf0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -25
app.py CHANGED
@@ -36,21 +36,17 @@ def safe_model_call(func):
36
  raise
37
  return wrapper
38
 
39
- # 메모리 관리 함수
40
  def clear_memory():
41
- if torch.cuda.is_available():
42
- torch.cuda.empty_cache()
43
- torch.cuda.synchronize()
44
  gc.collect()
 
 
45
 
 
46
  def setup_environment():
47
  # 메모리 관리 설정
48
- torch.cuda.empty_cache()
49
  gc.collect()
50
  os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128'
51
- torch.backends.cudnn.benchmark = True
52
- torch.backends.cuda.matmul.allow_tf32 = True
53
- torch.backends.cuda.max_split_size_mb = 128
54
 
55
  # Hugging Face 토큰 설정
56
  global HF_TOKEN
@@ -59,9 +55,11 @@ def setup_environment():
59
  raise ValueError("Please set the HF_TOKEN environment variable")
60
  login(token=HF_TOKEN)
61
 
62
- # CUDA 설정
63
  global device
64
- device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
65
 
66
  # 전역 변수 초기화
67
  fashion_pipe = None
@@ -78,31 +76,32 @@ HF_TOKEN = None
78
  # 환경 설정 실행
79
  setup_environment()
80
 
81
-
82
- # 모델 관리 함수들
83
  def initialize_fashion_pipe():
84
  global fashion_pipe
85
  if fashion_pipe is None:
86
- clear_memory()
87
  fashion_pipe = DiffusionPipeline.from_pretrained(
88
  BASE_MODEL,
89
  torch_dtype=torch.float16,
90
  use_auth_token=HF_TOKEN
91
- )
92
  try:
93
  fashion_pipe.enable_xformers_memory_efficient_attention()
94
  except Exception as e:
95
  print(f"Warning: Could not enable memory efficient attention: {e}")
96
- fashion_pipe.enable_sequential_cpu_offload()
97
  return fashion_pipe
98
 
99
- @safe_model_call
 
 
 
 
100
  def get_translator():
101
  global translator
102
  if translator is None:
103
  translator = pipeline("translation",
104
  model="Helsinki-NLP/opus-mt-ko-en",
105
- device=device if device == "cuda" else -1)
106
  return translator
107
 
108
  @safe_model_call
@@ -161,13 +160,16 @@ def load_lora(pipe, lora_path):
161
  print(f"Warning: Failed to load LoRA weights from {lora_path}: {e}")
162
  return pipe
163
 
164
- # 초기 설정 함수
165
- def setup():
166
- # Leffa 체크포인트 다운로드
167
- snapshot_download(repo_id="franciszzj/Leffa", local_dir="./ckpts")
168
- # 기본 모델 초기화
169
- initialize_fashion_pipe()
170
-
 
 
 
171
  # 유틸리티 함수
172
  def contains_korean(text):
173
  return any(ord('가') <= ord(char) <= ord('힣') for char in text)
@@ -175,7 +177,6 @@ def contains_korean(text):
175
 
176
  # 메인 기능 함수들
177
  @spaces.GPU()
178
- @safe_model_call
179
  def generate_fashion(prompt, mode, cfg_scale, steps, randomize_seed, seed, width, height, lora_scale, progress=gr.Progress(track_tqdm=True)):
180
  try:
181
  # 한글 처리
 
36
  raise
37
  return wrapper
38
 
39
+ # 메모리 관리 함수 수정
40
  def clear_memory():
 
 
 
41
  gc.collect()
42
+ if torch.cuda.is_available() and torch.cuda.current_device() >= 0:
43
+ torch.cuda.empty_cache()
44
 
45
+ # 환경 설정 함수 수정
46
  def setup_environment():
47
  # 메모리 관리 설정
 
48
  gc.collect()
49
  os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128'
 
 
 
50
 
51
  # Hugging Face 토큰 설정
52
  global HF_TOKEN
 
55
  raise ValueError("Please set the HF_TOKEN environment variable")
56
  login(token=HF_TOKEN)
57
 
58
+ # device 설정 제거 (spaces.GPU() 데코레이터가 처리)
59
  global device
60
+ device = "cpu" # 기본값으로 CPU 설정
61
+
62
+
63
 
64
  # 전역 변수 초기화
65
  fashion_pipe = None
 
76
  # 환경 설정 실행
77
  setup_environment()
78
 
79
+ @spaces.GPU()
 
80
  def initialize_fashion_pipe():
81
  global fashion_pipe
82
  if fashion_pipe is None:
 
83
  fashion_pipe = DiffusionPipeline.from_pretrained(
84
  BASE_MODEL,
85
  torch_dtype=torch.float16,
86
  use_auth_token=HF_TOKEN
87
+ ).to("cuda")
88
  try:
89
  fashion_pipe.enable_xformers_memory_efficient_attention()
90
  except Exception as e:
91
  print(f"Warning: Could not enable memory efficient attention: {e}")
 
92
  return fashion_pipe
93
 
94
+ def setup():
95
+ # Leffa 체크포인트 다운로드만 수행
96
+ snapshot_download(repo_id="franciszzj/Leffa", local_dir="./ckpts")
97
+
98
+ @spaces.GPU()
99
  def get_translator():
100
  global translator
101
  if translator is None:
102
  translator = pipeline("translation",
103
  model="Helsinki-NLP/opus-mt-ko-en",
104
+ device="cuda")
105
  return translator
106
 
107
  @safe_model_call
 
160
  print(f"Warning: Failed to load LoRA weights from {lora_path}: {e}")
161
  return pipe
162
 
163
+ @spaces.GPU()
164
+ def get_mask_predictor():
165
+ global mask_predictor
166
+ if mask_predictor is None:
167
+ mask_predictor = AutoMasker(
168
+ densepose_path="./ckpts/densepose",
169
+ schp_path="./ckpts/schp",
170
+ )
171
+ return mask_predictor
172
+
173
  # 유틸리티 함수
174
  def contains_korean(text):
175
  return any(ord('가') <= ord(char) <= ord('힣') for char in text)
 
177
 
178
  # 메인 기능 함수들
179
  @spaces.GPU()
 
180
  def generate_fashion(prompt, mode, cfg_scale, steps, randomize_seed, seed, width, height, lora_scale, progress=gr.Progress(track_tqdm=True)):
181
  try:
182
  # 한글 처리