Spaces:
Running
on
Zero
Running
on
Zero
wondervictor
commited on
Commit
•
2422035
1
Parent(s):
f6bd4fa
update README
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +163 -0
- app.py +29 -4
- app_canny.py +100 -0
- app_depth.py +92 -0
- autoregressive/models/README.md +6 -0
- autoregressive/models/dinov2_adapter.py +36 -0
- autoregressive/models/generate.py +204 -0
- autoregressive/models/gpt_t2i.py +561 -0
- autoregressive/sample/sample_c2i.py +151 -0
- autoregressive/sample/sample_c2i_ddp.py +188 -0
- autoregressive/sample/sample_t2i.py +215 -0
- autoregressive/sample/sample_t2i_MR.py +237 -0
- autoregressive/sample/sample_t2i_ddp.py +229 -0
- checkpoints/vq_ds16_t2i.pt +3 -0
- condition/README.md +23 -0
- condition/canny.py +25 -0
- condition/depth.py +47 -0
- condition/example/t2i/multi_resolution/bird.jpg +0 -0
- condition/example/t2i/multi_resolution/car.jpg +0 -0
- condition/example/t2i/multigen/doll.jpg +0 -0
- condition/example/t2i/multigen/girl.jpg +0 -0
- condition/example/t2i/multigen/house.jpg +0 -0
- condition/example/t2i/multigen/sofa.png +0 -0
- condition/hed.py +117 -0
- condition/lineart.py +98 -0
- condition/midas/depth.py +223 -0
- condition/midas/midas/__init__.py +0 -0
- condition/midas/midas/base_model.py +16 -0
- condition/midas/midas/blocks.py +341 -0
- condition/midas/midas/dpt_depth.py +108 -0
- condition/midas/midas/midas_net.py +76 -0
- condition/midas/midas/midas_net_custom.py +128 -0
- condition/midas/midas/transforms.py +234 -0
- condition/midas/midas/vit.py +491 -0
- condition/utils.py +38 -0
- language/README.md +14 -0
- language/extract_t5_feature.py +129 -0
- language/t5.py +201 -0
- model.py +242 -0
- style.css +10 -0
- tokenizer/consistencydecoder/README.md +14 -0
- tokenizer/consistencydecoder/cd_demo.py +57 -0
- tokenizer/consistencydecoder/reconstruction_cd_ddp.py +208 -0
- tokenizer/tokenizer_image/cache/vgg.pth +3 -0
- tokenizer/tokenizer_image/discriminator.py +255 -0
- tokenizer/tokenizer_image/discriminator_patchgan.py +152 -0
- tokenizer/tokenizer_image/discriminator_stylegan.py +101 -0
- tokenizer/tokenizer_image/lpips.py +164 -0
- tokenizer/tokenizer_image/reconstruction_vq_ddp.py +207 -0
- tokenizer/tokenizer_image/vq_demo.py +84 -0
.gitignore
ADDED
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
share/python-wheels/
|
24 |
+
*.egg-info/
|
25 |
+
.installed.cfg
|
26 |
+
*.egg
|
27 |
+
MANIFEST
|
28 |
+
|
29 |
+
# PyInstaller
|
30 |
+
# Usually these files are written by a python script from a template
|
31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
32 |
+
*.manifest
|
33 |
+
*.spec
|
34 |
+
|
35 |
+
# Installer logs
|
36 |
+
pip-log.txt
|
37 |
+
pip-delete-this-directory.txt
|
38 |
+
|
39 |
+
# Unit test / coverage reports
|
40 |
+
htmlcov/
|
41 |
+
.tox/
|
42 |
+
.nox/
|
43 |
+
.coverage
|
44 |
+
.coverage.*
|
45 |
+
.cache
|
46 |
+
nosetests.xml
|
47 |
+
coverage.xml
|
48 |
+
*.cover
|
49 |
+
*.py,cover
|
50 |
+
.hypothesis/
|
51 |
+
.pytest_cache/
|
52 |
+
cover/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
.pybuilder/
|
76 |
+
target/
|
77 |
+
|
78 |
+
# Jupyter Notebook
|
79 |
+
.ipynb_checkpoints
|
80 |
+
|
81 |
+
# IPython
|
82 |
+
profile_default/
|
83 |
+
ipython_config.py
|
84 |
+
|
85 |
+
# pyenv
|
86 |
+
# For a library or package, you might want to ignore these files since the code is
|
87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
88 |
+
# .python-version
|
89 |
+
|
90 |
+
# pipenv
|
91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
94 |
+
# install all needed dependencies.
|
95 |
+
#Pipfile.lock
|
96 |
+
|
97 |
+
# poetry
|
98 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
100 |
+
# commonly ignored for libraries.
|
101 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
102 |
+
#poetry.lock
|
103 |
+
|
104 |
+
# pdm
|
105 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
106 |
+
#pdm.lock
|
107 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
108 |
+
# in version control.
|
109 |
+
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
110 |
+
.pdm.toml
|
111 |
+
.pdm-python
|
112 |
+
.pdm-build/
|
113 |
+
|
114 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
115 |
+
__pypackages__/
|
116 |
+
|
117 |
+
# Celery stuff
|
118 |
+
celerybeat-schedule
|
119 |
+
celerybeat.pid
|
120 |
+
|
121 |
+
# SageMath parsed files
|
122 |
+
*.sage.py
|
123 |
+
|
124 |
+
# Environments
|
125 |
+
.env
|
126 |
+
.venv
|
127 |
+
env/
|
128 |
+
venv/
|
129 |
+
ENV/
|
130 |
+
env.bak/
|
131 |
+
venv.bak/
|
132 |
+
|
133 |
+
# Spyder project settings
|
134 |
+
.spyderproject
|
135 |
+
.spyproject
|
136 |
+
|
137 |
+
# Rope project settings
|
138 |
+
.ropeproject
|
139 |
+
|
140 |
+
# mkdocs documentation
|
141 |
+
/site
|
142 |
+
|
143 |
+
# mypy
|
144 |
+
.mypy_cache/
|
145 |
+
.dmypy.json
|
146 |
+
dmypy.json
|
147 |
+
|
148 |
+
# Pyre type checker
|
149 |
+
.pyre/
|
150 |
+
|
151 |
+
# pytype static type analyzer
|
152 |
+
.pytype/
|
153 |
+
|
154 |
+
# Cython debug symbols
|
155 |
+
cython_debug/
|
156 |
+
|
157 |
+
# PyCharm
|
158 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
159 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
160 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
161 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
162 |
+
#.idea/
|
163 |
+
|
app.py
CHANGED
@@ -1,7 +1,32 @@
|
|
|
|
1 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
-
def greet(name):
|
4 |
-
return "Hello " + name + "!!"
|
5 |
|
6 |
-
|
7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
import gradio as gr
|
3 |
+
from huggingface_hub import hf_hub_download
|
4 |
+
from model import Model
|
5 |
+
from app_canny import create_demo as create_demo_canny
|
6 |
+
from app_depth import create_demo as create_demo_depth
|
7 |
+
import os
|
8 |
|
|
|
|
|
9 |
|
10 |
+
hf_hub_download('wondervictor/ControlAR', filename='canny_MR.safetensors', cache_dir='./checkpoints/')
|
11 |
+
hf_hub_download('wondervictor/ControlAR', filename='depth_MR.safetensors', cache_dir='./checkpoints/')
|
12 |
+
|
13 |
+
|
14 |
+
DESCRIPTION = "# [ControlAR: Controllable Image Generation with Autoregressive Models](https://arxiv.org/abs/2410.02705) \n ### The first row in outputs is the input image and condition. The second row is the images generated by ControlAR. \n ### You can run locally by following the instruction on our [Github Repo](https://github.com/hustvl/ControlAR)."
|
15 |
+
SHOW_DUPLICATE_BUTTON = os.getenv("SHOW_DUPLICATE_BUTTON") == "1"
|
16 |
+
model = Model()
|
17 |
+
device = "cuda"
|
18 |
+
with gr.Blocks(css="style.css") as demo:
|
19 |
+
gr.Markdown(DESCRIPTION)
|
20 |
+
gr.DuplicateButton(
|
21 |
+
value="Duplicate Space for private use",
|
22 |
+
elem_id="duplicate-button",
|
23 |
+
visible=SHOW_DUPLICATE_BUTTON,
|
24 |
+
)
|
25 |
+
with gr.Tabs():
|
26 |
+
with gr.TabItem("Depth"):
|
27 |
+
create_demo_depth(model.process_depth)
|
28 |
+
with gr.TabItem("Canny"):
|
29 |
+
create_demo_canny(model.process_canny)
|
30 |
+
|
31 |
+
if __name__ == "__main__":
|
32 |
+
demo.queue().launch(share=False, server_name="0.0.0.0")
|
app_canny.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import random
|
3 |
+
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
|
4 |
+
if randomize_seed:
|
5 |
+
seed = random.randint(0, 100000000)
|
6 |
+
return seed
|
7 |
+
examples = [
|
8 |
+
[
|
9 |
+
"condition/example/t2i/multigen/doll.png",
|
10 |
+
"A stuffed animal wearing a mask and a leash, sitting on a blanket",
|
11 |
+
"(512, 512)"
|
12 |
+
],
|
13 |
+
[
|
14 |
+
"condition/example/t2i/multigen/girl.png",
|
15 |
+
"An anime style girl with blue hair",
|
16 |
+
"(512, 512)"
|
17 |
+
],
|
18 |
+
[
|
19 |
+
"condition/example/t2i/multi_resolution/bird.jpg",
|
20 |
+
"colorful bird",
|
21 |
+
"(921, 564)"
|
22 |
+
],
|
23 |
+
]
|
24 |
+
def create_demo(process):
|
25 |
+
with gr.Blocks() as demo:
|
26 |
+
with gr.Row():
|
27 |
+
with gr.Column():
|
28 |
+
image = gr.Image()
|
29 |
+
prompt = gr.Textbox(label="Prompt")
|
30 |
+
run_button = gr.Button("Run")
|
31 |
+
with gr.Accordion("Advanced options", open=False):
|
32 |
+
canny_low_threshold = gr.Slider(
|
33 |
+
label="Canny low threshold", minimum=0, maximum=1000, value=100, step=50
|
34 |
+
)
|
35 |
+
canny_high_threshold = gr.Slider(
|
36 |
+
label="Canny high threshold", minimum=0, maximum=1000, value=200, step=50
|
37 |
+
)
|
38 |
+
cfg_scale = gr.Slider(label="Guidance scale", minimum=0.1, maximum=30.0, value=4, step=0.1)
|
39 |
+
relolution = gr.Slider(label="(H, W)", minimum=384, maximum=768, value=512, step=16)
|
40 |
+
top_k = gr.Slider(minimum=1, maximum=16384, step=1, value=2000, label='Top-K')
|
41 |
+
top_p = gr.Slider(minimum=0., maximum=1.0, step=0.1, value=1.0, label="Top-P")
|
42 |
+
temperature = gr.Slider(minimum=0., maximum=1.0, step=0.1, value=1.0, label='Temperature')
|
43 |
+
seed = gr.Slider(label="Seed", minimum=0, maximum=100000000, step=1, value=0)
|
44 |
+
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
|
45 |
+
with gr.Column():
|
46 |
+
result = gr.Gallery(label="Output", show_label=False, height='800px', columns=2, object_fit="scale-down")
|
47 |
+
gr.Examples(
|
48 |
+
examples=examples,
|
49 |
+
inputs=[
|
50 |
+
image,
|
51 |
+
prompt,
|
52 |
+
relolution,
|
53 |
+
],
|
54 |
+
outputs=result,
|
55 |
+
fn=process,
|
56 |
+
)
|
57 |
+
inputs = [
|
58 |
+
image,
|
59 |
+
prompt,
|
60 |
+
cfg_scale,
|
61 |
+
temperature,
|
62 |
+
top_k,
|
63 |
+
top_p,
|
64 |
+
seed,
|
65 |
+
canny_low_threshold,
|
66 |
+
canny_high_threshold,
|
67 |
+
]
|
68 |
+
prompt.submit(
|
69 |
+
fn=randomize_seed_fn,
|
70 |
+
inputs=[seed, randomize_seed],
|
71 |
+
outputs=seed,
|
72 |
+
queue=False,
|
73 |
+
api_name=False,
|
74 |
+
).then(
|
75 |
+
fn=process,
|
76 |
+
inputs=inputs,
|
77 |
+
outputs=result,
|
78 |
+
api_name=False,
|
79 |
+
)
|
80 |
+
run_button.click(
|
81 |
+
fn=randomize_seed_fn,
|
82 |
+
inputs=[seed, randomize_seed],
|
83 |
+
outputs=seed,
|
84 |
+
queue=False,
|
85 |
+
api_name=False,
|
86 |
+
).then(
|
87 |
+
fn=process,
|
88 |
+
inputs=inputs,
|
89 |
+
outputs=result,
|
90 |
+
api_name="canny",
|
91 |
+
)
|
92 |
+
return demo
|
93 |
+
if __name__ == "__main__":
|
94 |
+
from model import Model
|
95 |
+
model = Model()
|
96 |
+
demo = create_demo(model.process_canny)
|
97 |
+
demo.queue().launch(
|
98 |
+
share=False,
|
99 |
+
server_name="0.0.0.0"
|
100 |
+
)
|
app_depth.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import random
|
3 |
+
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
|
4 |
+
if randomize_seed:
|
5 |
+
seed = random.randint(0, 100000000)
|
6 |
+
return seed
|
7 |
+
examples = [
|
8 |
+
[
|
9 |
+
"condition/example/t2i/multigen/sofa.png",
|
10 |
+
"The red sofa in the living room has several pillows on it",
|
11 |
+
"(512, 512)"
|
12 |
+
],
|
13 |
+
[
|
14 |
+
"condition/example/t2i/multigen/house.png",
|
15 |
+
"A brick house with a chimney under a starry sky.",
|
16 |
+
"(512, 512)"
|
17 |
+
],
|
18 |
+
[
|
19 |
+
"condition/example/t2i/multi_resolution/car.jpg",
|
20 |
+
"a sport car",
|
21 |
+
"(448, 768)"
|
22 |
+
]
|
23 |
+
]
|
24 |
+
def create_demo(process):
|
25 |
+
with gr.Blocks() as demo:
|
26 |
+
with gr.Row():
|
27 |
+
with gr.Column():
|
28 |
+
image = gr.Image()
|
29 |
+
prompt = gr.Textbox(label="Prompt")
|
30 |
+
run_button = gr.Button("Run")
|
31 |
+
with gr.Accordion("Advanced options", open=False):
|
32 |
+
cfg_scale = gr.Slider(label="Guidance scale", minimum=0.1, maximum=30.0, value=4, step=0.1)
|
33 |
+
resolution = gr.Slider(label="(H, W)", minimum=384, maximum=768, value=512, step=16)
|
34 |
+
top_k = gr.Slider(minimum=1, maximum=16384, step=1, value=2000, label='Top-K')
|
35 |
+
top_p = gr.Slider(minimum=0., maximum=1.0, step=0.1, value=1.0, label="Top-P")
|
36 |
+
temperature = gr.Slider(minimum=0., maximum=1.0, step=0.1, value=1.0, label='Temperature')
|
37 |
+
seed = gr.Slider(label="Seed", minimum=0, maximum=100000000, step=1, value=0)
|
38 |
+
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
|
39 |
+
with gr.Column():
|
40 |
+
result = gr.Gallery(label="Output", show_label=False, height='800px', columns=2, object_fit="scale-down")
|
41 |
+
gr.Examples(
|
42 |
+
examples=examples,
|
43 |
+
inputs=[
|
44 |
+
image,
|
45 |
+
prompt,
|
46 |
+
resolution,
|
47 |
+
],
|
48 |
+
outputs=result,
|
49 |
+
fn=process,
|
50 |
+
)
|
51 |
+
inputs = [
|
52 |
+
image,
|
53 |
+
prompt,
|
54 |
+
cfg_scale,
|
55 |
+
temperature,
|
56 |
+
top_k,
|
57 |
+
top_p,
|
58 |
+
seed,
|
59 |
+
]
|
60 |
+
prompt.submit(
|
61 |
+
fn=randomize_seed_fn,
|
62 |
+
inputs=[seed, randomize_seed],
|
63 |
+
outputs=seed,
|
64 |
+
queue=False,
|
65 |
+
api_name=False,
|
66 |
+
).then(
|
67 |
+
fn=process,
|
68 |
+
inputs=inputs,
|
69 |
+
outputs=result,
|
70 |
+
api_name=False,
|
71 |
+
)
|
72 |
+
run_button.click(
|
73 |
+
fn=randomize_seed_fn,
|
74 |
+
inputs=[seed, randomize_seed],
|
75 |
+
outputs=seed,
|
76 |
+
queue=False,
|
77 |
+
api_name=False,
|
78 |
+
).then(
|
79 |
+
fn=process,
|
80 |
+
inputs=inputs,
|
81 |
+
outputs=result,
|
82 |
+
api_name="canny",
|
83 |
+
)
|
84 |
+
return demo
|
85 |
+
if __name__ == "__main__":
|
86 |
+
from model import Model
|
87 |
+
model = Model()
|
88 |
+
demo = create_demo(model.process_depth)
|
89 |
+
demo.queue().launch(
|
90 |
+
share=False,
|
91 |
+
server_name="0.0.0.0"
|
92 |
+
)
|
autoregressive/models/README.md
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Download the vit weight first
|
2 |
+
|
3 |
+
ViT-small: https://huggingface.co/WinKawaks/vit-small-patch16-224 \
|
4 |
+
Dinov2-small: https://huggingface.co/facebook/dinov2-small
|
5 |
+
|
6 |
+
Put them here
|
autoregressive/models/dinov2_adapter.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import AutoImageProcessor, AutoModel
|
2 |
+
from PIL import Image
|
3 |
+
import requests
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
|
7 |
+
|
8 |
+
class Dinov2_Adapter(nn.Module):
|
9 |
+
def __init__(self, input_dim=1, output_dim=768, attention=False, pool=False, nheads=8, dropout=0.1, adapter_size='small', condition_type='canny'):
|
10 |
+
super(Dinov2_Adapter, self).__init__()
|
11 |
+
print(f"Choose adapter size: {adapter_size}")
|
12 |
+
print(f"condition type: {condition_type}")
|
13 |
+
self.model = AutoModel.from_pretrained(f'autoregressive/models/dinov2-{adapter_size}')
|
14 |
+
self.condition_type = condition_type
|
15 |
+
|
16 |
+
def to_patch14(self, input):
|
17 |
+
H, W = input.shape[2:]
|
18 |
+
new_H = (H // 16) * 14
|
19 |
+
new_W = (W // 16) * 14
|
20 |
+
if self.condition_type in ['canny', 'seg']:
|
21 |
+
output = torch.nn.functional.interpolate(input, size=(new_H, new_W), mode='nearest')#, align_corners=True) canny, seg
|
22 |
+
else:
|
23 |
+
output = torch.nn.functional.interpolate(input, size=(new_H, new_W), mode='bicubic', align_corners=True) # depth, lineart, hed
|
24 |
+
return output
|
25 |
+
|
26 |
+
def forward(self, x):
|
27 |
+
x = self.to_patch14(x)
|
28 |
+
x = self.model(x)
|
29 |
+
return x.last_hidden_state[:, 1:]
|
30 |
+
|
31 |
+
|
32 |
+
if __name__ == '__main__':
|
33 |
+
model = Dinov2_Adapter().cuda()
|
34 |
+
inputs = torch.randn(4,3,512,512).cuda()
|
35 |
+
outputs = model(inputs)
|
36 |
+
print(outputs.shape)
|
autoregressive/models/generate.py
ADDED
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from:
|
2 |
+
# gpt-fast: https://github.com/pytorch-labs/gpt-fast/blob/main/generate.py
|
3 |
+
# DiT: https://github.com/facebookresearch/DiT/blob/main/models.py
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from torch.nn import functional as F
|
7 |
+
import torch._dynamo.config
|
8 |
+
import torch._inductor.config
|
9 |
+
import copy
|
10 |
+
import time
|
11 |
+
# torch._inductor.config.coordinate_descent_tuning = True
|
12 |
+
# torch._inductor.config.triton.unique_kernel_names = True
|
13 |
+
# torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
|
14 |
+
|
15 |
+
|
16 |
+
### from https://huggingface.co/transformers/v3.2.0/_modules/transformers/generation_utils.html
|
17 |
+
def top_k_top_p_filtering(
|
18 |
+
logits,
|
19 |
+
top_k: int = 0,
|
20 |
+
top_p: float = 1.0,
|
21 |
+
filter_value: float = -float("Inf"),
|
22 |
+
min_tokens_to_keep: int = 1,
|
23 |
+
):
|
24 |
+
"""Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
|
25 |
+
Args:
|
26 |
+
logits: logits distribution shape (batch size, vocabulary size)
|
27 |
+
if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
|
28 |
+
if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
|
29 |
+
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
|
30 |
+
Make sure we keep at least min_tokens_to_keep per batch example in the output
|
31 |
+
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
|
32 |
+
"""
|
33 |
+
if top_k > 0:
|
34 |
+
# import pdb;pdb.set_trace()
|
35 |
+
top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check
|
36 |
+
# Remove all tokens with a probability less than the last token of the top-k
|
37 |
+
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
38 |
+
logits[indices_to_remove] = filter_value
|
39 |
+
|
40 |
+
if top_p < 1.0:
|
41 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
42 |
+
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
43 |
+
|
44 |
+
# Remove tokens with cumulative probability above the threshold (token with 0 are kept)
|
45 |
+
sorted_indices_to_remove = cumulative_probs > top_p
|
46 |
+
if min_tokens_to_keep > 1:
|
47 |
+
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
|
48 |
+
sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
|
49 |
+
# Shift the indices to the right to keep also the first token above the threshold
|
50 |
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
51 |
+
sorted_indices_to_remove[..., 0] = 0
|
52 |
+
|
53 |
+
# scatter sorted tensors to original indexing
|
54 |
+
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
55 |
+
logits[indices_to_remove] = filter_value
|
56 |
+
return logits
|
57 |
+
|
58 |
+
|
59 |
+
def sample(logits, temperature: float=1.0, top_k: int=2000, top_p: float=1.0, sample_logits=True):
|
60 |
+
logits = logits[:, -1, :] / max(temperature, 1e-5)
|
61 |
+
if top_k > 0 or top_p < 1.0:
|
62 |
+
logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
|
63 |
+
probs = F.softmax(logits, dim=-1)
|
64 |
+
# values, indices = torch.max(probs, dim=1, keepdim=True)
|
65 |
+
# mask = (probs == values).float()
|
66 |
+
# probs = probs * (1 - mask)
|
67 |
+
# values, indices = torch.max(probs, dim=1, keepdim=True)
|
68 |
+
# mask = (probs == values).float()
|
69 |
+
# probs = probs * (1 - mask)
|
70 |
+
if sample_logits:
|
71 |
+
idx = torch.multinomial(probs, num_samples=1)
|
72 |
+
else:
|
73 |
+
_, idx = torch.topk(probs, k=1, dim=-1)
|
74 |
+
return idx, probs
|
75 |
+
|
76 |
+
|
77 |
+
def logits_to_probs(logits, temperature: float = 1.0, top_p: float=1.0, top_k: int = None, **kwargs):
|
78 |
+
logits = logits / max(temperature, 1e-5)
|
79 |
+
if top_k > 0 or top_p < 1.0:
|
80 |
+
logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
|
81 |
+
probs = torch.nn.functional.softmax(logits, dim=-1)
|
82 |
+
return probs
|
83 |
+
|
84 |
+
|
85 |
+
def prefill(model, cond_idx: torch.Tensor, input_pos: torch.Tensor, cfg_scale: float, condition:torch.Tensor, **sampling_kwargs):
|
86 |
+
if cfg_scale > 1.0:
|
87 |
+
logits, _ = model(None, cond_idx, input_pos, condition=condition)
|
88 |
+
logits_combined = logits
|
89 |
+
cond_logits, uncond_logits = torch.split(logits_combined, len(logits_combined) // 2, dim=0)
|
90 |
+
logits = uncond_logits + (cond_logits - uncond_logits) * cfg_scale
|
91 |
+
else:
|
92 |
+
logits, _ = model(None, cond_idx, input_pos, condition=condition)
|
93 |
+
|
94 |
+
return sample(logits, **sampling_kwargs)[0]
|
95 |
+
|
96 |
+
|
97 |
+
def decode_one_token(model, x: torch.Tensor, input_pos: torch.Tensor, cfg_scale: float, cfg_flag: bool, condition: torch.Tensor, **sampling_kwargs):
|
98 |
+
assert input_pos.shape[-1] == 1
|
99 |
+
if cfg_scale > 1.0:
|
100 |
+
x_combined = torch.cat([x, x])
|
101 |
+
logits, _ = model(x_combined, cond_idx=None, input_pos=input_pos, condition=condition)
|
102 |
+
logits_combined = logits
|
103 |
+
cond_logits, uncond_logits = torch.split(logits_combined, len(logits_combined) // 2, dim=0)
|
104 |
+
if cfg_flag:
|
105 |
+
logits = uncond_logits + (cond_logits - uncond_logits) * cfg_scale
|
106 |
+
else:
|
107 |
+
logits = cond_logits
|
108 |
+
else:
|
109 |
+
logits, _ = model(x, cond_idx=None, input_pos=input_pos, condition=None)
|
110 |
+
return sample(logits, **sampling_kwargs)
|
111 |
+
|
112 |
+
|
113 |
+
def decode_n_tokens(
|
114 |
+
model, cur_token: torch.Tensor, input_pos: torch.Tensor, num_new_tokens: int,
|
115 |
+
cfg_scale: float, cfg_interval: int, condition: torch.Tensor,
|
116 |
+
**sampling_kwargs):
|
117 |
+
new_tokens, new_probs = [], []
|
118 |
+
cfg_flag = True
|
119 |
+
for i in range(num_new_tokens):
|
120 |
+
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): # Actually better for Inductor to codegen attention here
|
121 |
+
if cfg_interval > -1 and i > cfg_interval:
|
122 |
+
cfg_flag = False
|
123 |
+
next_token, next_prob = decode_one_token(
|
124 |
+
model, cur_token, input_pos, cfg_scale, cfg_flag, condition=condition, **sampling_kwargs
|
125 |
+
)
|
126 |
+
input_pos += 1
|
127 |
+
new_tokens.append(next_token.clone())
|
128 |
+
new_probs.append(next_prob.clone())
|
129 |
+
cur_token = next_token.view(-1, 1)
|
130 |
+
|
131 |
+
return new_tokens, new_probs
|
132 |
+
|
133 |
+
|
134 |
+
@torch.no_grad()
|
135 |
+
def generate(model, cond, max_new_tokens, emb_masks=None, cfg_scale=1.0, cfg_interval=-1, condition=None, condition_null=None, condition_token_nums=0, **sampling_kwargs):
|
136 |
+
if condition is not None:
|
137 |
+
condition = model.adapter(condition)
|
138 |
+
condition = model.adapter_mlp(condition)
|
139 |
+
if model.model_type == 'c2i':
|
140 |
+
if cfg_scale > 1.0:
|
141 |
+
cond_null = torch.ones_like(cond) * model.num_classes
|
142 |
+
cond_combined = torch.cat([cond, cond_null])
|
143 |
+
if condition is not None:
|
144 |
+
condition_null = torch.zeros_like(condition)
|
145 |
+
condition_combined = torch.cat((condition, condition_null), dim=0)
|
146 |
+
else:
|
147 |
+
condition_combined = None
|
148 |
+
else:
|
149 |
+
cond_combined = cond
|
150 |
+
if condition is not None:
|
151 |
+
condition_combined = condition
|
152 |
+
else:
|
153 |
+
condition_combined = None
|
154 |
+
T = 1+condition_token_nums
|
155 |
+
elif model.model_type == 't2i':
|
156 |
+
if cfg_scale > 1.0:
|
157 |
+
cond_null = torch.zeros_like(cond) + model.cls_embedding.uncond_embedding
|
158 |
+
cond_combined = torch.cat([cond, cond_null])
|
159 |
+
|
160 |
+
if condition is not None:
|
161 |
+
condition_null = torch.zeros_like(condition)
|
162 |
+
condition_combined = torch.cat((condition, condition_null), dim=0)
|
163 |
+
else:
|
164 |
+
condition_combined = None
|
165 |
+
else:
|
166 |
+
cond_combined = cond
|
167 |
+
if condition is not None:
|
168 |
+
condition_combined = condition
|
169 |
+
else:
|
170 |
+
condition_combined = None
|
171 |
+
T = cond.shape[1]
|
172 |
+
else:
|
173 |
+
raise Exception("please check model type")
|
174 |
+
|
175 |
+
T_new = T + max_new_tokens
|
176 |
+
max_seq_length = T_new
|
177 |
+
max_batch_size = cond.shape[0]
|
178 |
+
|
179 |
+
device = cond.device
|
180 |
+
with torch.device(device):
|
181 |
+
max_batch_size_cfg = max_batch_size * 2 if cfg_scale > 1.0 else max_batch_size
|
182 |
+
model.setup_caches(max_batch_size=max_batch_size_cfg, max_seq_length=max_seq_length, dtype=model.tok_embeddings.weight.dtype)
|
183 |
+
|
184 |
+
if emb_masks is not None:
|
185 |
+
assert emb_masks.shape[0] == max_batch_size
|
186 |
+
assert emb_masks.shape[-1] == T
|
187 |
+
if cfg_scale > 1.0:
|
188 |
+
model.causal_mask[:, :, :T] = model.causal_mask[:, :, :T] * torch.cat([emb_masks, emb_masks]).unsqueeze(1)
|
189 |
+
else:
|
190 |
+
model.causal_mask[:, :, :T] = model.causal_mask[:, :, :T] * emb_masks.unsqueeze(1)
|
191 |
+
|
192 |
+
eye_matrix = torch.eye(model.causal_mask.size(1), model.causal_mask.size(2), device=device)
|
193 |
+
model.causal_mask[:] = model.causal_mask * (1 - eye_matrix) + eye_matrix
|
194 |
+
|
195 |
+
# create an empty tensor of the expected final shape and fill in the current tokens
|
196 |
+
seq = torch.empty((max_batch_size, T_new), dtype=torch.int, device=device)
|
197 |
+
input_pos = torch.arange(0, T, device=device)
|
198 |
+
next_token = prefill(model, cond_combined, input_pos, cfg_scale, condition_combined, **sampling_kwargs)
|
199 |
+
seq[:, T:T+1] = next_token
|
200 |
+
|
201 |
+
input_pos = torch.tensor([T], device=device, dtype=torch.int)
|
202 |
+
generated_tokens, _ = decode_n_tokens(model, next_token, input_pos, max_new_tokens-1, cfg_scale, cfg_interval, condition=condition_combined, **sampling_kwargs)
|
203 |
+
seq[:, T+1:] = torch.cat(generated_tokens, dim=1)
|
204 |
+
return seq[:, T:]
|
autoregressive/models/gpt_t2i.py
ADDED
@@ -0,0 +1,561 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from:
|
2 |
+
# VQGAN: https://github.com/CompVis/taming-transformers/blob/master/taming/modules/transformer/mingpt.py
|
3 |
+
# DiT: https://github.com/facebookresearch/DiT/blob/main/models.py
|
4 |
+
# nanoGPT: https://github.com/karpathy/nanoGPT/blob/master/model.py
|
5 |
+
# llama: https://github.com/facebookresearch/llama/blob/main/llama/model.py
|
6 |
+
# gpt-fast: https://github.com/pytorch-labs/gpt-fast/blob/main/model.py
|
7 |
+
# PixArt: https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
|
8 |
+
from dataclasses import dataclass
|
9 |
+
from typing import Optional, List
|
10 |
+
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import torch.nn as nn
|
14 |
+
from torch.nn import functional as F
|
15 |
+
from utils.drop_path import DropPath
|
16 |
+
# from autoregressive.models.vit_adapter import ViT_Adapter
|
17 |
+
from autoregressive.models.dinov2_adapter import Dinov2_Adapter
|
18 |
+
|
19 |
+
|
20 |
+
def get_causal_mask(seq_length):
|
21 |
+
mask = torch.triu(torch.ones(seq_length, seq_length), diagonal=1).type(torch.bool)
|
22 |
+
mask = mask.masked_fill(mask, float('-inf'))
|
23 |
+
mask = mask.masked_fill(~mask, float(0.0))
|
24 |
+
return mask
|
25 |
+
|
26 |
+
def find_multiple(n: int, k: int):
|
27 |
+
if n % k == 0:
|
28 |
+
return n
|
29 |
+
return n + k - (n % k)
|
30 |
+
|
31 |
+
@dataclass
|
32 |
+
class ModelArgs:
|
33 |
+
dim: int = 4096
|
34 |
+
n_layer: int = 32
|
35 |
+
n_head: int = 32
|
36 |
+
n_kv_head: Optional[int] = None
|
37 |
+
multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
|
38 |
+
ffn_dim_multiplier: Optional[float] = None
|
39 |
+
rope_base: float = 10000
|
40 |
+
norm_eps: float = 1e-5
|
41 |
+
initializer_range: float = 0.02
|
42 |
+
|
43 |
+
token_dropout_p: float = 0.1
|
44 |
+
attn_dropout_p: float = 0.0
|
45 |
+
resid_dropout_p: float = 0.1
|
46 |
+
ffn_dropout_p: float = 0.1
|
47 |
+
drop_path_rate: float = 0.0
|
48 |
+
|
49 |
+
num_classes: int = 1000
|
50 |
+
caption_dim: int = 2048
|
51 |
+
class_dropout_prob: float = 0.1
|
52 |
+
model_type: str = 'c2i'
|
53 |
+
|
54 |
+
vocab_size: int = 16384
|
55 |
+
cls_token_num: int = 1
|
56 |
+
block_size: int = 256
|
57 |
+
max_batch_size: int = 32
|
58 |
+
max_seq_len: int = 2048
|
59 |
+
adapter_size: str = 'small'
|
60 |
+
condition_type: str = 'canny'
|
61 |
+
|
62 |
+
|
63 |
+
|
64 |
+
#################################################################################
|
65 |
+
# Embedding Layers for Class Labels #
|
66 |
+
#################################################################################
|
67 |
+
class LabelEmbedder(nn.Module):
|
68 |
+
"""
|
69 |
+
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
|
70 |
+
"""
|
71 |
+
def __init__(self, num_classes, hidden_size, dropout_prob):
|
72 |
+
super().__init__()
|
73 |
+
use_cfg_embedding = dropout_prob > 0
|
74 |
+
self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
|
75 |
+
self.num_classes = num_classes
|
76 |
+
self.dropout_prob = dropout_prob
|
77 |
+
|
78 |
+
def token_drop(self, labels, force_drop_ids=None):
|
79 |
+
"""
|
80 |
+
Drops labels to enable classifier-free guidance.
|
81 |
+
"""
|
82 |
+
if force_drop_ids is None:
|
83 |
+
drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
|
84 |
+
else:
|
85 |
+
drop_ids = force_drop_ids == 1
|
86 |
+
labels = torch.where(drop_ids, self.num_classes, labels)
|
87 |
+
return labels, drop_ids
|
88 |
+
|
89 |
+
def forward(self, labels, train, force_drop_ids=None):
|
90 |
+
use_dropout = self.dropout_prob > 0
|
91 |
+
if (train and use_dropout) or (force_drop_ids is not None):
|
92 |
+
labels,drop_ids = self.token_drop(labels, force_drop_ids)
|
93 |
+
embeddings = self.embedding_table(labels).unsqueeze(1)
|
94 |
+
if (train and use_dropout) or (force_drop_ids is not None):
|
95 |
+
return embeddings,drop_ids
|
96 |
+
else:
|
97 |
+
return embeddings
|
98 |
+
|
99 |
+
|
100 |
+
class ConditionEmbedder(nn.Module):
|
101 |
+
"""
|
102 |
+
Embeds Condition into vector representations. Also handles label dropout for classifier-free guidance.
|
103 |
+
"""
|
104 |
+
def __init__(self, in_channels, hidden_size, uncond_prob, token_num=120, vocab_size=16384):
|
105 |
+
super().__init__()
|
106 |
+
self.cap_proj = MLP(in_features=hidden_size, hidden_features=hidden_size, out_features=hidden_size)
|
107 |
+
self.register_buffer("uncond_embedding", torch.zeros(token_num, hidden_size) / hidden_size ** 0.5)
|
108 |
+
self.uncond_prob = uncond_prob
|
109 |
+
|
110 |
+
def token_drop(self, caption, force_drop_ids=None, drop_ids=None):
|
111 |
+
"""
|
112 |
+
Drops labels to enable classifier-free guidance.
|
113 |
+
"""
|
114 |
+
if force_drop_ids is None:
|
115 |
+
if drop_ids is None:
|
116 |
+
drop_ids = torch.rand(caption.shape[0], device=caption.device) < self.uncond_prob
|
117 |
+
else:
|
118 |
+
drop_ids = force_drop_ids == 1
|
119 |
+
|
120 |
+
caption = torch.where(drop_ids[:, None, None], self.uncond_embedding[:caption.shape[1]], caption)
|
121 |
+
return caption
|
122 |
+
|
123 |
+
def forward(self, caption, train, force_drop_ids=None, drop_ids=None):
|
124 |
+
use_dropout = self.uncond_prob > 0
|
125 |
+
if (train and use_dropout) or (force_drop_ids is not None):
|
126 |
+
caption = self.token_drop(caption, force_drop_ids, drop_ids)
|
127 |
+
embeddings = self.cap_proj(caption)
|
128 |
+
return embeddings
|
129 |
+
|
130 |
+
#################################################################################
|
131 |
+
# Embedding Layers for Text Feature #
|
132 |
+
#################################################################################
|
133 |
+
class CaptionEmbedder(nn.Module):
|
134 |
+
"""
|
135 |
+
Embeds text caption into vector representations. Also handles label dropout for classifier-free guidance.
|
136 |
+
"""
|
137 |
+
def __init__(self, in_channels, hidden_size, uncond_prob, token_num=120):
|
138 |
+
super().__init__()
|
139 |
+
self.cap_proj = MLP(in_features=in_channels, hidden_features=hidden_size, out_features=hidden_size)
|
140 |
+
self.register_buffer("uncond_embedding", nn.Parameter(torch.randn(token_num, in_channels) / in_channels ** 0.5))
|
141 |
+
self.uncond_prob = uncond_prob
|
142 |
+
|
143 |
+
def token_drop(self, caption, force_drop_ids=None):
|
144 |
+
"""
|
145 |
+
Drops labels to enable classifier-free guidance.
|
146 |
+
"""
|
147 |
+
if force_drop_ids is None:
|
148 |
+
drop_ids = torch.rand(caption.shape[0], device=caption.device) < self.uncond_prob
|
149 |
+
else:
|
150 |
+
drop_ids = force_drop_ids == 1
|
151 |
+
caption = torch.where(drop_ids[:, None, None], self.uncond_embedding, caption)
|
152 |
+
return caption, drop_ids
|
153 |
+
|
154 |
+
def forward(self, caption, train, force_drop_ids=None):
|
155 |
+
use_dropout = self.uncond_prob > 0
|
156 |
+
if (train and use_dropout) or (force_drop_ids is not None):
|
157 |
+
caption, drop_ids = self.token_drop(caption, force_drop_ids)
|
158 |
+
embeddings = self.cap_proj(caption)
|
159 |
+
if (train and use_dropout) or (force_drop_ids is not None):
|
160 |
+
return embeddings,drop_ids
|
161 |
+
else:
|
162 |
+
return embeddings
|
163 |
+
|
164 |
+
|
165 |
+
class MLP(nn.Module):
|
166 |
+
def __init__(self, in_features, hidden_features, out_features):
|
167 |
+
super().__init__()
|
168 |
+
out_features = out_features or in_features
|
169 |
+
hidden_features = hidden_features or in_features
|
170 |
+
self.fc1 = nn.Linear(in_features, hidden_features, bias=False)
|
171 |
+
self.act = nn.GELU(approximate='tanh')
|
172 |
+
self.fc2 = nn.Linear(hidden_features, out_features, bias=False)
|
173 |
+
|
174 |
+
nn.init.zeros_(self.fc1.weight)
|
175 |
+
nn.init.zeros_(self.fc2.weight)
|
176 |
+
|
177 |
+
def forward(self, x):
|
178 |
+
x = self.fc1(x)
|
179 |
+
x = self.act(x)
|
180 |
+
x = self.fc2(x)
|
181 |
+
return x
|
182 |
+
|
183 |
+
|
184 |
+
#################################################################################
|
185 |
+
# GPT Model #
|
186 |
+
#################################################################################
|
187 |
+
class RMSNorm(torch.nn.Module):
|
188 |
+
def __init__(self, dim: int, eps: float = 1e-5):
|
189 |
+
super().__init__()
|
190 |
+
self.eps = eps
|
191 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
192 |
+
|
193 |
+
def _norm(self, x):
|
194 |
+
return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
|
195 |
+
|
196 |
+
def forward(self, x):
|
197 |
+
output = self._norm(x.float()).type_as(x)
|
198 |
+
return output * self.weight
|
199 |
+
|
200 |
+
|
201 |
+
class FeedForward(nn.Module):
|
202 |
+
def __init__(self, config: ModelArgs):
|
203 |
+
super().__init__()
|
204 |
+
hidden_dim = 4 * config.dim
|
205 |
+
hidden_dim = int(2 * hidden_dim / 3)
|
206 |
+
# custom dim factor multiplier
|
207 |
+
if config.ffn_dim_multiplier is not None:
|
208 |
+
hidden_dim = int(config.ffn_dim_multiplier * hidden_dim)
|
209 |
+
hidden_dim = find_multiple(hidden_dim, config.multiple_of)
|
210 |
+
|
211 |
+
self.w1 = nn.Linear(config.dim, hidden_dim, bias=False)
|
212 |
+
self.w3 = nn.Linear(config.dim, hidden_dim, bias=False)
|
213 |
+
self.w2 = nn.Linear(hidden_dim, config.dim, bias=False)
|
214 |
+
self.ffn_dropout = nn.Dropout(config.ffn_dropout_p)
|
215 |
+
|
216 |
+
def forward(self, x):
|
217 |
+
return self.ffn_dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
|
218 |
+
|
219 |
+
|
220 |
+
class KVCache(nn.Module):
|
221 |
+
def __init__(self, max_batch_size, max_seq_length, n_head, head_dim, dtype):
|
222 |
+
super().__init__()
|
223 |
+
cache_shape = (max_batch_size, n_head, max_seq_length, head_dim)
|
224 |
+
self.register_buffer('k_cache', torch.zeros(cache_shape, dtype=dtype))
|
225 |
+
self.register_buffer('v_cache', torch.zeros(cache_shape, dtype=dtype))
|
226 |
+
|
227 |
+
def update(self, input_pos, k_val, v_val):
|
228 |
+
# input_pos: [S], k_val: [B, H, S, D]
|
229 |
+
assert input_pos.shape[0] == k_val.shape[2]
|
230 |
+
k_out = self.k_cache
|
231 |
+
v_out = self.v_cache
|
232 |
+
k_out[:, :, input_pos] = k_val
|
233 |
+
v_out[:, :, input_pos] = v_val
|
234 |
+
|
235 |
+
return k_out, v_out
|
236 |
+
|
237 |
+
|
238 |
+
class Attention(nn.Module):
|
239 |
+
def __init__(self, config: ModelArgs):
|
240 |
+
super().__init__()
|
241 |
+
assert config.dim % config.n_head == 0
|
242 |
+
self.dim = config.dim
|
243 |
+
self.head_dim = config.dim // config.n_head
|
244 |
+
self.n_head = config.n_head
|
245 |
+
self.n_kv_head = config.n_kv_head if config.n_kv_head is not None else config.n_head
|
246 |
+
total_kv_dim = (self.n_head + 2 * self.n_kv_head) * self.head_dim
|
247 |
+
|
248 |
+
# key, query, value projections for all heads, but in a batch
|
249 |
+
self.wqkv = nn.Linear(config.dim, total_kv_dim, bias=False)
|
250 |
+
self.wo = nn.Linear(config.dim, config.dim, bias=False)
|
251 |
+
self.kv_cache = None
|
252 |
+
|
253 |
+
# regularization
|
254 |
+
self.attn_dropout_p = config.attn_dropout_p
|
255 |
+
self.resid_dropout = nn.Dropout(config.resid_dropout_p)
|
256 |
+
|
257 |
+
def forward(
|
258 |
+
self, x: torch.Tensor, freqs_cis: torch.Tensor = None,
|
259 |
+
input_pos: Optional[torch.Tensor] = None,
|
260 |
+
mask: Optional[torch.Tensor] = None
|
261 |
+
):
|
262 |
+
bsz, seqlen, _ = x.shape
|
263 |
+
kv_size = self.n_kv_head * self.head_dim
|
264 |
+
xq, xk, xv = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)
|
265 |
+
|
266 |
+
xq = xq.view(bsz, seqlen, self.n_head, self.head_dim)
|
267 |
+
xk = xk.view(bsz, seqlen, self.n_kv_head, self.head_dim)
|
268 |
+
xv = xv.view(bsz, seqlen, self.n_kv_head, self.head_dim)
|
269 |
+
|
270 |
+
xq = apply_rotary_emb(xq, freqs_cis)
|
271 |
+
xk = apply_rotary_emb(xk, freqs_cis)
|
272 |
+
|
273 |
+
xq, xk, xv = map(lambda x: x.transpose(1, 2), (xq, xk, xv))
|
274 |
+
|
275 |
+
if self.kv_cache is not None:
|
276 |
+
keys, values = self.kv_cache.update(input_pos, xk, xv)
|
277 |
+
else:
|
278 |
+
keys, values = xk, xv
|
279 |
+
keys = keys.repeat_interleave(self.n_head // self.n_kv_head, dim=1)
|
280 |
+
values = values.repeat_interleave(self.n_head // self.n_kv_head, dim=1)
|
281 |
+
|
282 |
+
output = F.scaled_dot_product_attention(
|
283 |
+
xq, keys, values,
|
284 |
+
attn_mask=mask,
|
285 |
+
is_causal=True if mask is None else False, # is_causal=False is for KV cache
|
286 |
+
dropout_p=self.attn_dropout_p if self.training else 0)
|
287 |
+
|
288 |
+
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
|
289 |
+
|
290 |
+
output = self.resid_dropout(self.wo(output))
|
291 |
+
return output
|
292 |
+
|
293 |
+
|
294 |
+
class TransformerBlock(nn.Module):
|
295 |
+
def __init__(self, config: ModelArgs, drop_path: float):
|
296 |
+
super().__init__()
|
297 |
+
self.attention = Attention(config)
|
298 |
+
self.feed_forward = FeedForward(config)
|
299 |
+
self.attention_norm = RMSNorm(config.dim, eps=config.norm_eps)
|
300 |
+
self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps)
|
301 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
302 |
+
|
303 |
+
def forward(
|
304 |
+
self, x: torch.Tensor, freqs_cis: torch.Tensor, start_pos: int, mask: Optional[torch.Tensor] = None):
|
305 |
+
h = x + self.drop_path(self.attention(self.attention_norm(x), freqs_cis, start_pos, mask))
|
306 |
+
out = h + self.drop_path(self.feed_forward(self.ffn_norm(h)))
|
307 |
+
return out
|
308 |
+
|
309 |
+
|
310 |
+
class Transformer(nn.Module):
|
311 |
+
def __init__(self, config: ModelArgs):
|
312 |
+
super().__init__()
|
313 |
+
self.config = config
|
314 |
+
self.vocab_size = config.vocab_size
|
315 |
+
self.n_layer = config.n_layer
|
316 |
+
self.block_size = config.block_size
|
317 |
+
self.num_classes = config.num_classes
|
318 |
+
self.model_type = config.model_type
|
319 |
+
self.cls_token_num = config.cls_token_num
|
320 |
+
self.layer_internal = config.n_layer // 3
|
321 |
+
# self.adapter = Adapter(output_dim=768)
|
322 |
+
# self.adapter = ViT_Adapter()
|
323 |
+
# self.adapter = DeiT_Adapter()
|
324 |
+
self.adapter = Dinov2_Adapter(adapter_size=config.adapter_size, condition_type=config.condition_type)
|
325 |
+
# self.adapter = EVA_Adapter()
|
326 |
+
if config.adapter_size == "small":
|
327 |
+
self.adapter_mlp = MLP(384, config.dim, config.dim)
|
328 |
+
elif config.adapter_size == 'base':
|
329 |
+
self.adapter_mlp = MLP(768, config.dim, config.dim)
|
330 |
+
|
331 |
+
if self.model_type == 'c2i':
|
332 |
+
self.cls_embedding = LabelEmbedder(config.num_classes, config.dim, config.class_dropout_prob)
|
333 |
+
elif self.model_type == 't2i':
|
334 |
+
self.cls_embedding = CaptionEmbedder(config.caption_dim, config.dim, config.class_dropout_prob)
|
335 |
+
else:
|
336 |
+
raise Exception("please check model type")
|
337 |
+
self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim)
|
338 |
+
self.tok_dropout = nn.Dropout(config.token_dropout_p)
|
339 |
+
|
340 |
+
self.condition_embeddings = nn.Embedding(config.vocab_size, config.dim)
|
341 |
+
self.condition_mlp = ConditionEmbedder(self.block_size, config.dim, config.class_dropout_prob, self.block_size, config.vocab_size)
|
342 |
+
self.condition_layers = torch.nn.ModuleList()
|
343 |
+
for layer_id in range(3):
|
344 |
+
self.condition_layers.append(MLP(config.dim,config.dim,config.dim))
|
345 |
+
|
346 |
+
# transformer blocks
|
347 |
+
dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.n_layer)]
|
348 |
+
self.layers = torch.nn.ModuleList()
|
349 |
+
for layer_id in range(config.n_layer):
|
350 |
+
self.layers.append(TransformerBlock(config, dpr[layer_id]))
|
351 |
+
|
352 |
+
# output layer
|
353 |
+
self.norm = RMSNorm(config.dim, eps=config.norm_eps)
|
354 |
+
self.output = nn.Linear(config.dim, config.vocab_size, bias=False)
|
355 |
+
|
356 |
+
# 2d rotary pos embedding
|
357 |
+
grid_size = int(self.block_size ** 0.5)
|
358 |
+
assert grid_size * grid_size == self.block_size
|
359 |
+
self.freqs_cis = precompute_freqs_cis_2d(grid_size, self.config.dim // self.config.n_head, self.config.rope_base, self.cls_token_num)
|
360 |
+
|
361 |
+
# KVCache
|
362 |
+
self.max_batch_size = -1
|
363 |
+
self.max_seq_length = -1
|
364 |
+
|
365 |
+
self.initialize_weights()
|
366 |
+
self.condition_token = None
|
367 |
+
self.mask = get_causal_mask(256)
|
368 |
+
self.global_token = None
|
369 |
+
|
370 |
+
|
371 |
+
def initialize_weights(self):
|
372 |
+
# Initialize nn.Linear and nn.Embedding
|
373 |
+
self.apply(self._init_weights)
|
374 |
+
|
375 |
+
# Zero-out output layers:
|
376 |
+
nn.init.constant_(self.output.weight, 0)
|
377 |
+
|
378 |
+
|
379 |
+
|
380 |
+
def _init_weights(self, module):
|
381 |
+
std = self.config.initializer_range
|
382 |
+
if isinstance(module, nn.Linear):
|
383 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
384 |
+
if module.bias is not None:
|
385 |
+
module.bias.data.zero_()
|
386 |
+
elif isinstance(module, nn.Embedding):
|
387 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
388 |
+
|
389 |
+
|
390 |
+
def setup_caches(self, max_batch_size, max_seq_length, dtype):
|
391 |
+
# if self.max_seq_length >= max_seq_length and self.max_batch_size >= max_batch_size:
|
392 |
+
# return
|
393 |
+
head_dim = self.config.dim // self.config.n_head
|
394 |
+
max_seq_length = find_multiple(max_seq_length, 8) #
|
395 |
+
self.max_seq_length = max_seq_length
|
396 |
+
self.max_batch_size = max_batch_size
|
397 |
+
for b in self.layers:
|
398 |
+
b.attention.kv_cache = KVCache(max_batch_size, max_seq_length, self.config.n_head, head_dim, dtype)
|
399 |
+
|
400 |
+
causal_mask = torch.tril(torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool))
|
401 |
+
self.causal_mask = causal_mask.unsqueeze(0).repeat(self.max_batch_size, 1, 1)
|
402 |
+
grid_size = int(self.config.block_size ** 0.5)
|
403 |
+
assert grid_size * grid_size == self.block_size
|
404 |
+
self.freqs_cis = precompute_freqs_cis_2d(grid_size, self.config.dim // self.config.n_head, self.config.rope_base, self.cls_token_num)
|
405 |
+
|
406 |
+
|
407 |
+
|
408 |
+
def forward(
|
409 |
+
self,
|
410 |
+
idx: torch.Tensor,
|
411 |
+
cond_idx: torch.Tensor, # cond_idx_or_embed
|
412 |
+
input_pos: Optional[torch.Tensor] = None,
|
413 |
+
targets: Optional[torch.Tensor] = None,
|
414 |
+
mask: Optional[torch.Tensor] = None,
|
415 |
+
valid: Optional[torch.Tensor] = None,
|
416 |
+
condition: Optional[torch.Tensor] = None
|
417 |
+
):
|
418 |
+
if idx is not None and cond_idx is not None: # training or naive inference
|
419 |
+
cond_embeddings,drop_ids = self.cls_embedding(cond_idx, train=self.training)
|
420 |
+
cond_embeddings = cond_embeddings[:,:self.cls_token_num]
|
421 |
+
token_embeddings = self.tok_embeddings(idx)
|
422 |
+
if condition is not None:
|
423 |
+
condition_embeddings = self.adapter(condition)
|
424 |
+
condition_embeddings = self.adapter_mlp(condition_embeddings)
|
425 |
+
self.condition_token = self.condition_mlp(condition_embeddings,train=self.training, drop_ids=drop_ids)
|
426 |
+
token_embeddings = torch.cat((cond_embeddings, token_embeddings), dim=1)
|
427 |
+
|
428 |
+
h = self.tok_dropout(token_embeddings)
|
429 |
+
self.freqs_cis = self.freqs_cis.to(h.device)
|
430 |
+
else:
|
431 |
+
if cond_idx is not None: # prefill in inference
|
432 |
+
token_embeddings = self.cls_embedding(cond_idx, train=self.training)
|
433 |
+
token_embeddings = token_embeddings[:,:self.cls_token_num]
|
434 |
+
if condition is not None:
|
435 |
+
condition_embeddings = self.condition_mlp(condition.to(torch.bfloat16),train=self.training)
|
436 |
+
self.condition_token = condition_embeddings
|
437 |
+
|
438 |
+
else: # decode_n_tokens(kv cache) in inference
|
439 |
+
token_embeddings = self.tok_embeddings(idx)
|
440 |
+
bs = token_embeddings.shape[0]
|
441 |
+
mask = self.causal_mask[:bs, None, input_pos]
|
442 |
+
h = self.tok_dropout(token_embeddings)
|
443 |
+
self.freqs_cis = self.freqs_cis
|
444 |
+
|
445 |
+
if self.training:
|
446 |
+
freqs_cis = self.freqs_cis[:token_embeddings.shape[1]]
|
447 |
+
else:
|
448 |
+
freqs_cis = self.freqs_cis[input_pos]
|
449 |
+
# transformer blocks
|
450 |
+
for i, layer in enumerate(self.layers):
|
451 |
+
if i%self.layer_internal == 0:
|
452 |
+
if self.training:
|
453 |
+
h[:, self.cls_token_num-1:] = h[:, self.cls_token_num-1:] + self.condition_layers[i//self.layer_internal](self.condition_token)
|
454 |
+
else:
|
455 |
+
if len(input_pos)>1:
|
456 |
+
h[:, -1:] = h[:, -1:] + self.condition_layers[i//self.layer_internal](self.condition_token[:,0:1])
|
457 |
+
else:
|
458 |
+
h = h + self.condition_layers[i//self.layer_internal](self.condition_token[:,input_pos-self.cls_token_num+1])
|
459 |
+
h = layer(h, freqs_cis, input_pos, mask)
|
460 |
+
# output layers
|
461 |
+
h = self.norm(h)
|
462 |
+
logits = self.output(h).float()
|
463 |
+
|
464 |
+
if self.training:
|
465 |
+
logits = logits[:, self.cls_token_num - 1:].contiguous()
|
466 |
+
# if we are given some desired targets also calculate the loss
|
467 |
+
loss = None
|
468 |
+
if valid is not None:
|
469 |
+
loss_all = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), reduction='none')
|
470 |
+
valid_all = valid[:,None].repeat(1, targets.shape[1]).view(-1)
|
471 |
+
loss = (loss_all * valid_all).sum() / max(valid_all.sum(), 1)
|
472 |
+
elif targets is not None:
|
473 |
+
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
|
474 |
+
|
475 |
+
|
476 |
+
return logits, loss
|
477 |
+
|
478 |
+
|
479 |
+
def get_fsdp_wrap_module_list(self) -> List[nn.Module]:
|
480 |
+
return list(self.layers)
|
481 |
+
|
482 |
+
|
483 |
+
|
484 |
+
#################################################################################
|
485 |
+
# Rotary Positional Embedding Functions #
|
486 |
+
#################################################################################
|
487 |
+
# https://github.com/pytorch-labs/gpt-fast/blob/main/model.py
|
488 |
+
def precompute_freqs_cis(seq_len: int, n_elem: int, base: int = 10000, cls_token_num=120):
|
489 |
+
freqs = 1.0 / (base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem))
|
490 |
+
t = torch.arange(seq_len, device=freqs.device)
|
491 |
+
freqs = torch.outer(t, freqs) # (seq_len, head_dim // 2)
|
492 |
+
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
|
493 |
+
cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1) # (cls_token_num+seq_len, head_dim // 2, 2)
|
494 |
+
cond_cache = torch.cat([torch.zeros(cls_token_num, n_elem // 2, 2), cache]) # (cls_token_num+seq_len, head_dim // 2, 2)
|
495 |
+
return cond_cache
|
496 |
+
|
497 |
+
|
498 |
+
def precompute_freqs_cis_2d(grid_size: int, n_elem: int, base: int = 10000, cls_token_num=120):
|
499 |
+
# split the dimension into half, one for x and one for y
|
500 |
+
half_dim = n_elem // 2
|
501 |
+
freqs = 1.0 / (base ** (torch.arange(0, half_dim, 2)[: (half_dim // 2)].float() / half_dim))
|
502 |
+
t = torch.arange(grid_size, device=freqs.device)
|
503 |
+
freqs = torch.outer(t, freqs) # (grid_size, head_dim // 2)
|
504 |
+
freqs_grid = torch.concat([
|
505 |
+
freqs[:, None, :].expand(-1, grid_size, -1),
|
506 |
+
freqs[None, :, :].expand(grid_size, -1, -1),
|
507 |
+
], dim=-1) # (grid_size, grid_size, head_dim // 2)
|
508 |
+
cache_grid = torch.stack([torch.cos(freqs_grid), torch.sin(freqs_grid)], dim=-1) # (grid_size, grid_size, head_dim // 2, 2)
|
509 |
+
cache = cache_grid.flatten(0, 1)
|
510 |
+
cond_cache = torch.cat([torch.zeros(cls_token_num, n_elem // 2, 2), cache]) # (cls_token_num+grid_size**2, head_dim // 2, 2)
|
511 |
+
return cond_cache
|
512 |
+
|
513 |
+
|
514 |
+
def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor):
|
515 |
+
# x: (bs, seq_len, n_head, head_dim)
|
516 |
+
# freqs_cis (seq_len, head_dim // 2, 2)
|
517 |
+
xshaped = x.float().reshape(*x.shape[:-1], -1, 2) # (bs, seq_len, n_head, head_dim//2, 2)
|
518 |
+
freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2) # (1, seq_len, 1, head_dim//2, 2)
|
519 |
+
x_out2 = torch.stack([
|
520 |
+
xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
|
521 |
+
xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
|
522 |
+
], dim=-1)
|
523 |
+
x_out2 = x_out2.flatten(3)
|
524 |
+
return x_out2.type_as(x)
|
525 |
+
|
526 |
+
|
527 |
+
|
528 |
+
#################################################################################
|
529 |
+
# GPT Configs #
|
530 |
+
#################################################################################
|
531 |
+
### text-conditional
|
532 |
+
def GPT_7B(**kwargs):
|
533 |
+
return Transformer(ModelArgs(n_layer=32, n_head=32, dim=4096, **kwargs)) # 6.6B
|
534 |
+
|
535 |
+
def GPT_3B(**kwargs):
|
536 |
+
return Transformer(ModelArgs(n_layer=24, n_head=32, dim=3200, **kwargs)) # 3.1B
|
537 |
+
|
538 |
+
def GPT_1B(**kwargs):
|
539 |
+
return Transformer(ModelArgs(n_layer=22, n_head=32, dim=2048, **kwargs)) # 1.2B
|
540 |
+
|
541 |
+
### class-conditional
|
542 |
+
def GPT_XXXL(**kwargs):
|
543 |
+
return Transformer(ModelArgs(n_layer=48, n_head=40, dim=2560, **kwargs)) # 3.9B
|
544 |
+
|
545 |
+
def GPT_XXL(**kwargs):
|
546 |
+
return Transformer(ModelArgs(n_layer=48, n_head=24, dim=1536, **kwargs)) # 1.4B
|
547 |
+
|
548 |
+
def GPT_XL(**kwargs):
|
549 |
+
return Transformer(ModelArgs(n_layer=36, n_head=20, dim=1280, **kwargs)) # 775M
|
550 |
+
|
551 |
+
def GPT_L(**kwargs):
|
552 |
+
return Transformer(ModelArgs(n_layer=24, n_head=16, dim=1024, **kwargs)) # 343M
|
553 |
+
|
554 |
+
def GPT_B(**kwargs):
|
555 |
+
return Transformer(ModelArgs(n_layer=12, n_head=12, dim=768, **kwargs)) # 111M
|
556 |
+
|
557 |
+
|
558 |
+
GPT_models = {
|
559 |
+
'GPT-B': GPT_B, 'GPT-L': GPT_L, 'GPT-XL': GPT_XL, 'GPT-XXL': GPT_XXL, 'GPT-XXXL': GPT_XXXL,
|
560 |
+
'GPT-1B': GPT_1B, 'GPT-3B': GPT_3B, 'GPT-7B': GPT_7B,
|
561 |
+
}
|
autoregressive/sample/sample_c2i.py
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from:
|
2 |
+
# DiT: https://github.com/facebookresearch/DiT/blob/main/sample.py
|
3 |
+
import torch
|
4 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
5 |
+
torch.backends.cudnn.allow_tf32 = True
|
6 |
+
torch.set_float32_matmul_precision('high')
|
7 |
+
setattr(torch.nn.Linear, 'reset_parameters', lambda self: None)
|
8 |
+
setattr(torch.nn.LayerNorm, 'reset_parameters', lambda self: None)
|
9 |
+
from torchvision.utils import save_image
|
10 |
+
import os
|
11 |
+
import sys
|
12 |
+
current_directory = os.getcwd()
|
13 |
+
sys.path.append(current_directory)
|
14 |
+
|
15 |
+
from PIL import Image
|
16 |
+
import time
|
17 |
+
import argparse
|
18 |
+
from tokenizer.tokenizer_image.vq_model import VQ_models
|
19 |
+
from autoregressive.models.gpt import GPT_models
|
20 |
+
from autoregressive.models.generate import generate
|
21 |
+
from functools import partial
|
22 |
+
import torch.nn.functional as F
|
23 |
+
import numpy as np
|
24 |
+
import cv2
|
25 |
+
|
26 |
+
|
27 |
+
def main(args):
|
28 |
+
# Setup PyTorch:
|
29 |
+
torch.manual_seed(args.seed)
|
30 |
+
torch.backends.cudnn.deterministic = True
|
31 |
+
torch.backends.cudnn.benchmark = False
|
32 |
+
torch.set_grad_enabled(False)
|
33 |
+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
34 |
+
|
35 |
+
# create and load model
|
36 |
+
vq_model = VQ_models[args.vq_model](
|
37 |
+
codebook_size=args.codebook_size,
|
38 |
+
codebook_embed_dim=args.codebook_embed_dim)
|
39 |
+
vq_model.to(device)
|
40 |
+
vq_model.eval()
|
41 |
+
checkpoint = torch.load(args.vq_ckpt, map_location="cpu")
|
42 |
+
vq_model.load_state_dict(checkpoint["model"])
|
43 |
+
del checkpoint
|
44 |
+
print(f"image tokenizer is loaded")
|
45 |
+
|
46 |
+
# create and load gpt model
|
47 |
+
precision = {'none': torch.float32, 'bf16': torch.bfloat16, 'fp16': torch.float16}[args.precision]
|
48 |
+
latent_size = args.image_size // args.downsample_size
|
49 |
+
gpt_model = GPT_models[args.gpt_model](
|
50 |
+
vocab_size=args.codebook_size,
|
51 |
+
block_size=latent_size ** 2,
|
52 |
+
num_classes=args.num_classes,
|
53 |
+
cls_token_num=args.cls_token_num,
|
54 |
+
model_type=args.gpt_type,
|
55 |
+
condition_token_num=args.condition_token_nums,
|
56 |
+
image_size=args.image_size
|
57 |
+
).to(device=device, dtype=precision)
|
58 |
+
|
59 |
+
_, file_extension = os.path.splitext(args.gpt_ckpt)
|
60 |
+
if file_extension.lower() == '.safetensors':
|
61 |
+
from safetensors.torch import load_file
|
62 |
+
model_weight = load_file(args.gpt_ckpt)
|
63 |
+
gpt_model.load_state_dict(model_weight, strict=False)
|
64 |
+
gpt_model.eval()
|
65 |
+
else:
|
66 |
+
checkpoint = torch.load(args.gpt_ckpt, map_location="cpu")
|
67 |
+
if "model" in checkpoint: # ddp
|
68 |
+
model_weight = checkpoint["model"]
|
69 |
+
elif "module" in checkpoint: # deepspeed
|
70 |
+
model_weight = checkpoint["module"]
|
71 |
+
elif "state_dict" in checkpoint:
|
72 |
+
model_weight = checkpoint["state_dict"]
|
73 |
+
else:
|
74 |
+
raise Exception("please check model weight")
|
75 |
+
gpt_model.load_state_dict(model_weight, strict=False)
|
76 |
+
gpt_model.eval()
|
77 |
+
del checkpoint
|
78 |
+
print(f"gpt model is loaded")
|
79 |
+
|
80 |
+
if args.compile:
|
81 |
+
print(f"compiling the model...")
|
82 |
+
gpt_model = torch.compile(
|
83 |
+
gpt_model,
|
84 |
+
mode="reduce-overhead",
|
85 |
+
fullgraph=True
|
86 |
+
) # requires PyTorch 2.0 (optional)
|
87 |
+
else:
|
88 |
+
print(f"no need to compile model in demo")
|
89 |
+
|
90 |
+
condition_null = None
|
91 |
+
if args.condition_type == 'canny':
|
92 |
+
sample_list = [650, 2312, 15000, 48850] # canny
|
93 |
+
elif args.condition_type == 'depth':
|
94 |
+
sample_list = [101, 4351, 10601, 48901]
|
95 |
+
|
96 |
+
class_labels = [np.load(f"condition/example/c2i/{args.condition_type}/{i}.npy")[0] for i in sample_list]
|
97 |
+
condition_imgs = [np.array(Image.open((f"condition/example/c2i/{args.condition_type}/{i}.png")))[None,None,...] for i in sample_list]
|
98 |
+
condition_imgs = torch.from_numpy(np.concatenate(condition_imgs, axis=0)).to(device).to(torch.float32)/255
|
99 |
+
condition_imgs = 2*(condition_imgs-0.5)
|
100 |
+
print(condition_imgs.shape)
|
101 |
+
c_indices = torch.tensor(class_labels, device=device)
|
102 |
+
qzshape = [len(class_labels), args.codebook_embed_dim, latent_size, latent_size]
|
103 |
+
t1 = time.time()
|
104 |
+
|
105 |
+
index_sample = generate(
|
106 |
+
gpt_model, c_indices, latent_size ** 2, condition=condition_imgs.repeat(1,3,1,1).to(precision), condition_null=condition_null, condition_token_nums=args.condition_token_nums,
|
107 |
+
cfg_scale=args.cfg_scale, cfg_interval=args.cfg_interval,
|
108 |
+
temperature=args.temperature, top_k=args.top_k,
|
109 |
+
top_p=args.top_p, sample_logits=True,
|
110 |
+
)
|
111 |
+
|
112 |
+
sampling_time = time.time() - t1
|
113 |
+
print(f"gpt sampling takes about {sampling_time:.2f} seconds.")
|
114 |
+
|
115 |
+
t2 = time.time()
|
116 |
+
samples = vq_model.decode_code(index_sample, qzshape) # output value is between [-1, 1]
|
117 |
+
decoder_time = time.time() - t2
|
118 |
+
print(f"decoder takes about {decoder_time:.2f} seconds.")
|
119 |
+
# Save and display images:
|
120 |
+
condition_imgs = condition_imgs.repeat(1,3,1,1)
|
121 |
+
samples = torch.cat((condition_imgs[:4], samples[:4]),dim=0)
|
122 |
+
save_image(samples, f"sample/example/sample_{args.gpt_type}_{args.condition_type}.png", nrow=4, normalize=True, value_range=(-1, 1))
|
123 |
+
|
124 |
+
|
125 |
+
|
126 |
+
if __name__ == "__main__":
|
127 |
+
parser = argparse.ArgumentParser()
|
128 |
+
parser.add_argument("--gpt-model", type=str, choices=list(GPT_models.keys()), default="GPT-B")
|
129 |
+
parser.add_argument("--gpt-ckpt", type=str, default=None)
|
130 |
+
parser.add_argument("--gpt-type", type=str, choices=['c2i', 't2i'], default="c2i", help="class-conditional or text-conditional")
|
131 |
+
parser.add_argument("--from-fsdp", action='store_true')
|
132 |
+
parser.add_argument("--cls-token-num", type=int, default=1, help="max token number of condition input")
|
133 |
+
parser.add_argument("--precision", type=str, default='bf16', choices=["none", "fp16", "bf16"])
|
134 |
+
parser.add_argument("--compile", action='store_true', default=False)
|
135 |
+
parser.add_argument("--vq-model", type=str, choices=list(VQ_models.keys()), default="VQ-16")
|
136 |
+
parser.add_argument("--vq-ckpt", type=str, default=None, help="ckpt path for vq model")
|
137 |
+
parser.add_argument("--codebook-size", type=int, default=16384, help="codebook size for vector quantization")
|
138 |
+
parser.add_argument("--codebook-embed-dim", type=int, default=8, help="codebook dimension for vector quantization")
|
139 |
+
parser.add_argument("--image-size", type=int, choices=[256, 384, 512], default=256)
|
140 |
+
parser.add_argument("--downsample-size", type=int, choices=[8, 16], default=16)
|
141 |
+
parser.add_argument("--num-classes", type=int, default=1000)
|
142 |
+
parser.add_argument("--cfg-scale", type=float, default=4.0)
|
143 |
+
parser.add_argument("--cfg-interval", type=float, default=-1)
|
144 |
+
parser.add_argument("--seed", type=int, default=0)
|
145 |
+
parser.add_argument("--top-k", type=int, default=2000,help="top-k value to sample with")
|
146 |
+
parser.add_argument("--temperature", type=float, default=1.0, help="temperature value to sample with")
|
147 |
+
parser.add_argument("--top-p", type=float, default=1.0, help="top-p value to sample with")
|
148 |
+
parser.add_argument("--condition-token-nums", type=int, default=0)
|
149 |
+
parser.add_argument("--condition-type", type=str, default='canny', choices=['canny', 'depth'])
|
150 |
+
args = parser.parse_args()
|
151 |
+
main(args)
|
autoregressive/sample/sample_c2i_ddp.py
ADDED
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from:
|
2 |
+
# DiT: https://github.com/facebookresearch/DiT/blob/main/sample_ddp.py
|
3 |
+
import torch
|
4 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
5 |
+
torch.backends.cudnn.allow_tf32 = True
|
6 |
+
import torch.nn.functional as F
|
7 |
+
import torch.distributed as dist
|
8 |
+
|
9 |
+
from tqdm import tqdm
|
10 |
+
import os
|
11 |
+
from PIL import Image
|
12 |
+
import numpy as np
|
13 |
+
import math
|
14 |
+
import argparse
|
15 |
+
|
16 |
+
from tokenizer.tokenizer_image.vq_model import VQ_models
|
17 |
+
from autoregressive.models.gpt import GPT_models
|
18 |
+
from autoregressive.models.generate import generate
|
19 |
+
|
20 |
+
|
21 |
+
def create_npz_from_sample_folder(sample_dir, num=50_000):
|
22 |
+
"""
|
23 |
+
Builds a single .npz file from a folder of .png samples.
|
24 |
+
"""
|
25 |
+
samples = []
|
26 |
+
for i in tqdm(range(num), desc="Building .npz file from samples"):
|
27 |
+
sample_pil = Image.open(f"{sample_dir}/{i:06d}.png")
|
28 |
+
sample_np = np.asarray(sample_pil).astype(np.uint8)
|
29 |
+
samples.append(sample_np)
|
30 |
+
samples = np.stack(samples)
|
31 |
+
assert samples.shape == (num, samples.shape[1], samples.shape[2], 3)
|
32 |
+
npz_path = f"{sample_dir}.npz"
|
33 |
+
np.savez(npz_path, arr_0=samples)
|
34 |
+
print(f"Saved .npz file to {npz_path} [shape={samples.shape}].")
|
35 |
+
return npz_path
|
36 |
+
|
37 |
+
|
38 |
+
def main(args):
|
39 |
+
# Setup PyTorch:
|
40 |
+
assert torch.cuda.is_available(), "Sampling with DDP requires at least one GPU. sample.py supports CPU-only usage"
|
41 |
+
torch.set_grad_enabled(False)
|
42 |
+
|
43 |
+
# Setup DDP:
|
44 |
+
dist.init_process_group("nccl")
|
45 |
+
rank = dist.get_rank()
|
46 |
+
device = rank % torch.cuda.device_count()
|
47 |
+
seed = args.global_seed * dist.get_world_size() + rank
|
48 |
+
torch.manual_seed(seed)
|
49 |
+
torch.cuda.set_device(device)
|
50 |
+
print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.")
|
51 |
+
|
52 |
+
# create and load model
|
53 |
+
vq_model = VQ_models[args.vq_model](
|
54 |
+
codebook_size=args.codebook_size,
|
55 |
+
codebook_embed_dim=args.codebook_embed_dim)
|
56 |
+
vq_model.to(device)
|
57 |
+
vq_model.eval()
|
58 |
+
checkpoint = torch.load(args.vq_ckpt, map_location="cpu")
|
59 |
+
vq_model.load_state_dict(checkpoint["model"])
|
60 |
+
del checkpoint
|
61 |
+
|
62 |
+
# create and load gpt model
|
63 |
+
precision = {'none': torch.float32, 'bf16': torch.bfloat16, 'fp16': torch.float16}[args.precision]
|
64 |
+
latent_size = args.image_size // args.downsample_size
|
65 |
+
gpt_model = GPT_models[args.gpt_model](
|
66 |
+
vocab_size=args.codebook_size,
|
67 |
+
block_size=latent_size ** 2,
|
68 |
+
num_classes=args.num_classes,
|
69 |
+
cls_token_num=args.cls_token_num,
|
70 |
+
model_type=args.gpt_type,
|
71 |
+
).to(device=device, dtype=precision)
|
72 |
+
checkpoint = torch.load(args.gpt_ckpt, map_location="cpu")
|
73 |
+
if args.from_fsdp: # fsdp
|
74 |
+
model_weight = checkpoint
|
75 |
+
elif "model" in checkpoint: # ddp
|
76 |
+
model_weight = checkpoint["model"]
|
77 |
+
elif "module" in checkpoint: # deepspeed
|
78 |
+
model_weight = checkpoint["module"]
|
79 |
+
elif "state_dict" in checkpoint:
|
80 |
+
model_weight = checkpoint["state_dict"]
|
81 |
+
else:
|
82 |
+
raise Exception("please check model weight, maybe add --from-fsdp to run command")
|
83 |
+
# if 'freqs_cis' in model_weight:
|
84 |
+
# model_weight.pop('freqs_cis')
|
85 |
+
gpt_model.load_state_dict(model_weight, strict=False)
|
86 |
+
gpt_model.eval()
|
87 |
+
del checkpoint
|
88 |
+
|
89 |
+
if args.compile:
|
90 |
+
print(f"compiling the model...")
|
91 |
+
gpt_model = torch.compile(
|
92 |
+
gpt_model,
|
93 |
+
mode="reduce-overhead",
|
94 |
+
fullgraph=True
|
95 |
+
) # requires PyTorch 2.0 (optional)
|
96 |
+
else:
|
97 |
+
print(f"no model compile")
|
98 |
+
|
99 |
+
# Create folder to save samples:
|
100 |
+
model_string_name = args.gpt_model.replace("/", "-")
|
101 |
+
if args.from_fsdp:
|
102 |
+
ckpt_string_name = args.gpt_ckpt.split('/')[-2]
|
103 |
+
else:
|
104 |
+
ckpt_string_name = os.path.basename(args.gpt_ckpt).replace(".pth", "").replace(".pt", "")
|
105 |
+
folder_name = f"{model_string_name}-{ckpt_string_name}-size-{args.image_size}-size-{args.image_size_eval}-{args.vq_model}-" \
|
106 |
+
f"topk-{args.top_k}-topp-{args.top_p}-temperature-{args.temperature}-" \
|
107 |
+
f"cfg-{args.cfg_scale}-seed-{args.global_seed}"
|
108 |
+
sample_folder_dir = f"{args.sample_dir}/{folder_name}"
|
109 |
+
if rank == 0:
|
110 |
+
os.makedirs(sample_folder_dir, exist_ok=True)
|
111 |
+
print(f"Saving .png samples at {sample_folder_dir}")
|
112 |
+
dist.barrier()
|
113 |
+
|
114 |
+
# Figure out how many samples we need to generate on each GPU and how many iterations we need to run:
|
115 |
+
n = args.per_proc_batch_size
|
116 |
+
global_batch_size = n * dist.get_world_size()
|
117 |
+
# To make things evenly-divisible, we'll sample a bit more than we need and then discard the extra samples:
|
118 |
+
total_samples = int(math.ceil(args.num_fid_samples / global_batch_size) * global_batch_size)
|
119 |
+
if rank == 0:
|
120 |
+
print(f"Total number of images that will be sampled: {total_samples}")
|
121 |
+
assert total_samples % dist.get_world_size() == 0, "total_samples must be divisible by world_size"
|
122 |
+
samples_needed_this_gpu = int(total_samples // dist.get_world_size())
|
123 |
+
assert samples_needed_this_gpu % n == 0, "samples_needed_this_gpu must be divisible by the per-GPU batch size"
|
124 |
+
iterations = int(samples_needed_this_gpu // n)
|
125 |
+
pbar = range(iterations)
|
126 |
+
pbar = tqdm(pbar) if rank == 0 else pbar
|
127 |
+
total = 0
|
128 |
+
for _ in pbar:
|
129 |
+
# Sample inputs:
|
130 |
+
c_indices = torch.randint(0, args.num_classes, (n,), device=device)
|
131 |
+
qzshape = [len(c_indices), args.codebook_embed_dim, latent_size, latent_size]
|
132 |
+
|
133 |
+
index_sample = generate(
|
134 |
+
gpt_model, c_indices, latent_size ** 2,
|
135 |
+
cfg_scale=args.cfg_scale, cfg_interval=args.cfg_interval,
|
136 |
+
temperature=args.temperature, top_k=args.top_k,
|
137 |
+
top_p=args.top_p, sample_logits=True,
|
138 |
+
)
|
139 |
+
|
140 |
+
samples = vq_model.decode_code(index_sample, qzshape) # output value is between [-1, 1]
|
141 |
+
if args.image_size_eval != args.image_size:
|
142 |
+
samples = F.interpolate(samples, size=(args.image_size_eval, args.image_size_eval), mode='bicubic')
|
143 |
+
samples = torch.clamp(127.5 * samples + 128.0, 0, 255).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy()
|
144 |
+
|
145 |
+
# Save samples to disk as individual .png files
|
146 |
+
for i, sample in enumerate(samples):
|
147 |
+
index = i * dist.get_world_size() + rank + total
|
148 |
+
Image.fromarray(sample).save(f"{sample_folder_dir}/{index:06d}.png")
|
149 |
+
total += global_batch_size
|
150 |
+
|
151 |
+
# Make sure all processes have finished saving their samples before attempting to convert to .npz
|
152 |
+
dist.barrier()
|
153 |
+
if rank == 0:
|
154 |
+
create_npz_from_sample_folder(sample_folder_dir, args.num_fid_samples)
|
155 |
+
print("Done.")
|
156 |
+
dist.barrier()
|
157 |
+
dist.destroy_process_group()
|
158 |
+
|
159 |
+
|
160 |
+
|
161 |
+
if __name__ == "__main__":
|
162 |
+
parser = argparse.ArgumentParser()
|
163 |
+
parser.add_argument("--gpt-model", type=str, choices=list(GPT_models.keys()), default="GPT-B")
|
164 |
+
parser.add_argument("--gpt-ckpt", type=str, default=None)
|
165 |
+
parser.add_argument("--gpt-type", type=str, choices=['c2i', 't2i'], default="c2i", help="class-conditional or text-conditional")
|
166 |
+
parser.add_argument("--from-fsdp", action='store_true')
|
167 |
+
parser.add_argument("--cls-token-num", type=int, default=1, help="max token number of condition input")
|
168 |
+
parser.add_argument("--precision", type=str, default='bf16', choices=["none", "fp16", "bf16"])
|
169 |
+
parser.add_argument("--compile", action='store_true', default=True)
|
170 |
+
parser.add_argument("--vq-model", type=str, choices=list(VQ_models.keys()), default="VQ-16")
|
171 |
+
parser.add_argument("--vq-ckpt", type=str, default=None, help="ckpt path for vq model")
|
172 |
+
parser.add_argument("--codebook-size", type=int, default=16384, help="codebook size for vector quantization")
|
173 |
+
parser.add_argument("--codebook-embed-dim", type=int, default=8, help="codebook dimension for vector quantization")
|
174 |
+
parser.add_argument("--image-size", type=int, choices=[256, 384, 512], default=384)
|
175 |
+
parser.add_argument("--image-size-eval", type=int, choices=[256, 384, 512], default=256)
|
176 |
+
parser.add_argument("--downsample-size", type=int, choices=[8, 16], default=16)
|
177 |
+
parser.add_argument("--num-classes", type=int, default=1000)
|
178 |
+
parser.add_argument("--cfg-scale", type=float, default=1.5)
|
179 |
+
parser.add_argument("--cfg-interval", type=float, default=-1)
|
180 |
+
parser.add_argument("--sample-dir", type=str, default="samples")
|
181 |
+
parser.add_argument("--per-proc-batch-size", type=int, default=32)
|
182 |
+
parser.add_argument("--num-fid-samples", type=int, default=5000)
|
183 |
+
parser.add_argument("--global-seed", type=int, default=0)
|
184 |
+
parser.add_argument("--top-k", type=int, default=0,help="top-k value to sample with")
|
185 |
+
parser.add_argument("--temperature", type=float, default=1.0, help="temperature value to sample with")
|
186 |
+
parser.add_argument("--top-p", type=float, default=1.0, help="top-p value to sample with")
|
187 |
+
args = parser.parse_args()
|
188 |
+
main(args)
|
autoregressive/sample/sample_t2i.py
ADDED
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
3 |
+
torch.backends.cudnn.allow_tf32 = True
|
4 |
+
torch.set_float32_matmul_precision('high')
|
5 |
+
setattr(torch.nn.Linear, 'reset_parameters', lambda self: None) # disable default parameter init for faster speed
|
6 |
+
setattr(torch.nn.LayerNorm, 'reset_parameters', lambda self: None) # disable default parameter init for faster speed
|
7 |
+
from torchvision.utils import save_image
|
8 |
+
|
9 |
+
import os
|
10 |
+
import sys
|
11 |
+
current_directory = os.getcwd()
|
12 |
+
sys.path.append(current_directory)
|
13 |
+
import time
|
14 |
+
import argparse
|
15 |
+
from tokenizer.tokenizer_image.vq_model import VQ_models
|
16 |
+
from language.t5 import T5Embedder
|
17 |
+
from autoregressive.models.gpt import GPT_models
|
18 |
+
from autoregressive.models.gpt_t2i import GPT_models
|
19 |
+
from autoregressive.models.generate import generate
|
20 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
21 |
+
from dataset.t2i_control import build_t2i_control_code
|
22 |
+
from accelerate import Accelerator
|
23 |
+
from dataset.build import build_dataset
|
24 |
+
from pathlib import Path
|
25 |
+
from accelerate.utils import ProjectConfiguration, set_seed
|
26 |
+
import torch.nn.functional as F
|
27 |
+
from condition.canny import CannyDetector
|
28 |
+
from condition.hed import HEDdetector
|
29 |
+
import numpy as np
|
30 |
+
from PIL import Image
|
31 |
+
from condition.lineart import LineArt
|
32 |
+
import cv2
|
33 |
+
from transformers import DPTImageProcessor, DPTForDepthEstimation
|
34 |
+
def main(args):
|
35 |
+
# Setup PyTorch:
|
36 |
+
torch.manual_seed(args.seed)
|
37 |
+
torch.backends.cudnn.deterministic = True
|
38 |
+
torch.backends.cudnn.benchmark = False
|
39 |
+
torch.set_grad_enabled(False)
|
40 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
41 |
+
|
42 |
+
# create and load model
|
43 |
+
vq_model = VQ_models[args.vq_model](
|
44 |
+
codebook_size=args.codebook_size,
|
45 |
+
codebook_embed_dim=args.codebook_embed_dim)
|
46 |
+
vq_model.to(device)
|
47 |
+
vq_model.eval()
|
48 |
+
checkpoint = torch.load(args.vq_ckpt, map_location="cpu")
|
49 |
+
vq_model.load_state_dict(checkpoint["model"])
|
50 |
+
del checkpoint
|
51 |
+
print(f"image tokenizer is loaded")
|
52 |
+
|
53 |
+
# create and load gpt model
|
54 |
+
precision = {'none': torch.float32, 'bf16': torch.bfloat16, 'fp16': torch.float16}[args.precision]
|
55 |
+
latent_size = args.image_size // args.downsample_size
|
56 |
+
gpt_model = GPT_models[args.gpt_model](
|
57 |
+
block_size=latent_size ** 2,
|
58 |
+
cls_token_num=args.cls_token_num,
|
59 |
+
model_type=args.gpt_type,
|
60 |
+
condition_type=args.condition_type,
|
61 |
+
).to(device=device, dtype=precision)
|
62 |
+
|
63 |
+
_, file_extension = os.path.splitext(args.gpt_ckpt)
|
64 |
+
if file_extension.lower() == '.safetensors':
|
65 |
+
from safetensors.torch import load_file
|
66 |
+
model_weight = load_file(args.gpt_ckpt)
|
67 |
+
gpt_model.load_state_dict(model_weight, strict=False)
|
68 |
+
gpt_model.eval()
|
69 |
+
else:
|
70 |
+
checkpoint = torch.load(args.gpt_ckpt, map_location="cpu")
|
71 |
+
if "model" in checkpoint: # ddp
|
72 |
+
model_weight = checkpoint["model"]
|
73 |
+
elif "module" in checkpoint: # deepspeed
|
74 |
+
model_weight = checkpoint["module"]
|
75 |
+
elif "state_dict" in checkpoint:
|
76 |
+
model_weight = checkpoint["state_dict"]
|
77 |
+
else:
|
78 |
+
raise Exception("please check model weight")
|
79 |
+
gpt_model.load_state_dict(model_weight, strict=False)
|
80 |
+
gpt_model.eval()
|
81 |
+
del checkpoint
|
82 |
+
print(f"gpt model is loaded")
|
83 |
+
|
84 |
+
if args.compile:
|
85 |
+
print(f"compiling the model...")
|
86 |
+
gpt_model = torch.compile(
|
87 |
+
gpt_model,
|
88 |
+
mode="reduce-overhead",
|
89 |
+
fullgraph=True
|
90 |
+
) # requires PyTorch 2.0 (optional)
|
91 |
+
else:
|
92 |
+
print(f"no need to compile model in demo")
|
93 |
+
|
94 |
+
assert os.path.exists(args.t5_path)
|
95 |
+
t5_model = T5Embedder(
|
96 |
+
device=device,
|
97 |
+
local_cache=True,
|
98 |
+
cache_dir=args.t5_path,
|
99 |
+
dir_or_name=args.t5_model_type,
|
100 |
+
torch_dtype=precision,
|
101 |
+
model_max_length=args.t5_feature_max_len,
|
102 |
+
)
|
103 |
+
|
104 |
+
|
105 |
+
if args.condition_type == 'canny':
|
106 |
+
get_control = CannyDetector()
|
107 |
+
elif args.condition_type == 'hed':
|
108 |
+
get_control = HEDdetector().to(device).eval()
|
109 |
+
elif args.condition_type == 'lineart':
|
110 |
+
get_control = LineArt()
|
111 |
+
get_control.load_state_dict(torch.load('condition/ckpts/model.pth', map_location=torch.device('cpu')))
|
112 |
+
get_control.to(device)
|
113 |
+
elif args.condition_type == 'depth':
|
114 |
+
processor = DPTImageProcessor.from_pretrained("condition/ckpts/dpt_large")
|
115 |
+
model = DPTForDepthEstimation.from_pretrained("condition/ckpts/dpt_large").to(device)
|
116 |
+
with torch.no_grad():
|
117 |
+
|
118 |
+
condition_path = args.condition_path
|
119 |
+
if args.condition_type == 'seg':
|
120 |
+
condition_img = torch.from_numpy(np.array(Image.open(condition_path)))
|
121 |
+
condition_img = condition_img.permute(2,0,1).unsqueeze(0).repeat(2,1,1,1)
|
122 |
+
elif args.condition_type == 'canny':
|
123 |
+
condition_img = get_control(np.array(Image.open(condition_path)))
|
124 |
+
condition_img = torch.from_numpy(condition_img[None,None,...]).repeat(2,3,1,1)
|
125 |
+
elif args.condition_type == 'hed':
|
126 |
+
condition_img = get_control(torch.from_numpy(np.array(Image.open(condition_path))).permute(2,0,1).unsqueeze(0).to(device))
|
127 |
+
condition_img = condition_img.unsqueeze(1).repeat(2,3,1,1)
|
128 |
+
elif args.condition_type == 'lineart':
|
129 |
+
condition_img = get_control(torch.from_numpy(np.array(Image.open(condition_path))).permute(2,0,1).unsqueeze(0).to(device).float())
|
130 |
+
condition_img = condition_img.repeat(2,3,1,1) * 255
|
131 |
+
elif args.condition_type == 'depth':
|
132 |
+
images = Image.open(condition_path)
|
133 |
+
inputs = processor(images=images, return_tensors="pt", size=(512,512)).to(device)
|
134 |
+
outputs = model(**inputs)
|
135 |
+
condition_img = outputs.predicted_depth
|
136 |
+
condition_img = condition_img.unsqueeze(0).repeat(2,3,1,1)
|
137 |
+
condition_img = (condition_img * 255 / condition_img.max())
|
138 |
+
condition_img = condition_img.to(device)
|
139 |
+
condition_img = 2*(condition_img/255 - 0.5)
|
140 |
+
prompts = [args.prompt if args.prompt is not None else "a high-quality image"]
|
141 |
+
prompts = prompts * 2
|
142 |
+
caption_embs, emb_masks = t5_model.get_text_embeddings(prompts)
|
143 |
+
|
144 |
+
if not args.no_left_padding:
|
145 |
+
print(f"processing left-padding...")
|
146 |
+
# a naive way to implement left-padding
|
147 |
+
new_emb_masks = torch.flip(emb_masks, dims=[-1])
|
148 |
+
new_caption_embs = []
|
149 |
+
for idx, (caption_emb, emb_mask) in enumerate(zip(caption_embs, emb_masks)):
|
150 |
+
valid_num = int(emb_mask.sum().item())
|
151 |
+
print(f' prompt {idx} token len: {valid_num}')
|
152 |
+
new_caption_emb = torch.cat([caption_emb[valid_num:],caption_emb[:valid_num]])
|
153 |
+
new_caption_embs.append(new_caption_emb)
|
154 |
+
new_caption_embs = torch.stack(new_caption_embs)
|
155 |
+
else:
|
156 |
+
new_caption_embs, new_emb_masks = caption_embs, emb_masks
|
157 |
+
c_indices = new_caption_embs * new_emb_masks[:,:, None]
|
158 |
+
c_emb_masks = new_emb_masks
|
159 |
+
qzshape = [len(c_indices), args.codebook_embed_dim, args.image_H//args.downsample_size, args.image_W//args.downsample_size]
|
160 |
+
t1 = time.time()
|
161 |
+
index_sample = generate(
|
162 |
+
gpt_model, c_indices, (args.image_H//args.downsample_size)*(args.image_W//args.downsample_size),#latent_size ** 2,
|
163 |
+
c_emb_masks, condition=condition_img.to(precision),
|
164 |
+
cfg_scale=args.cfg_scale,
|
165 |
+
temperature=args.temperature, top_k=args.top_k,
|
166 |
+
top_p=args.top_p, sample_logits=True,
|
167 |
+
)
|
168 |
+
sampling_time = time.time() - t1
|
169 |
+
print(f"Full sampling takes about {sampling_time:.2f} seconds.")
|
170 |
+
|
171 |
+
t2 = time.time()
|
172 |
+
print(index_sample.shape)
|
173 |
+
samples = vq_model.decode_code(index_sample, qzshape) # output value is between [-1, 1]
|
174 |
+
decoder_time = time.time() - t2
|
175 |
+
print(f"decoder takes about {decoder_time:.2f} seconds.")
|
176 |
+
|
177 |
+
samples = torch.cat((condition_img[0:1], samples), dim=0)
|
178 |
+
save_image(samples, f"sample/example/sample_t2i_{args.condition_type}.png", nrow=4, normalize=True, value_range=(-1, 1))
|
179 |
+
print(f"image is saved to sample/example/sample_t2i_{args.condition_type}.png")
|
180 |
+
print(prompts)
|
181 |
+
|
182 |
+
|
183 |
+
if __name__ == "__main__":
|
184 |
+
parser = argparse.ArgumentParser()
|
185 |
+
parser.add_argument("--t5-path", type=str, default='checkpoints/t5-ckpt')
|
186 |
+
parser.add_argument("--t5-model-type", type=str, default='flan-t5-xl')
|
187 |
+
parser.add_argument("--t5-feature-max-len", type=int, default=120)
|
188 |
+
parser.add_argument("--t5-feature-dim", type=int, default=2048)
|
189 |
+
parser.add_argument("--no-left-padding", action='store_true', default=False)
|
190 |
+
parser.add_argument("--gpt-model", type=str, choices=list(GPT_models.keys()), default="GPT-XL")
|
191 |
+
parser.add_argument("--gpt-ckpt", type=str, default=None)
|
192 |
+
parser.add_argument("--gpt-type", type=str, choices=['c2i', 't2i'], default="t2i", help="class->image or text->image")
|
193 |
+
parser.add_argument("--cls-token-num", type=int, default=120, help="max token number of condition input")
|
194 |
+
parser.add_argument("--precision", type=str, default='bf16', choices=["none", "fp16", "bf16"])
|
195 |
+
parser.add_argument("--compile", action='store_true', default=False)
|
196 |
+
parser.add_argument("--vq-model", type=str, choices=list(VQ_models.keys()), default="VQ-16")
|
197 |
+
parser.add_argument("--vq-ckpt", type=str, default=None, help="ckpt path for vq model")
|
198 |
+
parser.add_argument("--codebook-size", type=int, default=16384, help="codebook size for vector quantization")
|
199 |
+
parser.add_argument("--codebook-embed-dim", type=int, default=8, help="codebook dimension for vector quantization")
|
200 |
+
parser.add_argument("--image-size", type=int, choices=[256, 320, 384, 400, 448, 512, 576, 640, 704, 768], default=768)
|
201 |
+
parser.add_argument("--image-H", type=int, default=512)
|
202 |
+
parser.add_argument("--image-W", type=int, default=512)
|
203 |
+
parser.add_argument("--downsample-size", type=int, choices=[8, 16], default=16)
|
204 |
+
parser.add_argument("--cfg-scale", type=float, default=4)
|
205 |
+
parser.add_argument("--seed", type=int, default=0)
|
206 |
+
parser.add_argument("--top-k", type=int, default=2000, help="top-k value to sample with")
|
207 |
+
parser.add_argument("--temperature", type=float, default=1.0, help="temperature value to sample with")
|
208 |
+
parser.add_argument("--top-p", type=float, default=1.0, help="top-p value to sample with")
|
209 |
+
|
210 |
+
parser.add_argument("--mixed-precision", type=str, default='bf16', choices=["none", "fp16", "bf16"])
|
211 |
+
parser.add_argument("--condition-type", type=str, choices=['seg', 'canny', 'hed', 'lineart', 'depth'], default="canny")
|
212 |
+
parser.add_argument("--prompt", type=str, default='a high-quality image')
|
213 |
+
parser.add_argument("--condition-path", type=str, default='condition/example/t2i/multigen/landscape.png')
|
214 |
+
args = parser.parse_args()
|
215 |
+
main(args)
|
autoregressive/sample/sample_t2i_MR.py
ADDED
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
3 |
+
torch.backends.cudnn.allow_tf32 = True
|
4 |
+
torch.set_float32_matmul_precision('high')
|
5 |
+
setattr(torch.nn.Linear, 'reset_parameters', lambda self: None) # disable default parameter init for faster speed
|
6 |
+
setattr(torch.nn.LayerNorm, 'reset_parameters', lambda self: None) # disable default parameter init for faster speed
|
7 |
+
from torchvision.utils import save_image
|
8 |
+
|
9 |
+
import os
|
10 |
+
import sys
|
11 |
+
current_directory = os.getcwd()
|
12 |
+
sys.path.append(current_directory)
|
13 |
+
import time
|
14 |
+
import argparse
|
15 |
+
from tokenizer.tokenizer_image.vq_model import VQ_models
|
16 |
+
from language.t5 import T5Embedder
|
17 |
+
from autoregressive.models.gpt_t2i import GPT_models
|
18 |
+
from autoregressive.models.generate import generate
|
19 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
20 |
+
from dataset.t2i_control import build_t2i_control_code
|
21 |
+
from accelerate import Accelerator
|
22 |
+
from dataset.build import build_dataset
|
23 |
+
from pathlib import Path
|
24 |
+
from accelerate.utils import ProjectConfiguration, set_seed
|
25 |
+
import torch.nn.functional as F
|
26 |
+
from condition.canny import CannyDetector
|
27 |
+
from condition.hed import HEDdetector
|
28 |
+
import numpy as np
|
29 |
+
from PIL import Image
|
30 |
+
from condition.lineart import LineArt
|
31 |
+
import cv2
|
32 |
+
from transformers import DPTImageProcessor, DPTForDepthEstimation
|
33 |
+
from condition.midas.depth import MidasDetector
|
34 |
+
|
35 |
+
|
36 |
+
def resize_image_to_16_multiple(image_path, condition_type='seg'):
|
37 |
+
image = Image.open(image_path)
|
38 |
+
width, height = image.size
|
39 |
+
|
40 |
+
if condition_type == 'depth': # The depth model requires a side length that is a multiple of 32
|
41 |
+
new_width = (width + 31) // 32 * 32
|
42 |
+
new_height = (height + 31) // 32 * 32
|
43 |
+
else:
|
44 |
+
new_width = (width + 15) // 16 * 16
|
45 |
+
new_height = (height + 15) // 16 * 16
|
46 |
+
|
47 |
+
resized_image = image.resize((new_width, new_height))
|
48 |
+
return resized_image
|
49 |
+
|
50 |
+
def main(args):
|
51 |
+
# Setup PyTorch:
|
52 |
+
torch.manual_seed(args.seed)
|
53 |
+
torch.backends.cudnn.deterministic = True
|
54 |
+
torch.backends.cudnn.benchmark = False
|
55 |
+
torch.set_grad_enabled(False)
|
56 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
57 |
+
|
58 |
+
# create and load model
|
59 |
+
vq_model = VQ_models[args.vq_model](
|
60 |
+
codebook_size=args.codebook_size,
|
61 |
+
codebook_embed_dim=args.codebook_embed_dim)
|
62 |
+
vq_model.to(device)
|
63 |
+
vq_model.eval()
|
64 |
+
checkpoint = torch.load(args.vq_ckpt, map_location="cpu")
|
65 |
+
vq_model.load_state_dict(checkpoint["model"])
|
66 |
+
del checkpoint
|
67 |
+
print(f"image tokenizer is loaded")
|
68 |
+
|
69 |
+
# create and load gpt model
|
70 |
+
precision = {'none': torch.float32, 'bf16': torch.bfloat16, 'fp16': torch.float16}[args.precision]
|
71 |
+
latent_size = args.image_size // args.downsample_size
|
72 |
+
gpt_model = GPT_models[args.gpt_model](
|
73 |
+
block_size=latent_size ** 2,
|
74 |
+
cls_token_num=args.cls_token_num,
|
75 |
+
model_type=args.gpt_type,
|
76 |
+
condition_type=args.condition_type,
|
77 |
+
).to(device=device, dtype=precision)
|
78 |
+
|
79 |
+
_, file_extension = os.path.splitext(args.gpt_ckpt)
|
80 |
+
if file_extension.lower() == '.safetensors':
|
81 |
+
from safetensors.torch import load_file
|
82 |
+
model_weight = load_file(args.gpt_ckpt)
|
83 |
+
gpt_model.load_state_dict(model_weight, strict=False)
|
84 |
+
gpt_model.eval()
|
85 |
+
else:
|
86 |
+
checkpoint = torch.load(args.gpt_ckpt, map_location="cpu")
|
87 |
+
if "model" in checkpoint: # ddp
|
88 |
+
model_weight = checkpoint["model"]
|
89 |
+
elif "module" in checkpoint: # deepspeed
|
90 |
+
model_weight = checkpoint["module"]
|
91 |
+
elif "state_dict" in checkpoint:
|
92 |
+
model_weight = checkpoint["state_dict"]
|
93 |
+
else:
|
94 |
+
raise Exception("please check model weight")
|
95 |
+
gpt_model.load_state_dict(model_weight, strict=False)
|
96 |
+
gpt_model.eval()
|
97 |
+
del checkpoint
|
98 |
+
print(f"gpt model is loaded")
|
99 |
+
|
100 |
+
if args.compile:
|
101 |
+
print(f"compiling the model...")
|
102 |
+
gpt_model = torch.compile(
|
103 |
+
gpt_model,
|
104 |
+
mode="reduce-overhead",
|
105 |
+
fullgraph=True
|
106 |
+
) # requires PyTorch 2.0 (optional)
|
107 |
+
else:
|
108 |
+
print(f"no need to compile model in demo")
|
109 |
+
|
110 |
+
assert os.path.exists(args.t5_path)
|
111 |
+
t5_model = T5Embedder(
|
112 |
+
device=device,
|
113 |
+
local_cache=True,
|
114 |
+
cache_dir=args.t5_path,
|
115 |
+
dir_or_name=args.t5_model_type,
|
116 |
+
torch_dtype=precision,
|
117 |
+
model_max_length=args.t5_feature_max_len,
|
118 |
+
)
|
119 |
+
|
120 |
+
|
121 |
+
if args.condition_type == 'canny':
|
122 |
+
get_control = CannyDetector()
|
123 |
+
elif args.condition_type == 'hed':
|
124 |
+
get_control = HEDdetector().to(device).eval()
|
125 |
+
elif args.condition_type == 'lineart':
|
126 |
+
get_control = LineArt()
|
127 |
+
get_control.load_state_dict(torch.load('condition/ckpts/model.pth', map_location=torch.device('cpu')))
|
128 |
+
get_control.to(device)
|
129 |
+
elif args.condition_type == 'depth':
|
130 |
+
processor = DPTImageProcessor.from_pretrained("condition/ckpts/dpt_large")
|
131 |
+
model_large = DPTForDepthEstimation.from_pretrained("condition/ckpts/dpt_large").to(device)
|
132 |
+
model = MidasDetector(device=device)
|
133 |
+
with torch.no_grad():
|
134 |
+
|
135 |
+
condition_img = resize_image_to_16_multiple(args.condition_path, args.condition_type)
|
136 |
+
W, H = condition_img.size
|
137 |
+
print(H,W)
|
138 |
+
if args.condition_type == 'seg':
|
139 |
+
condition_img = torch.from_numpy(np.array(condition_img))
|
140 |
+
condition_img = condition_img.permute(2,0,1).unsqueeze(0).repeat(2,1,1,1)
|
141 |
+
elif args.condition_type == 'canny':
|
142 |
+
condition_img = get_control(np.array(condition_img))
|
143 |
+
condition_img = torch.from_numpy(condition_img[None,None,...]).repeat(2,3,1,1)
|
144 |
+
elif args.condition_type == 'hed':
|
145 |
+
condition_img = get_control(torch.from_numpy(np.array(condition_img)).permute(2,0,1).unsqueeze(0).to(device))
|
146 |
+
condition_img = condition_img.unsqueeze(1).repeat(2,3,1,1)
|
147 |
+
elif args.condition_type == 'lineart':
|
148 |
+
condition_img = get_control(torch.from_numpy(np.array(condition_img)).permute(2,0,1).unsqueeze(0).to(device).float())
|
149 |
+
condition_img = condition_img.repeat(2,3,1,1) * 255
|
150 |
+
elif args.condition_type == 'depth':
|
151 |
+
images = condition_img
|
152 |
+
if H == W:
|
153 |
+
inputs = processor(images=images, return_tensors="pt", size=(H,W)).to(device)
|
154 |
+
outputs = model_large(**inputs)
|
155 |
+
condition_img = outputs.predicted_depth
|
156 |
+
condition_img = (condition_img * 255 / condition_img.max())
|
157 |
+
else:
|
158 |
+
condition_img = torch.from_numpy(model(torch.from_numpy(np.array(condition_img)).to(device))).unsqueeze(0)
|
159 |
+
condition_img = condition_img.unsqueeze(0).repeat(2,3,1,1)
|
160 |
+
condition_img = condition_img.to(device)
|
161 |
+
condition_img = 2*(condition_img/255 - 0.5)
|
162 |
+
prompts = [args.prompt if args.prompt is not None else "a high-quality image"]
|
163 |
+
prompts = prompts * 2
|
164 |
+
caption_embs, emb_masks = t5_model.get_text_embeddings(prompts)
|
165 |
+
|
166 |
+
if not args.no_left_padding:
|
167 |
+
print(f"processing left-padding...")
|
168 |
+
# a naive way to implement left-padding
|
169 |
+
new_emb_masks = torch.flip(emb_masks, dims=[-1])
|
170 |
+
new_caption_embs = []
|
171 |
+
for idx, (caption_emb, emb_mask) in enumerate(zip(caption_embs, emb_masks)):
|
172 |
+
valid_num = int(emb_mask.sum().item())
|
173 |
+
print(f' prompt {idx} token len: {valid_num}')
|
174 |
+
new_caption_emb = torch.cat([caption_emb[valid_num:],caption_emb[:valid_num]])
|
175 |
+
new_caption_embs.append(new_caption_emb)
|
176 |
+
new_caption_embs = torch.stack(new_caption_embs)
|
177 |
+
else:
|
178 |
+
new_caption_embs, new_emb_masks = caption_embs, emb_masks
|
179 |
+
c_indices = new_caption_embs * new_emb_masks[:,:, None]
|
180 |
+
c_emb_masks = new_emb_masks
|
181 |
+
qzshape = [len(c_indices), args.codebook_embed_dim, H//args.downsample_size, W//args.downsample_size]
|
182 |
+
t1 = time.time()
|
183 |
+
index_sample = generate(
|
184 |
+
gpt_model, c_indices, (H//args.downsample_size)*(W//args.downsample_size),#latent_size ** 2,
|
185 |
+
c_emb_masks, condition=condition_img.to(precision),
|
186 |
+
cfg_scale=args.cfg_scale,
|
187 |
+
temperature=args.temperature, top_k=args.top_k,
|
188 |
+
top_p=args.top_p, sample_logits=True,
|
189 |
+
)
|
190 |
+
sampling_time = time.time() - t1
|
191 |
+
print(f"Full sampling takes about {sampling_time:.2f} seconds.")
|
192 |
+
|
193 |
+
t2 = time.time()
|
194 |
+
print(index_sample.shape)
|
195 |
+
samples = vq_model.decode_code(index_sample, qzshape) # output value is between [-1, 1]
|
196 |
+
decoder_time = time.time() - t2
|
197 |
+
print(f"decoder takes about {decoder_time:.2f} seconds.")
|
198 |
+
|
199 |
+
samples = torch.cat((condition_img[0:1], samples), dim=0)
|
200 |
+
save_image(samples, f"sample/example/sample_t2i_MR_{args.condition_type}.png", nrow=4, normalize=True, value_range=(-1, 1))
|
201 |
+
print(f"image is saved to sample/example/sample_t2i_MR_{args.condition_type}.png")
|
202 |
+
print(prompts)
|
203 |
+
|
204 |
+
|
205 |
+
if __name__ == "__main__":
|
206 |
+
parser = argparse.ArgumentParser()
|
207 |
+
parser.add_argument("--t5-path", type=str, default='checkpoints/t5-ckpt')
|
208 |
+
parser.add_argument("--t5-model-type", type=str, default='flan-t5-xl')
|
209 |
+
parser.add_argument("--t5-feature-max-len", type=int, default=120)
|
210 |
+
parser.add_argument("--t5-feature-dim", type=int, default=2048)
|
211 |
+
parser.add_argument("--no-left-padding", action='store_true', default=False)
|
212 |
+
parser.add_argument("--gpt-model", type=str, choices=list(GPT_models.keys()), default="GPT-XL")
|
213 |
+
parser.add_argument("--gpt-ckpt", type=str, default=None)
|
214 |
+
parser.add_argument("--gpt-type", type=str, choices=['c2i', 't2i'], default="t2i", help="class->image or text->image")
|
215 |
+
parser.add_argument("--cls-token-num", type=int, default=120, help="max token number of condition input")
|
216 |
+
parser.add_argument("--precision", type=str, default='bf16', choices=["none", "fp16", "bf16"])
|
217 |
+
parser.add_argument("--compile", action='store_true', default=False)
|
218 |
+
parser.add_argument("--vq-model", type=str, choices=list(VQ_models.keys()), default="VQ-16")
|
219 |
+
parser.add_argument("--vq-ckpt", type=str, default=None, help="ckpt path for vq model")
|
220 |
+
parser.add_argument("--codebook-size", type=int, default=16384, help="codebook size for vector quantization")
|
221 |
+
parser.add_argument("--codebook-embed-dim", type=int, default=8, help="codebook dimension for vector quantization")
|
222 |
+
parser.add_argument("--image-size", type=int, choices=[256, 320, 384, 400, 448, 512, 576, 640, 704, 768], default=768)
|
223 |
+
parser.add_argument("--image-H", type=int, default=512)
|
224 |
+
parser.add_argument("--image-W", type=int, default=512)
|
225 |
+
parser.add_argument("--downsample-size", type=int, choices=[8, 16], default=16)
|
226 |
+
parser.add_argument("--cfg-scale", type=float, default=4)
|
227 |
+
parser.add_argument("--seed", type=int, default=0)
|
228 |
+
parser.add_argument("--top-k", type=int, default=2000, help="top-k value to sample with")
|
229 |
+
parser.add_argument("--temperature", type=float, default=1.0, help="temperature value to sample with")
|
230 |
+
parser.add_argument("--top-p", type=float, default=1.0, help="top-p value to sample with")
|
231 |
+
|
232 |
+
parser.add_argument("--mixed-precision", type=str, default='bf16', choices=["none", "fp16", "bf16"])
|
233 |
+
parser.add_argument("--condition-type", type=str, choices=['seg', 'canny', 'hed', 'lineart', 'depth'], default="canny")
|
234 |
+
parser.add_argument("--prompt", type=str, default='a high-quality image')
|
235 |
+
parser.add_argument("--condition-path", type=str, default='condition/example/t2i/multigen/landscape.png')
|
236 |
+
args = parser.parse_args()
|
237 |
+
main(args)
|
autoregressive/sample/sample_t2i_ddp.py
ADDED
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
3 |
+
torch.backends.cudnn.allow_tf32 = True
|
4 |
+
torch.set_float32_matmul_precision('high')
|
5 |
+
setattr(torch.nn.Linear, 'reset_parameters', lambda self: None) # disable default parameter init for faster speed
|
6 |
+
setattr(torch.nn.LayerNorm, 'reset_parameters', lambda self: None) # disable default parameter init for faster speed
|
7 |
+
import torch.nn.functional as F
|
8 |
+
import torch.distributed as dist
|
9 |
+
|
10 |
+
import os
|
11 |
+
import math
|
12 |
+
import json
|
13 |
+
import argparse
|
14 |
+
import pandas as pd
|
15 |
+
from tqdm import tqdm
|
16 |
+
from PIL import Image
|
17 |
+
|
18 |
+
from tokenizer.tokenizer_image.vq_model import VQ_models
|
19 |
+
from language.t5 import T5Embedder
|
20 |
+
from autoregressive.models.gpt import GPT_models
|
21 |
+
from autoregressive.models.generate import generate
|
22 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
23 |
+
|
24 |
+
|
25 |
+
|
26 |
+
def main(args):
|
27 |
+
# Setup PyTorch:
|
28 |
+
assert torch.cuda.is_available(), "Sampling with DDP requires at least one GPU. sample.py supports CPU-only usage"
|
29 |
+
torch.set_grad_enabled(False)
|
30 |
+
|
31 |
+
# Setup DDP:
|
32 |
+
dist.init_process_group("nccl")
|
33 |
+
rank = dist.get_rank()
|
34 |
+
device = rank % torch.cuda.device_count()
|
35 |
+
seed = args.global_seed * dist.get_world_size() + rank
|
36 |
+
torch.manual_seed(seed)
|
37 |
+
torch.cuda.set_device(device)
|
38 |
+
print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.")
|
39 |
+
|
40 |
+
# create and load model
|
41 |
+
vq_model = VQ_models[args.vq_model](
|
42 |
+
codebook_size=args.codebook_size,
|
43 |
+
codebook_embed_dim=args.codebook_embed_dim)
|
44 |
+
vq_model.to(device)
|
45 |
+
vq_model.eval()
|
46 |
+
checkpoint = torch.load(args.vq_ckpt, map_location="cpu")
|
47 |
+
vq_model.load_state_dict(checkpoint["model"])
|
48 |
+
del checkpoint
|
49 |
+
print(f"image tokenizer is loaded")
|
50 |
+
|
51 |
+
# create and load gpt model
|
52 |
+
precision = {'none': torch.float32, 'bf16': torch.bfloat16, 'fp16': torch.float16}[args.precision]
|
53 |
+
latent_size = args.image_size // args.downsample_size
|
54 |
+
gpt_model = GPT_models[args.gpt_model](
|
55 |
+
block_size=latent_size ** 2,
|
56 |
+
cls_token_num=args.cls_token_num,
|
57 |
+
model_type=args.gpt_type,
|
58 |
+
).to(device=device, dtype=precision)
|
59 |
+
|
60 |
+
checkpoint = torch.load(args.gpt_ckpt, map_location="cpu")
|
61 |
+
|
62 |
+
if "model" in checkpoint: # ddp
|
63 |
+
model_weight = checkpoint["model"]
|
64 |
+
elif "module" in checkpoint: # deepspeed
|
65 |
+
model_weight = checkpoint["module"]
|
66 |
+
elif "state_dict" in checkpoint:
|
67 |
+
model_weight = checkpoint["state_dict"]
|
68 |
+
else:
|
69 |
+
raise Exception("please check model weight")
|
70 |
+
gpt_model.load_state_dict(model_weight, strict=False)
|
71 |
+
gpt_model.eval()
|
72 |
+
del checkpoint
|
73 |
+
print(f"gpt model is loaded")
|
74 |
+
|
75 |
+
if args.compile:
|
76 |
+
print(f"compiling the model...")
|
77 |
+
gpt_model = torch.compile(
|
78 |
+
gpt_model,
|
79 |
+
mode="reduce-overhead",
|
80 |
+
fullgraph=True
|
81 |
+
) # requires PyTorch 2.0 (optional)
|
82 |
+
else:
|
83 |
+
print(f"no need to compile model in demo")
|
84 |
+
|
85 |
+
assert os.path.exists(args.t5_path)
|
86 |
+
t5_model = T5Embedder(
|
87 |
+
device=device,
|
88 |
+
local_cache=True,
|
89 |
+
cache_dir=args.t5_path,
|
90 |
+
dir_or_name=args.t5_model_type,
|
91 |
+
torch_dtype=precision,
|
92 |
+
model_max_length=args.t5_feature_max_len,
|
93 |
+
)
|
94 |
+
print(f"t5 model is loaded")
|
95 |
+
|
96 |
+
# Create folder to save samples:
|
97 |
+
model_string_name = args.gpt_model.replace("/", "-")
|
98 |
+
ckpt_string_name = os.path.basename(args.gpt_ckpt).replace(".pth", "").replace(".pt", "")
|
99 |
+
prompt_name = args.prompt_csv.split('/')[-1].split('.')[0].lower()
|
100 |
+
folder_name = f"{model_string_name}-{ckpt_string_name}-{prompt_name}-size-{args.image_size}-size-{args.image_size}-{args.vq_model}-" \
|
101 |
+
f"topk-{args.top_k}-topp-{args.top_p}-temperature-{args.temperature}-" \
|
102 |
+
f"cfg-{args.cfg_scale}-seed-{args.global_seed}"
|
103 |
+
sample_folder_dir = f"{args.sample_dir}/{folder_name}"
|
104 |
+
if rank == 0:
|
105 |
+
os.makedirs(f"{sample_folder_dir}/images", exist_ok=True)
|
106 |
+
print(f"Saving .png samples at {sample_folder_dir}/images")
|
107 |
+
dist.barrier()
|
108 |
+
|
109 |
+
df = pd.read_csv(args.prompt_csv, delimiter='\t')
|
110 |
+
prompt_list = df['Prompt'].tolist()
|
111 |
+
|
112 |
+
# Figure out how many samples we need to generate on each GPU and how many iterations we need to run:
|
113 |
+
n = args.per_proc_batch_size
|
114 |
+
global_batch_size = n * dist.get_world_size()
|
115 |
+
num_fid_samples = min(args.num_fid_samples, len(prompt_list))
|
116 |
+
# To make things evenly-divisible, we'll sample a bit more than we need and then discard the extra samples:
|
117 |
+
total_samples = int(math.ceil(num_fid_samples / global_batch_size) * global_batch_size)
|
118 |
+
if rank == 0:
|
119 |
+
print(f"Total number of images that will be sampled: {total_samples}")
|
120 |
+
assert total_samples % dist.get_world_size() == 0, "total_samples must be divisible by world_size"
|
121 |
+
samples_needed_this_gpu = int(total_samples // dist.get_world_size())
|
122 |
+
assert samples_needed_this_gpu % n == 0, "samples_needed_this_gpu must be divisible by the per-GPU batch size"
|
123 |
+
iterations = int(samples_needed_this_gpu // n)
|
124 |
+
pbar = range(iterations)
|
125 |
+
pbar = tqdm(pbar) if rank == 0 else pbar
|
126 |
+
total = 0
|
127 |
+
for _ in pbar:
|
128 |
+
# Select text prompt
|
129 |
+
prompt_batch = []
|
130 |
+
for i in range(n):
|
131 |
+
index = i * dist.get_world_size() + rank + total
|
132 |
+
prompt_batch.append(prompt_list[index] if index < len(prompt_list) else "a cute dog")
|
133 |
+
|
134 |
+
# Sample inputs:
|
135 |
+
caption_embs, emb_masks = t5_model.get_text_embeddings(prompt_batch)
|
136 |
+
|
137 |
+
if not args.no_left_padding:
|
138 |
+
new_emb_masks = torch.flip(emb_masks, dims=[-1])
|
139 |
+
new_caption_embs = []
|
140 |
+
for idx, (caption_emb, emb_mask) in enumerate(zip(caption_embs, emb_masks)):
|
141 |
+
valid_num = int(emb_mask.sum().item())
|
142 |
+
# prompt_cur = prompt_batch[idx]
|
143 |
+
# print(f' prompt {idx} token len: {valid_num} : {prompt_cur}')
|
144 |
+
new_caption_emb = torch.cat([caption_emb[valid_num:], caption_emb[:valid_num]])
|
145 |
+
new_caption_embs.append(new_caption_emb)
|
146 |
+
new_caption_embs = torch.stack(new_caption_embs)
|
147 |
+
|
148 |
+
else:
|
149 |
+
new_caption_embs, new_emb_masks = caption_embs, emb_masks
|
150 |
+
|
151 |
+
c_indices = new_caption_embs * new_emb_masks[:,:, None]
|
152 |
+
c_emb_masks = new_emb_masks
|
153 |
+
|
154 |
+
qzshape = [len(c_indices), args.codebook_embed_dim, latent_size, latent_size]
|
155 |
+
index_sample = generate(
|
156 |
+
gpt_model, c_indices, latent_size ** 2,
|
157 |
+
c_emb_masks,
|
158 |
+
cfg_scale=args.cfg_scale,
|
159 |
+
temperature=args.temperature, top_k=args.top_k,
|
160 |
+
top_p=args.top_p, sample_logits=True,
|
161 |
+
)
|
162 |
+
|
163 |
+
samples = vq_model.decode_code(index_sample, qzshape) # output value is between [-1, 1]
|
164 |
+
samples = torch.clamp(127.5 * samples + 128.0, 0, 255).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy()
|
165 |
+
|
166 |
+
# Save samples to disk as individual .png files
|
167 |
+
for i, sample in enumerate(samples):
|
168 |
+
index = i * dist.get_world_size() + rank + total
|
169 |
+
Image.fromarray(sample).save(f"{sample_folder_dir}/images/{index:06d}.png")
|
170 |
+
total += global_batch_size
|
171 |
+
|
172 |
+
# Make sure all processes have finished saving their samples before attempting to convert to .npz
|
173 |
+
dist.barrier()
|
174 |
+
if rank == 0:
|
175 |
+
# Save infer result in a jsonl file
|
176 |
+
json_items = []
|
177 |
+
for idx, prompt in enumerate(prompt_list):
|
178 |
+
image_path = os.path.join(sample_folder_dir, "images", f"{idx:06d}.png")
|
179 |
+
json_items.append({"text": prompt, "image_path": image_path})
|
180 |
+
res_jsonl_path = os.path.join(sample_folder_dir, "result.jsonl")
|
181 |
+
print(f"Save jsonl to {res_jsonl_path}...")
|
182 |
+
with open(res_jsonl_path, "w") as f:
|
183 |
+
for item in json_items:
|
184 |
+
f.write(json.dumps(item) + "\n")
|
185 |
+
|
186 |
+
# Save captions to txt
|
187 |
+
caption_path = os.path.join(sample_folder_dir, "captions.txt")
|
188 |
+
print(f"Save captions to {caption_path}...")
|
189 |
+
with open(caption_path, "w") as f:
|
190 |
+
for item in prompt_list:
|
191 |
+
f.write(f"{item}\n")
|
192 |
+
print("Done.")
|
193 |
+
|
194 |
+
dist.barrier()
|
195 |
+
dist.destroy_process_group()
|
196 |
+
|
197 |
+
|
198 |
+
|
199 |
+
if __name__ == "__main__":
|
200 |
+
parser = argparse.ArgumentParser()
|
201 |
+
parser.add_argument("--prompt-csv", type=str, default='evaluations/t2i/PartiPrompts.tsv')
|
202 |
+
parser.add_argument("--t5-path", type=str, default='pretrained_models/t5-ckpt')
|
203 |
+
parser.add_argument("--t5-model-type", type=str, default='flan-t5-xl')
|
204 |
+
parser.add_argument("--t5-feature-max-len", type=int, default=120)
|
205 |
+
parser.add_argument("--t5-feature-dim", type=int, default=2048)
|
206 |
+
parser.add_argument("--no-left-padding", action='store_true', default=False)
|
207 |
+
parser.add_argument("--gpt-model", type=str, choices=list(GPT_models.keys()), default="GPT-XL")
|
208 |
+
parser.add_argument("--gpt-ckpt", type=str, default=None)
|
209 |
+
parser.add_argument("--gpt-type", type=str, choices=['c2i', 't2i'], default="t2i", help="class->image or text->image")
|
210 |
+
parser.add_argument("--cls-token-num", type=int, default=120, help="max token number of condition input")
|
211 |
+
parser.add_argument("--precision", type=str, default='bf16', choices=["none", "fp16", "bf16"])
|
212 |
+
parser.add_argument("--compile", action='store_true', default=False)
|
213 |
+
parser.add_argument("--vq-model", type=str, choices=list(VQ_models.keys()), default="VQ-16")
|
214 |
+
parser.add_argument("--vq-ckpt", type=str, default=None, help="ckpt path for vq model")
|
215 |
+
parser.add_argument("--codebook-size", type=int, default=16384, help="codebook size for vector quantization")
|
216 |
+
parser.add_argument("--codebook-embed-dim", type=int, default=8, help="codebook dimension for vector quantization")
|
217 |
+
parser.add_argument("--image-size", type=int, choices=[256, 384, 512], default=512)
|
218 |
+
parser.add_argument("--downsample-size", type=int, choices=[8, 16], default=16)
|
219 |
+
parser.add_argument("--num-classes", type=int, default=1000)
|
220 |
+
parser.add_argument("--cfg-scale", type=float, default=7.5)
|
221 |
+
parser.add_argument("--sample-dir", type=str, default="samples_parti", help="samples_coco or samples_parti")
|
222 |
+
parser.add_argument("--per-proc-batch-size", type=int, default=32)
|
223 |
+
parser.add_argument("--num-fid-samples", type=int, default=30000)
|
224 |
+
parser.add_argument("--global-seed", type=int, default=0)
|
225 |
+
parser.add_argument("--top-k", type=int, default=1000, help="top-k value to sample with")
|
226 |
+
parser.add_argument("--temperature", type=float, default=1.0, help="temperature value to sample with")
|
227 |
+
parser.add_argument("--top-p", type=float, default=1.0, help="top-p value to sample with")
|
228 |
+
args = parser.parse_args()
|
229 |
+
main(args)
|
checkpoints/vq_ds16_t2i.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0e21fc1318e2e9ee641a07bdad0e20675e9ec35e6e3eb911d58b5d7a2cd8d4cb
|
3 |
+
size 287920306
|
condition/README.md
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Prepare the preprocessing model
|
2 |
+
|
3 |
+
Hed: https://huggingface.co/lllyasviel/Annotators/blob/main/ControlNetHED.pth\
|
4 |
+
Lineart: https://huggingface.co/spaces/awacke1/Image-to-Line-Drawings/resolve/main/model.pth\
|
5 |
+
depth: https://huggingface.co/lllyasviel/Annotators/blob/main/dpt_hybrid-midas-501f0c75.pt (hybrid for inference)\
|
6 |
+
https://huggingface.co/Intel/dpt-large (large for test conditional consistency and fid)\
|
7 |
+
|
8 |
+
We recommend storing them in the following paths
|
9 |
+
|
10 |
+
|---condition
|
11 |
+
|---ckpts
|
12 |
+
|---dpt_large
|
13 |
+
|---config.json
|
14 |
+
|---preprocessor_config.json
|
15 |
+
|---pytorch_model.bin
|
16 |
+
|---ControlNetHED.pth
|
17 |
+
|---dpt_hybrid-midas-501f0c75.pt
|
18 |
+
|---model.pth
|
19 |
+
|---example
|
20 |
+
|---midas
|
21 |
+
.
|
22 |
+
.
|
23 |
+
.
|
condition/canny.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
|
6 |
+
class CannyDetector:
|
7 |
+
def __call__(self, img, low_threshold=100, high_threshold=200):
|
8 |
+
"""
|
9 |
+
input: array or tensor (H,W,3)
|
10 |
+
output: array (H,W)
|
11 |
+
"""
|
12 |
+
if torch.is_tensor(img):
|
13 |
+
img = img.cpu().detach().numpy().astype(np.uint8)
|
14 |
+
return cv2.Canny(img, low_threshold, high_threshold)
|
15 |
+
|
16 |
+
|
17 |
+
if __name__ == '__main__':
|
18 |
+
apply_canny = CannyDetector()
|
19 |
+
img = cv2.imread('condition/dragon_resize.png')
|
20 |
+
import numpy as np
|
21 |
+
print(img.max())
|
22 |
+
detected_map = apply_canny(img, 100, 200)
|
23 |
+
print(detected_map.shape, detected_map.max(), detected_map.min())
|
24 |
+
cv2.imwrite('condition/example_canny.jpg', detected_map)
|
25 |
+
np.save('condition/example_canny.npy', detected_map[None,None])
|
condition/depth.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from controlnet_aux import LineartDetector
|
2 |
+
import torch
|
3 |
+
import cv2
|
4 |
+
import numpy as np
|
5 |
+
from transformers import DPTImageProcessor, DPTForDepthEstimation
|
6 |
+
class Depth:
|
7 |
+
def __init__(self, device):
|
8 |
+
self.model = DPTForDepthEstimation.from_pretrained("condition/ckpts/dpt_large")
|
9 |
+
|
10 |
+
def __call__(self, input_image):
|
11 |
+
"""
|
12 |
+
input: tensor()
|
13 |
+
"""
|
14 |
+
control_image = self.model(input_image)
|
15 |
+
return np.array(control_image)
|
16 |
+
|
17 |
+
if __name__ == '__main__':
|
18 |
+
import matplotlib.pyplot as plt
|
19 |
+
from tqdm import tqdm
|
20 |
+
from transformers import DPTImageProcessor, DPTForDepthEstimation
|
21 |
+
from PIL import Image
|
22 |
+
|
23 |
+
image = Image.open('condition/example/t2i/depth/depth.png')
|
24 |
+
img = cv2.imread('condition/example/t2i/depth/depth.png')
|
25 |
+
processor = DPTImageProcessor.from_pretrained("condition/ckpts/dpt_large")
|
26 |
+
model = DPTForDepthEstimation.from_pretrained("condition/ckpts/dpt_large")
|
27 |
+
|
28 |
+
inputs = torch.from_numpy(np.array(img)).permute(2,0,1).unsqueeze(0).float()#
|
29 |
+
inputs = 2*(inputs/255 - 0.5)
|
30 |
+
inputs = processor(images=image, return_tensors="pt", size=(512,512))
|
31 |
+
print(inputs)
|
32 |
+
with torch.no_grad():
|
33 |
+
outputs = model(**inputs)
|
34 |
+
predicted_depth = outputs.predicted_depth
|
35 |
+
print(predicted_depth.shape)
|
36 |
+
prediction = torch.nn.functional.interpolate(
|
37 |
+
predicted_depth.unsqueeze(1),
|
38 |
+
size=image.size[::-1],
|
39 |
+
mode="bicubic",
|
40 |
+
align_corners=False,
|
41 |
+
)
|
42 |
+
|
43 |
+
output = prediction.squeeze().cpu().numpy()
|
44 |
+
formatted = (output * 255 / np.max(output)).astype("uint8")
|
45 |
+
|
46 |
+
depth = Image.fromarray(formatted)
|
47 |
+
depth.save('condition/example/t2i/depth/example_depth.jpg')
|
condition/example/t2i/multi_resolution/bird.jpg
ADDED
condition/example/t2i/multi_resolution/car.jpg
ADDED
condition/example/t2i/multigen/doll.jpg
ADDED
condition/example/t2i/multigen/girl.jpg
ADDED
condition/example/t2i/multigen/house.jpg
ADDED
condition/example/t2i/multigen/sofa.png
ADDED
condition/hed.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This is an improved version and model of HED edge detection with Apache License, Version 2.0.
|
2 |
+
# Please use this implementation in your products
|
3 |
+
# This implementation may produce slightly different results from Saining Xie's official implementations,
|
4 |
+
# but it generates smoother edges and is more suitable for ControlNet as well as other image-to-image translations.
|
5 |
+
# Different from official models and other implementations, this is an RGB-input model (rather than BGR)
|
6 |
+
# and in this way it works better for gradio's RGB protocol
|
7 |
+
|
8 |
+
import os
|
9 |
+
import cv2
|
10 |
+
import torch
|
11 |
+
import numpy as np
|
12 |
+
from torch.nn.parallel import DataParallel
|
13 |
+
from einops import rearrange
|
14 |
+
from condition.utils import annotator_ckpts_path
|
15 |
+
import torch.nn.functional as F
|
16 |
+
|
17 |
+
class DoubleConvBlock(torch.nn.Module):
|
18 |
+
def __init__(self, input_channel, output_channel, layer_number):
|
19 |
+
super().__init__()
|
20 |
+
self.convs = torch.nn.Sequential()
|
21 |
+
self.convs.append(torch.nn.Conv2d(in_channels=input_channel, out_channels=output_channel, kernel_size=(3, 3), stride=(1, 1), padding=1))
|
22 |
+
for i in range(1, layer_number):
|
23 |
+
self.convs.append(torch.nn.Conv2d(in_channels=output_channel, out_channels=output_channel, kernel_size=(3, 3), stride=(1, 1), padding=1))
|
24 |
+
self.projection = torch.nn.Conv2d(in_channels=output_channel, out_channels=1, kernel_size=(1, 1), stride=(1, 1), padding=0)
|
25 |
+
|
26 |
+
def __call__(self, x, down_sampling=False):
|
27 |
+
h = x
|
28 |
+
if down_sampling:
|
29 |
+
h = torch.nn.functional.max_pool2d(h, kernel_size=(2, 2), stride=(2, 2))
|
30 |
+
for conv in self.convs:
|
31 |
+
h = conv(h)
|
32 |
+
h = torch.nn.functional.relu(h)
|
33 |
+
return h, self.projection(h)
|
34 |
+
|
35 |
+
|
36 |
+
class ControlNetHED_Apache2(torch.nn.Module):
|
37 |
+
def __init__(self):
|
38 |
+
super().__init__()
|
39 |
+
self.norm = torch.nn.Parameter(torch.zeros(size=(1, 3, 1, 1)))
|
40 |
+
self.block1 = DoubleConvBlock(input_channel=3, output_channel=64, layer_number=2)
|
41 |
+
self.block2 = DoubleConvBlock(input_channel=64, output_channel=128, layer_number=2)
|
42 |
+
self.block3 = DoubleConvBlock(input_channel=128, output_channel=256, layer_number=3)
|
43 |
+
self.block4 = DoubleConvBlock(input_channel=256, output_channel=512, layer_number=3)
|
44 |
+
self.block5 = DoubleConvBlock(input_channel=512, output_channel=512, layer_number=3)
|
45 |
+
|
46 |
+
def __call__(self, x):
|
47 |
+
h = x - self.norm
|
48 |
+
h, projection1 = self.block1(h)
|
49 |
+
h, projection2 = self.block2(h, down_sampling=True)
|
50 |
+
h, projection3 = self.block3(h, down_sampling=True)
|
51 |
+
h, projection4 = self.block4(h, down_sampling=True)
|
52 |
+
h, projection5 = self.block5(h, down_sampling=True)
|
53 |
+
return projection1, projection2, projection3, projection4, projection5
|
54 |
+
|
55 |
+
|
56 |
+
class HEDdetector(torch.nn.Module):
|
57 |
+
def __init__(self):
|
58 |
+
super().__init__()
|
59 |
+
remote_model_path = "https://huggingface.co/lllyasviel/Annotators/resolve/main/ControlNetHED.pth"
|
60 |
+
modelpath = os.path.join(annotator_ckpts_path, "ControlNetHED.pth")
|
61 |
+
if not os.path.exists(modelpath):
|
62 |
+
from basicsr.utils.download_util import load_file_from_url
|
63 |
+
load_file_from_url(remote_model_path, model_dir=annotator_ckpts_path)
|
64 |
+
self.netNetwork = ControlNetHED_Apache2().float()#.to(self.device).eval()
|
65 |
+
self.netNetwork.load_state_dict(torch.load(modelpath))
|
66 |
+
|
67 |
+
def __call__(self, input_image):
|
68 |
+
"""
|
69 |
+
input: tensor (B,C,H,W)
|
70 |
+
output: tensor (B,H,W)
|
71 |
+
"""
|
72 |
+
B, C, H, W = input_image.shape
|
73 |
+
image_hed = input_image
|
74 |
+
|
75 |
+
edges = self.netNetwork(image_hed)
|
76 |
+
edges = [F.interpolate(e, size=(H, W), mode='bilinear', align_corners=False).squeeze(1) for e in edges]
|
77 |
+
edges = torch.stack(edges, dim=1)
|
78 |
+
edge = 1 / (1 + torch.exp(-torch.mean(edges, dim=1)))
|
79 |
+
edge = (edge * 255.0).clamp(0, 255)
|
80 |
+
|
81 |
+
return edge
|
82 |
+
|
83 |
+
|
84 |
+
def nms(x, t, s):
|
85 |
+
x = cv2.GaussianBlur(x.astype(np.float32), (0, 0), s)
|
86 |
+
|
87 |
+
f1 = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]], dtype=np.uint8)
|
88 |
+
f2 = np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=np.uint8)
|
89 |
+
f3 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.uint8)
|
90 |
+
f4 = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=np.uint8)
|
91 |
+
|
92 |
+
y = np.zeros_like(x)
|
93 |
+
|
94 |
+
for f in [f1, f2, f3, f4]:
|
95 |
+
np.putmask(y, cv2.dilate(x, kernel=f) == x, x)
|
96 |
+
|
97 |
+
z = np.zeros_like(y, dtype=np.uint8)
|
98 |
+
z[y > t] = 255
|
99 |
+
return z
|
100 |
+
|
101 |
+
if __name__ == '__main__':
|
102 |
+
import matplotlib.pyplot as plt
|
103 |
+
from tqdm import tqdm
|
104 |
+
import torch.nn.functional as F
|
105 |
+
device = torch.device('cuda')
|
106 |
+
apply_hed = HEDdetector().to(device).eval()
|
107 |
+
img = cv2.imread('condition/dragon_1024_512.jpg')
|
108 |
+
H,W = img.shape[:2]
|
109 |
+
resize_img = cv2.resize(img,(512,1024))
|
110 |
+
detected_map = apply_hed(torch.from_numpy(img).permute(2,0,1).unsqueeze(0).cuda())
|
111 |
+
resize_detected_map = apply_hed(torch.from_numpy(resize_img).permute(2,0,1).unsqueeze(0).cuda())
|
112 |
+
cv2.imwrite('condition/example_hed_resize.jpg', resize_detected_map[0].cpu().detach().numpy())
|
113 |
+
resize_detected_map = F.interpolate(resize_detected_map.unsqueeze(0).to(torch.float32), size=(H,W), mode='bilinear', align_corners=False, antialias=True)
|
114 |
+
print(abs(detected_map - resize_detected_map).sum())
|
115 |
+
print(img.shape, img.max(),img.min(),detected_map.shape, detected_map.max(),detected_map.min())
|
116 |
+
cv2.imwrite('condition/example_hed.jpg', detected_map[0].cpu().detach().numpy())
|
117 |
+
cv2.imwrite('condition/example_hed_resized.jpg', resize_detected_map[0,0].cpu().detach().numpy())
|
condition/lineart.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from controlnet_aux import LineartDetector
|
2 |
+
import torch
|
3 |
+
import cv2
|
4 |
+
import numpy as np
|
5 |
+
import torch.nn as nn
|
6 |
+
|
7 |
+
|
8 |
+
norm_layer = nn.InstanceNorm2d
|
9 |
+
class ResidualBlock(nn.Module):
|
10 |
+
def __init__(self, in_features):
|
11 |
+
super(ResidualBlock, self).__init__()
|
12 |
+
|
13 |
+
conv_block = [ nn.ReflectionPad2d(1),
|
14 |
+
nn.Conv2d(in_features, in_features, 3),
|
15 |
+
norm_layer(in_features),
|
16 |
+
nn.ReLU(inplace=True),
|
17 |
+
nn.ReflectionPad2d(1),
|
18 |
+
nn.Conv2d(in_features, in_features, 3),
|
19 |
+
norm_layer(in_features)
|
20 |
+
]
|
21 |
+
|
22 |
+
self.conv_block = nn.Sequential(*conv_block)
|
23 |
+
|
24 |
+
def forward(self, x):
|
25 |
+
return x + self.conv_block(x)
|
26 |
+
class LineArt(nn.Module):
|
27 |
+
def __init__(self, input_nc=3, output_nc=1, n_residual_blocks=3, sigmoid=True):
|
28 |
+
super(LineArt, self).__init__()
|
29 |
+
|
30 |
+
# Initial convolution block
|
31 |
+
model0 = [ nn.ReflectionPad2d(3),
|
32 |
+
nn.Conv2d(input_nc, 64, 7),
|
33 |
+
norm_layer(64),
|
34 |
+
nn.ReLU(inplace=True) ]
|
35 |
+
self.model0 = nn.Sequential(*model0)
|
36 |
+
|
37 |
+
# Downsampling
|
38 |
+
model1 = []
|
39 |
+
in_features = 64
|
40 |
+
out_features = in_features*2
|
41 |
+
for _ in range(2):
|
42 |
+
model1 += [ nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
|
43 |
+
norm_layer(out_features),
|
44 |
+
nn.ReLU(inplace=True) ]
|
45 |
+
in_features = out_features
|
46 |
+
out_features = in_features*2
|
47 |
+
self.model1 = nn.Sequential(*model1)
|
48 |
+
|
49 |
+
model2 = []
|
50 |
+
# Residual blocks
|
51 |
+
for _ in range(n_residual_blocks):
|
52 |
+
model2 += [ResidualBlock(in_features)]
|
53 |
+
self.model2 = nn.Sequential(*model2)
|
54 |
+
|
55 |
+
# Upsampling
|
56 |
+
model3 = []
|
57 |
+
out_features = in_features//2
|
58 |
+
for _ in range(2):
|
59 |
+
model3 += [ nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),
|
60 |
+
norm_layer(out_features),
|
61 |
+
nn.ReLU(inplace=True) ]
|
62 |
+
in_features = out_features
|
63 |
+
out_features = in_features//2
|
64 |
+
self.model3 = nn.Sequential(*model3)
|
65 |
+
|
66 |
+
# Output layer
|
67 |
+
model4 = [ nn.ReflectionPad2d(3),
|
68 |
+
nn.Conv2d(64, output_nc, 7)]
|
69 |
+
if sigmoid:
|
70 |
+
model4 += [nn.Sigmoid()]
|
71 |
+
|
72 |
+
self.model4 = nn.Sequential(*model4)
|
73 |
+
|
74 |
+
def forward(self, x, cond=None):
|
75 |
+
"""
|
76 |
+
input: tensor (B,C,H,W)
|
77 |
+
output: tensor (B,1,H,W) 0~1
|
78 |
+
"""
|
79 |
+
|
80 |
+
out = self.model0(x)
|
81 |
+
out = self.model1(out)
|
82 |
+
out = self.model2(out)
|
83 |
+
out = self.model3(out)
|
84 |
+
out = self.model4(out)
|
85 |
+
|
86 |
+
return out
|
87 |
+
|
88 |
+
|
89 |
+
if __name__ == '__main__':
|
90 |
+
import matplotlib.pyplot as plt
|
91 |
+
from tqdm import tqdm
|
92 |
+
apply_lineart = LineArt()
|
93 |
+
apply_lineart.load_state_dict(torch.load('condition/ckpts/model.pth', map_location=torch.device('cpu')))
|
94 |
+
img = cv2.imread('condition/car_448_768.jpg')
|
95 |
+
img = torch.from_numpy(img).permute(2,0,1).unsqueeze(0).repeat(8,1,1,1).float()
|
96 |
+
detected_map = apply_lineart(img)
|
97 |
+
print(img.shape, img.max(),img.min(),detected_map.shape, detected_map.max(),detected_map.min())
|
98 |
+
cv2.imwrite('condition/example_lineart.jpg', 255*detected_map[0,0].cpu().detach().numpy())
|
condition/midas/depth.py
ADDED
@@ -0,0 +1,223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Midas Depth Estimation
|
2 |
+
# From https://github.com/isl-org/MiDaS
|
3 |
+
# MIT LICENSE
|
4 |
+
|
5 |
+
import cv2
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
import os
|
9 |
+
import sys
|
10 |
+
current_directory = os.getcwd()
|
11 |
+
sys.path.append(current_directory)
|
12 |
+
from einops import rearrange
|
13 |
+
# from .api import MiDaSInference
|
14 |
+
from condition.utils import annotator_ckpts_path
|
15 |
+
from condition.midas.midas.dpt_depth import DPTDepthModel
|
16 |
+
from condition.midas.midas.midas_net import MidasNet
|
17 |
+
from condition.midas.midas.midas_net_custom import MidasNet_small
|
18 |
+
from condition.midas.midas.transforms import Resize, NormalizeImage, PrepareForNet
|
19 |
+
import os
|
20 |
+
import torch.nn as nn
|
21 |
+
from torchvision.transforms import Compose
|
22 |
+
|
23 |
+
ISL_PATHS = {
|
24 |
+
"dpt_large": os.path.join(annotator_ckpts_path, "dpt_large-midas-2f21e586.pt"),
|
25 |
+
"dpt_hybrid": os.path.join(annotator_ckpts_path, "dpt_hybrid-midas-501f0c75.pt"),
|
26 |
+
"midas_v21": "",
|
27 |
+
"midas_v21_small": "",
|
28 |
+
}
|
29 |
+
|
30 |
+
remote_model_path = "https://huggingface.co/lllyasviel/ControlNet/resolve/main/annotator/ckpts/dpt_hybrid-midas-501f0c75.pt"
|
31 |
+
|
32 |
+
|
33 |
+
def disabled_train(self, mode=True):
|
34 |
+
"""Overwrite model.train with this function to make sure train/eval mode
|
35 |
+
does not change anymore."""
|
36 |
+
return self
|
37 |
+
|
38 |
+
|
39 |
+
def load_midas_transform(model_type):
|
40 |
+
# https://github.com/isl-org/MiDaS/blob/master/run.py
|
41 |
+
# load transform only
|
42 |
+
if model_type == "dpt_large": # DPT-Large
|
43 |
+
net_w, net_h = 384, 384
|
44 |
+
resize_mode = "minimal"
|
45 |
+
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
46 |
+
|
47 |
+
elif model_type == "dpt_hybrid": # DPT-Hybrid
|
48 |
+
net_w, net_h = 384, 384
|
49 |
+
resize_mode = "minimal"
|
50 |
+
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
51 |
+
|
52 |
+
elif model_type == "midas_v21":
|
53 |
+
net_w, net_h = 384, 384
|
54 |
+
resize_mode = "upper_bound"
|
55 |
+
normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
56 |
+
|
57 |
+
elif model_type == "midas_v21_small":
|
58 |
+
net_w, net_h = 256, 256
|
59 |
+
resize_mode = "upper_bound"
|
60 |
+
normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
61 |
+
|
62 |
+
else:
|
63 |
+
assert False, f"model_type '{model_type}' not implemented, use: --model_type large"
|
64 |
+
|
65 |
+
transform = Compose(
|
66 |
+
[
|
67 |
+
Resize(
|
68 |
+
net_w,
|
69 |
+
net_h,
|
70 |
+
resize_target=None,
|
71 |
+
keep_aspect_ratio=True,
|
72 |
+
ensure_multiple_of=32,
|
73 |
+
resize_method=resize_mode,
|
74 |
+
image_interpolation_method=cv2.INTER_CUBIC,
|
75 |
+
),
|
76 |
+
normalization,
|
77 |
+
PrepareForNet(),
|
78 |
+
]
|
79 |
+
)
|
80 |
+
|
81 |
+
return transform
|
82 |
+
|
83 |
+
|
84 |
+
def load_model(model_type):
|
85 |
+
# https://github.com/isl-org/MiDaS/blob/master/run.py
|
86 |
+
# load network
|
87 |
+
model_path = ISL_PATHS[model_type]
|
88 |
+
if model_type == "dpt_large": # DPT-Large
|
89 |
+
model = DPTDepthModel(
|
90 |
+
path=model_path,
|
91 |
+
backbone="vitl16_384",
|
92 |
+
non_negative=True,
|
93 |
+
)
|
94 |
+
net_w, net_h = 384, 384
|
95 |
+
resize_mode = "minimal"
|
96 |
+
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
97 |
+
|
98 |
+
elif model_type == "dpt_hybrid": # DPT-Hybrid
|
99 |
+
if not os.path.exists(model_path):
|
100 |
+
from basicsr.utils.download_util import load_file_from_url
|
101 |
+
load_file_from_url(remote_model_path, model_dir=annotator_ckpts_path)
|
102 |
+
|
103 |
+
model = DPTDepthModel(
|
104 |
+
path=model_path,
|
105 |
+
backbone="vitb_rn50_384",
|
106 |
+
non_negative=True,
|
107 |
+
)
|
108 |
+
net_w, net_h = 384, 384
|
109 |
+
resize_mode = "minimal"
|
110 |
+
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
111 |
+
|
112 |
+
elif model_type == "midas_v21":
|
113 |
+
model = MidasNet(model_path, non_negative=True)
|
114 |
+
net_w, net_h = 384, 384
|
115 |
+
resize_mode = "upper_bound"
|
116 |
+
normalization = NormalizeImage(
|
117 |
+
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
118 |
+
)
|
119 |
+
|
120 |
+
elif model_type == "midas_v21_small":
|
121 |
+
model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True,
|
122 |
+
non_negative=True, blocks={'expand': True})
|
123 |
+
net_w, net_h = 256, 256
|
124 |
+
resize_mode = "upper_bound"
|
125 |
+
normalization = NormalizeImage(
|
126 |
+
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
127 |
+
)
|
128 |
+
|
129 |
+
else:
|
130 |
+
print(f"model_type '{model_type}' not implemented, use: --model_type large")
|
131 |
+
assert False
|
132 |
+
|
133 |
+
transform = Compose(
|
134 |
+
[
|
135 |
+
Resize(
|
136 |
+
net_w,
|
137 |
+
net_h,
|
138 |
+
resize_target=None,
|
139 |
+
keep_aspect_ratio=True,
|
140 |
+
ensure_multiple_of=32,
|
141 |
+
resize_method=resize_mode,
|
142 |
+
image_interpolation_method=cv2.INTER_CUBIC,
|
143 |
+
),
|
144 |
+
normalization,
|
145 |
+
PrepareForNet(),
|
146 |
+
]
|
147 |
+
)
|
148 |
+
|
149 |
+
return model.eval(), transform
|
150 |
+
|
151 |
+
|
152 |
+
class MiDaSInference(nn.Module):
|
153 |
+
MODEL_TYPES_TORCH_HUB = [
|
154 |
+
"DPT_Large",
|
155 |
+
"DPT_Hybrid",
|
156 |
+
"MiDaS_small"
|
157 |
+
]
|
158 |
+
MODEL_TYPES_ISL = [
|
159 |
+
"dpt_large",
|
160 |
+
"dpt_hybrid",
|
161 |
+
"midas_v21",
|
162 |
+
"midas_v21_small",
|
163 |
+
]
|
164 |
+
|
165 |
+
def __init__(self, model_type):
|
166 |
+
super().__init__()
|
167 |
+
assert (model_type in self.MODEL_TYPES_ISL)
|
168 |
+
model, _ = load_model(model_type)
|
169 |
+
self.model = model
|
170 |
+
self.model.train = disabled_train
|
171 |
+
|
172 |
+
def forward(self, x):
|
173 |
+
with torch.no_grad():
|
174 |
+
prediction = self.model(x)
|
175 |
+
return prediction
|
176 |
+
|
177 |
+
|
178 |
+
class MidasDetector:
|
179 |
+
def __init__(self,device=torch.device('cuda:0'), model_type="dpt_hybrid"):
|
180 |
+
self.device = device
|
181 |
+
self.model = MiDaSInference(model_type=model_type).to(device)
|
182 |
+
|
183 |
+
def __call__(self, input_image, a=np.pi * 2.0, bg_th=0.1):
|
184 |
+
assert input_image.ndim == 3
|
185 |
+
image_depth = input_image
|
186 |
+
with torch.no_grad():
|
187 |
+
image_depth = image_depth
|
188 |
+
image_depth = image_depth / 127.5 - 1.0
|
189 |
+
image_depth = rearrange(image_depth, 'h w c -> 1 c h w')
|
190 |
+
depth = self.model(image_depth)[0]
|
191 |
+
|
192 |
+
depth_pt = depth.clone()
|
193 |
+
depth_pt -= torch.min(depth_pt)
|
194 |
+
depth_pt /= torch.max(depth_pt)
|
195 |
+
depth_pt = depth_pt.cpu().numpy()
|
196 |
+
depth_image = (depth_pt * 255.0).clip(0, 255).astype(np.uint8)
|
197 |
+
|
198 |
+
depth_np = depth.cpu().numpy()
|
199 |
+
x = cv2.Sobel(depth_np, cv2.CV_32F, 1, 0, ksize=3)
|
200 |
+
y = cv2.Sobel(depth_np, cv2.CV_32F, 0, 1, ksize=3)
|
201 |
+
z = np.ones_like(x) * a
|
202 |
+
x[depth_pt < bg_th] = 0
|
203 |
+
y[depth_pt < bg_th] = 0
|
204 |
+
# normal = np.stack([x, y, z], axis=2)
|
205 |
+
# normal /= np.sum(normal ** 2.0, axis=2, keepdims=True) ** 0.5
|
206 |
+
# normal_image = (normal * 127.5 + 127.5).clip(0, 255).astype(np.uint8)
|
207 |
+
|
208 |
+
return depth_image#, normal_image
|
209 |
+
|
210 |
+
if __name__ == '__main__':
|
211 |
+
import matplotlib.pyplot as plt
|
212 |
+
from tqdm import tqdm
|
213 |
+
from PIL import Image
|
214 |
+
import torchvision.transforms.functional as F
|
215 |
+
apply_depth = MidasDetector(device=torch.device('cuda:0'))
|
216 |
+
img = cv2.imread('/data/vjuicefs_sz_cv_v2/11171709/ControlAR_github/condition/example/t2i/multi_resolution/car_1_448_768.jpg')
|
217 |
+
img = cv2.resize(img,(768,448))
|
218 |
+
detected_map = apply_depth(torch.from_numpy(img).cuda().float())
|
219 |
+
print(img.shape, img.max(),img.min(),detected_map.shape, detected_map.max(),detected_map.min())
|
220 |
+
plt.imshow(detected_map, cmap='gray')
|
221 |
+
plt.show()
|
222 |
+
cv2.imwrite('condition/example_depth.jpg', detected_map)
|
223 |
+
# cv2.imwrite('condition/example_normal.jpg', normal_map)
|
condition/midas/midas/__init__.py
ADDED
File without changes
|
condition/midas/midas/base_model.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
class BaseModel(torch.nn.Module):
|
5 |
+
def load(self, path):
|
6 |
+
"""Load model from file.
|
7 |
+
|
8 |
+
Args:
|
9 |
+
path (str): file path
|
10 |
+
"""
|
11 |
+
parameters = torch.load(path, map_location=torch.device('cpu'))
|
12 |
+
|
13 |
+
if "optimizer" in parameters:
|
14 |
+
parameters = parameters["model"]
|
15 |
+
|
16 |
+
self.load_state_dict(parameters)
|
condition/midas/midas/blocks.py
ADDED
@@ -0,0 +1,341 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from .vit import (
|
5 |
+
_make_pretrained_vitb_rn50_384,
|
6 |
+
_make_pretrained_vitl16_384,
|
7 |
+
_make_pretrained_vitb16_384,
|
8 |
+
forward_vit,
|
9 |
+
)
|
10 |
+
|
11 |
+
def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, use_vit_only=False, use_readout="ignore",):
|
12 |
+
if backbone == "vitl16_384":
|
13 |
+
pretrained = _make_pretrained_vitl16_384(
|
14 |
+
use_pretrained, hooks=hooks, use_readout=use_readout
|
15 |
+
)
|
16 |
+
scratch = _make_scratch(
|
17 |
+
[256, 512, 1024, 1024], features, groups=groups, expand=expand
|
18 |
+
) # ViT-L/16 - 85.0% Top1 (backbone)
|
19 |
+
elif backbone == "vitb_rn50_384":
|
20 |
+
pretrained = _make_pretrained_vitb_rn50_384(
|
21 |
+
use_pretrained,
|
22 |
+
hooks=hooks,
|
23 |
+
use_vit_only=use_vit_only,
|
24 |
+
use_readout=use_readout,
|
25 |
+
)
|
26 |
+
scratch = _make_scratch(
|
27 |
+
[256, 512, 768, 768], features, groups=groups, expand=expand
|
28 |
+
) # ViT-H/16 - 85.0% Top1 (backbone)
|
29 |
+
elif backbone == "vitb16_384":
|
30 |
+
pretrained = _make_pretrained_vitb16_384(
|
31 |
+
use_pretrained, hooks=hooks, use_readout=use_readout
|
32 |
+
)
|
33 |
+
scratch = _make_scratch(
|
34 |
+
[96, 192, 384, 768], features, groups=groups, expand=expand
|
35 |
+
) # ViT-B/16 - 84.6% Top1 (backbone)
|
36 |
+
elif backbone == "resnext101_wsl":
|
37 |
+
pretrained = _make_pretrained_resnext101_wsl(use_pretrained)
|
38 |
+
scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3
|
39 |
+
elif backbone == "efficientnet_lite3":
|
40 |
+
pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable)
|
41 |
+
scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3
|
42 |
+
else:
|
43 |
+
print(f"Backbone '{backbone}' not implemented")
|
44 |
+
assert False
|
45 |
+
|
46 |
+
return pretrained, scratch
|
47 |
+
|
48 |
+
|
49 |
+
def _make_scratch(in_shape, out_shape, groups=1, expand=False):
|
50 |
+
scratch = nn.Module()
|
51 |
+
|
52 |
+
out_shape1 = out_shape
|
53 |
+
out_shape2 = out_shape
|
54 |
+
out_shape3 = out_shape
|
55 |
+
out_shape4 = out_shape
|
56 |
+
if expand==True:
|
57 |
+
out_shape1 = out_shape
|
58 |
+
out_shape2 = out_shape*2
|
59 |
+
out_shape3 = out_shape*4
|
60 |
+
out_shape4 = out_shape*8
|
61 |
+
|
62 |
+
scratch.layer1_rn = nn.Conv2d(
|
63 |
+
in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
64 |
+
)
|
65 |
+
scratch.layer2_rn = nn.Conv2d(
|
66 |
+
in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
67 |
+
)
|
68 |
+
scratch.layer3_rn = nn.Conv2d(
|
69 |
+
in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
70 |
+
)
|
71 |
+
scratch.layer4_rn = nn.Conv2d(
|
72 |
+
in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
73 |
+
)
|
74 |
+
|
75 |
+
return scratch
|
76 |
+
|
77 |
+
|
78 |
+
def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False):
|
79 |
+
efficientnet = torch.hub.load(
|
80 |
+
"rwightman/gen-efficientnet-pytorch",
|
81 |
+
"tf_efficientnet_lite3",
|
82 |
+
pretrained=use_pretrained,
|
83 |
+
exportable=exportable
|
84 |
+
)
|
85 |
+
return _make_efficientnet_backbone(efficientnet)
|
86 |
+
|
87 |
+
|
88 |
+
def _make_efficientnet_backbone(effnet):
|
89 |
+
pretrained = nn.Module()
|
90 |
+
|
91 |
+
pretrained.layer1 = nn.Sequential(
|
92 |
+
effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2]
|
93 |
+
)
|
94 |
+
pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3])
|
95 |
+
pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5])
|
96 |
+
pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9])
|
97 |
+
|
98 |
+
return pretrained
|
99 |
+
|
100 |
+
|
101 |
+
def _make_resnet_backbone(resnet):
|
102 |
+
pretrained = nn.Module()
|
103 |
+
pretrained.layer1 = nn.Sequential(
|
104 |
+
resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1
|
105 |
+
)
|
106 |
+
|
107 |
+
pretrained.layer2 = resnet.layer2
|
108 |
+
pretrained.layer3 = resnet.layer3
|
109 |
+
pretrained.layer4 = resnet.layer4
|
110 |
+
|
111 |
+
return pretrained
|
112 |
+
|
113 |
+
|
114 |
+
def _make_pretrained_resnext101_wsl(use_pretrained):
|
115 |
+
resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl")
|
116 |
+
return _make_resnet_backbone(resnet)
|
117 |
+
|
118 |
+
|
119 |
+
|
120 |
+
class Interpolate(nn.Module):
|
121 |
+
"""Interpolation module.
|
122 |
+
"""
|
123 |
+
|
124 |
+
def __init__(self, scale_factor, mode, align_corners=False):
|
125 |
+
"""Init.
|
126 |
+
|
127 |
+
Args:
|
128 |
+
scale_factor (float): scaling
|
129 |
+
mode (str): interpolation mode
|
130 |
+
"""
|
131 |
+
super(Interpolate, self).__init__()
|
132 |
+
|
133 |
+
self.interp = nn.functional.interpolate
|
134 |
+
self.scale_factor = scale_factor
|
135 |
+
self.mode = mode
|
136 |
+
self.align_corners = align_corners
|
137 |
+
|
138 |
+
def forward(self, x):
|
139 |
+
"""Forward pass.
|
140 |
+
|
141 |
+
Args:
|
142 |
+
x (tensor): input
|
143 |
+
|
144 |
+
Returns:
|
145 |
+
tensor: interpolated data
|
146 |
+
"""
|
147 |
+
|
148 |
+
x = self.interp(
|
149 |
+
x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners
|
150 |
+
)
|
151 |
+
|
152 |
+
return x
|
153 |
+
|
154 |
+
|
155 |
+
class ResidualConvUnit(nn.Module):
|
156 |
+
"""Residual convolution module.
|
157 |
+
"""
|
158 |
+
|
159 |
+
def __init__(self, features):
|
160 |
+
"""Init.
|
161 |
+
|
162 |
+
Args:
|
163 |
+
features (int): number of features
|
164 |
+
"""
|
165 |
+
super().__init__()
|
166 |
+
|
167 |
+
self.conv1 = nn.Conv2d(
|
168 |
+
features, features, kernel_size=3, stride=1, padding=1, bias=True
|
169 |
+
)
|
170 |
+
|
171 |
+
self.conv2 = nn.Conv2d(
|
172 |
+
features, features, kernel_size=3, stride=1, padding=1, bias=True
|
173 |
+
)
|
174 |
+
|
175 |
+
self.relu = nn.ReLU(inplace=True)
|
176 |
+
|
177 |
+
def forward(self, x):
|
178 |
+
"""Forward pass.
|
179 |
+
|
180 |
+
Args:
|
181 |
+
x (tensor): input
|
182 |
+
|
183 |
+
Returns:
|
184 |
+
tensor: output
|
185 |
+
"""
|
186 |
+
out = self.relu(x)
|
187 |
+
out = self.conv1(out)
|
188 |
+
out = self.relu(out)
|
189 |
+
out = self.conv2(out)
|
190 |
+
|
191 |
+
return out + x
|
192 |
+
|
193 |
+
|
194 |
+
class FeatureFusionBlock(nn.Module):
|
195 |
+
"""Feature fusion block.
|
196 |
+
"""
|
197 |
+
|
198 |
+
def __init__(self, features):
|
199 |
+
"""Init.
|
200 |
+
|
201 |
+
Args:
|
202 |
+
features (int): number of features
|
203 |
+
"""
|
204 |
+
super(FeatureFusionBlock, self).__init__()
|
205 |
+
|
206 |
+
self.resConfUnit1 = ResidualConvUnit(features)
|
207 |
+
self.resConfUnit2 = ResidualConvUnit(features)
|
208 |
+
|
209 |
+
def forward(self, *xs):
|
210 |
+
"""Forward pass.
|
211 |
+
|
212 |
+
Returns:
|
213 |
+
tensor: output
|
214 |
+
"""
|
215 |
+
output = xs[0]
|
216 |
+
|
217 |
+
if len(xs) == 2:
|
218 |
+
output += self.resConfUnit1(xs[1])
|
219 |
+
|
220 |
+
output = self.resConfUnit2(output)
|
221 |
+
|
222 |
+
output = nn.functional.interpolate(
|
223 |
+
output, scale_factor=2, mode="bilinear", align_corners=True
|
224 |
+
)
|
225 |
+
|
226 |
+
return output
|
227 |
+
|
228 |
+
|
229 |
+
|
230 |
+
|
231 |
+
class ResidualConvUnit_custom(nn.Module):
|
232 |
+
"""Residual convolution module.
|
233 |
+
"""
|
234 |
+
|
235 |
+
def __init__(self, features, activation, bn):
|
236 |
+
"""Init.
|
237 |
+
|
238 |
+
Args:
|
239 |
+
features (int): number of features
|
240 |
+
"""
|
241 |
+
super().__init__()
|
242 |
+
|
243 |
+
self.bn = bn
|
244 |
+
|
245 |
+
self.groups=1
|
246 |
+
|
247 |
+
self.conv1 = nn.Conv2d(
|
248 |
+
features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
|
249 |
+
)
|
250 |
+
|
251 |
+
self.conv2 = nn.Conv2d(
|
252 |
+
features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
|
253 |
+
)
|
254 |
+
|
255 |
+
if self.bn==True:
|
256 |
+
self.bn1 = nn.BatchNorm2d(features)
|
257 |
+
self.bn2 = nn.BatchNorm2d(features)
|
258 |
+
|
259 |
+
self.activation = activation
|
260 |
+
|
261 |
+
self.skip_add = nn.quantized.FloatFunctional()
|
262 |
+
|
263 |
+
def forward(self, x):
|
264 |
+
"""Forward pass.
|
265 |
+
|
266 |
+
Args:
|
267 |
+
x (tensor): input
|
268 |
+
|
269 |
+
Returns:
|
270 |
+
tensor: output
|
271 |
+
"""
|
272 |
+
|
273 |
+
out = self.activation(x)
|
274 |
+
out = self.conv1(out)
|
275 |
+
if self.bn==True:
|
276 |
+
out = self.bn1(out)
|
277 |
+
|
278 |
+
out = self.activation(out)
|
279 |
+
out = self.conv2(out)
|
280 |
+
if self.bn==True:
|
281 |
+
out = self.bn2(out)
|
282 |
+
|
283 |
+
if self.groups > 1:
|
284 |
+
out = self.conv_merge(out)
|
285 |
+
|
286 |
+
return self.skip_add.add(out, x)
|
287 |
+
|
288 |
+
# return out + x
|
289 |
+
|
290 |
+
|
291 |
+
class FeatureFusionBlock_custom(nn.Module):
|
292 |
+
"""Feature fusion block.
|
293 |
+
"""
|
294 |
+
|
295 |
+
def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True):
|
296 |
+
"""Init.
|
297 |
+
|
298 |
+
Args:
|
299 |
+
features (int): number of features
|
300 |
+
"""
|
301 |
+
super(FeatureFusionBlock_custom, self).__init__()
|
302 |
+
|
303 |
+
self.deconv = deconv
|
304 |
+
self.align_corners = align_corners
|
305 |
+
|
306 |
+
self.groups=1
|
307 |
+
|
308 |
+
self.expand = expand
|
309 |
+
out_features = features
|
310 |
+
if self.expand==True:
|
311 |
+
out_features = features//2
|
312 |
+
|
313 |
+
self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
|
314 |
+
|
315 |
+
self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn)
|
316 |
+
self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn)
|
317 |
+
|
318 |
+
self.skip_add = nn.quantized.FloatFunctional()
|
319 |
+
|
320 |
+
def forward(self, *xs):
|
321 |
+
"""Forward pass.
|
322 |
+
|
323 |
+
Returns:
|
324 |
+
tensor: output
|
325 |
+
"""
|
326 |
+
output = xs[0]
|
327 |
+
|
328 |
+
if len(xs) == 2:
|
329 |
+
res = self.resConfUnit1(xs[1])
|
330 |
+
output = self.skip_add.add(output, res)
|
331 |
+
# output += res
|
332 |
+
|
333 |
+
output = self.resConfUnit2(output)
|
334 |
+
|
335 |
+
output = nn.functional.interpolate(
|
336 |
+
output, scale_factor=2, mode="bilinear", align_corners=self.align_corners
|
337 |
+
)
|
338 |
+
|
339 |
+
output = self.out_conv(output)
|
340 |
+
|
341 |
+
return output
|
condition/midas/midas/dpt_depth.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
from .base_model import BaseModel
|
6 |
+
from .blocks import (
|
7 |
+
FeatureFusionBlock,
|
8 |
+
FeatureFusionBlock_custom,
|
9 |
+
Interpolate,
|
10 |
+
_make_encoder,
|
11 |
+
forward_vit,
|
12 |
+
)
|
13 |
+
|
14 |
+
|
15 |
+
def _make_fusion_block(features, use_bn):
|
16 |
+
return FeatureFusionBlock_custom(
|
17 |
+
features,
|
18 |
+
nn.ReLU(False),
|
19 |
+
deconv=False,
|
20 |
+
bn=use_bn,
|
21 |
+
expand=False,
|
22 |
+
align_corners=True,
|
23 |
+
)
|
24 |
+
|
25 |
+
|
26 |
+
class DPT(BaseModel):
|
27 |
+
def __init__(
|
28 |
+
self,
|
29 |
+
head,
|
30 |
+
features=256,
|
31 |
+
backbone="vitb_rn50_384",
|
32 |
+
readout="project",
|
33 |
+
channels_last=False,
|
34 |
+
use_bn=False,
|
35 |
+
):
|
36 |
+
|
37 |
+
super(DPT, self).__init__()
|
38 |
+
|
39 |
+
self.channels_last = channels_last
|
40 |
+
|
41 |
+
hooks = {
|
42 |
+
"vitb_rn50_384": [0, 1, 8, 11],
|
43 |
+
"vitb16_384": [2, 5, 8, 11],
|
44 |
+
"vitl16_384": [5, 11, 17, 23],
|
45 |
+
}
|
46 |
+
|
47 |
+
# Instantiate backbone and reassemble blocks
|
48 |
+
self.pretrained, self.scratch = _make_encoder(
|
49 |
+
backbone,
|
50 |
+
features,
|
51 |
+
False, # Set to true of you want to train from scratch, uses ImageNet weights
|
52 |
+
groups=1,
|
53 |
+
expand=False,
|
54 |
+
exportable=False,
|
55 |
+
hooks=hooks[backbone],
|
56 |
+
use_readout=readout,
|
57 |
+
)
|
58 |
+
|
59 |
+
self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
|
60 |
+
self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
|
61 |
+
self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
|
62 |
+
self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
|
63 |
+
|
64 |
+
self.scratch.output_conv = head
|
65 |
+
|
66 |
+
|
67 |
+
def forward(self, x):
|
68 |
+
if self.channels_last == True:
|
69 |
+
x.contiguous(memory_format=torch.channels_last)
|
70 |
+
|
71 |
+
layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x)
|
72 |
+
|
73 |
+
layer_1_rn = self.scratch.layer1_rn(layer_1)
|
74 |
+
layer_2_rn = self.scratch.layer2_rn(layer_2)
|
75 |
+
layer_3_rn = self.scratch.layer3_rn(layer_3)
|
76 |
+
layer_4_rn = self.scratch.layer4_rn(layer_4)
|
77 |
+
|
78 |
+
path_4 = self.scratch.refinenet4(layer_4_rn)
|
79 |
+
path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
|
80 |
+
path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
|
81 |
+
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
|
82 |
+
|
83 |
+
out = self.scratch.output_conv(path_1)
|
84 |
+
|
85 |
+
return out
|
86 |
+
|
87 |
+
|
88 |
+
class DPTDepthModel(DPT):
|
89 |
+
def __init__(self, path=None, non_negative=True, **kwargs):
|
90 |
+
features = kwargs["features"] if "features" in kwargs else 256
|
91 |
+
|
92 |
+
head = nn.Sequential(
|
93 |
+
nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
|
94 |
+
Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
|
95 |
+
nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
|
96 |
+
nn.ReLU(True),
|
97 |
+
nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
|
98 |
+
nn.ReLU(True) if non_negative else nn.Identity(),
|
99 |
+
nn.Identity(),
|
100 |
+
)
|
101 |
+
|
102 |
+
super().__init__(head, **kwargs)
|
103 |
+
|
104 |
+
if path is not None:
|
105 |
+
self.load(path)
|
106 |
+
|
107 |
+
def forward(self, x):
|
108 |
+
return super().forward(x).squeeze(dim=1)
|
condition/midas/midas/midas_net.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""MidashNet: Network for monocular depth estimation trained by mixing several datasets.
|
2 |
+
This file contains code that is adapted from
|
3 |
+
https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
|
4 |
+
"""
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
|
8 |
+
from .base_model import BaseModel
|
9 |
+
from .blocks import FeatureFusionBlock, Interpolate, _make_encoder
|
10 |
+
|
11 |
+
|
12 |
+
class MidasNet(BaseModel):
|
13 |
+
"""Network for monocular depth estimation.
|
14 |
+
"""
|
15 |
+
|
16 |
+
def __init__(self, path=None, features=256, non_negative=True):
|
17 |
+
"""Init.
|
18 |
+
|
19 |
+
Args:
|
20 |
+
path (str, optional): Path to saved model. Defaults to None.
|
21 |
+
features (int, optional): Number of features. Defaults to 256.
|
22 |
+
backbone (str, optional): Backbone network for encoder. Defaults to resnet50
|
23 |
+
"""
|
24 |
+
print("Loading weights: ", path)
|
25 |
+
|
26 |
+
super(MidasNet, self).__init__()
|
27 |
+
|
28 |
+
use_pretrained = False if path is None else True
|
29 |
+
|
30 |
+
self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained)
|
31 |
+
|
32 |
+
self.scratch.refinenet4 = FeatureFusionBlock(features)
|
33 |
+
self.scratch.refinenet3 = FeatureFusionBlock(features)
|
34 |
+
self.scratch.refinenet2 = FeatureFusionBlock(features)
|
35 |
+
self.scratch.refinenet1 = FeatureFusionBlock(features)
|
36 |
+
|
37 |
+
self.scratch.output_conv = nn.Sequential(
|
38 |
+
nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1),
|
39 |
+
Interpolate(scale_factor=2, mode="bilinear"),
|
40 |
+
nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1),
|
41 |
+
nn.ReLU(True),
|
42 |
+
nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
|
43 |
+
nn.ReLU(True) if non_negative else nn.Identity(),
|
44 |
+
)
|
45 |
+
|
46 |
+
if path:
|
47 |
+
self.load(path)
|
48 |
+
|
49 |
+
def forward(self, x):
|
50 |
+
"""Forward pass.
|
51 |
+
|
52 |
+
Args:
|
53 |
+
x (tensor): input data (image)
|
54 |
+
|
55 |
+
Returns:
|
56 |
+
tensor: depth
|
57 |
+
"""
|
58 |
+
|
59 |
+
layer_1 = self.pretrained.layer1(x)
|
60 |
+
layer_2 = self.pretrained.layer2(layer_1)
|
61 |
+
layer_3 = self.pretrained.layer3(layer_2)
|
62 |
+
layer_4 = self.pretrained.layer4(layer_3)
|
63 |
+
|
64 |
+
layer_1_rn = self.scratch.layer1_rn(layer_1)
|
65 |
+
layer_2_rn = self.scratch.layer2_rn(layer_2)
|
66 |
+
layer_3_rn = self.scratch.layer3_rn(layer_3)
|
67 |
+
layer_4_rn = self.scratch.layer4_rn(layer_4)
|
68 |
+
|
69 |
+
path_4 = self.scratch.refinenet4(layer_4_rn)
|
70 |
+
path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
|
71 |
+
path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
|
72 |
+
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
|
73 |
+
|
74 |
+
out = self.scratch.output_conv(path_1)
|
75 |
+
|
76 |
+
return torch.squeeze(out, dim=1)
|
condition/midas/midas/midas_net_custom.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""MidashNet: Network for monocular depth estimation trained by mixing several datasets.
|
2 |
+
This file contains code that is adapted from
|
3 |
+
https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
|
4 |
+
"""
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
|
8 |
+
from .base_model import BaseModel
|
9 |
+
from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder
|
10 |
+
|
11 |
+
|
12 |
+
class MidasNet_small(BaseModel):
|
13 |
+
"""Network for monocular depth estimation.
|
14 |
+
"""
|
15 |
+
|
16 |
+
def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True,
|
17 |
+
blocks={'expand': True}):
|
18 |
+
"""Init.
|
19 |
+
|
20 |
+
Args:
|
21 |
+
path (str, optional): Path to saved model. Defaults to None.
|
22 |
+
features (int, optional): Number of features. Defaults to 256.
|
23 |
+
backbone (str, optional): Backbone network for encoder. Defaults to resnet50
|
24 |
+
"""
|
25 |
+
print("Loading weights: ", path)
|
26 |
+
|
27 |
+
super(MidasNet_small, self).__init__()
|
28 |
+
|
29 |
+
use_pretrained = False if path else True
|
30 |
+
|
31 |
+
self.channels_last = channels_last
|
32 |
+
self.blocks = blocks
|
33 |
+
self.backbone = backbone
|
34 |
+
|
35 |
+
self.groups = 1
|
36 |
+
|
37 |
+
features1=features
|
38 |
+
features2=features
|
39 |
+
features3=features
|
40 |
+
features4=features
|
41 |
+
self.expand = False
|
42 |
+
if "expand" in self.blocks and self.blocks['expand'] == True:
|
43 |
+
self.expand = True
|
44 |
+
features1=features
|
45 |
+
features2=features*2
|
46 |
+
features3=features*4
|
47 |
+
features4=features*8
|
48 |
+
|
49 |
+
self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable)
|
50 |
+
|
51 |
+
self.scratch.activation = nn.ReLU(False)
|
52 |
+
|
53 |
+
self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
|
54 |
+
self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
|
55 |
+
self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
|
56 |
+
self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners)
|
57 |
+
|
58 |
+
|
59 |
+
self.scratch.output_conv = nn.Sequential(
|
60 |
+
nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups),
|
61 |
+
Interpolate(scale_factor=2, mode="bilinear"),
|
62 |
+
nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1),
|
63 |
+
self.scratch.activation,
|
64 |
+
nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
|
65 |
+
nn.ReLU(True) if non_negative else nn.Identity(),
|
66 |
+
nn.Identity(),
|
67 |
+
)
|
68 |
+
|
69 |
+
if path:
|
70 |
+
self.load(path)
|
71 |
+
|
72 |
+
|
73 |
+
def forward(self, x):
|
74 |
+
"""Forward pass.
|
75 |
+
|
76 |
+
Args:
|
77 |
+
x (tensor): input data (image)
|
78 |
+
|
79 |
+
Returns:
|
80 |
+
tensor: depth
|
81 |
+
"""
|
82 |
+
if self.channels_last==True:
|
83 |
+
print("self.channels_last = ", self.channels_last)
|
84 |
+
x.contiguous(memory_format=torch.channels_last)
|
85 |
+
|
86 |
+
|
87 |
+
layer_1 = self.pretrained.layer1(x)
|
88 |
+
layer_2 = self.pretrained.layer2(layer_1)
|
89 |
+
layer_3 = self.pretrained.layer3(layer_2)
|
90 |
+
layer_4 = self.pretrained.layer4(layer_3)
|
91 |
+
|
92 |
+
layer_1_rn = self.scratch.layer1_rn(layer_1)
|
93 |
+
layer_2_rn = self.scratch.layer2_rn(layer_2)
|
94 |
+
layer_3_rn = self.scratch.layer3_rn(layer_3)
|
95 |
+
layer_4_rn = self.scratch.layer4_rn(layer_4)
|
96 |
+
|
97 |
+
|
98 |
+
path_4 = self.scratch.refinenet4(layer_4_rn)
|
99 |
+
path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
|
100 |
+
path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
|
101 |
+
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
|
102 |
+
|
103 |
+
out = self.scratch.output_conv(path_1)
|
104 |
+
|
105 |
+
return torch.squeeze(out, dim=1)
|
106 |
+
|
107 |
+
|
108 |
+
|
109 |
+
def fuse_model(m):
|
110 |
+
prev_previous_type = nn.Identity()
|
111 |
+
prev_previous_name = ''
|
112 |
+
previous_type = nn.Identity()
|
113 |
+
previous_name = ''
|
114 |
+
for name, module in m.named_modules():
|
115 |
+
if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU:
|
116 |
+
# print("FUSED ", prev_previous_name, previous_name, name)
|
117 |
+
torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True)
|
118 |
+
elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d:
|
119 |
+
# print("FUSED ", prev_previous_name, previous_name)
|
120 |
+
torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True)
|
121 |
+
# elif previous_type == nn.Conv2d and type(module) == nn.ReLU:
|
122 |
+
# print("FUSED ", previous_name, name)
|
123 |
+
# torch.quantization.fuse_modules(m, [previous_name, name], inplace=True)
|
124 |
+
|
125 |
+
prev_previous_type = previous_type
|
126 |
+
prev_previous_name = previous_name
|
127 |
+
previous_type = type(module)
|
128 |
+
previous_name = name
|
condition/midas/midas/transforms.py
ADDED
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import cv2
|
3 |
+
import math
|
4 |
+
|
5 |
+
|
6 |
+
def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA):
|
7 |
+
"""Rezise the sample to ensure the given size. Keeps aspect ratio.
|
8 |
+
|
9 |
+
Args:
|
10 |
+
sample (dict): sample
|
11 |
+
size (tuple): image size
|
12 |
+
|
13 |
+
Returns:
|
14 |
+
tuple: new size
|
15 |
+
"""
|
16 |
+
shape = list(sample["disparity"].shape)
|
17 |
+
|
18 |
+
if shape[0] >= size[0] and shape[1] >= size[1]:
|
19 |
+
return sample
|
20 |
+
|
21 |
+
scale = [0, 0]
|
22 |
+
scale[0] = size[0] / shape[0]
|
23 |
+
scale[1] = size[1] / shape[1]
|
24 |
+
|
25 |
+
scale = max(scale)
|
26 |
+
|
27 |
+
shape[0] = math.ceil(scale * shape[0])
|
28 |
+
shape[1] = math.ceil(scale * shape[1])
|
29 |
+
|
30 |
+
# resize
|
31 |
+
sample["image"] = cv2.resize(
|
32 |
+
sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method
|
33 |
+
)
|
34 |
+
|
35 |
+
sample["disparity"] = cv2.resize(
|
36 |
+
sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST
|
37 |
+
)
|
38 |
+
sample["mask"] = cv2.resize(
|
39 |
+
sample["mask"].astype(np.float32),
|
40 |
+
tuple(shape[::-1]),
|
41 |
+
interpolation=cv2.INTER_NEAREST,
|
42 |
+
)
|
43 |
+
sample["mask"] = sample["mask"].astype(bool)
|
44 |
+
|
45 |
+
return tuple(shape)
|
46 |
+
|
47 |
+
|
48 |
+
class Resize(object):
|
49 |
+
"""Resize sample to given size (width, height).
|
50 |
+
"""
|
51 |
+
|
52 |
+
def __init__(
|
53 |
+
self,
|
54 |
+
width,
|
55 |
+
height,
|
56 |
+
resize_target=True,
|
57 |
+
keep_aspect_ratio=False,
|
58 |
+
ensure_multiple_of=1,
|
59 |
+
resize_method="lower_bound",
|
60 |
+
image_interpolation_method=cv2.INTER_AREA,
|
61 |
+
):
|
62 |
+
"""Init.
|
63 |
+
|
64 |
+
Args:
|
65 |
+
width (int): desired output width
|
66 |
+
height (int): desired output height
|
67 |
+
resize_target (bool, optional):
|
68 |
+
True: Resize the full sample (image, mask, target).
|
69 |
+
False: Resize image only.
|
70 |
+
Defaults to True.
|
71 |
+
keep_aspect_ratio (bool, optional):
|
72 |
+
True: Keep the aspect ratio of the input sample.
|
73 |
+
Output sample might not have the given width and height, and
|
74 |
+
resize behaviour depends on the parameter 'resize_method'.
|
75 |
+
Defaults to False.
|
76 |
+
ensure_multiple_of (int, optional):
|
77 |
+
Output width and height is constrained to be multiple of this parameter.
|
78 |
+
Defaults to 1.
|
79 |
+
resize_method (str, optional):
|
80 |
+
"lower_bound": Output will be at least as large as the given size.
|
81 |
+
"upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
|
82 |
+
"minimal": Scale as least as possible. (Output size might be smaller than given size.)
|
83 |
+
Defaults to "lower_bound".
|
84 |
+
"""
|
85 |
+
self.__width = width
|
86 |
+
self.__height = height
|
87 |
+
|
88 |
+
self.__resize_target = resize_target
|
89 |
+
self.__keep_aspect_ratio = keep_aspect_ratio
|
90 |
+
self.__multiple_of = ensure_multiple_of
|
91 |
+
self.__resize_method = resize_method
|
92 |
+
self.__image_interpolation_method = image_interpolation_method
|
93 |
+
|
94 |
+
def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
|
95 |
+
y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
|
96 |
+
|
97 |
+
if max_val is not None and y > max_val:
|
98 |
+
y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
|
99 |
+
|
100 |
+
if y < min_val:
|
101 |
+
y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
|
102 |
+
|
103 |
+
return y
|
104 |
+
|
105 |
+
def get_size(self, width, height):
|
106 |
+
# determine new height and width
|
107 |
+
scale_height = self.__height / height
|
108 |
+
scale_width = self.__width / width
|
109 |
+
|
110 |
+
if self.__keep_aspect_ratio:
|
111 |
+
if self.__resize_method == "lower_bound":
|
112 |
+
# scale such that output size is lower bound
|
113 |
+
if scale_width > scale_height:
|
114 |
+
# fit width
|
115 |
+
scale_height = scale_width
|
116 |
+
else:
|
117 |
+
# fit height
|
118 |
+
scale_width = scale_height
|
119 |
+
elif self.__resize_method == "upper_bound":
|
120 |
+
# scale such that output size is upper bound
|
121 |
+
if scale_width < scale_height:
|
122 |
+
# fit width
|
123 |
+
scale_height = scale_width
|
124 |
+
else:
|
125 |
+
# fit height
|
126 |
+
scale_width = scale_height
|
127 |
+
elif self.__resize_method == "minimal":
|
128 |
+
# scale as least as possbile
|
129 |
+
if abs(1 - scale_width) < abs(1 - scale_height):
|
130 |
+
# fit width
|
131 |
+
scale_height = scale_width
|
132 |
+
else:
|
133 |
+
# fit height
|
134 |
+
scale_width = scale_height
|
135 |
+
else:
|
136 |
+
raise ValueError(
|
137 |
+
f"resize_method {self.__resize_method} not implemented"
|
138 |
+
)
|
139 |
+
|
140 |
+
if self.__resize_method == "lower_bound":
|
141 |
+
new_height = self.constrain_to_multiple_of(
|
142 |
+
scale_height * height, min_val=self.__height
|
143 |
+
)
|
144 |
+
new_width = self.constrain_to_multiple_of(
|
145 |
+
scale_width * width, min_val=self.__width
|
146 |
+
)
|
147 |
+
elif self.__resize_method == "upper_bound":
|
148 |
+
new_height = self.constrain_to_multiple_of(
|
149 |
+
scale_height * height, max_val=self.__height
|
150 |
+
)
|
151 |
+
new_width = self.constrain_to_multiple_of(
|
152 |
+
scale_width * width, max_val=self.__width
|
153 |
+
)
|
154 |
+
elif self.__resize_method == "minimal":
|
155 |
+
new_height = self.constrain_to_multiple_of(scale_height * height)
|
156 |
+
new_width = self.constrain_to_multiple_of(scale_width * width)
|
157 |
+
else:
|
158 |
+
raise ValueError(f"resize_method {self.__resize_method} not implemented")
|
159 |
+
|
160 |
+
return (new_width, new_height)
|
161 |
+
|
162 |
+
def __call__(self, sample):
|
163 |
+
width, height = self.get_size(
|
164 |
+
sample["image"].shape[1], sample["image"].shape[0]
|
165 |
+
)
|
166 |
+
|
167 |
+
# resize sample
|
168 |
+
sample["image"] = cv2.resize(
|
169 |
+
sample["image"],
|
170 |
+
(width, height),
|
171 |
+
interpolation=self.__image_interpolation_method,
|
172 |
+
)
|
173 |
+
|
174 |
+
if self.__resize_target:
|
175 |
+
if "disparity" in sample:
|
176 |
+
sample["disparity"] = cv2.resize(
|
177 |
+
sample["disparity"],
|
178 |
+
(width, height),
|
179 |
+
interpolation=cv2.INTER_NEAREST,
|
180 |
+
)
|
181 |
+
|
182 |
+
if "depth" in sample:
|
183 |
+
sample["depth"] = cv2.resize(
|
184 |
+
sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST
|
185 |
+
)
|
186 |
+
|
187 |
+
sample["mask"] = cv2.resize(
|
188 |
+
sample["mask"].astype(np.float32),
|
189 |
+
(width, height),
|
190 |
+
interpolation=cv2.INTER_NEAREST,
|
191 |
+
)
|
192 |
+
sample["mask"] = sample["mask"].astype(bool)
|
193 |
+
|
194 |
+
return sample
|
195 |
+
|
196 |
+
|
197 |
+
class NormalizeImage(object):
|
198 |
+
"""Normlize image by given mean and std.
|
199 |
+
"""
|
200 |
+
|
201 |
+
def __init__(self, mean, std):
|
202 |
+
self.__mean = mean
|
203 |
+
self.__std = std
|
204 |
+
|
205 |
+
def __call__(self, sample):
|
206 |
+
sample["image"] = (sample["image"] - self.__mean) / self.__std
|
207 |
+
|
208 |
+
return sample
|
209 |
+
|
210 |
+
|
211 |
+
class PrepareForNet(object):
|
212 |
+
"""Prepare sample for usage as network input.
|
213 |
+
"""
|
214 |
+
|
215 |
+
def __init__(self):
|
216 |
+
pass
|
217 |
+
|
218 |
+
def __call__(self, sample):
|
219 |
+
image = np.transpose(sample["image"], (2, 0, 1))
|
220 |
+
sample["image"] = np.ascontiguousarray(image).astype(np.float32)
|
221 |
+
|
222 |
+
if "mask" in sample:
|
223 |
+
sample["mask"] = sample["mask"].astype(np.float32)
|
224 |
+
sample["mask"] = np.ascontiguousarray(sample["mask"])
|
225 |
+
|
226 |
+
if "disparity" in sample:
|
227 |
+
disparity = sample["disparity"].astype(np.float32)
|
228 |
+
sample["disparity"] = np.ascontiguousarray(disparity)
|
229 |
+
|
230 |
+
if "depth" in sample:
|
231 |
+
depth = sample["depth"].astype(np.float32)
|
232 |
+
sample["depth"] = np.ascontiguousarray(depth)
|
233 |
+
|
234 |
+
return sample
|
condition/midas/midas/vit.py
ADDED
@@ -0,0 +1,491 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import timm
|
4 |
+
import types
|
5 |
+
import math
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
|
9 |
+
class Slice(nn.Module):
|
10 |
+
def __init__(self, start_index=1):
|
11 |
+
super(Slice, self).__init__()
|
12 |
+
self.start_index = start_index
|
13 |
+
|
14 |
+
def forward(self, x):
|
15 |
+
return x[:, self.start_index :]
|
16 |
+
|
17 |
+
|
18 |
+
class AddReadout(nn.Module):
|
19 |
+
def __init__(self, start_index=1):
|
20 |
+
super(AddReadout, self).__init__()
|
21 |
+
self.start_index = start_index
|
22 |
+
|
23 |
+
def forward(self, x):
|
24 |
+
if self.start_index == 2:
|
25 |
+
readout = (x[:, 0] + x[:, 1]) / 2
|
26 |
+
else:
|
27 |
+
readout = x[:, 0]
|
28 |
+
return x[:, self.start_index :] + readout.unsqueeze(1)
|
29 |
+
|
30 |
+
|
31 |
+
class ProjectReadout(nn.Module):
|
32 |
+
def __init__(self, in_features, start_index=1):
|
33 |
+
super(ProjectReadout, self).__init__()
|
34 |
+
self.start_index = start_index
|
35 |
+
|
36 |
+
self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU())
|
37 |
+
|
38 |
+
def forward(self, x):
|
39 |
+
readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :])
|
40 |
+
features = torch.cat((x[:, self.start_index :], readout), -1)
|
41 |
+
|
42 |
+
return self.project(features)
|
43 |
+
|
44 |
+
|
45 |
+
class Transpose(nn.Module):
|
46 |
+
def __init__(self, dim0, dim1):
|
47 |
+
super(Transpose, self).__init__()
|
48 |
+
self.dim0 = dim0
|
49 |
+
self.dim1 = dim1
|
50 |
+
|
51 |
+
def forward(self, x):
|
52 |
+
x = x.transpose(self.dim0, self.dim1)
|
53 |
+
return x
|
54 |
+
|
55 |
+
|
56 |
+
def forward_vit(pretrained, x):
|
57 |
+
b, c, h, w = x.shape
|
58 |
+
|
59 |
+
glob = pretrained.model.forward_flex(x)
|
60 |
+
|
61 |
+
layer_1 = pretrained.activations["1"]
|
62 |
+
layer_2 = pretrained.activations["2"]
|
63 |
+
layer_3 = pretrained.activations["3"]
|
64 |
+
layer_4 = pretrained.activations["4"]
|
65 |
+
|
66 |
+
layer_1 = pretrained.act_postprocess1[0:2](layer_1)
|
67 |
+
layer_2 = pretrained.act_postprocess2[0:2](layer_2)
|
68 |
+
layer_3 = pretrained.act_postprocess3[0:2](layer_3)
|
69 |
+
layer_4 = pretrained.act_postprocess4[0:2](layer_4)
|
70 |
+
|
71 |
+
unflatten = nn.Sequential(
|
72 |
+
nn.Unflatten(
|
73 |
+
2,
|
74 |
+
torch.Size(
|
75 |
+
[
|
76 |
+
h // pretrained.model.patch_size[1],
|
77 |
+
w // pretrained.model.patch_size[0],
|
78 |
+
]
|
79 |
+
),
|
80 |
+
)
|
81 |
+
)
|
82 |
+
|
83 |
+
if layer_1.ndim == 3:
|
84 |
+
layer_1 = unflatten(layer_1)
|
85 |
+
if layer_2.ndim == 3:
|
86 |
+
layer_2 = unflatten(layer_2)
|
87 |
+
if layer_3.ndim == 3:
|
88 |
+
layer_3 = unflatten(layer_3)
|
89 |
+
if layer_4.ndim == 3:
|
90 |
+
layer_4 = unflatten(layer_4)
|
91 |
+
|
92 |
+
layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1)
|
93 |
+
layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2)
|
94 |
+
layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3)
|
95 |
+
layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4)
|
96 |
+
|
97 |
+
return layer_1, layer_2, layer_3, layer_4
|
98 |
+
|
99 |
+
|
100 |
+
def _resize_pos_embed(self, posemb, gs_h, gs_w):
|
101 |
+
posemb_tok, posemb_grid = (
|
102 |
+
posemb[:, : self.start_index],
|
103 |
+
posemb[0, self.start_index :],
|
104 |
+
)
|
105 |
+
|
106 |
+
gs_old = int(math.sqrt(len(posemb_grid)))
|
107 |
+
|
108 |
+
posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
|
109 |
+
posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear")
|
110 |
+
posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1)
|
111 |
+
|
112 |
+
posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
|
113 |
+
|
114 |
+
return posemb
|
115 |
+
|
116 |
+
|
117 |
+
def forward_flex(self, x):
|
118 |
+
b, c, h, w = x.shape
|
119 |
+
|
120 |
+
pos_embed = self._resize_pos_embed(
|
121 |
+
self.pos_embed, h // self.patch_size[1], w // self.patch_size[0]
|
122 |
+
)
|
123 |
+
|
124 |
+
B = x.shape[0]
|
125 |
+
|
126 |
+
if hasattr(self.patch_embed, "backbone"):
|
127 |
+
x = self.patch_embed.backbone(x)
|
128 |
+
if isinstance(x, (list, tuple)):
|
129 |
+
x = x[-1] # last feature if backbone outputs list/tuple of features
|
130 |
+
|
131 |
+
x = self.patch_embed.proj(x).flatten(2).transpose(1, 2)
|
132 |
+
|
133 |
+
if getattr(self, "dist_token", None) is not None:
|
134 |
+
cls_tokens = self.cls_token.expand(
|
135 |
+
B, -1, -1
|
136 |
+
) # stole cls_tokens impl from Phil Wang, thanks
|
137 |
+
dist_token = self.dist_token.expand(B, -1, -1)
|
138 |
+
x = torch.cat((cls_tokens, dist_token, x), dim=1)
|
139 |
+
else:
|
140 |
+
cls_tokens = self.cls_token.expand(
|
141 |
+
B, -1, -1
|
142 |
+
) # stole cls_tokens impl from Phil Wang, thanks
|
143 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
144 |
+
|
145 |
+
x = x + pos_embed
|
146 |
+
x = self.pos_drop(x)
|
147 |
+
|
148 |
+
for blk in self.blocks:
|
149 |
+
x = blk(x)
|
150 |
+
|
151 |
+
x = self.norm(x)
|
152 |
+
|
153 |
+
return x
|
154 |
+
|
155 |
+
|
156 |
+
activations = {}
|
157 |
+
|
158 |
+
|
159 |
+
def get_activation(name):
|
160 |
+
def hook(model, input, output):
|
161 |
+
activations[name] = output
|
162 |
+
|
163 |
+
return hook
|
164 |
+
|
165 |
+
|
166 |
+
def get_readout_oper(vit_features, features, use_readout, start_index=1):
|
167 |
+
if use_readout == "ignore":
|
168 |
+
readout_oper = [Slice(start_index)] * len(features)
|
169 |
+
elif use_readout == "add":
|
170 |
+
readout_oper = [AddReadout(start_index)] * len(features)
|
171 |
+
elif use_readout == "project":
|
172 |
+
readout_oper = [
|
173 |
+
ProjectReadout(vit_features, start_index) for out_feat in features
|
174 |
+
]
|
175 |
+
else:
|
176 |
+
assert (
|
177 |
+
False
|
178 |
+
), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'"
|
179 |
+
|
180 |
+
return readout_oper
|
181 |
+
|
182 |
+
|
183 |
+
def _make_vit_b16_backbone(
|
184 |
+
model,
|
185 |
+
features=[96, 192, 384, 768],
|
186 |
+
size=[384, 384],
|
187 |
+
hooks=[2, 5, 8, 11],
|
188 |
+
vit_features=768,
|
189 |
+
use_readout="ignore",
|
190 |
+
start_index=1,
|
191 |
+
):
|
192 |
+
pretrained = nn.Module()
|
193 |
+
|
194 |
+
pretrained.model = model
|
195 |
+
pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
|
196 |
+
pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
|
197 |
+
pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
|
198 |
+
pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
|
199 |
+
|
200 |
+
pretrained.activations = activations
|
201 |
+
|
202 |
+
readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
|
203 |
+
|
204 |
+
# 32, 48, 136, 384
|
205 |
+
pretrained.act_postprocess1 = nn.Sequential(
|
206 |
+
readout_oper[0],
|
207 |
+
Transpose(1, 2),
|
208 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
209 |
+
nn.Conv2d(
|
210 |
+
in_channels=vit_features,
|
211 |
+
out_channels=features[0],
|
212 |
+
kernel_size=1,
|
213 |
+
stride=1,
|
214 |
+
padding=0,
|
215 |
+
),
|
216 |
+
nn.ConvTranspose2d(
|
217 |
+
in_channels=features[0],
|
218 |
+
out_channels=features[0],
|
219 |
+
kernel_size=4,
|
220 |
+
stride=4,
|
221 |
+
padding=0,
|
222 |
+
bias=True,
|
223 |
+
dilation=1,
|
224 |
+
groups=1,
|
225 |
+
),
|
226 |
+
)
|
227 |
+
|
228 |
+
pretrained.act_postprocess2 = nn.Sequential(
|
229 |
+
readout_oper[1],
|
230 |
+
Transpose(1, 2),
|
231 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
232 |
+
nn.Conv2d(
|
233 |
+
in_channels=vit_features,
|
234 |
+
out_channels=features[1],
|
235 |
+
kernel_size=1,
|
236 |
+
stride=1,
|
237 |
+
padding=0,
|
238 |
+
),
|
239 |
+
nn.ConvTranspose2d(
|
240 |
+
in_channels=features[1],
|
241 |
+
out_channels=features[1],
|
242 |
+
kernel_size=2,
|
243 |
+
stride=2,
|
244 |
+
padding=0,
|
245 |
+
bias=True,
|
246 |
+
dilation=1,
|
247 |
+
groups=1,
|
248 |
+
),
|
249 |
+
)
|
250 |
+
|
251 |
+
pretrained.act_postprocess3 = nn.Sequential(
|
252 |
+
readout_oper[2],
|
253 |
+
Transpose(1, 2),
|
254 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
255 |
+
nn.Conv2d(
|
256 |
+
in_channels=vit_features,
|
257 |
+
out_channels=features[2],
|
258 |
+
kernel_size=1,
|
259 |
+
stride=1,
|
260 |
+
padding=0,
|
261 |
+
),
|
262 |
+
)
|
263 |
+
|
264 |
+
pretrained.act_postprocess4 = nn.Sequential(
|
265 |
+
readout_oper[3],
|
266 |
+
Transpose(1, 2),
|
267 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
268 |
+
nn.Conv2d(
|
269 |
+
in_channels=vit_features,
|
270 |
+
out_channels=features[3],
|
271 |
+
kernel_size=1,
|
272 |
+
stride=1,
|
273 |
+
padding=0,
|
274 |
+
),
|
275 |
+
nn.Conv2d(
|
276 |
+
in_channels=features[3],
|
277 |
+
out_channels=features[3],
|
278 |
+
kernel_size=3,
|
279 |
+
stride=2,
|
280 |
+
padding=1,
|
281 |
+
),
|
282 |
+
)
|
283 |
+
|
284 |
+
pretrained.model.start_index = start_index
|
285 |
+
pretrained.model.patch_size = [16, 16]
|
286 |
+
|
287 |
+
# We inject this function into the VisionTransformer instances so that
|
288 |
+
# we can use it with interpolated position embeddings without modifying the library source.
|
289 |
+
pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
|
290 |
+
pretrained.model._resize_pos_embed = types.MethodType(
|
291 |
+
_resize_pos_embed, pretrained.model
|
292 |
+
)
|
293 |
+
|
294 |
+
return pretrained
|
295 |
+
|
296 |
+
|
297 |
+
def _make_pretrained_vitl16_384(pretrained, use_readout="ignore", hooks=None):
|
298 |
+
model = timm.create_model("vit_large_patch16_384", pretrained=pretrained)
|
299 |
+
|
300 |
+
hooks = [5, 11, 17, 23] if hooks == None else hooks
|
301 |
+
return _make_vit_b16_backbone(
|
302 |
+
model,
|
303 |
+
features=[256, 512, 1024, 1024],
|
304 |
+
hooks=hooks,
|
305 |
+
vit_features=1024,
|
306 |
+
use_readout=use_readout,
|
307 |
+
)
|
308 |
+
|
309 |
+
|
310 |
+
def _make_pretrained_vitb16_384(pretrained, use_readout="ignore", hooks=None):
|
311 |
+
model = timm.create_model("vit_base_patch16_384", pretrained=pretrained)
|
312 |
+
|
313 |
+
hooks = [2, 5, 8, 11] if hooks == None else hooks
|
314 |
+
return _make_vit_b16_backbone(
|
315 |
+
model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
|
316 |
+
)
|
317 |
+
|
318 |
+
|
319 |
+
def _make_pretrained_deitb16_384(pretrained, use_readout="ignore", hooks=None):
|
320 |
+
model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained)
|
321 |
+
|
322 |
+
hooks = [2, 5, 8, 11] if hooks == None else hooks
|
323 |
+
return _make_vit_b16_backbone(
|
324 |
+
model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
|
325 |
+
)
|
326 |
+
|
327 |
+
|
328 |
+
def _make_pretrained_deitb16_distil_384(pretrained, use_readout="ignore", hooks=None):
|
329 |
+
model = timm.create_model(
|
330 |
+
"vit_deit_base_distilled_patch16_384", pretrained=pretrained
|
331 |
+
)
|
332 |
+
|
333 |
+
hooks = [2, 5, 8, 11] if hooks == None else hooks
|
334 |
+
return _make_vit_b16_backbone(
|
335 |
+
model,
|
336 |
+
features=[96, 192, 384, 768],
|
337 |
+
hooks=hooks,
|
338 |
+
use_readout=use_readout,
|
339 |
+
start_index=2,
|
340 |
+
)
|
341 |
+
|
342 |
+
|
343 |
+
def _make_vit_b_rn50_backbone(
|
344 |
+
model,
|
345 |
+
features=[256, 512, 768, 768],
|
346 |
+
size=[384, 384],
|
347 |
+
hooks=[0, 1, 8, 11],
|
348 |
+
vit_features=768,
|
349 |
+
use_vit_only=False,
|
350 |
+
use_readout="ignore",
|
351 |
+
start_index=1,
|
352 |
+
):
|
353 |
+
pretrained = nn.Module()
|
354 |
+
|
355 |
+
pretrained.model = model
|
356 |
+
|
357 |
+
if use_vit_only == True:
|
358 |
+
pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
|
359 |
+
pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
|
360 |
+
else:
|
361 |
+
pretrained.model.patch_embed.backbone.stages[0].register_forward_hook(
|
362 |
+
get_activation("1")
|
363 |
+
)
|
364 |
+
pretrained.model.patch_embed.backbone.stages[1].register_forward_hook(
|
365 |
+
get_activation("2")
|
366 |
+
)
|
367 |
+
|
368 |
+
pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
|
369 |
+
pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
|
370 |
+
|
371 |
+
pretrained.activations = activations
|
372 |
+
|
373 |
+
readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
|
374 |
+
|
375 |
+
if use_vit_only == True:
|
376 |
+
pretrained.act_postprocess1 = nn.Sequential(
|
377 |
+
readout_oper[0],
|
378 |
+
Transpose(1, 2),
|
379 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
380 |
+
nn.Conv2d(
|
381 |
+
in_channels=vit_features,
|
382 |
+
out_channels=features[0],
|
383 |
+
kernel_size=1,
|
384 |
+
stride=1,
|
385 |
+
padding=0,
|
386 |
+
),
|
387 |
+
nn.ConvTranspose2d(
|
388 |
+
in_channels=features[0],
|
389 |
+
out_channels=features[0],
|
390 |
+
kernel_size=4,
|
391 |
+
stride=4,
|
392 |
+
padding=0,
|
393 |
+
bias=True,
|
394 |
+
dilation=1,
|
395 |
+
groups=1,
|
396 |
+
),
|
397 |
+
)
|
398 |
+
|
399 |
+
pretrained.act_postprocess2 = nn.Sequential(
|
400 |
+
readout_oper[1],
|
401 |
+
Transpose(1, 2),
|
402 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
403 |
+
nn.Conv2d(
|
404 |
+
in_channels=vit_features,
|
405 |
+
out_channels=features[1],
|
406 |
+
kernel_size=1,
|
407 |
+
stride=1,
|
408 |
+
padding=0,
|
409 |
+
),
|
410 |
+
nn.ConvTranspose2d(
|
411 |
+
in_channels=features[1],
|
412 |
+
out_channels=features[1],
|
413 |
+
kernel_size=2,
|
414 |
+
stride=2,
|
415 |
+
padding=0,
|
416 |
+
bias=True,
|
417 |
+
dilation=1,
|
418 |
+
groups=1,
|
419 |
+
),
|
420 |
+
)
|
421 |
+
else:
|
422 |
+
pretrained.act_postprocess1 = nn.Sequential(
|
423 |
+
nn.Identity(), nn.Identity(), nn.Identity()
|
424 |
+
)
|
425 |
+
pretrained.act_postprocess2 = nn.Sequential(
|
426 |
+
nn.Identity(), nn.Identity(), nn.Identity()
|
427 |
+
)
|
428 |
+
|
429 |
+
pretrained.act_postprocess3 = nn.Sequential(
|
430 |
+
readout_oper[2],
|
431 |
+
Transpose(1, 2),
|
432 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
433 |
+
nn.Conv2d(
|
434 |
+
in_channels=vit_features,
|
435 |
+
out_channels=features[2],
|
436 |
+
kernel_size=1,
|
437 |
+
stride=1,
|
438 |
+
padding=0,
|
439 |
+
),
|
440 |
+
)
|
441 |
+
|
442 |
+
pretrained.act_postprocess4 = nn.Sequential(
|
443 |
+
readout_oper[3],
|
444 |
+
Transpose(1, 2),
|
445 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
446 |
+
nn.Conv2d(
|
447 |
+
in_channels=vit_features,
|
448 |
+
out_channels=features[3],
|
449 |
+
kernel_size=1,
|
450 |
+
stride=1,
|
451 |
+
padding=0,
|
452 |
+
),
|
453 |
+
nn.Conv2d(
|
454 |
+
in_channels=features[3],
|
455 |
+
out_channels=features[3],
|
456 |
+
kernel_size=3,
|
457 |
+
stride=2,
|
458 |
+
padding=1,
|
459 |
+
),
|
460 |
+
)
|
461 |
+
|
462 |
+
pretrained.model.start_index = start_index
|
463 |
+
pretrained.model.patch_size = [16, 16]
|
464 |
+
|
465 |
+
# We inject this function into the VisionTransformer instances so that
|
466 |
+
# we can use it with interpolated position embeddings without modifying the library source.
|
467 |
+
pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
|
468 |
+
|
469 |
+
# We inject this function into the VisionTransformer instances so that
|
470 |
+
# we can use it with interpolated position embeddings without modifying the library source.
|
471 |
+
pretrained.model._resize_pos_embed = types.MethodType(
|
472 |
+
_resize_pos_embed, pretrained.model
|
473 |
+
)
|
474 |
+
|
475 |
+
return pretrained
|
476 |
+
|
477 |
+
|
478 |
+
def _make_pretrained_vitb_rn50_384(
|
479 |
+
pretrained, use_readout="ignore", hooks=None, use_vit_only=False
|
480 |
+
):
|
481 |
+
model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained)
|
482 |
+
|
483 |
+
hooks = [0, 1, 8, 11] if hooks == None else hooks
|
484 |
+
return _make_vit_b_rn50_backbone(
|
485 |
+
model,
|
486 |
+
features=[256, 512, 768, 768],
|
487 |
+
size=[384, 384],
|
488 |
+
hooks=hooks,
|
489 |
+
use_vit_only=use_vit_only,
|
490 |
+
use_readout=use_readout,
|
491 |
+
)
|
condition/utils.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import cv2
|
3 |
+
import os
|
4 |
+
|
5 |
+
|
6 |
+
annotator_ckpts_path = os.path.join(os.path.dirname(__file__), 'ckpts')
|
7 |
+
|
8 |
+
|
9 |
+
def HWC3(x):
|
10 |
+
assert x.dtype == np.uint8
|
11 |
+
if x.ndim == 2:
|
12 |
+
x = x[:, :, None]
|
13 |
+
assert x.ndim == 3
|
14 |
+
H, W, C = x.shape
|
15 |
+
assert C == 1 or C == 3 or C == 4
|
16 |
+
if C == 3:
|
17 |
+
return x
|
18 |
+
if C == 1:
|
19 |
+
return np.concatenate([x, x, x], axis=2)
|
20 |
+
if C == 4:
|
21 |
+
color = x[:, :, 0:3].astype(np.float32)
|
22 |
+
alpha = x[:, :, 3:4].astype(np.float32) / 255.0
|
23 |
+
y = color * alpha + 255.0 * (1.0 - alpha)
|
24 |
+
y = y.clip(0, 255).astype(np.uint8)
|
25 |
+
return y
|
26 |
+
|
27 |
+
|
28 |
+
def resize_image(input_image, resolution):
|
29 |
+
H, W, C = input_image.shape
|
30 |
+
H = float(H)
|
31 |
+
W = float(W)
|
32 |
+
k = float(resolution) / min(H, W)
|
33 |
+
H *= k
|
34 |
+
W *= k
|
35 |
+
H = int(np.round(H / 64.0)) * 64
|
36 |
+
W = int(np.round(W / 64.0)) * 64
|
37 |
+
img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA)
|
38 |
+
return img
|
language/README.md
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Language models for text-conditional image generation
|
2 |
+
|
3 |
+
### Requirements
|
4 |
+
```
|
5 |
+
pip install ftfy
|
6 |
+
pip install transformers
|
7 |
+
pip install accelerate
|
8 |
+
pip install sentencepiece
|
9 |
+
pip install pandas
|
10 |
+
pip install bs4
|
11 |
+
```
|
12 |
+
|
13 |
+
### Language Models
|
14 |
+
Download flan-t5-xl models from [flan-t5-xl](https://huggingface.co/google/flan-t5-xl) and put into the folder of `./pretrained_models/t5-ckpt/`
|
language/extract_t5_feature.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
3 |
+
torch.backends.cudnn.allow_tf32 = True
|
4 |
+
import torch.distributed as dist
|
5 |
+
from torch.utils.data import Dataset, DataLoader
|
6 |
+
from torch.utils.data.distributed import DistributedSampler
|
7 |
+
import numpy as np
|
8 |
+
import argparse
|
9 |
+
import os
|
10 |
+
import json
|
11 |
+
|
12 |
+
from utils.distributed import init_distributed_mode
|
13 |
+
from language.t5 import T5Embedder
|
14 |
+
|
15 |
+
CAPTION_KEY = {
|
16 |
+
'blip': 0,
|
17 |
+
'llava': 1,
|
18 |
+
'llava_first': 2,
|
19 |
+
}
|
20 |
+
#################################################################################
|
21 |
+
# Training Helper Functions #
|
22 |
+
#################################################################################
|
23 |
+
class CustomDataset(Dataset):
|
24 |
+
def __init__(self, lst_dir, start, end, caption_key, trunc_caption=False):
|
25 |
+
img_path_list = []
|
26 |
+
for lst_name in sorted(os.listdir(lst_dir))[start: end+1]:
|
27 |
+
if not lst_name.endswith('.jsonl'):
|
28 |
+
continue
|
29 |
+
file_path = os.path.join(lst_dir, lst_name)
|
30 |
+
with open(file_path, 'r') as file:
|
31 |
+
for line_idx, line in enumerate(file):
|
32 |
+
data = json.loads(line)
|
33 |
+
# caption = data[caption_key]
|
34 |
+
caption = data['text'][CAPTION_KEY[caption_key]]
|
35 |
+
code_dir = file_path.split('/')[-1].split('.')[0]
|
36 |
+
if trunc_caption:
|
37 |
+
caption = caption.split('.')[0]
|
38 |
+
img_path_list.append((caption, code_dir, line_idx))
|
39 |
+
self.img_path_list = img_path_list
|
40 |
+
|
41 |
+
def __len__(self):
|
42 |
+
return len(self.img_path_list)
|
43 |
+
|
44 |
+
def __getitem__(self, index):
|
45 |
+
caption, code_dir, code_name = self.img_path_list[index]
|
46 |
+
return caption, code_dir, code_name
|
47 |
+
|
48 |
+
|
49 |
+
|
50 |
+
#################################################################################
|
51 |
+
# Training Loop #
|
52 |
+
#################################################################################
|
53 |
+
def main(args):
|
54 |
+
"""
|
55 |
+
Trains a new DiT model.
|
56 |
+
"""
|
57 |
+
assert torch.cuda.is_available(), "Training currently requires at least one GPU."
|
58 |
+
|
59 |
+
# Setup DDP:
|
60 |
+
# dist.init_process_group("nccl")
|
61 |
+
init_distributed_mode(args)
|
62 |
+
rank = dist.get_rank()
|
63 |
+
device = rank % torch.cuda.device_count()
|
64 |
+
seed = args.global_seed * dist.get_world_size() + rank
|
65 |
+
torch.manual_seed(seed)
|
66 |
+
torch.cuda.set_device(device)
|
67 |
+
print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.")
|
68 |
+
|
69 |
+
# Setup a feature folder:
|
70 |
+
if rank == 0:
|
71 |
+
os.makedirs(args.t5_path, exist_ok=True)
|
72 |
+
|
73 |
+
# Setup data:
|
74 |
+
print(f"Dataset is preparing...")
|
75 |
+
dataset = CustomDataset(args.data_path, args.data_start, args.data_end, args.caption_key, args.trunc_caption)
|
76 |
+
sampler = DistributedSampler(
|
77 |
+
dataset,
|
78 |
+
num_replicas=dist.get_world_size(),
|
79 |
+
rank=rank,
|
80 |
+
shuffle=False,
|
81 |
+
seed=args.global_seed
|
82 |
+
)
|
83 |
+
loader = DataLoader(
|
84 |
+
dataset,
|
85 |
+
batch_size=1, # important!
|
86 |
+
shuffle=False,
|
87 |
+
sampler=sampler,
|
88 |
+
num_workers=args.num_workers,
|
89 |
+
pin_memory=True,
|
90 |
+
drop_last=False
|
91 |
+
)
|
92 |
+
print(f"Dataset contains {len(dataset):,} images")
|
93 |
+
|
94 |
+
precision = {'none': torch.float32, 'bf16': torch.bfloat16, 'fp16': torch.float16}[args.precision]
|
95 |
+
assert os.path.exists(args.t5_model_path)
|
96 |
+
t5_xxl = T5Embedder(
|
97 |
+
device=device,
|
98 |
+
local_cache=True,
|
99 |
+
cache_dir=args.t5_model_path,
|
100 |
+
dir_or_name=args.t5_model_type,
|
101 |
+
torch_dtype=precision
|
102 |
+
)
|
103 |
+
|
104 |
+
for caption, code_dir, code_name in loader:
|
105 |
+
caption_embs, emb_masks = t5_xxl.get_text_embeddings(caption)
|
106 |
+
valid_caption_embs = caption_embs[:, :emb_masks.sum()]
|
107 |
+
x = valid_caption_embs.to(torch.float32).detach().cpu().numpy()
|
108 |
+
os.makedirs(os.path.join(args.t5_path, code_dir[0]), exist_ok=True)
|
109 |
+
np.save(os.path.join(args.t5_path, code_dir[0], '{}.npy'.format(code_name.item())), x)
|
110 |
+
print(code_name.item())
|
111 |
+
|
112 |
+
dist.destroy_process_group()
|
113 |
+
|
114 |
+
|
115 |
+
if __name__ == "__main__":
|
116 |
+
parser = argparse.ArgumentParser()
|
117 |
+
parser.add_argument("--data-path", type=str, required=True)
|
118 |
+
parser.add_argument("--t5-path", type=str, required=True)
|
119 |
+
parser.add_argument("--data-start", type=int, required=True)
|
120 |
+
parser.add_argument("--data-end", type=int, required=True)
|
121 |
+
parser.add_argument("--caption-key", type=str, default='blip', choices=list(CAPTION_KEY.keys()))
|
122 |
+
parser.add_argument("--trunc-caption", action='store_true', default=False)
|
123 |
+
parser.add_argument("--t5-model-path", type=str, default='./pretrained_models/t5-ckpt')
|
124 |
+
parser.add_argument("--t5-model-type", type=str, default='flan-t5-xl')
|
125 |
+
parser.add_argument("--precision", type=str, default='bf16', choices=["none", "fp16", "bf16"])
|
126 |
+
parser.add_argument("--global-seed", type=int, default=0)
|
127 |
+
parser.add_argument("--num-workers", type=int, default=24)
|
128 |
+
args = parser.parse_args()
|
129 |
+
main(args)
|
language/t5.py
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from:
|
2 |
+
# PixArt: https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/t5.py
|
3 |
+
import os
|
4 |
+
import re
|
5 |
+
import html
|
6 |
+
import urllib.parse as ul
|
7 |
+
|
8 |
+
import ftfy
|
9 |
+
import torch
|
10 |
+
from bs4 import BeautifulSoup
|
11 |
+
from transformers import T5EncoderModel, AutoTokenizer
|
12 |
+
from huggingface_hub import hf_hub_download
|
13 |
+
|
14 |
+
|
15 |
+
class T5Embedder:
|
16 |
+
available_models = ['t5-v1_1-xxl', 't5-v1_1-xl', 'flan-t5-xl']
|
17 |
+
bad_punct_regex = re.compile(r'['+'#®•©™&@·º½¾¿¡§~'+'\)'+'\('+'\]'+'\['+'\}'+'\{'+'\|'+'\\'+'\/'+'\*' + r']{1,}') # noqa
|
18 |
+
|
19 |
+
def __init__(self, device, dir_or_name='t5-v1_1-xxl', *, local_cache=False, cache_dir=None, hf_token=None, use_text_preprocessing=True,
|
20 |
+
t5_model_kwargs=None, torch_dtype=None, use_offload_folder=None, model_max_length=120):
|
21 |
+
self.device = torch.device(device)
|
22 |
+
self.torch_dtype = torch_dtype or torch.bfloat16
|
23 |
+
if t5_model_kwargs is None:
|
24 |
+
t5_model_kwargs = {'low_cpu_mem_usage': True, 'torch_dtype': self.torch_dtype}
|
25 |
+
t5_model_kwargs['device_map'] = {'shared': self.device, 'encoder': self.device}
|
26 |
+
|
27 |
+
self.use_text_preprocessing = use_text_preprocessing
|
28 |
+
self.hf_token = hf_token
|
29 |
+
self.cache_dir = cache_dir or os.path.expanduser('~/.cache/IF_')
|
30 |
+
self.dir_or_name = dir_or_name
|
31 |
+
tokenizer_path, path = dir_or_name, dir_or_name
|
32 |
+
if local_cache:
|
33 |
+
cache_dir = os.path.join(self.cache_dir, dir_or_name)
|
34 |
+
tokenizer_path, path = cache_dir, cache_dir
|
35 |
+
elif dir_or_name in self.available_models:
|
36 |
+
cache_dir = os.path.join(self.cache_dir, dir_or_name)
|
37 |
+
for filename in [
|
38 |
+
'config.json', 'special_tokens_map.json', 'spiece.model', 'tokenizer_config.json',
|
39 |
+
'pytorch_model.bin.index.json', 'pytorch_model-00001-of-00002.bin', 'pytorch_model-00002-of-00002.bin'
|
40 |
+
]:
|
41 |
+
hf_hub_download(repo_id=f'DeepFloyd/{dir_or_name}', filename=filename, cache_dir=cache_dir,
|
42 |
+
force_filename=filename, token=self.hf_token)
|
43 |
+
tokenizer_path, path = cache_dir, cache_dir
|
44 |
+
else:
|
45 |
+
cache_dir = os.path.join(self.cache_dir, 't5-v1_1-xxl')
|
46 |
+
for filename in [
|
47 |
+
'config.json', 'special_tokens_map.json', 'spiece.model', 'tokenizer_config.json',
|
48 |
+
]:
|
49 |
+
hf_hub_download(repo_id='DeepFloyd/t5-v1_1-xxl', filename=filename, cache_dir=cache_dir,
|
50 |
+
force_filename=filename, token=self.hf_token)
|
51 |
+
tokenizer_path = cache_dir
|
52 |
+
|
53 |
+
print(tokenizer_path)
|
54 |
+
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
|
55 |
+
self.model = T5EncoderModel.from_pretrained(path, **t5_model_kwargs).eval()
|
56 |
+
self.model_max_length = model_max_length
|
57 |
+
|
58 |
+
def get_text_embeddings(self, texts):
|
59 |
+
texts = [self.text_preprocessing(text) for text in texts]
|
60 |
+
|
61 |
+
text_tokens_and_mask = self.tokenizer(
|
62 |
+
texts,
|
63 |
+
max_length=self.model_max_length,
|
64 |
+
padding='max_length',
|
65 |
+
truncation=True,
|
66 |
+
return_attention_mask=True,
|
67 |
+
add_special_tokens=True,
|
68 |
+
return_tensors='pt'
|
69 |
+
)
|
70 |
+
|
71 |
+
text_tokens_and_mask['input_ids'] = text_tokens_and_mask['input_ids']
|
72 |
+
text_tokens_and_mask['attention_mask'] = text_tokens_and_mask['attention_mask']
|
73 |
+
|
74 |
+
with torch.no_grad():
|
75 |
+
text_encoder_embs = self.model(
|
76 |
+
input_ids=text_tokens_and_mask['input_ids'].to(self.device),
|
77 |
+
attention_mask=text_tokens_and_mask['attention_mask'].to(self.device),
|
78 |
+
)['last_hidden_state'].detach()
|
79 |
+
return text_encoder_embs, text_tokens_and_mask['attention_mask'].to(self.device)
|
80 |
+
|
81 |
+
def text_preprocessing(self, text):
|
82 |
+
if self.use_text_preprocessing:
|
83 |
+
# The exact text cleaning as was in the training stage:
|
84 |
+
text = self.clean_caption(text)
|
85 |
+
text = self.clean_caption(text)
|
86 |
+
return text
|
87 |
+
else:
|
88 |
+
return text.lower().strip()
|
89 |
+
|
90 |
+
@staticmethod
|
91 |
+
def basic_clean(text):
|
92 |
+
text = ftfy.fix_text(text)
|
93 |
+
text = html.unescape(html.unescape(text))
|
94 |
+
return text.strip()
|
95 |
+
|
96 |
+
def clean_caption(self, caption):
|
97 |
+
caption = str(caption)
|
98 |
+
caption = ul.unquote_plus(caption)
|
99 |
+
caption = caption.strip().lower()
|
100 |
+
caption = re.sub('<person>', 'person', caption)
|
101 |
+
# urls:
|
102 |
+
caption = re.sub(
|
103 |
+
r'\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))', # noqa
|
104 |
+
'', caption) # regex for urls
|
105 |
+
caption = re.sub(
|
106 |
+
r'\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))', # noqa
|
107 |
+
'', caption) # regex for urls
|
108 |
+
# html:
|
109 |
+
caption = BeautifulSoup(caption, features='html.parser').text
|
110 |
+
|
111 |
+
# @<nickname>
|
112 |
+
caption = re.sub(r'@[\w\d]+\b', '', caption)
|
113 |
+
|
114 |
+
# 31C0—31EF CJK Strokes
|
115 |
+
# 31F0—31FF Katakana Phonetic Extensions
|
116 |
+
# 3200—32FF Enclosed CJK Letters and Months
|
117 |
+
# 3300—33FF CJK Compatibility
|
118 |
+
# 3400—4DBF CJK Unified Ideographs Extension A
|
119 |
+
# 4DC0—4DFF Yijing Hexagram Symbols
|
120 |
+
# 4E00—9FFF CJK Unified Ideographs
|
121 |
+
caption = re.sub(r'[\u31c0-\u31ef]+', '', caption)
|
122 |
+
caption = re.sub(r'[\u31f0-\u31ff]+', '', caption)
|
123 |
+
caption = re.sub(r'[\u3200-\u32ff]+', '', caption)
|
124 |
+
caption = re.sub(r'[\u3300-\u33ff]+', '', caption)
|
125 |
+
caption = re.sub(r'[\u3400-\u4dbf]+', '', caption)
|
126 |
+
caption = re.sub(r'[\u4dc0-\u4dff]+', '', caption)
|
127 |
+
caption = re.sub(r'[\u4e00-\u9fff]+', '', caption)
|
128 |
+
#######################################################
|
129 |
+
|
130 |
+
# все виды тире / all types of dash --> "-"
|
131 |
+
caption = re.sub(
|
132 |
+
r'[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+', # noqa
|
133 |
+
'-', caption)
|
134 |
+
|
135 |
+
# кавычки к одному стандарту
|
136 |
+
caption = re.sub(r'[`´«»“”¨]', '"', caption)
|
137 |
+
caption = re.sub(r'[‘’]', "'", caption)
|
138 |
+
|
139 |
+
# "
|
140 |
+
caption = re.sub(r'"?', '', caption)
|
141 |
+
# &
|
142 |
+
caption = re.sub(r'&', '', caption)
|
143 |
+
|
144 |
+
# ip adresses:
|
145 |
+
caption = re.sub(r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}', ' ', caption)
|
146 |
+
|
147 |
+
# article ids:
|
148 |
+
caption = re.sub(r'\d:\d\d\s+$', '', caption)
|
149 |
+
|
150 |
+
# \n
|
151 |
+
caption = re.sub(r'\\n', ' ', caption)
|
152 |
+
|
153 |
+
# "#123"
|
154 |
+
caption = re.sub(r'#\d{1,3}\b', '', caption)
|
155 |
+
# "#12345.."
|
156 |
+
caption = re.sub(r'#\d{5,}\b', '', caption)
|
157 |
+
# "123456.."
|
158 |
+
caption = re.sub(r'\b\d{6,}\b', '', caption)
|
159 |
+
# filenames:
|
160 |
+
caption = re.sub(r'[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)', '', caption)
|
161 |
+
|
162 |
+
#
|
163 |
+
caption = re.sub(r'[\"\']{2,}', r'"', caption) # """AUSVERKAUFT"""
|
164 |
+
caption = re.sub(r'[\.]{2,}', r' ', caption) # """AUSVERKAUFT"""
|
165 |
+
|
166 |
+
caption = re.sub(self.bad_punct_regex, r' ', caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
|
167 |
+
caption = re.sub(r'\s+\.\s+', r' ', caption) # " . "
|
168 |
+
|
169 |
+
# this-is-my-cute-cat / this_is_my_cute_cat
|
170 |
+
regex2 = re.compile(r'(?:\-|\_)')
|
171 |
+
if len(re.findall(regex2, caption)) > 3:
|
172 |
+
caption = re.sub(regex2, ' ', caption)
|
173 |
+
|
174 |
+
caption = self.basic_clean(caption)
|
175 |
+
|
176 |
+
caption = re.sub(r'\b[a-zA-Z]{1,3}\d{3,15}\b', '', caption) # jc6640
|
177 |
+
caption = re.sub(r'\b[a-zA-Z]+\d+[a-zA-Z]+\b', '', caption) # jc6640vc
|
178 |
+
caption = re.sub(r'\b\d+[a-zA-Z]+\d+\b', '', caption) # 6640vc231
|
179 |
+
|
180 |
+
caption = re.sub(r'(worldwide\s+)?(free\s+)?shipping', '', caption)
|
181 |
+
caption = re.sub(r'(free\s)?download(\sfree)?', '', caption)
|
182 |
+
caption = re.sub(r'\bclick\b\s(?:for|on)\s\w+', '', caption)
|
183 |
+
caption = re.sub(r'\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?', '', caption)
|
184 |
+
caption = re.sub(r'\bpage\s+\d+\b', '', caption)
|
185 |
+
|
186 |
+
caption = re.sub(r'\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b', r' ', caption) # j2d1a2a...
|
187 |
+
|
188 |
+
caption = re.sub(r'\b\d+\.?\d*[xх×]\d+\.?\d*\b', '', caption)
|
189 |
+
|
190 |
+
caption = re.sub(r'\b\s+\:\s+', r': ', caption)
|
191 |
+
caption = re.sub(r'(\D[,\./])\b', r'\1 ', caption)
|
192 |
+
caption = re.sub(r'\s+', ' ', caption)
|
193 |
+
|
194 |
+
caption.strip()
|
195 |
+
|
196 |
+
caption = re.sub(r'^[\"\']([\w\W]+)[\"\']$', r'\1', caption)
|
197 |
+
caption = re.sub(r'^[\'\_,\-\:;]', r'', caption)
|
198 |
+
caption = re.sub(r'[\'\_,\-\:\-\+]$', r'', caption)
|
199 |
+
caption = re.sub(r'^\.\S+$', '', caption)
|
200 |
+
|
201 |
+
return caption.strip()
|
model.py
ADDED
@@ -0,0 +1,242 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gc
|
2 |
+
import spaces
|
3 |
+
from safetensors.torch import load_file
|
4 |
+
from autoregressive.models.gpt_t2i import GPT_models
|
5 |
+
from tokenizer.tokenizer_image.vq_model import VQ_models
|
6 |
+
from language.t5 import T5Embedder
|
7 |
+
import torch
|
8 |
+
import numpy as np
|
9 |
+
import PIL
|
10 |
+
from PIL import Image
|
11 |
+
from condition.canny import CannyDetector
|
12 |
+
import time
|
13 |
+
from autoregressive.models.generate import generate
|
14 |
+
from condition.midas.depth import MidasDetector
|
15 |
+
|
16 |
+
models = {
|
17 |
+
"canny": "checkpoints/t2i/canny_MR.safetensors",
|
18 |
+
"depth": "checkpoints/t2i/depth_MR.safetensors",
|
19 |
+
}
|
20 |
+
|
21 |
+
|
22 |
+
def resize_image_to_16_multiple(image, condition_type='canny'):
|
23 |
+
if isinstance(image, np.ndarray):
|
24 |
+
image = Image.fromarray(image)
|
25 |
+
# image = Image.open(image_path)
|
26 |
+
width, height = image.size
|
27 |
+
|
28 |
+
if condition_type == 'depth': # The depth model requires a side length that is a multiple of 32
|
29 |
+
new_width = (width + 31) // 32 * 32
|
30 |
+
new_height = (height + 31) // 32 * 32
|
31 |
+
else:
|
32 |
+
new_width = (width + 15) // 16 * 16
|
33 |
+
new_height = (height + 15) // 16 * 16
|
34 |
+
|
35 |
+
resized_image = image.resize((new_width, new_height))
|
36 |
+
return resized_image
|
37 |
+
|
38 |
+
|
39 |
+
class Model:
|
40 |
+
|
41 |
+
def __init__(self):
|
42 |
+
self.device = torch.device(
|
43 |
+
"cuda:0" if torch.cuda.is_available() else "cpu")
|
44 |
+
self.base_model_id = ""
|
45 |
+
self.task_name = ""
|
46 |
+
self.vq_model = self.load_vq()
|
47 |
+
self.t5_model = self.load_t5()
|
48 |
+
self.gpt_model_canny = self.load_gpt(condition_type='canny')
|
49 |
+
self.gpt_model_depth = self.load_gpt(condition_type='depth')
|
50 |
+
self.get_control_canny = CannyDetector()
|
51 |
+
self.get_control_depth = MidasDetector(device=self.device)
|
52 |
+
|
53 |
+
def load_vq(self):
|
54 |
+
vq_model = VQ_models["VQ-16"](codebook_size=16384,
|
55 |
+
codebook_embed_dim=8)
|
56 |
+
vq_model.to(self.device)
|
57 |
+
vq_model.eval()
|
58 |
+
checkpoint = torch.load(f"checkpoints/vq_ds16_t2i.pt",
|
59 |
+
map_location="cpu")
|
60 |
+
vq_model.load_state_dict(checkpoint["model"])
|
61 |
+
del checkpoint
|
62 |
+
print(f"image tokenizer is loaded")
|
63 |
+
return vq_model
|
64 |
+
|
65 |
+
def load_gpt(self, condition_type='canny'):
|
66 |
+
gpt_ckpt = models[condition_type]
|
67 |
+
precision = torch.bfloat16
|
68 |
+
latent_size = 768 // 16
|
69 |
+
gpt_model = GPT_models["GPT-XL"](
|
70 |
+
block_size=latent_size**2,
|
71 |
+
cls_token_num=120,
|
72 |
+
model_type='t2i',
|
73 |
+
condition_type=condition_type,
|
74 |
+
).to(device=self.device, dtype=precision)
|
75 |
+
|
76 |
+
model_weight = load_file(gpt_ckpt)
|
77 |
+
gpt_model.load_state_dict(model_weight, strict=False)
|
78 |
+
gpt_model.eval()
|
79 |
+
print(f"gpt model is loaded")
|
80 |
+
return gpt_model
|
81 |
+
|
82 |
+
def load_t5(self):
|
83 |
+
precision = torch.bfloat16
|
84 |
+
t5_model = T5Embedder(
|
85 |
+
device=self.device,
|
86 |
+
local_cache=True,
|
87 |
+
# cache_dir='checkpoints/t5-ckpt',
|
88 |
+
dir_or_name='flan-t5-xl',
|
89 |
+
torch_dtype=precision,
|
90 |
+
model_max_length=120,
|
91 |
+
)
|
92 |
+
return t5_model
|
93 |
+
|
94 |
+
@torch.no_grad()
|
95 |
+
@spaces.GPU(enable_queue=True)
|
96 |
+
def process_canny(
|
97 |
+
self,
|
98 |
+
image: np.ndarray,
|
99 |
+
prompt: str,
|
100 |
+
cfg_scale: float,
|
101 |
+
temperature: float,
|
102 |
+
top_k: int,
|
103 |
+
top_p: int,
|
104 |
+
seed: int,
|
105 |
+
low_threshold: int,
|
106 |
+
high_threshold: int,
|
107 |
+
) -> list[PIL.Image.Image]:
|
108 |
+
|
109 |
+
image = resize_image_to_16_multiple(image, 'canny')
|
110 |
+
W, H = image.size
|
111 |
+
print(W, H)
|
112 |
+
condition_img = self.get_control_canny(np.array(image), low_threshold,
|
113 |
+
high_threshold)
|
114 |
+
condition_img = torch.from_numpy(condition_img[None, None,
|
115 |
+
...]).repeat(
|
116 |
+
2, 3, 1, 1)
|
117 |
+
condition_img = condition_img.to(self.device)
|
118 |
+
condition_img = 2 * (condition_img / 255 - 0.5)
|
119 |
+
prompts = [prompt] * 2
|
120 |
+
caption_embs, emb_masks = self.t5_model.get_text_embeddings(prompts)
|
121 |
+
|
122 |
+
print(f"processing left-padding...")
|
123 |
+
new_emb_masks = torch.flip(emb_masks, dims=[-1])
|
124 |
+
new_caption_embs = []
|
125 |
+
for idx, (caption_emb,
|
126 |
+
emb_mask) in enumerate(zip(caption_embs, emb_masks)):
|
127 |
+
valid_num = int(emb_mask.sum().item())
|
128 |
+
print(f' prompt {idx} token len: {valid_num}')
|
129 |
+
new_caption_emb = torch.cat(
|
130 |
+
[caption_emb[valid_num:], caption_emb[:valid_num]])
|
131 |
+
new_caption_embs.append(new_caption_emb)
|
132 |
+
new_caption_embs = torch.stack(new_caption_embs)
|
133 |
+
c_indices = new_caption_embs * new_emb_masks[:, :, None]
|
134 |
+
c_emb_masks = new_emb_masks
|
135 |
+
qzshape = [len(c_indices), 8, H // 16, W // 16]
|
136 |
+
t1 = time.time()
|
137 |
+
index_sample = generate(
|
138 |
+
self.gpt_model_canny,
|
139 |
+
c_indices,
|
140 |
+
(H // 16) * (W // 16),
|
141 |
+
c_emb_masks,
|
142 |
+
condition=condition_img,
|
143 |
+
cfg_scale=cfg_scale,
|
144 |
+
temperature=temperature,
|
145 |
+
top_k=top_k,
|
146 |
+
top_p=top_p,
|
147 |
+
sample_logits=True,
|
148 |
+
)
|
149 |
+
sampling_time = time.time() - t1
|
150 |
+
print(f"Full sampling takes about {sampling_time:.2f} seconds.")
|
151 |
+
|
152 |
+
t2 = time.time()
|
153 |
+
print(index_sample.shape)
|
154 |
+
samples = self.vq_model.decode_code(
|
155 |
+
index_sample, qzshape) # output value is between [-1, 1]
|
156 |
+
decoder_time = time.time() - t2
|
157 |
+
print(f"decoder takes about {decoder_time:.2f} seconds.")
|
158 |
+
|
159 |
+
samples = torch.cat((condition_img[0:1], samples), dim=0)
|
160 |
+
samples = 255 * (samples * 0.5 + 0.5)
|
161 |
+
samples = [image] + [
|
162 |
+
Image.fromarray(
|
163 |
+
sample.permute(1, 2, 0).cpu().detach().numpy().clip(
|
164 |
+
0, 255).astype(np.uint8)) for sample in samples
|
165 |
+
]
|
166 |
+
del condition_img
|
167 |
+
torch.cuda.empty_cache()
|
168 |
+
return samples
|
169 |
+
|
170 |
+
@torch.no_grad()
|
171 |
+
@spaces.GPU(enable_queue=True)
|
172 |
+
def process_depth(
|
173 |
+
self,
|
174 |
+
image: np.ndarray,
|
175 |
+
prompt: str,
|
176 |
+
cfg_scale: float,
|
177 |
+
temperature: float,
|
178 |
+
top_k: int,
|
179 |
+
top_p: int,
|
180 |
+
seed: int,
|
181 |
+
) -> list[PIL.Image.Image]:
|
182 |
+
image = resize_image_to_16_multiple(image, 'depth')
|
183 |
+
W, H = image.size
|
184 |
+
print(W, H)
|
185 |
+
image_tensor = torch.from_numpy(np.array(image)).to(self.device)
|
186 |
+
condition_img = torch.from_numpy(
|
187 |
+
self.get_control_depth(image_tensor)).unsqueeze(0)
|
188 |
+
condition_img = condition_img.unsqueeze(0).repeat(2, 3, 1, 1)
|
189 |
+
condition_img = condition_img.to(self.device)
|
190 |
+
condition_img = 2 * (condition_img / 255 - 0.5)
|
191 |
+
prompts = [prompt] * 2
|
192 |
+
caption_embs, emb_masks = self.t5_model.get_text_embeddings(prompts)
|
193 |
+
|
194 |
+
print(f"processing left-padding...")
|
195 |
+
new_emb_masks = torch.flip(emb_masks, dims=[-1])
|
196 |
+
new_caption_embs = []
|
197 |
+
for idx, (caption_emb,
|
198 |
+
emb_mask) in enumerate(zip(caption_embs, emb_masks)):
|
199 |
+
valid_num = int(emb_mask.sum().item())
|
200 |
+
print(f' prompt {idx} token len: {valid_num}')
|
201 |
+
new_caption_emb = torch.cat(
|
202 |
+
[caption_emb[valid_num:], caption_emb[:valid_num]])
|
203 |
+
new_caption_embs.append(new_caption_emb)
|
204 |
+
new_caption_embs = torch.stack(new_caption_embs)
|
205 |
+
|
206 |
+
c_indices = new_caption_embs * new_emb_masks[:, :, None]
|
207 |
+
c_emb_masks = new_emb_masks
|
208 |
+
qzshape = [len(c_indices), 8, H // 16, W // 16]
|
209 |
+
t1 = time.time()
|
210 |
+
index_sample = generate(
|
211 |
+
self.gpt_model_depth,
|
212 |
+
c_indices,
|
213 |
+
(H // 16) * (W // 16),
|
214 |
+
c_emb_masks,
|
215 |
+
condition=condition_img,
|
216 |
+
cfg_scale=cfg_scale,
|
217 |
+
temperature=temperature,
|
218 |
+
top_k=top_k,
|
219 |
+
top_p=top_p,
|
220 |
+
sample_logits=True,
|
221 |
+
)
|
222 |
+
sampling_time = time.time() - t1
|
223 |
+
print(f"Full sampling takes about {sampling_time:.2f} seconds.")
|
224 |
+
|
225 |
+
t2 = time.time()
|
226 |
+
print(index_sample.shape)
|
227 |
+
samples = self.vq_model.decode_code(index_sample, qzshape)
|
228 |
+
decoder_time = time.time() - t2
|
229 |
+
print(f"decoder takes about {decoder_time:.2f} seconds.")
|
230 |
+
condition_img = condition_img.cpu()
|
231 |
+
samples = samples.cpu()
|
232 |
+
samples = torch.cat((condition_img[0:1], samples), dim=0)
|
233 |
+
samples = 255 * (samples * 0.5 + 0.5)
|
234 |
+
samples = [image] + [
|
235 |
+
Image.fromarray(
|
236 |
+
sample.permute(1, 2, 0).numpy().clip(0, 255).astype(np.uint8))
|
237 |
+
for sample in samples
|
238 |
+
]
|
239 |
+
del image_tensor
|
240 |
+
del condition_img
|
241 |
+
torch.cuda.empty_cache()
|
242 |
+
return samples
|
style.css
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
h1 {
|
2 |
+
text-align: center;
|
3 |
+
}
|
4 |
+
|
5 |
+
#duplicate-button {
|
6 |
+
margin: auto;
|
7 |
+
color: #fff;
|
8 |
+
background: #1565c0;
|
9 |
+
border-radius: 100vh;
|
10 |
+
}
|
tokenizer/consistencydecoder/README.md
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Consistency Decoder from OpenAI
|
2 |
+
|
3 |
+
### install
|
4 |
+
```
|
5 |
+
pip install diffusers
|
6 |
+
pip install accelerate
|
7 |
+
```
|
8 |
+
|
9 |
+
### demo
|
10 |
+
```
|
11 |
+
cd ${THIS_REPO_ROOT}
|
12 |
+
python3 tokenizer/consistencydecoder/cd_demo.py
|
13 |
+
```
|
14 |
+
|
tokenizer/consistencydecoder/cd_demo.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import numpy as np
|
5 |
+
from PIL import Image
|
6 |
+
from diffusers import ConsistencyDecoderVAE
|
7 |
+
|
8 |
+
|
9 |
+
def main(args):
|
10 |
+
# Setup PyTorch:
|
11 |
+
torch.manual_seed(args.seed)
|
12 |
+
torch.set_grad_enabled(False)
|
13 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
14 |
+
|
15 |
+
# create and load model
|
16 |
+
vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder", torch_dtype=torch.float16).to(device)
|
17 |
+
|
18 |
+
# load image
|
19 |
+
img_path = args.image_path
|
20 |
+
out_path = args.image_path.replace('.jpg', '_cd.jpg').replace('.jpeg', '_cd.jpeg').replace('.png', '_cd.png')
|
21 |
+
input_size = args.image_size
|
22 |
+
img = Image.open(img_path).convert("RGB")
|
23 |
+
|
24 |
+
# preprocess
|
25 |
+
size_org = img.size
|
26 |
+
img = img.resize((input_size, input_size))
|
27 |
+
img = np.array(img) / 255.
|
28 |
+
x = 2.0 * img - 1.0 # x value is between [-1, 1]
|
29 |
+
x = torch.tensor(x)
|
30 |
+
x = x.unsqueeze(dim=0)
|
31 |
+
x = torch.einsum('nhwc->nchw', x)
|
32 |
+
x_input = x.half().to(device)
|
33 |
+
|
34 |
+
# inference
|
35 |
+
with torch.no_grad():
|
36 |
+
# Map input images to latent space + normalize latents:
|
37 |
+
latent = vae.encode(x_input).latent_dist.sample().mul_(0.18215)
|
38 |
+
# reconstruct:
|
39 |
+
output = vae.decode(latent / 0.18215).sample # output value is between [-1, 1]
|
40 |
+
|
41 |
+
# postprocess
|
42 |
+
output = F.interpolate(output, size=[size_org[1], size_org[0]], mode='bilinear').permute(0, 2, 3, 1)[0]
|
43 |
+
sample = torch.clamp(127.5 * output + 128.0, 0, 255).to("cpu", dtype=torch.uint8).numpy()
|
44 |
+
|
45 |
+
# save
|
46 |
+
Image.fromarray(sample).save(out_path)
|
47 |
+
print("Reconstructed image is saved to {}".format(out_path))
|
48 |
+
|
49 |
+
|
50 |
+
|
51 |
+
if __name__ == "__main__":
|
52 |
+
parser = argparse.ArgumentParser()
|
53 |
+
parser.add_argument("--image-path", type=str, default="assets/example.jpg")
|
54 |
+
parser.add_argument("--image-size", type=int, choices=[256, 512, 1024], default=512)
|
55 |
+
parser.add_argument("--seed", type=int, default=0)
|
56 |
+
args = parser.parse_args()
|
57 |
+
main(args)
|
tokenizer/consistencydecoder/reconstruction_cd_ddp.py
ADDED
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
3 |
+
torch.backends.cudnn.allow_tf32 = True
|
4 |
+
import torch.distributed as dist
|
5 |
+
from torch.utils.data import Dataset, DataLoader
|
6 |
+
from torch.utils.data.distributed import DistributedSampler
|
7 |
+
from torchvision.datasets import ImageFolder
|
8 |
+
from torchvision import transforms
|
9 |
+
from tqdm import tqdm
|
10 |
+
import os
|
11 |
+
import itertools
|
12 |
+
from PIL import Image
|
13 |
+
import numpy as np
|
14 |
+
import argparse
|
15 |
+
import random
|
16 |
+
|
17 |
+
from skimage.metrics import peak_signal_noise_ratio as psnr_loss
|
18 |
+
from skimage.metrics import structural_similarity as ssim_loss
|
19 |
+
from diffusers.models import ConsistencyDecoderVAE
|
20 |
+
|
21 |
+
|
22 |
+
class SingleFolderDataset(Dataset):
|
23 |
+
def __init__(self, directory, transform=None):
|
24 |
+
super().__init__()
|
25 |
+
self.directory = directory
|
26 |
+
self.transform = transform
|
27 |
+
self.image_paths = [os.path.join(directory, file_name) for file_name in os.listdir(directory)
|
28 |
+
if os.path.isfile(os.path.join(directory, file_name))]
|
29 |
+
|
30 |
+
def __len__(self):
|
31 |
+
return len(self.image_paths)
|
32 |
+
|
33 |
+
def __getitem__(self, idx):
|
34 |
+
image_path = self.image_paths[idx]
|
35 |
+
image = Image.open(image_path).convert('RGB')
|
36 |
+
if self.transform:
|
37 |
+
image = self.transform(image)
|
38 |
+
return image, torch.tensor(0)
|
39 |
+
|
40 |
+
|
41 |
+
def create_npz_from_sample_folder(sample_dir, num=50_000):
|
42 |
+
"""
|
43 |
+
Builds a single .npz file from a folder of .png samples.
|
44 |
+
"""
|
45 |
+
samples = []
|
46 |
+
for i in tqdm(range(num), desc="Building .npz file from samples"):
|
47 |
+
sample_pil = Image.open(f"{sample_dir}/{i:06d}.png")
|
48 |
+
sample_np = np.asarray(sample_pil).astype(np.uint8)
|
49 |
+
samples.append(sample_np)
|
50 |
+
|
51 |
+
random.shuffle(samples) # This is very important for IS(Inception Score) !!!
|
52 |
+
samples = np.stack(samples)
|
53 |
+
assert samples.shape == (num, samples.shape[1], samples.shape[2], 3)
|
54 |
+
npz_path = f"{sample_dir}.npz"
|
55 |
+
np.savez(npz_path, arr_0=samples)
|
56 |
+
print(f"Saved .npz file to {npz_path} [shape={samples.shape}].")
|
57 |
+
return npz_path
|
58 |
+
|
59 |
+
|
60 |
+
def center_crop_arr(pil_image, image_size):
|
61 |
+
"""
|
62 |
+
Center cropping implementation from ADM.
|
63 |
+
https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
|
64 |
+
"""
|
65 |
+
while min(*pil_image.size) >= 2 * image_size:
|
66 |
+
pil_image = pil_image.resize(
|
67 |
+
tuple(x // 2 for x in pil_image.size), resample=Image.BOX
|
68 |
+
)
|
69 |
+
|
70 |
+
scale = image_size / min(*pil_image.size)
|
71 |
+
pil_image = pil_image.resize(
|
72 |
+
tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
|
73 |
+
)
|
74 |
+
|
75 |
+
arr = np.array(pil_image)
|
76 |
+
crop_y = (arr.shape[0] - image_size) // 2
|
77 |
+
crop_x = (arr.shape[1] - image_size) // 2
|
78 |
+
return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size])
|
79 |
+
|
80 |
+
|
81 |
+
def main(args):
|
82 |
+
# Setup PyTorch:
|
83 |
+
assert torch.cuda.is_available(), "Sampling with DDP requires at least one GPU. sample.py supports CPU-only usage"
|
84 |
+
torch.set_grad_enabled(False)
|
85 |
+
|
86 |
+
# Setup env
|
87 |
+
dist.init_process_group("nccl")
|
88 |
+
rank = dist.get_rank()
|
89 |
+
device = rank % torch.cuda.device_count()
|
90 |
+
seed = args.global_seed * dist.get_world_size() + rank
|
91 |
+
torch.manual_seed(seed)
|
92 |
+
torch.cuda.set_device(device)
|
93 |
+
print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.")
|
94 |
+
|
95 |
+
# create and load model
|
96 |
+
vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder", torch_dtype=torch.float16).to("cuda:{}".format(device))
|
97 |
+
|
98 |
+
# Create folder to save samples:
|
99 |
+
folder_name = f"openai-consistencydecoder-{args.dataset}-size-{args.image_size}-seed-{args.global_seed}"
|
100 |
+
sample_folder_dir = f"{args.sample_dir}/{folder_name}"
|
101 |
+
if rank == 0:
|
102 |
+
os.makedirs(sample_folder_dir, exist_ok=True)
|
103 |
+
print(f"Saving .png samples at {sample_folder_dir}")
|
104 |
+
dist.barrier()
|
105 |
+
|
106 |
+
# Setup data:
|
107 |
+
transform = transforms.Compose([
|
108 |
+
transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, args.image_size)),
|
109 |
+
transforms.ToTensor(),
|
110 |
+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
|
111 |
+
])
|
112 |
+
if args.dataset == 'imagenet':
|
113 |
+
dataset = ImageFolder(args.data_path, transform=transform)
|
114 |
+
num_fid_samples = 50000
|
115 |
+
elif args.dataset == 'coco':
|
116 |
+
dataset = SingleFolderDataset(args.data_path, transform=transform)
|
117 |
+
num_fid_samples = 5000
|
118 |
+
else:
|
119 |
+
raise Exception("please check dataset")
|
120 |
+
sampler = DistributedSampler(
|
121 |
+
dataset,
|
122 |
+
num_replicas=dist.get_world_size(),
|
123 |
+
rank=rank,
|
124 |
+
shuffle=False,
|
125 |
+
seed=args.global_seed
|
126 |
+
)
|
127 |
+
loader = DataLoader(
|
128 |
+
dataset,
|
129 |
+
batch_size=args.per_proc_batch_size,
|
130 |
+
shuffle=False,
|
131 |
+
sampler=sampler,
|
132 |
+
num_workers=args.num_workers,
|
133 |
+
pin_memory=True,
|
134 |
+
drop_last=False
|
135 |
+
)
|
136 |
+
|
137 |
+
# Figure out how many samples we need to generate on each GPU and how many iterations we need to run:
|
138 |
+
n = args.per_proc_batch_size
|
139 |
+
global_batch_size = n * dist.get_world_size()
|
140 |
+
psnr_val_rgb = []
|
141 |
+
ssim_val_rgb = []
|
142 |
+
|
143 |
+
loader = tqdm(loader) if rank == 0 else loader
|
144 |
+
total = 0
|
145 |
+
for x, _ in loader:
|
146 |
+
rgb_gts = x
|
147 |
+
rgb_gts = (rgb_gts.permute(0, 2, 3, 1).to("cpu").numpy() + 1.0) / 2.0 # rgb_gt value is between [0, 1]
|
148 |
+
x = x.half().to("cuda:{}".format(device))
|
149 |
+
with torch.no_grad():
|
150 |
+
# Map input images to latent space + normalize latents:
|
151 |
+
latent = vae.encode(x).latent_dist.sample().mul_(0.18215)
|
152 |
+
# reconstruct:
|
153 |
+
samples = vae.decode(latent / 0.18215).sample # output value is between [-1, 1]
|
154 |
+
samples = torch.clamp(127.5 * samples + 128.0, 0, 255).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy()
|
155 |
+
|
156 |
+
# Save samples to disk as individual .png files
|
157 |
+
for i, (sample, rgb_gt) in enumerate(zip(samples, rgb_gts)):
|
158 |
+
index = i * dist.get_world_size() + rank + total
|
159 |
+
Image.fromarray(sample).save(f"{sample_folder_dir}/{index:06d}.png")
|
160 |
+
# metric
|
161 |
+
rgb_restored = sample.astype(np.float32) / 255. # rgb_restored value is between [0, 1]
|
162 |
+
psnr = psnr_loss(rgb_restored, rgb_gt)
|
163 |
+
ssim = ssim_loss(rgb_restored, rgb_gt, multichannel=True, data_range=2.0, channel_axis=-1)
|
164 |
+
psnr_val_rgb.append(psnr)
|
165 |
+
ssim_val_rgb.append(ssim)
|
166 |
+
total += global_batch_size
|
167 |
+
|
168 |
+
# ------------------------------------
|
169 |
+
# Summary
|
170 |
+
# ------------------------------------
|
171 |
+
# Make sure all processes have finished saving their samples
|
172 |
+
dist.barrier()
|
173 |
+
world_size = dist.get_world_size()
|
174 |
+
gather_psnr_val = [None for _ in range(world_size)]
|
175 |
+
gather_ssim_val = [None for _ in range(world_size)]
|
176 |
+
dist.all_gather_object(gather_psnr_val, psnr_val_rgb)
|
177 |
+
dist.all_gather_object(gather_ssim_val, ssim_val_rgb)
|
178 |
+
|
179 |
+
if rank == 0:
|
180 |
+
gather_psnr_val = list(itertools.chain(*gather_psnr_val))
|
181 |
+
gather_ssim_val = list(itertools.chain(*gather_ssim_val))
|
182 |
+
psnr_val_rgb = sum(gather_psnr_val) / len(gather_psnr_val)
|
183 |
+
ssim_val_rgb = sum(gather_ssim_val) / len(gather_ssim_val)
|
184 |
+
print("PSNR: %f, SSIM: %f " % (psnr_val_rgb, ssim_val_rgb))
|
185 |
+
|
186 |
+
result_file = f"{sample_folder_dir}_results.txt"
|
187 |
+
print("writing results to {}".format(result_file))
|
188 |
+
with open(result_file, 'w') as f:
|
189 |
+
print("PSNR: %f, SSIM: %f " % (psnr_val_rgb, ssim_val_rgb), file=f)
|
190 |
+
|
191 |
+
create_npz_from_sample_folder(sample_folder_dir, num_fid_samples)
|
192 |
+
print("Done.")
|
193 |
+
|
194 |
+
dist.barrier()
|
195 |
+
dist.destroy_process_group()
|
196 |
+
|
197 |
+
|
198 |
+
if __name__ == "__main__":
|
199 |
+
parser = argparse.ArgumentParser()
|
200 |
+
parser.add_argument("--data-path", type=str, required=True)
|
201 |
+
parser.add_argument("--dataset", type=str, choices=['imagenet', 'coco'], default='imagenet')
|
202 |
+
parser.add_argument("--image-size", type=int, choices=[256, 512], default=256)
|
203 |
+
parser.add_argument("--sample-dir", type=str, default="reconstructions")
|
204 |
+
parser.add_argument("--per-proc-batch-size", type=int, default=32)
|
205 |
+
parser.add_argument("--global-seed", type=int, default=0)
|
206 |
+
parser.add_argument("--num-workers", type=int, default=4)
|
207 |
+
args = parser.parse_args()
|
208 |
+
main(args)
|
tokenizer/tokenizer_image/cache/vgg.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a78928a0af1e5f0fcb1f3b9e8f8c3a2a5a3de244d830ad5c1feddc79b8432868
|
3 |
+
size 7289
|
tokenizer/tokenizer_image/discriminator.py
ADDED
@@ -0,0 +1,255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from:
|
2 |
+
# taming-transformers: https://github.com/CompVis/taming-transformers
|
3 |
+
# stylegan2-pytorch: https://github.com/rosinality/stylegan2-pytorch/blob/master/model.py
|
4 |
+
# maskgit: https://github.com/google-research/maskgit/blob/main/maskgit/nets/discriminator.py
|
5 |
+
import functools
|
6 |
+
import math
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
try:
|
10 |
+
from kornia.filters import filter2d
|
11 |
+
except:
|
12 |
+
pass
|
13 |
+
|
14 |
+
#################################################################################
|
15 |
+
# PatchGAN #
|
16 |
+
#################################################################################
|
17 |
+
class PatchGANDiscriminator(nn.Module):
|
18 |
+
"""Defines a PatchGAN discriminator as in Pix2Pix
|
19 |
+
--> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
|
20 |
+
"""
|
21 |
+
def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
|
22 |
+
"""Construct a PatchGAN discriminator
|
23 |
+
Parameters:
|
24 |
+
input_nc (int) -- the number of channels in input images
|
25 |
+
ndf (int) -- the number of filters in the last conv layer
|
26 |
+
n_layers (int) -- the number of conv layers in the discriminator
|
27 |
+
norm_layer -- normalization layer
|
28 |
+
"""
|
29 |
+
super(PatchGANDiscriminator, self).__init__()
|
30 |
+
if not use_actnorm:
|
31 |
+
norm_layer = nn.BatchNorm2d
|
32 |
+
else:
|
33 |
+
norm_layer = ActNorm
|
34 |
+
if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
|
35 |
+
use_bias = norm_layer.func != nn.BatchNorm2d
|
36 |
+
else:
|
37 |
+
use_bias = norm_layer != nn.BatchNorm2d
|
38 |
+
|
39 |
+
kw = 4
|
40 |
+
padw = 1
|
41 |
+
sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
|
42 |
+
nf_mult = 1
|
43 |
+
nf_mult_prev = 1
|
44 |
+
for n in range(1, n_layers): # gradually increase the number of filters
|
45 |
+
nf_mult_prev = nf_mult
|
46 |
+
nf_mult = min(2 ** n, 8)
|
47 |
+
sequence += [
|
48 |
+
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
|
49 |
+
norm_layer(ndf * nf_mult),
|
50 |
+
nn.LeakyReLU(0.2, True)
|
51 |
+
]
|
52 |
+
|
53 |
+
nf_mult_prev = nf_mult
|
54 |
+
nf_mult = min(2 ** n_layers, 8)
|
55 |
+
sequence += [
|
56 |
+
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
|
57 |
+
norm_layer(ndf * nf_mult),
|
58 |
+
nn.LeakyReLU(0.2, True)
|
59 |
+
]
|
60 |
+
|
61 |
+
sequence += [
|
62 |
+
nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
|
63 |
+
self.main = nn.Sequential(*sequence)
|
64 |
+
|
65 |
+
self.apply(self._init_weights)
|
66 |
+
|
67 |
+
def _init_weights(self, module):
|
68 |
+
if isinstance(module, nn.Conv2d):
|
69 |
+
nn.init.normal_(module.weight.data, 0.0, 0.02)
|
70 |
+
elif isinstance(module, nn.BatchNorm2d):
|
71 |
+
nn.init.normal_(module.weight.data, 1.0, 0.02)
|
72 |
+
nn.init.constant_(module.bias.data, 0)
|
73 |
+
|
74 |
+
def forward(self, input):
|
75 |
+
"""Standard forward."""
|
76 |
+
return self.main(input)
|
77 |
+
|
78 |
+
|
79 |
+
class ActNorm(nn.Module):
|
80 |
+
def __init__(self, num_features, logdet=False, affine=True,
|
81 |
+
allow_reverse_init=False):
|
82 |
+
assert affine
|
83 |
+
super().__init__()
|
84 |
+
self.logdet = logdet
|
85 |
+
self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1))
|
86 |
+
self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1))
|
87 |
+
self.allow_reverse_init = allow_reverse_init
|
88 |
+
|
89 |
+
self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8))
|
90 |
+
|
91 |
+
def initialize(self, input):
|
92 |
+
with torch.no_grad():
|
93 |
+
flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)
|
94 |
+
mean = (
|
95 |
+
flatten.mean(1)
|
96 |
+
.unsqueeze(1)
|
97 |
+
.unsqueeze(2)
|
98 |
+
.unsqueeze(3)
|
99 |
+
.permute(1, 0, 2, 3)
|
100 |
+
)
|
101 |
+
std = (
|
102 |
+
flatten.std(1)
|
103 |
+
.unsqueeze(1)
|
104 |
+
.unsqueeze(2)
|
105 |
+
.unsqueeze(3)
|
106 |
+
.permute(1, 0, 2, 3)
|
107 |
+
)
|
108 |
+
|
109 |
+
self.loc.data.copy_(-mean)
|
110 |
+
self.scale.data.copy_(1 / (std + 1e-6))
|
111 |
+
|
112 |
+
def forward(self, input, reverse=False):
|
113 |
+
if reverse:
|
114 |
+
return self.reverse(input)
|
115 |
+
if len(input.shape) == 2:
|
116 |
+
input = input[:,:,None,None]
|
117 |
+
squeeze = True
|
118 |
+
else:
|
119 |
+
squeeze = False
|
120 |
+
|
121 |
+
_, _, height, width = input.shape
|
122 |
+
|
123 |
+
if self.training and self.initialized.item() == 0:
|
124 |
+
self.initialize(input)
|
125 |
+
self.initialized.fill_(1)
|
126 |
+
|
127 |
+
h = self.scale * (input + self.loc)
|
128 |
+
|
129 |
+
if squeeze:
|
130 |
+
h = h.squeeze(-1).squeeze(-1)
|
131 |
+
|
132 |
+
if self.logdet:
|
133 |
+
log_abs = torch.log(torch.abs(self.scale))
|
134 |
+
logdet = height*width*torch.sum(log_abs)
|
135 |
+
logdet = logdet * torch.ones(input.shape[0]).to(input)
|
136 |
+
return h, logdet
|
137 |
+
|
138 |
+
return h
|
139 |
+
|
140 |
+
def reverse(self, output):
|
141 |
+
if self.training and self.initialized.item() == 0:
|
142 |
+
if not self.allow_reverse_init:
|
143 |
+
raise RuntimeError(
|
144 |
+
"Initializing ActNorm in reverse direction is "
|
145 |
+
"disabled by default. Use allow_reverse_init=True to enable."
|
146 |
+
)
|
147 |
+
else:
|
148 |
+
self.initialize(output)
|
149 |
+
self.initialized.fill_(1)
|
150 |
+
|
151 |
+
if len(output.shape) == 2:
|
152 |
+
output = output[:,:,None,None]
|
153 |
+
squeeze = True
|
154 |
+
else:
|
155 |
+
squeeze = False
|
156 |
+
|
157 |
+
h = output / self.scale - self.loc
|
158 |
+
|
159 |
+
if squeeze:
|
160 |
+
h = h.squeeze(-1).squeeze(-1)
|
161 |
+
return h
|
162 |
+
|
163 |
+
|
164 |
+
|
165 |
+
#################################################################################
|
166 |
+
# StyleGAN #
|
167 |
+
#################################################################################
|
168 |
+
class StyleGANDiscriminator(nn.Module):
|
169 |
+
def __init__(self, input_nc=3, ndf=64, n_layers=3, channel_multiplier=1, image_size=256):
|
170 |
+
super().__init__()
|
171 |
+
channels = {
|
172 |
+
4: 512,
|
173 |
+
8: 512,
|
174 |
+
16: 512,
|
175 |
+
32: 512,
|
176 |
+
64: 256 * channel_multiplier,
|
177 |
+
128: 128 * channel_multiplier,
|
178 |
+
256: 64 * channel_multiplier,
|
179 |
+
512: 32 * channel_multiplier,
|
180 |
+
1024: 16 * channel_multiplier,
|
181 |
+
}
|
182 |
+
|
183 |
+
log_size = int(math.log(image_size, 2))
|
184 |
+
in_channel = channels[image_size]
|
185 |
+
|
186 |
+
blocks = [nn.Conv2d(input_nc, in_channel, 3, padding=1), leaky_relu()]
|
187 |
+
for i in range(log_size, 2, -1):
|
188 |
+
out_channel = channels[2 ** (i - 1)]
|
189 |
+
blocks.append(DiscriminatorBlock(in_channel, out_channel))
|
190 |
+
in_channel = out_channel
|
191 |
+
self.blocks = nn.ModuleList(blocks)
|
192 |
+
|
193 |
+
self.final_conv = nn.Sequential(
|
194 |
+
nn.Conv2d(in_channel, channels[4], 3, padding=1),
|
195 |
+
leaky_relu(),
|
196 |
+
)
|
197 |
+
self.final_linear = nn.Sequential(
|
198 |
+
nn.Linear(channels[4] * 4 * 4, channels[4]),
|
199 |
+
leaky_relu(),
|
200 |
+
nn.Linear(channels[4], 1)
|
201 |
+
)
|
202 |
+
|
203 |
+
def forward(self, x):
|
204 |
+
for block in self.blocks:
|
205 |
+
x = block(x)
|
206 |
+
x = self.final_conv(x)
|
207 |
+
x = x.view(x.shape[0], -1)
|
208 |
+
x = self.final_linear(x)
|
209 |
+
return x
|
210 |
+
|
211 |
+
|
212 |
+
class DiscriminatorBlock(nn.Module):
|
213 |
+
def __init__(self, input_channels, filters, downsample=True):
|
214 |
+
super().__init__()
|
215 |
+
self.conv_res = nn.Conv2d(input_channels, filters, 1, stride = (2 if downsample else 1))
|
216 |
+
|
217 |
+
self.net = nn.Sequential(
|
218 |
+
nn.Conv2d(input_channels, filters, 3, padding=1),
|
219 |
+
leaky_relu(),
|
220 |
+
nn.Conv2d(filters, filters, 3, padding=1),
|
221 |
+
leaky_relu()
|
222 |
+
)
|
223 |
+
|
224 |
+
self.downsample = nn.Sequential(
|
225 |
+
Blur(),
|
226 |
+
nn.Conv2d(filters, filters, 3, padding = 1, stride = 2)
|
227 |
+
) if downsample else None
|
228 |
+
|
229 |
+
def forward(self, x):
|
230 |
+
res = self.conv_res(x)
|
231 |
+
x = self.net(x)
|
232 |
+
if exists(self.downsample):
|
233 |
+
x = self.downsample(x)
|
234 |
+
x = (x + res) * (1 / math.sqrt(2))
|
235 |
+
return x
|
236 |
+
|
237 |
+
|
238 |
+
class Blur(nn.Module):
|
239 |
+
def __init__(self):
|
240 |
+
super().__init__()
|
241 |
+
f = torch.Tensor([1, 2, 1])
|
242 |
+
self.register_buffer('f', f)
|
243 |
+
|
244 |
+
def forward(self, x):
|
245 |
+
f = self.f
|
246 |
+
f = f[None, None, :] * f [None, :, None]
|
247 |
+
return filter2d(x, f, normalized=True)
|
248 |
+
|
249 |
+
|
250 |
+
def leaky_relu(p=0.2):
|
251 |
+
return nn.LeakyReLU(p, inplace=True)
|
252 |
+
|
253 |
+
|
254 |
+
def exists(val):
|
255 |
+
return val is not None
|
tokenizer/tokenizer_image/discriminator_patchgan.py
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from:
|
2 |
+
# taming-transformers: https://github.com/CompVis/taming-transformers
|
3 |
+
import functools
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
|
7 |
+
|
8 |
+
class NLayerDiscriminator(nn.Module):
|
9 |
+
"""Defines a PatchGAN discriminator as in Pix2Pix
|
10 |
+
--> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
|
11 |
+
"""
|
12 |
+
def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
|
13 |
+
"""Construct a PatchGAN discriminator
|
14 |
+
Parameters:
|
15 |
+
input_nc (int) -- the number of channels in input images
|
16 |
+
ndf (int) -- the number of filters in the last conv layer
|
17 |
+
n_layers (int) -- the number of conv layers in the discriminator
|
18 |
+
norm_layer -- normalization layer
|
19 |
+
"""
|
20 |
+
super(NLayerDiscriminator, self).__init__()
|
21 |
+
if not use_actnorm:
|
22 |
+
norm_layer = nn.BatchNorm2d
|
23 |
+
else:
|
24 |
+
norm_layer = ActNorm
|
25 |
+
if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
|
26 |
+
use_bias = norm_layer.func != nn.BatchNorm2d
|
27 |
+
else:
|
28 |
+
use_bias = norm_layer != nn.BatchNorm2d
|
29 |
+
|
30 |
+
kw = 4
|
31 |
+
padw = 1
|
32 |
+
sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
|
33 |
+
nf_mult = 1
|
34 |
+
nf_mult_prev = 1
|
35 |
+
for n in range(1, n_layers): # gradually increase the number of filters
|
36 |
+
nf_mult_prev = nf_mult
|
37 |
+
nf_mult = min(2 ** n, 8)
|
38 |
+
sequence += [
|
39 |
+
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
|
40 |
+
norm_layer(ndf * nf_mult),
|
41 |
+
nn.LeakyReLU(0.2, True)
|
42 |
+
]
|
43 |
+
|
44 |
+
nf_mult_prev = nf_mult
|
45 |
+
nf_mult = min(2 ** n_layers, 8)
|
46 |
+
sequence += [
|
47 |
+
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
|
48 |
+
norm_layer(ndf * nf_mult),
|
49 |
+
nn.LeakyReLU(0.2, True)
|
50 |
+
]
|
51 |
+
|
52 |
+
sequence += [
|
53 |
+
nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
|
54 |
+
self.main = nn.Sequential(*sequence)
|
55 |
+
|
56 |
+
self.apply(self._init_weights)
|
57 |
+
|
58 |
+
def _init_weights(self, module):
|
59 |
+
if isinstance(module, nn.Conv2d):
|
60 |
+
nn.init.normal_(module.weight.data, 0.0, 0.02)
|
61 |
+
elif isinstance(module, nn.BatchNorm2d):
|
62 |
+
nn.init.normal_(module.weight.data, 1.0, 0.02)
|
63 |
+
nn.init.constant_(module.bias.data, 0)
|
64 |
+
|
65 |
+
def forward(self, input):
|
66 |
+
"""Standard forward."""
|
67 |
+
return self.main(input)
|
68 |
+
|
69 |
+
|
70 |
+
class ActNorm(nn.Module):
|
71 |
+
def __init__(self, num_features, logdet=False, affine=True,
|
72 |
+
allow_reverse_init=False):
|
73 |
+
assert affine
|
74 |
+
super().__init__()
|
75 |
+
self.logdet = logdet
|
76 |
+
self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1))
|
77 |
+
self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1))
|
78 |
+
self.allow_reverse_init = allow_reverse_init
|
79 |
+
|
80 |
+
self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8))
|
81 |
+
|
82 |
+
def initialize(self, input):
|
83 |
+
with torch.no_grad():
|
84 |
+
flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)
|
85 |
+
mean = (
|
86 |
+
flatten.mean(1)
|
87 |
+
.unsqueeze(1)
|
88 |
+
.unsqueeze(2)
|
89 |
+
.unsqueeze(3)
|
90 |
+
.permute(1, 0, 2, 3)
|
91 |
+
)
|
92 |
+
std = (
|
93 |
+
flatten.std(1)
|
94 |
+
.unsqueeze(1)
|
95 |
+
.unsqueeze(2)
|
96 |
+
.unsqueeze(3)
|
97 |
+
.permute(1, 0, 2, 3)
|
98 |
+
)
|
99 |
+
|
100 |
+
self.loc.data.copy_(-mean)
|
101 |
+
self.scale.data.copy_(1 / (std + 1e-6))
|
102 |
+
|
103 |
+
def forward(self, input, reverse=False):
|
104 |
+
if reverse:
|
105 |
+
return self.reverse(input)
|
106 |
+
if len(input.shape) == 2:
|
107 |
+
input = input[:,:,None,None]
|
108 |
+
squeeze = True
|
109 |
+
else:
|
110 |
+
squeeze = False
|
111 |
+
|
112 |
+
_, _, height, width = input.shape
|
113 |
+
|
114 |
+
if self.training and self.initialized.item() == 0:
|
115 |
+
self.initialize(input)
|
116 |
+
self.initialized.fill_(1)
|
117 |
+
|
118 |
+
h = self.scale * (input + self.loc)
|
119 |
+
|
120 |
+
if squeeze:
|
121 |
+
h = h.squeeze(-1).squeeze(-1)
|
122 |
+
|
123 |
+
if self.logdet:
|
124 |
+
log_abs = torch.log(torch.abs(self.scale))
|
125 |
+
logdet = height*width*torch.sum(log_abs)
|
126 |
+
logdet = logdet * torch.ones(input.shape[0]).to(input)
|
127 |
+
return h, logdet
|
128 |
+
|
129 |
+
return h
|
130 |
+
|
131 |
+
def reverse(self, output):
|
132 |
+
if self.training and self.initialized.item() == 0:
|
133 |
+
if not self.allow_reverse_init:
|
134 |
+
raise RuntimeError(
|
135 |
+
"Initializing ActNorm in reverse direction is "
|
136 |
+
"disabled by default. Use allow_reverse_init=True to enable."
|
137 |
+
)
|
138 |
+
else:
|
139 |
+
self.initialize(output)
|
140 |
+
self.initialized.fill_(1)
|
141 |
+
|
142 |
+
if len(output.shape) == 2:
|
143 |
+
output = output[:,:,None,None]
|
144 |
+
squeeze = True
|
145 |
+
else:
|
146 |
+
squeeze = False
|
147 |
+
|
148 |
+
h = output / self.scale - self.loc
|
149 |
+
|
150 |
+
if squeeze:
|
151 |
+
h = h.squeeze(-1).squeeze(-1)
|
152 |
+
return h
|
tokenizer/tokenizer_image/discriminator_stylegan.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from:
|
2 |
+
# stylegan2-pytorch: https://github.com/lucidrains/stylegan2-pytorch/blob/master/stylegan2_pytorch/stylegan2_pytorch.py
|
3 |
+
# stylegan2-pytorch: https://github.com/rosinality/stylegan2-pytorch/blob/master/model.py
|
4 |
+
# maskgit: https://github.com/google-research/maskgit/blob/main/maskgit/nets/discriminator.py
|
5 |
+
import math
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
try:
|
9 |
+
from kornia.filters import filter2d
|
10 |
+
except:
|
11 |
+
pass
|
12 |
+
|
13 |
+
class Discriminator(nn.Module):
|
14 |
+
def __init__(self, input_nc=3, ndf=64, n_layers=3, channel_multiplier=1, image_size=256):
|
15 |
+
super().__init__()
|
16 |
+
channels = {
|
17 |
+
4: 512,
|
18 |
+
8: 512,
|
19 |
+
16: 512,
|
20 |
+
32: 512,
|
21 |
+
64: 256 * channel_multiplier,
|
22 |
+
128: 128 * channel_multiplier,
|
23 |
+
256: 64 * channel_multiplier,
|
24 |
+
512: 32 * channel_multiplier,
|
25 |
+
1024: 16 * channel_multiplier,
|
26 |
+
}
|
27 |
+
|
28 |
+
log_size = int(math.log(image_size, 2))
|
29 |
+
in_channel = channels[image_size]
|
30 |
+
|
31 |
+
blocks = [nn.Conv2d(input_nc, in_channel, 3, padding=1), leaky_relu()]
|
32 |
+
for i in range(log_size, 2, -1):
|
33 |
+
out_channel = channels[2 ** (i - 1)]
|
34 |
+
blocks.append(DiscriminatorBlock(in_channel, out_channel))
|
35 |
+
in_channel = out_channel
|
36 |
+
self.blocks = nn.ModuleList(blocks)
|
37 |
+
|
38 |
+
self.final_conv = nn.Sequential(
|
39 |
+
nn.Conv2d(in_channel, channels[4], 3, padding=1),
|
40 |
+
leaky_relu(),
|
41 |
+
)
|
42 |
+
self.final_linear = nn.Sequential(
|
43 |
+
nn.Linear(channels[4] * 4 * 4, channels[4]),
|
44 |
+
leaky_relu(),
|
45 |
+
nn.Linear(channels[4], 1)
|
46 |
+
)
|
47 |
+
|
48 |
+
def forward(self, x):
|
49 |
+
for block in self.blocks:
|
50 |
+
x = block(x)
|
51 |
+
x = self.final_conv(x)
|
52 |
+
x = x.view(x.shape[0], -1)
|
53 |
+
x = self.final_linear(x)
|
54 |
+
return x
|
55 |
+
|
56 |
+
|
57 |
+
class DiscriminatorBlock(nn.Module):
|
58 |
+
def __init__(self, input_channels, filters, downsample=True):
|
59 |
+
super().__init__()
|
60 |
+
self.conv_res = nn.Conv2d(input_channels, filters, 1, stride = (2 if downsample else 1))
|
61 |
+
|
62 |
+
self.net = nn.Sequential(
|
63 |
+
nn.Conv2d(input_channels, filters, 3, padding=1),
|
64 |
+
leaky_relu(),
|
65 |
+
nn.Conv2d(filters, filters, 3, padding=1),
|
66 |
+
leaky_relu()
|
67 |
+
)
|
68 |
+
|
69 |
+
self.downsample = nn.Sequential(
|
70 |
+
Blur(),
|
71 |
+
nn.Conv2d(filters, filters, 3, padding = 1, stride = 2)
|
72 |
+
) if downsample else None
|
73 |
+
|
74 |
+
def forward(self, x):
|
75 |
+
res = self.conv_res(x)
|
76 |
+
x = self.net(x)
|
77 |
+
if exists(self.downsample):
|
78 |
+
x = self.downsample(x)
|
79 |
+
x = (x + res) * (1 / math.sqrt(2))
|
80 |
+
return x
|
81 |
+
|
82 |
+
|
83 |
+
|
84 |
+
class Blur(nn.Module):
|
85 |
+
def __init__(self):
|
86 |
+
super().__init__()
|
87 |
+
f = torch.Tensor([1, 2, 1])
|
88 |
+
self.register_buffer('f', f)
|
89 |
+
|
90 |
+
def forward(self, x):
|
91 |
+
f = self.f
|
92 |
+
f = f[None, None, :] * f [None, :, None]
|
93 |
+
return filter2d(x, f, normalized=True)
|
94 |
+
|
95 |
+
|
96 |
+
def leaky_relu(p=0.2):
|
97 |
+
return nn.LeakyReLU(p, inplace=True)
|
98 |
+
|
99 |
+
|
100 |
+
def exists(val):
|
101 |
+
return val is not None
|
tokenizer/tokenizer_image/lpips.py
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models"""
|
2 |
+
|
3 |
+
import os, hashlib
|
4 |
+
import requests
|
5 |
+
from tqdm import tqdm
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
from torchvision import models
|
10 |
+
from collections import namedtuple
|
11 |
+
|
12 |
+
URL_MAP = {
|
13 |
+
"vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"
|
14 |
+
}
|
15 |
+
|
16 |
+
CKPT_MAP = {
|
17 |
+
"vgg_lpips": "vgg.pth"
|
18 |
+
}
|
19 |
+
|
20 |
+
MD5_MAP = {
|
21 |
+
"vgg_lpips": "d507d7349b931f0638a25a48a722f98a"
|
22 |
+
}
|
23 |
+
|
24 |
+
def download(url, local_path, chunk_size=1024):
|
25 |
+
os.makedirs(os.path.split(local_path)[0], exist_ok=True)
|
26 |
+
with requests.get(url, stream=True) as r:
|
27 |
+
total_size = int(r.headers.get("content-length", 0))
|
28 |
+
with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
|
29 |
+
with open(local_path, "wb") as f:
|
30 |
+
for data in r.iter_content(chunk_size=chunk_size):
|
31 |
+
if data:
|
32 |
+
f.write(data)
|
33 |
+
pbar.update(chunk_size)
|
34 |
+
|
35 |
+
|
36 |
+
def md5_hash(path):
|
37 |
+
with open(path, "rb") as f:
|
38 |
+
content = f.read()
|
39 |
+
return hashlib.md5(content).hexdigest()
|
40 |
+
|
41 |
+
|
42 |
+
def get_ckpt_path(name, root, check=False):
|
43 |
+
assert name in URL_MAP
|
44 |
+
path = os.path.join(root, CKPT_MAP[name])
|
45 |
+
if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
|
46 |
+
print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
|
47 |
+
download(URL_MAP[name], path)
|
48 |
+
md5 = md5_hash(path)
|
49 |
+
assert md5 == MD5_MAP[name], md5
|
50 |
+
return path
|
51 |
+
|
52 |
+
|
53 |
+
class LPIPS(nn.Module):
|
54 |
+
# Learned perceptual metric
|
55 |
+
def __init__(self, use_dropout=True):
|
56 |
+
super().__init__()
|
57 |
+
self.scaling_layer = ScalingLayer()
|
58 |
+
self.chns = [64, 128, 256, 512, 512] # vg16 features
|
59 |
+
self.net = vgg16(pretrained=True, requires_grad=False)
|
60 |
+
self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
|
61 |
+
self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
|
62 |
+
self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
|
63 |
+
self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
|
64 |
+
self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
|
65 |
+
self.load_from_pretrained()
|
66 |
+
for param in self.parameters():
|
67 |
+
param.requires_grad = False
|
68 |
+
|
69 |
+
def load_from_pretrained(self, name="vgg_lpips"):
|
70 |
+
ckpt = get_ckpt_path(name, os.path.join(os.path.dirname(os.path.abspath(__file__)), "cache"))
|
71 |
+
self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
|
72 |
+
print("loaded pretrained LPIPS loss from {}".format(ckpt))
|
73 |
+
|
74 |
+
@classmethod
|
75 |
+
def from_pretrained(cls, name="vgg_lpips"):
|
76 |
+
if name != "vgg_lpips":
|
77 |
+
raise NotImplementedError
|
78 |
+
model = cls()
|
79 |
+
ckpt = get_ckpt_path(name, os.path.join(os.path.dirname(os.path.abspath(__file__)), "cache"))
|
80 |
+
model.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
|
81 |
+
return model
|
82 |
+
|
83 |
+
def forward(self, input, target):
|
84 |
+
in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target))
|
85 |
+
outs0, outs1 = self.net(in0_input), self.net(in1_input)
|
86 |
+
feats0, feats1, diffs = {}, {}, {}
|
87 |
+
lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
|
88 |
+
for kk in range(len(self.chns)):
|
89 |
+
feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk])
|
90 |
+
diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
|
91 |
+
|
92 |
+
res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))]
|
93 |
+
val = res[0]
|
94 |
+
for l in range(1, len(self.chns)):
|
95 |
+
val += res[l]
|
96 |
+
return val
|
97 |
+
|
98 |
+
|
99 |
+
class ScalingLayer(nn.Module):
|
100 |
+
def __init__(self):
|
101 |
+
super(ScalingLayer, self).__init__()
|
102 |
+
self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
|
103 |
+
self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None])
|
104 |
+
|
105 |
+
def forward(self, inp):
|
106 |
+
return (inp - self.shift) / self.scale
|
107 |
+
|
108 |
+
|
109 |
+
class NetLinLayer(nn.Module):
|
110 |
+
""" A single linear layer which does a 1x1 conv """
|
111 |
+
def __init__(self, chn_in, chn_out=1, use_dropout=False):
|
112 |
+
super(NetLinLayer, self).__init__()
|
113 |
+
layers = [nn.Dropout(), ] if (use_dropout) else []
|
114 |
+
layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ]
|
115 |
+
self.model = nn.Sequential(*layers)
|
116 |
+
|
117 |
+
|
118 |
+
class vgg16(torch.nn.Module):
|
119 |
+
def __init__(self, requires_grad=False, pretrained=True):
|
120 |
+
super(vgg16, self).__init__()
|
121 |
+
vgg_pretrained_features = models.vgg16(pretrained=pretrained).features
|
122 |
+
self.slice1 = torch.nn.Sequential()
|
123 |
+
self.slice2 = torch.nn.Sequential()
|
124 |
+
self.slice3 = torch.nn.Sequential()
|
125 |
+
self.slice4 = torch.nn.Sequential()
|
126 |
+
self.slice5 = torch.nn.Sequential()
|
127 |
+
self.N_slices = 5
|
128 |
+
for x in range(4):
|
129 |
+
self.slice1.add_module(str(x), vgg_pretrained_features[x])
|
130 |
+
for x in range(4, 9):
|
131 |
+
self.slice2.add_module(str(x), vgg_pretrained_features[x])
|
132 |
+
for x in range(9, 16):
|
133 |
+
self.slice3.add_module(str(x), vgg_pretrained_features[x])
|
134 |
+
for x in range(16, 23):
|
135 |
+
self.slice4.add_module(str(x), vgg_pretrained_features[x])
|
136 |
+
for x in range(23, 30):
|
137 |
+
self.slice5.add_module(str(x), vgg_pretrained_features[x])
|
138 |
+
if not requires_grad:
|
139 |
+
for param in self.parameters():
|
140 |
+
param.requires_grad = False
|
141 |
+
|
142 |
+
def forward(self, X):
|
143 |
+
h = self.slice1(X)
|
144 |
+
h_relu1_2 = h
|
145 |
+
h = self.slice2(h)
|
146 |
+
h_relu2_2 = h
|
147 |
+
h = self.slice3(h)
|
148 |
+
h_relu3_3 = h
|
149 |
+
h = self.slice4(h)
|
150 |
+
h_relu4_3 = h
|
151 |
+
h = self.slice5(h)
|
152 |
+
h_relu5_3 = h
|
153 |
+
vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
|
154 |
+
out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
|
155 |
+
return out
|
156 |
+
|
157 |
+
|
158 |
+
def normalize_tensor(x,eps=1e-10):
|
159 |
+
norm_factor = torch.sqrt(torch.sum(x**2,dim=1,keepdim=True))
|
160 |
+
return x/(norm_factor+eps)
|
161 |
+
|
162 |
+
|
163 |
+
def spatial_average(x, keepdim=True):
|
164 |
+
return x.mean([2,3],keepdim=keepdim)
|
tokenizer/tokenizer_image/reconstruction_vq_ddp.py
ADDED
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
3 |
+
torch.backends.cudnn.allow_tf32 = True
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import torch.distributed as dist
|
6 |
+
from torch.utils.data import DataLoader
|
7 |
+
from torch.utils.data.distributed import DistributedSampler
|
8 |
+
from torchvision import transforms
|
9 |
+
from tqdm import tqdm
|
10 |
+
import os
|
11 |
+
from PIL import Image
|
12 |
+
import numpy as np
|
13 |
+
import argparse
|
14 |
+
import itertools
|
15 |
+
|
16 |
+
from skimage.metrics import peak_signal_noise_ratio as psnr_loss
|
17 |
+
from skimage.metrics import structural_similarity as ssim_loss
|
18 |
+
from dataset.augmentation import center_crop_arr
|
19 |
+
from dataset.build import build_dataset
|
20 |
+
from tokenizer.tokenizer_image.vq_model import VQ_models
|
21 |
+
|
22 |
+
|
23 |
+
|
24 |
+
def create_npz_from_sample_folder(sample_dir, num=50000):
|
25 |
+
"""
|
26 |
+
Builds a single .npz file from a folder of .png samples.
|
27 |
+
"""
|
28 |
+
samples = []
|
29 |
+
for i in tqdm(range(num), desc="Building .npz file from samples"):
|
30 |
+
sample_pil = Image.open(f"{sample_dir}/{i:06d}.png")
|
31 |
+
sample_np = np.asarray(sample_pil).astype(np.uint8)
|
32 |
+
samples.append(sample_np)
|
33 |
+
samples = np.stack(samples)
|
34 |
+
assert samples.shape == (num, samples.shape[1], samples.shape[2], 3)
|
35 |
+
npz_path = f"{sample_dir}.npz"
|
36 |
+
np.savez(npz_path, arr_0=samples)
|
37 |
+
print(f"Saved .npz file to {npz_path} [shape={samples.shape}].")
|
38 |
+
return npz_path
|
39 |
+
|
40 |
+
|
41 |
+
|
42 |
+
def main(args):
|
43 |
+
# Setup PyTorch:
|
44 |
+
assert torch.cuda.is_available(), "Sampling with DDP requires at least one GPU. sample.py supports CPU-only usage"
|
45 |
+
torch.set_grad_enabled(False)
|
46 |
+
|
47 |
+
# Setup DDP:
|
48 |
+
dist.init_process_group("nccl")
|
49 |
+
rank = dist.get_rank()
|
50 |
+
device = rank % torch.cuda.device_count()
|
51 |
+
seed = args.global_seed * dist.get_world_size() + rank
|
52 |
+
torch.manual_seed(seed)
|
53 |
+
torch.cuda.set_device(device)
|
54 |
+
print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.")
|
55 |
+
|
56 |
+
# create and load model
|
57 |
+
vq_model = VQ_models[args.vq_model](
|
58 |
+
codebook_size=args.codebook_size,
|
59 |
+
codebook_embed_dim=args.codebook_embed_dim)
|
60 |
+
vq_model.to(device)
|
61 |
+
vq_model.eval()
|
62 |
+
checkpoint = torch.load(args.vq_ckpt, map_location="cpu")
|
63 |
+
if "ema" in checkpoint: # ema
|
64 |
+
model_weight = checkpoint["ema"]
|
65 |
+
elif "model" in checkpoint: # ddp
|
66 |
+
model_weight = checkpoint["model"]
|
67 |
+
elif "state_dict" in checkpoint:
|
68 |
+
model_weight = checkpoint["state_dict"]
|
69 |
+
else:
|
70 |
+
raise Exception("please check model weight")
|
71 |
+
vq_model.load_state_dict(model_weight)
|
72 |
+
del checkpoint
|
73 |
+
|
74 |
+
# Create folder to save samples:
|
75 |
+
folder_name = (f"{args.vq_model}-{args.dataset}-size-{args.image_size}-size-{args.image_size_eval}"
|
76 |
+
f"-codebook-size-{args.codebook_size}-dim-{args.codebook_embed_dim}-seed-{args.global_seed}")
|
77 |
+
sample_folder_dir = f"{args.sample_dir}/{folder_name}"
|
78 |
+
if rank == 0:
|
79 |
+
os.makedirs(sample_folder_dir, exist_ok=True)
|
80 |
+
print(f"Saving .png samples at {sample_folder_dir}")
|
81 |
+
dist.barrier()
|
82 |
+
|
83 |
+
# Setup data:
|
84 |
+
transform = transforms.Compose([
|
85 |
+
transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, args.image_size)),
|
86 |
+
transforms.ToTensor(),
|
87 |
+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
|
88 |
+
])
|
89 |
+
|
90 |
+
if args.dataset == 'imagenet':
|
91 |
+
dataset = build_dataset(args, transform=transform)
|
92 |
+
num_fid_samples = 50000
|
93 |
+
elif args.dataset == 'coco':
|
94 |
+
dataset = build_dataset(args, transform=transform)
|
95 |
+
num_fid_samples = 5000
|
96 |
+
elif args.dataset == 'imagenet_code':
|
97 |
+
dataset = build_dataset(args)
|
98 |
+
num_fid_samples = 50000
|
99 |
+
else:
|
100 |
+
raise Exception("please check dataset")
|
101 |
+
|
102 |
+
sampler = DistributedSampler(
|
103 |
+
dataset,
|
104 |
+
num_replicas=dist.get_world_size(),
|
105 |
+
rank=rank,
|
106 |
+
shuffle=False,
|
107 |
+
seed=args.global_seed
|
108 |
+
)
|
109 |
+
loader = DataLoader(
|
110 |
+
dataset,
|
111 |
+
batch_size=args.per_proc_batch_size,
|
112 |
+
shuffle=False,
|
113 |
+
sampler=sampler,
|
114 |
+
num_workers=args.num_workers,
|
115 |
+
pin_memory=True,
|
116 |
+
drop_last=False
|
117 |
+
)
|
118 |
+
|
119 |
+
# Figure out how many samples we need to generate on each GPU and how many iterations we need to run:
|
120 |
+
n = args.per_proc_batch_size
|
121 |
+
global_batch_size = n * dist.get_world_size()
|
122 |
+
|
123 |
+
psnr_val_rgb = []
|
124 |
+
ssim_val_rgb = []
|
125 |
+
loader = tqdm(loader) if rank == 0 else loader
|
126 |
+
total = 0
|
127 |
+
# for x, _ in loader:
|
128 |
+
for batch in loader:
|
129 |
+
x = batch['condition_imgs'].repeat(1,3,1,1)
|
130 |
+
# import pdb
|
131 |
+
# pdb.set_trace()
|
132 |
+
if args.image_size_eval != args.image_size:
|
133 |
+
rgb_gts = F.interpolate(x, size=(args.image_size_eval, args.image_size_eval), mode='bicubic')
|
134 |
+
else:
|
135 |
+
rgb_gts = x
|
136 |
+
rgb_gts = (rgb_gts.permute(0, 2, 3, 1).to("cpu").numpy() + 1.0) / 2.0 # rgb_gt value is between [0, 1]
|
137 |
+
x = x.to(device, non_blocking=True)
|
138 |
+
with torch.no_grad():
|
139 |
+
latent, _, [_, _, indices] = vq_model.encode(x.float())
|
140 |
+
import pdb;pdb.set_trace()
|
141 |
+
samples = vq_model.decode_code(indices, latent.shape) # output value is between [-1, 1]
|
142 |
+
if args.image_size_eval != args.image_size:
|
143 |
+
samples = F.interpolate(samples, size=(args.image_size_eval, args.image_size_eval), mode='bicubic')
|
144 |
+
samples = torch.clamp(127.5 * samples + 128.0, 0, 255).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy()
|
145 |
+
|
146 |
+
# Save samples to disk as individual .png files
|
147 |
+
for i, (sample, rgb_gt) in enumerate(zip(samples, rgb_gts)):
|
148 |
+
index = i * dist.get_world_size() + rank + total
|
149 |
+
# Image.fromarray(sample).save(f"{sample_folder_dir}/{index:06d}.png")
|
150 |
+
# metric
|
151 |
+
rgb_restored = sample.astype(np.float32) / 255. # rgb_restored value is between [0, 1]
|
152 |
+
psnr = psnr_loss(rgb_restored, rgb_gt)
|
153 |
+
ssim = ssim_loss(rgb_restored, rgb_gt, multichannel=True, data_range=2.0, channel_axis=-1)
|
154 |
+
psnr_val_rgb.append(psnr)
|
155 |
+
ssim_val_rgb.append(ssim)
|
156 |
+
|
157 |
+
total += global_batch_size
|
158 |
+
|
159 |
+
# ------------------------------------
|
160 |
+
# Summary
|
161 |
+
# ------------------------------------
|
162 |
+
# Make sure all processes have finished saving their samples
|
163 |
+
dist.barrier()
|
164 |
+
world_size = dist.get_world_size()
|
165 |
+
gather_psnr_val = [None for _ in range(world_size)]
|
166 |
+
gather_ssim_val = [None for _ in range(world_size)]
|
167 |
+
dist.all_gather_object(gather_psnr_val, psnr_val_rgb)
|
168 |
+
dist.all_gather_object(gather_ssim_val, ssim_val_rgb)
|
169 |
+
|
170 |
+
if rank == 0:
|
171 |
+
gather_psnr_val = list(itertools.chain(*gather_psnr_val))
|
172 |
+
gather_ssim_val = list(itertools.chain(*gather_ssim_val))
|
173 |
+
psnr_val_rgb = sum(gather_psnr_val) / len(gather_psnr_val)
|
174 |
+
ssim_val_rgb = sum(gather_ssim_val) / len(gather_ssim_val)
|
175 |
+
print("PSNR: %f, SSIM: %f " % (psnr_val_rgb, ssim_val_rgb))
|
176 |
+
|
177 |
+
result_file = f"{sample_folder_dir}_results.txt"
|
178 |
+
print("writing results to {}".format(result_file))
|
179 |
+
with open(result_file, 'w') as f:
|
180 |
+
print("PSNR: %f, SSIM: %f " % (psnr_val_rgb, ssim_val_rgb), file=f)
|
181 |
+
|
182 |
+
create_npz_from_sample_folder(sample_folder_dir, num_fid_samples)
|
183 |
+
print("Done.")
|
184 |
+
|
185 |
+
dist.barrier()
|
186 |
+
dist.destroy_process_group()
|
187 |
+
|
188 |
+
|
189 |
+
if __name__ == "__main__":
|
190 |
+
parser = argparse.ArgumentParser()
|
191 |
+
parser.add_argument("--data-path", type=str, default=None)
|
192 |
+
parser.add_argument("--code-path", type=str, required=True)
|
193 |
+
parser.add_argument("--dataset", type=str, choices=['imagenet', 'coco', 'imagenet_code'], default='imagenet')
|
194 |
+
parser.add_argument("--vq-model", type=str, choices=list(VQ_models.keys()), default="VQ-16")
|
195 |
+
parser.add_argument("--vq-ckpt", type=str, default=None, help="ckpt path for vq model")
|
196 |
+
parser.add_argument("--codebook-size", type=int, default=16384, help="codebook size for vector quantization")
|
197 |
+
parser.add_argument("--codebook-embed-dim", type=int, default=8, help="codebook dimension for vector quantization")
|
198 |
+
parser.add_argument("--image-size", type=int, choices=[256, 384, 512], default=256)
|
199 |
+
parser.add_argument("--image-size-eval", type=int, choices=[256, 384, 512], default=256)
|
200 |
+
parser.add_argument("--sample-dir", type=str, default="reconstructions")
|
201 |
+
parser.add_argument("--per-proc-batch-size", type=int, default=32)
|
202 |
+
parser.add_argument("--global-seed", type=int, default=0)
|
203 |
+
parser.add_argument("--num-workers", type=int, default=4)
|
204 |
+
parser.add_argument("--condition", type=str, choices=['canny', 'hed'], default='canny')
|
205 |
+
parser.add_argument("--get-condition-img", type=bool, default=False)
|
206 |
+
args = parser.parse_args()
|
207 |
+
main(args)
|
tokenizer/tokenizer_image/vq_demo.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
|
4 |
+
import os
|
5 |
+
import argparse
|
6 |
+
import numpy as np
|
7 |
+
from PIL import Image
|
8 |
+
|
9 |
+
from tokenizer.tokenizer_image.vq_model import VQ_models
|
10 |
+
from dataset.augmentation import center_crop_arr
|
11 |
+
|
12 |
+
|
13 |
+
def main(args):
|
14 |
+
# Setup PyTorch:
|
15 |
+
torch.manual_seed(args.seed)
|
16 |
+
torch.set_grad_enabled(False)
|
17 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
18 |
+
|
19 |
+
# create and load model
|
20 |
+
model = VQ_models[args.vq_model](
|
21 |
+
codebook_size=args.codebook_size,
|
22 |
+
codebook_embed_dim=args.codebook_embed_dim)
|
23 |
+
model.to(device)
|
24 |
+
model.eval()
|
25 |
+
checkpoint = torch.load(args.vq_ckpt, map_location="cpu")
|
26 |
+
if "ema" in checkpoint: # ema
|
27 |
+
model_weight = checkpoint["ema"]
|
28 |
+
elif "model" in checkpoint: # ddp
|
29 |
+
model_weight = checkpoint["model"]
|
30 |
+
elif "state_dict" in checkpoint:
|
31 |
+
model_weight = checkpoint["state_dict"]
|
32 |
+
else:
|
33 |
+
raise Exception("please check model weight")
|
34 |
+
model.load_state_dict(model_weight)
|
35 |
+
del checkpoint
|
36 |
+
|
37 |
+
# output dir
|
38 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
39 |
+
out_path = args.image_path.replace('.jpg', '_{}.jpg'.format(args.suffix))
|
40 |
+
out_path = out_path.replace('.jpeg', '_{}.jpeg'.format(args.suffix))
|
41 |
+
out_path = out_path.replace('.png', '_{}.png'.format(args.suffix))
|
42 |
+
out_filename = out_path.split('/')[-1]
|
43 |
+
out_path = os.path.join(args.output_dir, out_filename)
|
44 |
+
|
45 |
+
# load image
|
46 |
+
pil_image = Image.open(args.image_path).convert("RGB")
|
47 |
+
img = center_crop_arr(pil_image, args.image_size)
|
48 |
+
# # preprocess
|
49 |
+
# size_org = img.size
|
50 |
+
# img = img.resize((input_size, input_size))
|
51 |
+
img = np.array(img) / 255.
|
52 |
+
x = 2.0 * img - 1.0 # x value is between [-1, 1]
|
53 |
+
x = torch.tensor(x)
|
54 |
+
x = x.unsqueeze(dim=0)
|
55 |
+
x = torch.einsum('nhwc->nchw', x)
|
56 |
+
x_input = x.float().to("cuda")
|
57 |
+
|
58 |
+
# inference
|
59 |
+
with torch.no_grad():
|
60 |
+
latent, _, [_, _, indices] = model.encode(x_input)
|
61 |
+
output = model.decode_code(indices, latent.shape) # output value is between [-1, 1]
|
62 |
+
|
63 |
+
# postprocess
|
64 |
+
output = F.interpolate(output, size=[args.image_size, args.image_size], mode='bicubic').permute(0, 2, 3, 1)[0]
|
65 |
+
sample = torch.clamp(127.5 * output + 128.0, 0, 255).to("cpu", dtype=torch.uint8).numpy()
|
66 |
+
|
67 |
+
# save
|
68 |
+
Image.fromarray(sample).save(out_path)
|
69 |
+
print("Reconstructed image is saved to {}".format(out_path))
|
70 |
+
|
71 |
+
|
72 |
+
if __name__ == "__main__":
|
73 |
+
parser = argparse.ArgumentParser()
|
74 |
+
parser.add_argument("--image-path", type=str, default="assets/example.jpg")
|
75 |
+
parser.add_argument("--output-dir", type=str, default="output_vq_demo")
|
76 |
+
parser.add_argument("--suffix", type=str, default="tokenizer_image")
|
77 |
+
parser.add_argument("--vq-model", type=str, choices=list(VQ_models.keys()), default="VQ-16")
|
78 |
+
parser.add_argument("--vq-ckpt", type=str, default=None, help="ckpt path for vq model")
|
79 |
+
parser.add_argument("--codebook-size", type=int, default=16384, help="codebook size for vector quantization")
|
80 |
+
parser.add_argument("--codebook-embed-dim", type=int, default=8, help="codebook dimension for vector quantization")
|
81 |
+
parser.add_argument("--image-size", type=int, choices=[256, 384, 448, 512, 1024], default=512)
|
82 |
+
parser.add_argument("--seed", type=int, default=0)
|
83 |
+
args = parser.parse_args()
|
84 |
+
main(args)
|