Spaces:
ginipick
/
Running on Zero

ginipick commited on
Commit
6fca6aa
·
verified ·
1 Parent(s): e696492

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -1
app.py CHANGED
@@ -328,6 +328,56 @@ def leffa_predict_pt(src_image_path, ref_image_path):
328
  print(f"Error in leffa_predict_pt: {str(e)}")
329
  raise
330
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
331
  # 초기 설정 실행
332
  setup()
333
  def create_interface():
@@ -561,7 +611,7 @@ def create_interface():
561
  if __name__ == "__main__":
562
  setup_environment()
563
  demo = create_interface()
564
- demo.queue() # 큐 활성화
565
  demo.launch(
566
  server_name="0.0.0.0",
567
  server_port=7860,
 
328
  print(f"Error in leffa_predict_pt: {str(e)}")
329
  raise
330
 
331
+
332
+ @spaces.GPU()
333
+ def generate_image(prompt, mode, cfg_scale=7.0, steps=30, seed=None, width=512, height=768, lora_scale=0.85):
334
+ try:
335
+ with torch_gc():
336
+ # 한글 처리
337
+ if contains_korean(prompt):
338
+ translator = get_translator()
339
+ with torch.inference_mode():
340
+ translated = translator(prompt)[0]['translation_text']
341
+ actual_prompt = translated
342
+ else:
343
+ actual_prompt = prompt
344
+
345
+ # 파이프라인 초기화
346
+ pipe = DiffusionPipeline.from_pretrained(
347
+ BASE_MODEL,
348
+ torch_dtype=torch.float16,
349
+ )
350
+ pipe = pipe.to("cuda")
351
+
352
+ # LoRA 설정
353
+ if mode == "Generate Model":
354
+ pipe.load_lora_weights(MODEL_LORA_REPO)
355
+ trigger_word = "fashion photography, professional model"
356
+ else:
357
+ pipe.load_lora_weights(CLOTHES_LORA_REPO)
358
+ trigger_word = "upper clothing, fashion item"
359
+
360
+ # 이미지 생성
361
+ with torch.inference_mode():
362
+ result = pipe(
363
+ prompt=f"{actual_prompt} {trigger_word}",
364
+ num_inference_steps=steps,
365
+ guidance_scale=cfg_scale,
366
+ width=width,
367
+ height=height,
368
+ generator=torch.Generator("cuda").manual_seed(
369
+ seed if seed is not None else torch.randint(0, 2**32 - 1, (1,)).item()
370
+ ),
371
+ joint_attention_kwargs={"scale": lora_scale},
372
+ ).images[0]
373
+
374
+ # 메모리 정리
375
+ del pipe
376
+ return result, seed
377
+
378
+ except Exception as e:
379
+ raise gr.Error(f"Generation failed: {str(e)}")
380
+
381
  # 초기 설정 실행
382
  setup()
383
  def create_interface():
 
611
  if __name__ == "__main__":
612
  setup_environment()
613
  demo = create_interface()
614
+ demo.queue()
615
  demo.launch(
616
  server_name="0.0.0.0",
617
  server_port=7860,