Spaces:
ginipick
/
Running on Zero

ginipick commited on
Commit
235a7ab
·
verified ·
1 Parent(s): 7ac3a5d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -13
app.py CHANGED
@@ -42,23 +42,55 @@ def clear_memory():
42
  if torch.cuda.is_available() and torch.cuda.current_device() >= 0:
43
  torch.cuda.empty_cache()
44
 
45
- # 환경 설정 함수 수정
46
  def setup_environment():
47
- # 메모리 관리 설정
48
- gc.collect()
49
  os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128'
50
-
51
- # Hugging Face 토큰 설정
52
- global HF_TOKEN
53
  HF_TOKEN = os.getenv("HF_TOKEN")
54
- if HF_TOKEN is None:
55
- raise ValueError("Please set the HF_TOKEN environment variable")
56
  login(token=HF_TOKEN)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
- # device 설정 제거 (spaces.GPU() 데코레이터가 처리)
59
- global device
60
- device = "cpu" # 기본값으로 CPU 설정
61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
 
64
  # 전역 변수 초기화
@@ -537,5 +569,15 @@ with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange") as demo:
537
  outputs=[pt_gen_image]
538
  )
539
 
540
- # 앱 실행
541
- demo.launch(share=True, server_port=7860)
 
 
 
 
 
 
 
 
 
 
 
42
  if torch.cuda.is_available() and torch.cuda.current_device() >= 0:
43
  torch.cuda.empty_cache()
44
 
 
45
  def setup_environment():
 
 
46
  os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128'
 
 
 
47
  HF_TOKEN = os.getenv("HF_TOKEN")
48
+ if not HF_TOKEN:
49
+ raise ValueError("HF_TOKEN not found in environment variables")
50
  login(token=HF_TOKEN)
51
+ return HF_TOKEN
52
+ @spaces.GPU()
53
+ def generate_image(prompt, mode, cfg_scale=7.0, steps=30, seed=None, width=512, height=768, lora_scale=0.85):
54
+ try:
55
+ # 파이프라인 초기화
56
+ pipe = DiffusionPipeline.from_pretrained(
57
+ BASE_MODEL,
58
+ torch_dtype=torch.float16,
59
+ )
60
+ pipe.to("cuda")
61
+
62
+ # LoRA 설정
63
+ if mode == "Generate Model":
64
+ pipe.load_lora_weights(MODEL_LORA_REPO)
65
+ trigger_word = "fashion photography, professional model"
66
+ else:
67
+ pipe.load_lora_weights(CLOTHES_LORA_REPO)
68
+ trigger_word = "upper clothing, fashion item"
69
 
70
+ # 생성 설정
71
+ generator = torch.Generator("cuda").manual_seed(seed if seed is not None else torch.randint(0, 2**32 - 1, (1,)).item())
 
72
 
73
+ # 이미지 생성
74
+ with torch.inference_mode():
75
+ result = pipe(
76
+ prompt=f"{prompt} {trigger_word}",
77
+ num_inference_steps=steps,
78
+ guidance_scale=cfg_scale,
79
+ width=width,
80
+ height=height,
81
+ generator=generator,
82
+ cross_attention_kwargs={"scale": lora_scale},
83
+ ).images[0]
84
+
85
+ # 메모리 정리
86
+ del pipe
87
+ clear_memory()
88
+
89
+ return result, seed
90
+
91
+ except Exception as e:
92
+ clear_memory()
93
+ raise gr.Error(f"Generation failed: {str(e)}")
94
 
95
 
96
  # 전역 변수 초기화
 
569
  outputs=[pt_gen_image]
570
  )
571
 
572
+
573
+ if __name__ == "__main__":
574
+ # 환경 설정
575
+ setup_environment()
576
+
577
+ # 인터페이스 생성 및 실행
578
+ demo = create_interface()
579
+ demo.launch(
580
+ server_name="0.0.0.0",
581
+ server_port=7860,
582
+ share=False
583
+ )