Update app.py
Browse files
app.py
CHANGED
@@ -251,67 +251,105 @@ def generate_fashion(prompt, mode, cfg_scale, steps, randomize_seed, seed, width
|
|
251 |
print(f"Error in generate_fashion: {str(e)}")
|
252 |
raise gr.Error(f"Generation failed: {str(e)}")
|
253 |
|
254 |
-
|
255 |
-
def
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
262 |
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
|
|
|
|
268 |
|
269 |
-
|
270 |
-
|
271 |
|
272 |
-
|
273 |
-
|
|
|
|
|
|
|
274 |
if control_type == "virtual_tryon":
|
275 |
-
|
276 |
-
mask_pred = get_mask_predictor()
|
277 |
-
mask = mask_pred(src_image, "upper")["mask"]
|
278 |
else:
|
279 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
280 |
|
281 |
-
|
282 |
-
|
283 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
284 |
|
285 |
-
|
286 |
-
|
287 |
-
else:
|
288 |
-
densepose = Image.fromarray(src_image_iuv_array)
|
289 |
-
|
290 |
-
# Leffa 변환 및 추론
|
291 |
-
transform = LeffaTransform()
|
292 |
-
data = {
|
293 |
-
"src_image": [src_image],
|
294 |
-
"ref_image": [ref_image],
|
295 |
-
"mask": [mask],
|
296 |
-
"densepose": [densepose],
|
297 |
-
}
|
298 |
-
data = transform(data)
|
299 |
-
|
300 |
-
with torch.inference_mode():
|
301 |
-
output = inference(data)
|
302 |
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
|
309 |
-
|
310 |
-
|
311 |
except Exception as e:
|
312 |
print(f"Error in leffa_predict: {str(e)}")
|
313 |
raise
|
314 |
|
|
|
|
|
315 |
@spaces.GPU()
|
316 |
def leffa_predict_vt(src_image_path, ref_image_path):
|
317 |
try:
|
@@ -328,20 +366,20 @@ def leffa_predict_pt(src_image_path, ref_image_path):
|
|
328 |
print(f"Error in leffa_predict_pt: {str(e)}")
|
329 |
raise
|
330 |
|
331 |
-
|
332 |
@spaces.GPU()
|
333 |
def generate_image(prompt, mode, cfg_scale=7.0, steps=30, seed=None, width=512, height=768, lora_scale=0.85):
|
334 |
try:
|
335 |
with torch_gc():
|
336 |
# 한글 처리
|
337 |
if contains_korean(prompt):
|
338 |
-
translator = get_translator()
|
339 |
with torch.inference_mode():
|
340 |
translated = translator(prompt)[0]['translation_text']
|
341 |
actual_prompt = translated
|
342 |
else:
|
343 |
actual_prompt = prompt
|
344 |
|
|
|
345 |
# 파이프라인 초기화
|
346 |
pipe = DiffusionPipeline.from_pretrained(
|
347 |
BASE_MODEL,
|
@@ -375,8 +413,10 @@ def generate_image(prompt, mode, cfg_scale=7.0, steps=30, seed=None, width=512,
|
|
375 |
del pipe
|
376 |
return result, seed
|
377 |
|
|
|
378 |
except Exception as e:
|
379 |
raise gr.Error(f"Generation failed: {str(e)}")
|
|
|
380 |
|
381 |
# 초기 설정 실행
|
382 |
setup()
|
|
|
251 |
print(f"Error in generate_fashion: {str(e)}")
|
252 |
raise gr.Error(f"Generation failed: {str(e)}")
|
253 |
|
254 |
+
class ModelManager:
|
255 |
+
def __init__(self):
|
256 |
+
self.mask_predictor = None
|
257 |
+
self.densepose_predictor = None
|
258 |
+
self.translator = None
|
259 |
+
|
260 |
+
@spaces.GPU()
|
261 |
+
def get_mask_predictor(self):
|
262 |
+
if self.mask_predictor is None:
|
263 |
+
self.mask_predictor = AutoMasker(
|
264 |
+
densepose_path="./ckpts/densepose",
|
265 |
+
schp_path="./ckpts/schp",
|
266 |
+
)
|
267 |
+
return self.mask_predictor
|
268 |
+
|
269 |
+
@spaces.GPU()
|
270 |
+
def get_densepose_predictor(self):
|
271 |
+
if self.densepose_predictor is None:
|
272 |
+
self.densepose_predictor = DensePosePredictor(
|
273 |
+
config_path="./ckpts/densepose/densepose_rcnn_R_50_FPN_s1x.yaml",
|
274 |
+
weights_path="./ckpts/densepose/model_final_162be9.pkl",
|
275 |
+
)
|
276 |
+
return self.densepose_predictor
|
277 |
|
278 |
+
@spaces.GPU()
|
279 |
+
def get_translator(self):
|
280 |
+
if self.translator is None:
|
281 |
+
self.translator = pipeline("translation",
|
282 |
+
model="Helsinki-NLP/opus-mt-ko-en",
|
283 |
+
device="cuda")
|
284 |
+
return self.translator
|
285 |
|
286 |
+
# 모델 매니저 인스턴스 생성
|
287 |
+
model_manager = ModelManager()
|
288 |
|
289 |
+
@spaces.GPU()
|
290 |
+
def leffa_predict(src_image_path, ref_image_path, control_type):
|
291 |
+
try:
|
292 |
+
with torch_gc():
|
293 |
+
# 모델 초기화
|
294 |
if control_type == "virtual_tryon":
|
295 |
+
model, inference = get_vt_model()
|
|
|
|
|
296 |
else:
|
297 |
+
model, inference = get_pt_model()
|
298 |
+
|
299 |
+
# 이미지 처리
|
300 |
+
src_image = Image.open(src_image_path)
|
301 |
+
ref_image = Image.open(ref_image_path)
|
302 |
+
src_image = resize_and_center(src_image, 768, 1024)
|
303 |
+
ref_image = resize_and_center(ref_image, 768, 1024)
|
304 |
|
305 |
+
src_image_array = np.array(src_image)
|
306 |
+
ref_image_array = np.array(ref_image)
|
307 |
+
|
308 |
+
# Mask 및 DensePose 처리
|
309 |
+
with torch.inference_mode():
|
310 |
+
if control_type == "virtual_tryon":
|
311 |
+
src_image = src_image.convert("RGB")
|
312 |
+
mask_pred = model_manager.get_mask_predictor()
|
313 |
+
mask = mask_pred(src_image, "upper")["mask"]
|
314 |
+
else:
|
315 |
+
mask = Image.fromarray(np.ones_like(src_image_array) * 255)
|
316 |
+
|
317 |
+
dense_pred = model_manager.get_densepose_predictor()
|
318 |
+
src_image_iuv_array = dense_pred.predict_iuv(src_image_array)
|
319 |
+
src_image_seg_array = dense_pred.predict_seg(src_image_array)
|
320 |
+
|
321 |
+
if control_type == "virtual_tryon":
|
322 |
+
densepose = Image.fromarray(src_image_seg_array)
|
323 |
+
else:
|
324 |
+
densepose = Image.fromarray(src_image_iuv_array)
|
325 |
+
|
326 |
+
# Leffa 변환 및 추론
|
327 |
+
transform = LeffaTransform()
|
328 |
+
data = {
|
329 |
+
"src_image": [src_image],
|
330 |
+
"ref_image": [ref_image],
|
331 |
+
"mask": [mask],
|
332 |
+
"densepose": [densepose],
|
333 |
+
}
|
334 |
+
data = transform(data)
|
335 |
|
336 |
+
with torch.inference_mode():
|
337 |
+
output = inference(data)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
338 |
|
339 |
+
# 메모리 정리
|
340 |
+
del model
|
341 |
+
del inference
|
342 |
+
torch.cuda.empty_cache()
|
343 |
+
gc.collect()
|
344 |
|
345 |
+
return np.array(output["generated_image"][0])
|
346 |
+
|
347 |
except Exception as e:
|
348 |
print(f"Error in leffa_predict: {str(e)}")
|
349 |
raise
|
350 |
|
351 |
+
|
352 |
+
|
353 |
@spaces.GPU()
|
354 |
def leffa_predict_vt(src_image_path, ref_image_path):
|
355 |
try:
|
|
|
366 |
print(f"Error in leffa_predict_pt: {str(e)}")
|
367 |
raise
|
368 |
|
|
|
369 |
@spaces.GPU()
|
370 |
def generate_image(prompt, mode, cfg_scale=7.0, steps=30, seed=None, width=512, height=768, lora_scale=0.85):
|
371 |
try:
|
372 |
with torch_gc():
|
373 |
# 한글 처리
|
374 |
if contains_korean(prompt):
|
375 |
+
translator = model_manager.get_translator()
|
376 |
with torch.inference_mode():
|
377 |
translated = translator(prompt)[0]['translation_text']
|
378 |
actual_prompt = translated
|
379 |
else:
|
380 |
actual_prompt = prompt
|
381 |
|
382 |
+
|
383 |
# 파이프라인 초기화
|
384 |
pipe = DiffusionPipeline.from_pretrained(
|
385 |
BASE_MODEL,
|
|
|
413 |
del pipe
|
414 |
return result, seed
|
415 |
|
416 |
+
|
417 |
except Exception as e:
|
418 |
raise gr.Error(f"Generation failed: {str(e)}")
|
419 |
+
|
420 |
|
421 |
# 초기 설정 실행
|
422 |
setup()
|