Update app.py
Browse files
app.py
CHANGED
@@ -54,10 +54,18 @@ def load_model_with_optimization(model_class, *args, **kwargs):
|
|
54 |
model = model.half() # FP16์ผ๋ก ๋ณํ
|
55 |
return model.to(device)
|
56 |
|
57 |
-
# LoRA ๋ก๋ ํจ์
|
58 |
def load_lora(pipe, lora_path):
|
59 |
-
|
60 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
|
62 |
# FLUX ๋ชจ๋ธ ์ด๊ธฐํ (ํ์ํ ๋๋ง ๋ก๋)
|
63 |
fashion_pipe = None
|
@@ -151,6 +159,14 @@ def generate_fashion(prompt, mode, cfg_scale, steps, randomize_seed, seed, width
|
|
151 |
|
152 |
pipe = get_fashion_pipe()
|
153 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
154 |
# ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๋ ์ ํ์ ์ํ ํฌ๊ธฐ ์กฐ์
|
155 |
width = min(width, 768) # ์ต๋ ํฌ๊ธฐ ์ ํ
|
156 |
height = min(height, 768) # ์ต๋ ํฌ๊ธฐ ์ ํ
|
@@ -178,59 +194,6 @@ def generate_fashion(prompt, mode, cfg_scale, steps, randomize_seed, seed, width
|
|
178 |
clear_memory() # ์ค๋ฅ ๋ฐ์ ์์๋ ๋ฉ๋ชจ๋ฆฌ ์ ๋ฆฌ
|
179 |
raise e
|
180 |
|
181 |
-
def leffa_predict(src_image_path, ref_image_path, control_type):
|
182 |
-
torch.cuda.empty_cache()
|
183 |
-
|
184 |
-
assert control_type in [
|
185 |
-
"virtual_tryon", "pose_transfer"], "Invalid control type: {}".format(control_type)
|
186 |
-
|
187 |
-
# ์ด๋ฏธ์ง ๋ก๋ ๋ฐ ํฌ๊ธฐ ์กฐ์
|
188 |
-
src_image = Image.open(src_image_path)
|
189 |
-
ref_image = Image.open(ref_image_path)
|
190 |
-
src_image = resize_and_center(src_image, 768, 1024)
|
191 |
-
ref_image = resize_and_center(ref_image, 768, 1024)
|
192 |
-
|
193 |
-
src_image_array = np.array(src_image)
|
194 |
-
ref_image_array = np.array(ref_image)
|
195 |
-
|
196 |
-
# Mask ์์ฑ
|
197 |
-
if control_type == "virtual_tryon":
|
198 |
-
mask_pred = get_mask_predictor()
|
199 |
-
src_image = src_image.convert("RGB")
|
200 |
-
mask = mask_pred(src_image, "upper")["mask"]
|
201 |
-
elif control_type == "pose_transfer":
|
202 |
-
mask = Image.fromarray(np.ones_like(src_image_array) * 255)
|
203 |
-
|
204 |
-
# DensePose ์์ธก
|
205 |
-
dense_pred = get_densepose_predictor()
|
206 |
-
src_image_iuv_array = dense_pred.predict_iuv(src_image_array)
|
207 |
-
src_image_seg_array = dense_pred.predict_seg(src_image_array)
|
208 |
-
src_image_iuv = Image.fromarray(src_image_iuv_array)
|
209 |
-
src_image_seg = Image.fromarray(src_image_seg_array)
|
210 |
-
|
211 |
-
if control_type == "virtual_tryon":
|
212 |
-
densepose = src_image_seg
|
213 |
-
model, inference = get_vt_model()
|
214 |
-
elif control_type == "pose_transfer":
|
215 |
-
densepose = src_image_iuv
|
216 |
-
model, inference = get_pt_model()
|
217 |
-
|
218 |
-
# Leffa ๋ณํ ๋ฐ ์ถ๋ก
|
219 |
-
transform = LeffaTransform()
|
220 |
-
data = {
|
221 |
-
"src_image": [src_image],
|
222 |
-
"ref_image": [ref_image],
|
223 |
-
"mask": [mask],
|
224 |
-
"densepose": [densepose],
|
225 |
-
}
|
226 |
-
data = transform(data)
|
227 |
-
|
228 |
-
output = inference(data)
|
229 |
-
gen_image = output["generated_image"][0]
|
230 |
-
|
231 |
-
torch.cuda.empty_cache()
|
232 |
-
return np.array(gen_image)
|
233 |
-
|
234 |
def leffa_predict_vt(src_image_path, ref_image_path):
|
235 |
return leffa_predict(src_image_path, ref_image_path, "virtual_tryon")
|
236 |
|
@@ -239,7 +202,7 @@ def leffa_predict_pt(src_image_path, ref_image_path):
|
|
239 |
|
240 |
|
241 |
# Gradio ์ธํฐํ์ด์ค
|
242 |
-
with gr.Blocks(theme=
|
243 |
gr.Markdown("# ๐ญ Fashion Studio & Virtual Try-on")
|
244 |
|
245 |
with gr.Tabs():
|
|
|
54 |
model = model.half() # FP16์ผ๋ก ๋ณํ
|
55 |
return model.to(device)
|
56 |
|
|
|
57 |
def load_lora(pipe, lora_path):
|
58 |
+
try:
|
59 |
+
pipe.unload_lora_weights() # ๊ธฐ์กด LoRA ๊ฐ์ค์น ์ ๊ฑฐ
|
60 |
+
except:
|
61 |
+
pass
|
62 |
+
|
63 |
+
try:
|
64 |
+
pipe.load_lora_weights(lora_path)
|
65 |
+
return pipe
|
66 |
+
except Exception as e:
|
67 |
+
print(f"Warning: Failed to load LoRA weights from {lora_path}: {e}")
|
68 |
+
return pipe
|
69 |
|
70 |
# FLUX ๋ชจ๋ธ ์ด๊ธฐํ (ํ์ํ ๋๋ง ๋ก๋)
|
71 |
fashion_pipe = None
|
|
|
159 |
|
160 |
pipe = get_fashion_pipe()
|
161 |
|
162 |
+
# ๋ชจ๋์ ๋ฐ๋ฅธ LoRA ๋ก๋ฉ ๋ฐ ํธ๋ฆฌ๊ฑฐ์๋ ์ค์
|
163 |
+
if mode == "Generate Model":
|
164 |
+
pipe = load_lora(pipe, MODEL_LORA_REPO)
|
165 |
+
trigger_word = "fashion photography, professional model"
|
166 |
+
else:
|
167 |
+
pipe = load_lora(pipe, CLOTHES_LORA_REPO)
|
168 |
+
trigger_word = "upper clothing, fashion item"
|
169 |
+
|
170 |
# ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๋ ์ ํ์ ์ํ ํฌ๊ธฐ ์กฐ์
|
171 |
width = min(width, 768) # ์ต๋ ํฌ๊ธฐ ์ ํ
|
172 |
height = min(height, 768) # ์ต๋ ํฌ๊ธฐ ์ ํ
|
|
|
194 |
clear_memory() # ์ค๋ฅ ๋ฐ์ ์์๋ ๋ฉ๋ชจ๋ฆฌ ์ ๋ฆฌ
|
195 |
raise e
|
196 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
197 |
def leffa_predict_vt(src_image_path, ref_image_path):
|
198 |
return leffa_predict(src_image_path, ref_image_path, "virtual_tryon")
|
199 |
|
|
|
202 |
|
203 |
|
204 |
# Gradio ์ธํฐํ์ด์ค
|
205 |
+
with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange") as demo:
|
206 |
gr.Markdown("# ๐ญ Fashion Studio & Virtual Try-on")
|
207 |
|
208 |
with gr.Tabs():
|