Spaces:
ginipick
/
Running on Zero

ginipick commited on
Commit
e696492
·
verified ·
1 Parent(s): d9d2fce

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -46
app.py CHANGED
@@ -63,52 +63,16 @@ def setup_environment():
63
  login(token=HF_TOKEN)
64
  return HF_TOKEN
65
 
66
- @spaces.GPU()
67
- def generate_image(prompt, mode, cfg_scale=7.0, steps=30, seed=None, width=512, height=768, lora_scale=0.85):
68
  try:
69
- with torch_gc():
70
- # 한글 처리
71
- if contains_korean(prompt):
72
- translator = get_translator()
73
- with torch.inference_mode():
74
- translated = translator(prompt)[0]['translation_text']
75
- actual_prompt = translated
76
- else:
77
- actual_prompt = prompt
78
-
79
- # 파이프라인 초기화
80
- pipe = DiffusionPipeline.from_pretrained(
81
- BASE_MODEL,
82
- torch_dtype=torch.float16,
83
- )
84
- pipe = pipe.to("cuda")
85
-
86
- # LoRA 설정
87
- if mode == "Generate Model":
88
- pipe.load_lora_weights(MODEL_LORA_REPO)
89
- trigger_word = "fashion photography, professional model"
90
- else:
91
- pipe.load_lora_weights(CLOTHES_LORA_REPO)
92
- trigger_word = "upper clothing, fashion item"
93
-
94
- # 이미지 생성
95
- with torch.inference_mode():
96
- result = pipe(
97
- prompt=f"{actual_prompt} {trigger_word}",
98
- num_inference_steps=steps,
99
- guidance_scale=cfg_scale,
100
- width=width,
101
- height=height,
102
- generator=torch.Generator("cuda").manual_seed(
103
- seed if seed is not None else torch.randint(0, 2**32 - 1, (1,)).item()
104
- ),
105
- joint_attention_kwargs={"scale": lora_scale},
106
- ).images[0]
107
-
108
- return result, seed
109
 
110
- except Exception as e:
111
- raise gr.Error(f"Generation failed: {str(e)}")
112
 
113
  def contains_korean(text):
114
  return any(ord('가') <= ord(char) <= ord('힣') for char in text)
@@ -601,6 +565,5 @@ if __name__ == "__main__":
601
  demo.launch(
602
  server_name="0.0.0.0",
603
  server_port=7860,
604
- share=False,
605
- memory_target_gb=0.5 # 메모리 제한 설정
606
  )
 
63
  login(token=HF_TOKEN)
64
  return HF_TOKEN
65
 
66
+ @contextmanager
67
+ def torch_gc():
68
  try:
69
+ yield
70
+ finally:
71
+ gc.collect()
72
+ if torch.cuda.is_available() and torch.cuda.current_device() >= 0:
73
+ with torch.cuda.device('cuda'):
74
+ torch.cuda.empty_cache()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
 
 
76
 
77
  def contains_korean(text):
78
  return any(ord('가') <= ord(char) <= ord('힣') for char in text)
 
565
  demo.launch(
566
  server_name="0.0.0.0",
567
  server_port=7860,
568
+ share=False
 
569
  )