Spaces:
Running
on
Zero
Running
on
Zero
try off version
Browse files- .gitignore +58 -0
- README.md +10 -7
- app.py +185 -3
- example/person/00008_00.jpg +0 -0
- example/person/00008_00_mask.png +0 -0
- example/person/00055_00.jpg +0 -0
- example/person/00055_00_mask.png +0 -0
- example/person/00057_00.jpg +0 -0
- example/person/00057_00_mask.png +0 -0
- example/person/00064_00.jpg +0 -0
- example/person/00064_00_mask.png +0 -0
- example/person/00067_00.jpg +0 -0
- example/person/00067_00_mask.png +0 -0
- example/person/00069_00.jpg +0 -0
- example/person/00069_00_mask.png +0 -0
- example/person/1.jpg +0 -0
- example/person/1_mask.png +0 -0
- requirements.txt +14 -0
- tryoff.sh +7 -0
- tryoff_inference.py +117 -0
.gitignore
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# Distribution / packaging
|
7 |
+
dist/
|
8 |
+
build/
|
9 |
+
*.egg-info/
|
10 |
+
|
11 |
+
# Virtual environments
|
12 |
+
venv/
|
13 |
+
env/
|
14 |
+
.env/
|
15 |
+
.venv/
|
16 |
+
|
17 |
+
# IDE specific files
|
18 |
+
.idea/
|
19 |
+
.vscode/
|
20 |
+
*.swp
|
21 |
+
*.swo
|
22 |
+
|
23 |
+
# Unit test / coverage reports
|
24 |
+
htmlcov/
|
25 |
+
.tox/
|
26 |
+
.coverage
|
27 |
+
.coverage.*
|
28 |
+
coverage.xml
|
29 |
+
*.cover
|
30 |
+
|
31 |
+
# Jupyter Notebook
|
32 |
+
.ipynb_checkpoints
|
33 |
+
|
34 |
+
# Local development settings
|
35 |
+
.env
|
36 |
+
.env.local
|
37 |
+
|
38 |
+
# Logs
|
39 |
+
*.log
|
40 |
+
|
41 |
+
# Database files
|
42 |
+
*.db
|
43 |
+
*.sqlite3
|
44 |
+
|
45 |
+
# OS generated files
|
46 |
+
.DS_Store
|
47 |
+
.DS_Store?
|
48 |
+
._*
|
49 |
+
.Spotlight-V100
|
50 |
+
.Trashes
|
51 |
+
ehthumbs.db
|
52 |
+
Thumbs.db
|
53 |
+
|
54 |
+
# Gradio cache
|
55 |
+
.gradio/example/github.mp4
|
56 |
+
|
57 |
+
aws/
|
58 |
+
checkpoints/
|
README.md
CHANGED
@@ -1,14 +1,17 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
colorFrom: yellow
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
-
sdk_version: 5.
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
-
license: cc-by-nc-4.0
|
11 |
-
short_description: Extract and reconstruct the front view of clothing
|
12 |
---
|
13 |
|
14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
+
title: cat-tryoff-flux
|
3 |
+
emoji: 🖥️
|
4 |
colorFrom: yellow
|
5 |
+
colorTo: pink
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 5.0.1
|
8 |
app_file: app.py
|
9 |
pinned: false
|
|
|
|
|
10 |
---
|
11 |
|
12 |
+
|
13 |
+
# cat-tryoff-flux
|
14 |
+
|
15 |
+
CAT-Tryoff-Flux is an advanced tryoff model. This model can extract and reconstruct the front view of clothing items from images of people wearing them. It used the same method of (CATVTON-FLUX)[https://huggingface.co/xiaozaa/catvton-flux-alpha].
|
16 |
+
|
17 |
+
The github repo is [here](https://github.com/nftblackmagic/catvton-flux).
|
app.py
CHANGED
@@ -1,7 +1,189 @@
|
|
|
|
|
|
1 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
-
def greet(name):
|
4 |
-
return "Hello " + name + "!!"
|
5 |
|
6 |
-
demo = gr.Interface(fn=greet, inputs="text", outputs="text")
|
7 |
demo.launch()
|
|
|
1 |
+
import spaces
|
2 |
+
|
3 |
import gradio as gr
|
4 |
+
from tryoff_inference import run_inference
|
5 |
+
import os
|
6 |
+
import numpy as np
|
7 |
+
from PIL import Image
|
8 |
+
import tempfile
|
9 |
+
import torch
|
10 |
+
from diffusers import FluxTransformer2DModel, FluxFillPipeline
|
11 |
+
import subprocess
|
12 |
+
|
13 |
+
subprocess.run("rm -rf /data-nvme/zerogpu-offload/*", env={}, shell=True)
|
14 |
+
dtype = torch.bfloat16
|
15 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
16 |
+
|
17 |
+
print('Loading diffusion model ...')
|
18 |
+
transformer = FluxTransformer2DModel.from_pretrained(
|
19 |
+
"xiaozaa/cat-tryoff-flux",
|
20 |
+
torch_dtype=dtype
|
21 |
+
)
|
22 |
+
pipe = FluxFillPipeline.from_pretrained(
|
23 |
+
"black-forest-labs/FLUX.1-dev",
|
24 |
+
transformer=transformer,
|
25 |
+
torch_dtype=dtype
|
26 |
+
).to(device)
|
27 |
+
print('Loading Finished!')
|
28 |
+
|
29 |
+
@spaces.GPU(duration=120)
|
30 |
+
def gradio_inference(
|
31 |
+
image_data,
|
32 |
+
garment,
|
33 |
+
num_steps=50,
|
34 |
+
guidance_scale=30.0,
|
35 |
+
seed=-1,
|
36 |
+
width=768,
|
37 |
+
height=1024
|
38 |
+
):
|
39 |
+
"""Wrapper function for Gradio interface"""
|
40 |
+
# Check if mask has been drawn
|
41 |
+
if image_data is None or "layers" not in image_data or not image_data["layers"]:
|
42 |
+
raise gr.Error("Please draw a mask over the clothing area before generating!")
|
43 |
+
|
44 |
+
# Check if mask is empty (all black)
|
45 |
+
mask = image_data["layers"][0]
|
46 |
+
mask_array = np.array(mask)
|
47 |
+
if np.all(mask_array < 10):
|
48 |
+
raise gr.Error("The mask is empty! Please draw over the clothing area you want to replace.")
|
49 |
+
|
50 |
+
# Use temporary directory
|
51 |
+
with tempfile.TemporaryDirectory() as tmp_dir:
|
52 |
+
# Save inputs to temp directory
|
53 |
+
temp_image = os.path.join(tmp_dir, "image.png")
|
54 |
+
temp_mask = os.path.join(tmp_dir, "mask.png")
|
55 |
+
|
56 |
+
# Extract image and mask from ImageEditor data
|
57 |
+
image = image_data["background"]
|
58 |
+
mask = image_data["layers"][0] # First layer contains the mask
|
59 |
+
|
60 |
+
# Convert to numpy array and process mask
|
61 |
+
mask_array = np.array(mask)
|
62 |
+
is_black = np.all(mask_array < 10, axis=2)
|
63 |
+
mask = Image.fromarray(((~is_black) * 255).astype(np.uint8))
|
64 |
+
|
65 |
+
# Save files to temp directory
|
66 |
+
image.save(temp_image)
|
67 |
+
mask.save(temp_mask)
|
68 |
+
|
69 |
+
try:
|
70 |
+
# Run inference
|
71 |
+
garment_result, _ = run_inference(
|
72 |
+
pipe=pipe,
|
73 |
+
image_path=temp_image,
|
74 |
+
mask_path=temp_mask,
|
75 |
+
num_steps=num_steps,
|
76 |
+
guidance_scale=guidance_scale,
|
77 |
+
seed=seed,
|
78 |
+
size=(width, height)
|
79 |
+
)
|
80 |
+
return garment_result
|
81 |
+
except Exception as e:
|
82 |
+
raise gr.Error(f"Error during inference: {str(e)}")
|
83 |
+
|
84 |
+
with gr.Blocks() as demo:
|
85 |
+
gr.Markdown("""
|
86 |
+
# CAT-TRYOFF-FLUX Virtual Try-Off Demo
|
87 |
+
Upload a model image, draw a mask, and a garment image to generate virtual try-off results.
|
88 |
+
|
89 |
+
""")
|
90 |
+
|
91 |
+
# gr.Video("example/github.mp4", label="Demo Video: How to use the tool")
|
92 |
+
|
93 |
+
with gr.Column():
|
94 |
+
gr.Markdown("""
|
95 |
+
### ⚠️ Important:
|
96 |
+
1. Choose a model image or upload your own
|
97 |
+
2. Use the Pen tool to draw a mask over the clothing area you want to restore
|
98 |
+
""")
|
99 |
+
|
100 |
+
with gr.Row():
|
101 |
+
with gr.Column():
|
102 |
+
image_input = gr.ImageMask(
|
103 |
+
label="Model Image (Click 'Edit' and draw mask over the clothing area)",
|
104 |
+
type="pil",
|
105 |
+
height=600,
|
106 |
+
width=300
|
107 |
+
)
|
108 |
+
gr.Examples(
|
109 |
+
examples=[
|
110 |
+
["./example/person/00008_00.jpg"],
|
111 |
+
["./example/person/00055_00.jpg"],
|
112 |
+
["./example/person/00064_00.jpg"],
|
113 |
+
["./example/person/00067_00.jpg"],
|
114 |
+
["./example/person/00069_00.jpg"],
|
115 |
+
],
|
116 |
+
inputs=[image_input],
|
117 |
+
label="Person Images",
|
118 |
+
)
|
119 |
+
with gr.Column():
|
120 |
+
garment_output = gr.Image(label="Try-On Result", height=600, width=300)
|
121 |
+
|
122 |
+
with gr.Row():
|
123 |
+
num_steps = gr.Slider(
|
124 |
+
minimum=1,
|
125 |
+
maximum=100,
|
126 |
+
value=30,
|
127 |
+
step=1,
|
128 |
+
label="Number of Steps"
|
129 |
+
)
|
130 |
+
guidance_scale = gr.Slider(
|
131 |
+
minimum=1.0,
|
132 |
+
maximum=50.0,
|
133 |
+
value=30.0,
|
134 |
+
step=0.5,
|
135 |
+
label="Guidance Scale"
|
136 |
+
)
|
137 |
+
seed = gr.Slider(
|
138 |
+
minimum=-1,
|
139 |
+
maximum=2147483647,
|
140 |
+
step=1,
|
141 |
+
value=-1,
|
142 |
+
label="Seed (-1 for random)"
|
143 |
+
)
|
144 |
+
width = gr.Slider(
|
145 |
+
minimum=256,
|
146 |
+
maximum=1024,
|
147 |
+
step=64,
|
148 |
+
value=768,
|
149 |
+
label="Width"
|
150 |
+
)
|
151 |
+
height = gr.Slider(
|
152 |
+
minimum=256,
|
153 |
+
maximum=1024,
|
154 |
+
step=64,
|
155 |
+
value=1024,
|
156 |
+
label="Height"
|
157 |
+
)
|
158 |
+
|
159 |
+
|
160 |
+
submit_btn = gr.Button("Generate Try-On", variant="primary")
|
161 |
+
|
162 |
+
|
163 |
+
with gr.Row():
|
164 |
+
gr.Markdown("""
|
165 |
+
### Notes:
|
166 |
+
- The model is trained on VITON-HD dataset. It focuses on the woman upper body try-on generation.
|
167 |
+
- The mask should indicate the region where the garment will be placed.
|
168 |
+
- The garment image should be on a clean background.
|
169 |
+
- The model is not perfect. It may generate some artifacts.
|
170 |
+
- The model is slow. Please be patient.
|
171 |
+
- The model is just for research purpose.
|
172 |
+
""")
|
173 |
+
|
174 |
+
submit_btn.click(
|
175 |
+
fn=gradio_inference,
|
176 |
+
inputs=[
|
177 |
+
image_input,
|
178 |
+
num_steps,
|
179 |
+
guidance_scale,
|
180 |
+
seed,
|
181 |
+
width,
|
182 |
+
height
|
183 |
+
],
|
184 |
+
outputs=[garment_output],
|
185 |
+
api_name="try-off"
|
186 |
+
)
|
187 |
|
|
|
|
|
188 |
|
|
|
189 |
demo.launch()
|
example/person/00008_00.jpg
ADDED
example/person/00008_00_mask.png
ADDED
example/person/00055_00.jpg
ADDED
example/person/00055_00_mask.png
ADDED
example/person/00057_00.jpg
ADDED
example/person/00057_00_mask.png
ADDED
example/person/00064_00.jpg
ADDED
example/person/00064_00_mask.png
ADDED
example/person/00067_00.jpg
ADDED
example/person/00067_00_mask.png
ADDED
example/person/00069_00.jpg
ADDED
example/person/00069_00_mask.png
ADDED
example/person/1.jpg
ADDED
example/person/1_mask.png
ADDED
requirements.txt
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
accelerate
|
2 |
+
git+https://github.com/huggingface/diffusers.git
|
3 |
+
gradio==5.6.0
|
4 |
+
gradio_client==1.4.3
|
5 |
+
torch==2.4.0
|
6 |
+
torchvision==0.19.0
|
7 |
+
tqdm==4.66.5
|
8 |
+
transformers==4.43.3
|
9 |
+
numpy==1.26.4
|
10 |
+
sentencepiece
|
11 |
+
peft==0.13.2
|
12 |
+
huggingface-hub
|
13 |
+
spaces
|
14 |
+
protobuf
|
tryoff.sh
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
python tryoff_inference.py \
|
2 |
+
--image ./example/person/00069_00.jpg \
|
3 |
+
--mask ./example/person/00069_00_mask.png \
|
4 |
+
--seed 41 \
|
5 |
+
--output_tryon test_original.png \
|
6 |
+
--output_garment restored_garment6.png \
|
7 |
+
--steps 30
|
tryoff_inference.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import torch
|
3 |
+
from diffusers.utils import load_image, check_min_version
|
4 |
+
from diffusers import FluxPriorReduxPipeline, FluxFillPipeline
|
5 |
+
from diffusers import FluxTransformer2DModel
|
6 |
+
import numpy as np
|
7 |
+
from torchvision import transforms
|
8 |
+
|
9 |
+
def run_inference(
|
10 |
+
image_path,
|
11 |
+
mask_path,
|
12 |
+
size=(576, 768),
|
13 |
+
num_steps=50,
|
14 |
+
guidance_scale=30,
|
15 |
+
seed=42,
|
16 |
+
pipe=None
|
17 |
+
):
|
18 |
+
# Build pipeline
|
19 |
+
if pipe is None:
|
20 |
+
transformer = FluxTransformer2DModel.from_pretrained(
|
21 |
+
"xiaozaa/cat-tryoff-flux",
|
22 |
+
torch_dtype=torch.bfloat16
|
23 |
+
)
|
24 |
+
pipe = FluxFillPipeline.from_pretrained(
|
25 |
+
"black-forest-labs/FLUX.1-dev",
|
26 |
+
transformer=transformer,
|
27 |
+
torch_dtype=torch.bfloat16
|
28 |
+
).to("cuda")
|
29 |
+
else:
|
30 |
+
pipe.to("cuda")
|
31 |
+
|
32 |
+
pipe.transformer.to(torch.bfloat16)
|
33 |
+
|
34 |
+
# Add transform
|
35 |
+
transform = transforms.Compose([
|
36 |
+
transforms.ToTensor(),
|
37 |
+
transforms.Normalize([0.5], [0.5]) # For RGB images
|
38 |
+
])
|
39 |
+
mask_transform = transforms.Compose([
|
40 |
+
transforms.ToTensor()
|
41 |
+
])
|
42 |
+
|
43 |
+
# Load and process images
|
44 |
+
# print("image_path", image_path)
|
45 |
+
image = load_image(image_path).convert("RGB").resize(size)
|
46 |
+
mask = load_image(mask_path).convert("RGB").resize(size)
|
47 |
+
|
48 |
+
# Transform images using the new preprocessing
|
49 |
+
image_tensor = transform(image)
|
50 |
+
mask_tensor = mask_transform(mask)[:1] # Take only first channel
|
51 |
+
garment_tensor = torch.zeros_like(image_tensor)
|
52 |
+
image_tensor = image_tensor * mask_tensor
|
53 |
+
|
54 |
+
# Create concatenated images
|
55 |
+
inpaint_image = torch.cat([garment_tensor, image_tensor], dim=2) # Concatenate along width
|
56 |
+
garment_mask = torch.zeros_like(mask_tensor)
|
57 |
+
extended_mask = torch.cat([1 - garment_mask, garment_mask], dim=2)
|
58 |
+
|
59 |
+
prompt = f"The pair of images highlights a clothing and its styling on a model, high resolution, 4K, 8K; " \
|
60 |
+
f"[IMAGE1] Detailed product shot of a clothing" \
|
61 |
+
f"[IMAGE2] The same cloth is worn by a model in a lifestyle setting."
|
62 |
+
|
63 |
+
generator = torch.Generator(device="cuda").manual_seed(seed)
|
64 |
+
|
65 |
+
result = pipe(
|
66 |
+
height=size[1],
|
67 |
+
width=size[0] * 2,
|
68 |
+
image=inpaint_image,
|
69 |
+
mask_image=extended_mask,
|
70 |
+
num_inference_steps=num_steps,
|
71 |
+
generator=generator,
|
72 |
+
max_sequence_length=512,
|
73 |
+
guidance_scale=guidance_scale,
|
74 |
+
prompt=prompt,
|
75 |
+
).images[0]
|
76 |
+
|
77 |
+
# Split and save results
|
78 |
+
width = size[0]
|
79 |
+
garment_result = result.crop((0, 0, width, size[1]))
|
80 |
+
tryon_result = result.crop((width, 0, width * 2, size[1]))
|
81 |
+
|
82 |
+
return garment_result, tryon_result
|
83 |
+
|
84 |
+
def main():
|
85 |
+
parser = argparse.ArgumentParser(description='Run FLUX virtual try-on inference')
|
86 |
+
parser.add_argument('--image', required=True, help='Path to the model image')
|
87 |
+
parser.add_argument('--mask', required=True, help='Path to the agnostic mask')
|
88 |
+
parser.add_argument('--output_garment', default='flux_inpaint_garment.png', help='Output path for garment result')
|
89 |
+
parser.add_argument('--output_tryon', default='flux_inpaint_tryon.png', help='Output path for try-on result')
|
90 |
+
parser.add_argument('--steps', type=int, default=50, help='Number of inference steps')
|
91 |
+
parser.add_argument('--guidance_scale', type=float, default=30, help='Guidance scale')
|
92 |
+
parser.add_argument('--seed', type=int, default=0, help='Random seed')
|
93 |
+
parser.add_argument('--width', type=int, default=576, help='Width')
|
94 |
+
parser.add_argument('--height', type=int, default=768, help='Height')
|
95 |
+
|
96 |
+
args = parser.parse_args()
|
97 |
+
|
98 |
+
check_min_version("0.30.2")
|
99 |
+
|
100 |
+
garment_result, tryon_result = run_inference(
|
101 |
+
image_path=args.image,
|
102 |
+
mask_path=args.mask,
|
103 |
+
num_steps=args.steps,
|
104 |
+
guidance_scale=args.guidance_scale,
|
105 |
+
seed=args.seed,
|
106 |
+
size=(args.width, args.height)
|
107 |
+
)
|
108 |
+
output_tryon_path=args.output_tryon
|
109 |
+
output_garment_path=args.output_garment
|
110 |
+
|
111 |
+
tryon_result.save(output_tryon_path)
|
112 |
+
garment_result.save(output_garment_path)
|
113 |
+
|
114 |
+
print("Successfully saved garment and try-on images")
|
115 |
+
|
116 |
+
if __name__ == "__main__":
|
117 |
+
main()
|