imw34531 commited on
Commit
87e21d1
1 Parent(s): c10cc8d

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. .github/workflows/bot-autolint.yaml +50 -0
  3. .github/workflows/ci.yaml +54 -0
  4. .gitignore +178 -0
  5. .pre-commit-config.yaml +62 -0
  6. CIs/add_license_all.sh +2 -0
  7. Dockerfile +20 -0
  8. LICENSE +117 -0
  9. README.md +231 -12
  10. app/app_sana.py +488 -0
  11. app/app_sana_multithread.py +565 -0
  12. app/safety_check.py +72 -0
  13. app/sana_pipeline.py +324 -0
  14. asset/Sana.jpg +3 -0
  15. asset/docs/metrics_toolkit.md +118 -0
  16. asset/example_data/00000000.txt +1 -0
  17. asset/examples.py +69 -0
  18. asset/model-incremental.jpg +0 -0
  19. asset/model_paths.txt +2 -0
  20. asset/samples.txt +125 -0
  21. asset/samples_mini.txt +10 -0
  22. configs/sana_app_config/Sana_1600M_app.yaml +107 -0
  23. configs/sana_app_config/Sana_600M_app.yaml +105 -0
  24. configs/sana_base.yaml +140 -0
  25. configs/sana_config/1024ms/Sana_1600M_img1024.yaml +109 -0
  26. configs/sana_config/1024ms/Sana_600M_img1024.yaml +105 -0
  27. configs/sana_config/512ms/Sana_1600M_img512.yaml +108 -0
  28. configs/sana_config/512ms/Sana_600M_img512.yaml +107 -0
  29. configs/sana_config/512ms/ci_Sana_600M_img512.yaml +107 -0
  30. configs/sana_config/512ms/sample_dataset.yaml +107 -0
  31. diffusion/__init__.py +9 -0
  32. diffusion/data/__init__.py +2 -0
  33. diffusion/data/builder.py +76 -0
  34. diffusion/data/datasets/__init__.py +3 -0
  35. diffusion/data/datasets/sana_data.py +467 -0
  36. diffusion/data/datasets/sana_data_multi_scale.py +265 -0
  37. diffusion/data/datasets/utils.py +506 -0
  38. diffusion/data/transforms.py +46 -0
  39. diffusion/data/wids/__init__.py +16 -0
  40. diffusion/data/wids/wids.py +1051 -0
  41. diffusion/data/wids/wids_dl.py +174 -0
  42. diffusion/data/wids/wids_lru.py +81 -0
  43. diffusion/data/wids/wids_mmtar.py +168 -0
  44. diffusion/data/wids/wids_specs.py +192 -0
  45. diffusion/data/wids/wids_tar.py +98 -0
  46. diffusion/dpm_solver.py +69 -0
  47. diffusion/flow_euler_sampler.py +74 -0
  48. diffusion/iddpm.py +76 -0
  49. diffusion/lcm_scheduler.py +457 -0
  50. diffusion/model/__init__.py +1 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ asset/Sana.jpg filter=lfs diff=lfs merge=lfs -text
.github/workflows/bot-autolint.yaml ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Auto Lint (triggered by "auto lint" label)
2
+ on:
3
+ pull_request:
4
+ types:
5
+ - opened
6
+ - edited
7
+ - closed
8
+ - reopened
9
+ - synchronize
10
+ - labeled
11
+ - unlabeled
12
+ # run only one unit test for a branch / tag.
13
+ concurrency:
14
+ group: ci-lint-${{ github.ref }}
15
+ cancel-in-progress: true
16
+ jobs:
17
+ lint-by-label:
18
+ if: contains(github.event.pull_request.labels.*.name, 'lint wanted')
19
+ runs-on: ubuntu-latest
20
+ steps:
21
+ - name: Check out Git repository
22
+ uses: actions/checkout@v4
23
+ with:
24
+ token: ${{ secrets.PAT }}
25
+ ref: ${{ github.event.pull_request.head.ref }}
26
+ - name: Set up Python
27
+ uses: actions/setup-python@v5
28
+ with:
29
+ python-version: '3.10'
30
+ - name: Test pre-commit hooks
31
+ continue-on-error: true
32
+ uses: pre-commit/action@v3.0.0 # sync with https://github.com/Efficient-Large-Model/VILA-Internal/blob/main/.github/workflows/pre-commit.yaml
33
+ with:
34
+ extra_args: --all-files
35
+ - name: Check if there are any changes
36
+ id: verify_diff
37
+ run: |
38
+ git diff --quiet . || echo "changed=true" >> $GITHUB_OUTPUT
39
+ - name: Commit files
40
+ if: steps.verify_diff.outputs.changed == 'true'
41
+ run: |
42
+ git config --local user.email "action@github.com"
43
+ git config --local user.name "GitHub Action"
44
+ git add .
45
+ git commit -m "[CI-Lint] Fix code style issues with pre-commit ${{ github.sha }}" -a
46
+ git push
47
+ - name: Remove label(s) after lint
48
+ uses: actions-ecosystem/action-remove-labels@v1
49
+ with:
50
+ labels: lint wanted
.github/workflows/ci.yaml ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: ci
2
+ on:
3
+ pull_request:
4
+ push:
5
+ branches: [main, feat/Sana-public, feat/Sana-public-for-NVLab]
6
+ concurrency:
7
+ group: ci-${{ github.workflow }}-${{ github.ref }}
8
+ cancel-in-progress: true
9
+ # if: ${{ github.repository == 'Efficient-Large-Model/Sana' }}
10
+ jobs:
11
+ pre-commit:
12
+ runs-on: ubuntu-latest
13
+ steps:
14
+ - name: Check out Git repository
15
+ uses: actions/checkout@v4
16
+ - name: Set up Python
17
+ uses: actions/setup-python@v5
18
+ with:
19
+ python-version: 3.10.10
20
+ - name: Test pre-commit hooks
21
+ uses: pre-commit/action@v3.0.1
22
+ tests-bash:
23
+ # needs: pre-commit
24
+ runs-on: self-hosted
25
+ steps:
26
+ - name: Check out Git repository
27
+ uses: actions/checkout@v4
28
+ - name: Set up Python
29
+ uses: actions/setup-python@v5
30
+ with:
31
+ python-version: 3.10.10
32
+ - name: Set up the environment
33
+ run: |
34
+ bash environment_setup.sh
35
+ - name: Run tests with Slurm
36
+ run: |
37
+ sana-run --pty -m ci -J tests-bash bash tests/bash/entry.sh
38
+
39
+ # tests-python:
40
+ # needs: pre-commit
41
+ # runs-on: self-hosted
42
+ # steps:
43
+ # - name: Check out Git repository
44
+ # uses: actions/checkout@v4
45
+ # - name: Set up Python
46
+ # uses: actions/setup-python@v5
47
+ # with:
48
+ # python-version: 3.10.10
49
+ # - name: Set up the environment
50
+ # run: |
51
+ # ./environment_setup.sh
52
+ # - name: Run tests with Slurm
53
+ # run: |
54
+ # sana-run --pty -m ci -J tests-python pytest tests/python
.gitignore ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Sana related files
2
+ .idea/
3
+ *.png
4
+ *.json
5
+ tmp*
6
+ output*
7
+ output/
8
+ outputs/
9
+ wandb/
10
+ .vscode/
11
+ private/
12
+ ldm_ae*
13
+ data/*
14
+ *.pth
15
+ .gradio/
16
+
17
+ # Byte-compiled / optimized / DLL files
18
+ __pycache__/
19
+ *.py[cod]
20
+ *$py.class
21
+
22
+ # C extensions
23
+ *.so
24
+
25
+ # Distribution / packaging
26
+ .Python
27
+ build/
28
+ develop-eggs/
29
+ dist/
30
+ downloads/
31
+ eggs/
32
+ .eggs/
33
+ lib/
34
+ lib64/
35
+ parts/
36
+ sdist/
37
+ var/
38
+ wheels/
39
+ share/python-wheels/
40
+ *.egg-info/
41
+ .installed.cfg
42
+ *.egg
43
+ MANIFEST
44
+
45
+ # PyInstaller
46
+ # Usually these files are written by a python script from a template
47
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
48
+ *.manifest
49
+ *.spec
50
+
51
+ # Installer logs
52
+ pip-log.txt
53
+ pip-delete-this-directory.txt
54
+
55
+ # Unit test / coverage reports
56
+ htmlcov/
57
+ .tox/
58
+ .nox/
59
+ .coverage
60
+ .coverage.*
61
+ .cache
62
+ nosetests.xml
63
+ coverage.xml
64
+ *.cover
65
+ *.py,cover
66
+ .hypothesis/
67
+ .pytest_cache/
68
+ cover/
69
+
70
+ # Translations
71
+ *.mo
72
+ *.pot
73
+
74
+ # Django stuff:
75
+ *.log
76
+ local_settings.py
77
+ db.sqlite3
78
+ db.sqlite3-journal
79
+
80
+ # Flask stuff:
81
+ instance/
82
+ .webassets-cache
83
+
84
+ # Scrapy stuff:
85
+ .scrapy
86
+
87
+ # Sphinx documentation
88
+ docs/_build/
89
+
90
+ # PyBuilder
91
+ .pybuilder/
92
+ target/
93
+
94
+ # Jupyter Notebook
95
+ .ipynb_checkpoints
96
+
97
+ # IPython
98
+ profile_default/
99
+ ipython_config.py
100
+
101
+ # pyenv
102
+ # For a library or package, you might want to ignore these files since the code is
103
+ # intended to run in multiple environments; otherwise, check them in:
104
+ # .python-version
105
+
106
+ # pipenv
107
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
108
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
109
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
110
+ # install all needed dependencies.
111
+ #Pipfile.lock
112
+
113
+ # poetry
114
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
115
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
116
+ # commonly ignored for libraries.
117
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
118
+ #poetry.lock
119
+
120
+ # pdm
121
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
122
+ #pdm.lock
123
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
124
+ # in version control.
125
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
126
+ .pdm.toml
127
+ .pdm-python
128
+ .pdm-build/
129
+
130
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
131
+ __pypackages__/
132
+
133
+ # Celery stuff
134
+ celerybeat-schedule
135
+ celerybeat.pid
136
+
137
+ # SageMath parsed files
138
+ *.sage.py
139
+
140
+ # Environments
141
+ .env
142
+ .venv
143
+ env/
144
+ venv/
145
+ ENV/
146
+ env.bak/
147
+ venv.bak/
148
+
149
+ # Spyder project settings
150
+ .spyderproject
151
+ .spyproject
152
+
153
+ # Rope project settings
154
+ .ropeproject
155
+
156
+ # mkdocs documentation
157
+ /site
158
+
159
+ # mypy
160
+ .mypy_cache/
161
+ .dmypy.json
162
+ dmypy.json
163
+
164
+ # Pyre type checker
165
+ .pyre/
166
+
167
+ # pytype static type analyzer
168
+ .pytype/
169
+
170
+ # Cython debug symbols
171
+ cython_debug/
172
+
173
+ # PyCharm
174
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
175
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
176
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
177
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
178
+ #.idea/
.pre-commit-config.yaml ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ repos:
2
+ - repo: https://github.com/pre-commit/pre-commit-hooks
3
+ rev: v5.0.0
4
+ hooks:
5
+ - id: trailing-whitespace
6
+ name: (Common) Remove trailing whitespaces
7
+ - id: mixed-line-ending
8
+ name: (Common) Fix mixed line ending
9
+ args: [--fix=lf]
10
+ - id: end-of-file-fixer
11
+ name: (Common) Remove extra EOF newlines
12
+ - id: check-merge-conflict
13
+ name: (Common) Check for merge conflicts
14
+ - id: requirements-txt-fixer
15
+ name: (Common) Sort "requirements.txt"
16
+ - id: fix-encoding-pragma
17
+ name: (Python) Remove encoding pragmas
18
+ args: [--remove]
19
+ # - id: debug-statements
20
+ # name: (Python) Check for debugger imports
21
+ - id: check-json
22
+ name: (JSON) Check syntax
23
+ - id: check-yaml
24
+ name: (YAML) Check syntax
25
+ - id: check-toml
26
+ name: (TOML) Check syntax
27
+ # - repo: https://github.com/shellcheck-py/shellcheck-py
28
+ # rev: v0.10.0.1
29
+ # hooks:
30
+ # - id: shellcheck
31
+ - repo: https://github.com/google/yamlfmt
32
+ rev: v0.13.0
33
+ hooks:
34
+ - id: yamlfmt
35
+ - repo: https://github.com/executablebooks/mdformat
36
+ rev: 0.7.16
37
+ hooks:
38
+ - id: mdformat
39
+ name: (Markdown) Format docs with mdformat
40
+ - repo: https://github.com/asottile/pyupgrade
41
+ rev: v3.2.2
42
+ hooks:
43
+ - id: pyupgrade
44
+ name: (Python) Update syntax for newer versions
45
+ args: [--py37-plus]
46
+ - repo: https://github.com/psf/black
47
+ rev: 22.10.0
48
+ hooks:
49
+ - id: black
50
+ name: (Python) Format code with black
51
+ - repo: https://github.com/pycqa/isort
52
+ rev: 5.12.0
53
+ hooks:
54
+ - id: isort
55
+ name: (Python) Sort imports with isort
56
+ - repo: https://github.com/pre-commit/mirrors-clang-format
57
+ rev: v15.0.4
58
+ hooks:
59
+ - id: clang-format
60
+ name: (C/C++/CUDA) Format code with clang-format
61
+ args: [-style=google, -i]
62
+ types_or: [c, c++, cuda]
CIs/add_license_all.sh ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ #/bin/bash
2
+ addlicense -s -c 'NVIDIA CORPORATION & AFFILIATES' -ignore "**/*__init__.py" **/*.py
Dockerfile ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM nvcr.io/nvidia/pytorch:24.06-py3
2
+
3
+ WORKDIR /app
4
+
5
+ RUN curl https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -o ~/miniconda.sh \
6
+ && sh ~/miniconda.sh -b -p /opt/conda \
7
+ && rm ~/miniconda.sh
8
+
9
+ ENV PATH /opt/conda/bin:$PATH
10
+ COPY pyproject.toml pyproject.toml
11
+ COPY diffusion diffusion
12
+ COPY configs configs
13
+ COPY sana sana
14
+ COPY app app
15
+
16
+ COPY environment_setup.sh environment_setup.sh
17
+ RUN ./environment_setup.sh sana
18
+
19
+ # COPY server.py server.py
20
+ CMD ["conda", "run", "-n", "sana", "--no-capture-output", "python", "-u", "-W", "ignore", "app/app_sana.py", "--config=configs/sana_config/1024ms/Sana_1600M_img1024.yaml", "--model_path=hf://Sana_1600M_1024px/checkpoints/Sana_1600M_1024px.pth",]
LICENSE ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2
+
3
+
4
+ Nvidia Source Code License-NC
5
+
6
+ =======================================================================
7
+
8
+ 1. Definitions
9
+
10
+ “Licensor” means any person or entity that distributes its Work.
11
+
12
+ “Work” means (a) the original work of authorship made available under
13
+ this license, which may include software, documentation, or other
14
+ files, and (b) any additions to or derivative works thereof
15
+ that are made available under this license.
16
+
17
+ “NVIDIA Processors” means any central processing unit (CPU),
18
+ graphics processing unit (GPU), field-programmable gate array (FPGA),
19
+ application-specific integrated circuit (ASIC) or any combination
20
+ thereof designed, made, sold, or provided by NVIDIA or its affiliates.
21
+
22
+ The terms “reproduce,” “reproduction,” “derivative works,” and
23
+ “distribution” have the meaning as provided under U.S. copyright law;
24
+ provided, however, that for the purposes of this license, derivative
25
+ works shall not include works that remain separable from, or merely
26
+ link (or bind by name) to the interfaces of, the Work.
27
+
28
+ Works are “made available” under this license by including in or with
29
+ the Work either (a) a copyright notice referencing the applicability
30
+ of this license to the Work, or (b) a copy of this license.
31
+
32
+ "Safe Model" means ShieldGemma-2B, which is a series of safety
33
+ content moderation models designed to moderate four categories of
34
+ harmful content: sexually explicit material, dangerous content,
35
+ hate speech, and harassment, and which you separately obtain
36
+ from Google at https://huggingface.co/google/shieldgemma-2b.
37
+
38
+
39
+ 2. License Grant
40
+
41
+ 2.1 Copyright Grant. Subject to the terms and conditions of this
42
+ license, each Licensor grants to you a perpetual, worldwide,
43
+ non-exclusive, royalty-free, copyright license to use, reproduce,
44
+ prepare derivative works of, publicly display, publicly perform,
45
+ sublicense and distribute its Work and any resulting derivative
46
+ works in any form.
47
+
48
+ 3. Limitations
49
+
50
+ 3.1 Redistribution. You may reproduce or distribute the Work only if
51
+ (a) you do so under this license, (b) you include a complete copy of
52
+ this license with your distribution, and (c) you retain without
53
+ modification any copyright, patent, trademark, or attribution notices
54
+ that are present in the Work.
55
+
56
+ 3.2 Derivative Works. You may specify that additional or different
57
+ terms apply to the use, reproduction, and distribution of your
58
+ derivative works of the Work (“Your Terms”) only if (a) Your Terms
59
+ provide that the use limitation in Section 3.3 applies to your
60
+ derivative works, and (b) you identify the specific derivative works
61
+ that are subject to Your Terms. Notwithstanding Your Terms, this
62
+ license (including the redistribution requirements in Section 3.1)
63
+ will continue to apply to the Work itself.
64
+
65
+ 3.3 Use Limitation. The Work and any derivative works thereof only may
66
+ be used or intended for use non-commercially and with NVIDIA Processors,
67
+ in accordance with Section 3.4, below. Notwithstanding the foregoing,
68
+ NVIDIA Corporation and its affiliates may use the Work and any
69
+ derivative works commercially. As used herein, “non-commercially”
70
+ means for research or evaluation purposes only.
71
+
72
+ 3.4 You shall filter your input content to the Work and any derivative
73
+ works thereof through the Safe Model to ensure that no content described
74
+ as Not Safe For Work (NSFW) is processed or generated. You shall not use
75
+ the Work to process or generate NSFW content. You are solely responsible
76
+ for any damages and liabilities arising from your failure to adequately
77
+ filter content in accordance with this section. As used herein,
78
+ “Not Safe For Work” or “NSFW” means content, videos or website pages
79
+ that contain potentially disturbing subject matter, including but not
80
+ limited to content that is sexually explicit, dangerous, hate,
81
+ or harassment.
82
+
83
+ 3.5 Patent Claims. If you bring or threaten to bring a patent claim
84
+ against any Licensor (including any claim, cross-claim or counterclaim
85
+ in a lawsuit) to enforce any patents that you allege are infringed by
86
+ any Work, then your rights under this license from such Licensor
87
+ (including the grant in Section 2.1) will terminate immediately.
88
+
89
+ 3.6 Trademarks. This license does not grant any rights to use any
90
+ Licensor’s or its affiliates’ names, logos, or trademarks, except as
91
+ necessary to reproduce the notices described in this license.
92
+
93
+ 3.7 Termination. If you violate any term of this license, then your
94
+ rights under this license (including the grant in Section 2.1) will
95
+ terminate immediately.
96
+
97
+ 4. Disclaimer of Warranty.
98
+
99
+ THE WORK IS PROVIDED “AS IS” WITHOUT WARRANTIES OR CONDITIONS OF ANY
100
+ KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF
101
+ MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR
102
+ NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES
103
+ UNDER THIS LICENSE.
104
+
105
+ 5. Limitation of Liability.
106
+
107
+ EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL
108
+ THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE
109
+ SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT,
110
+ INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF
111
+ OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK
112
+ (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION,
113
+ LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER
114
+ DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF THE
115
+ POSSIBILITY OF SUCH DAMAGES.
116
+
117
+ =======================================================================
README.md CHANGED
@@ -1,12 +1,231 @@
1
- ---
2
- title: Nvlabs Sana
3
- emoji: 😻
4
- colorFrom: purple
5
- colorTo: yellow
6
- sdk: gradio
7
- sdk_version: 5.6.0
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <p align="center" style="border-radius: 10px">
2
+ <img src="asset/logo.png" width="35%" alt="logo"/>
3
+ </p>
4
+
5
+ # ⚡️Sana: Efficient High-Resolution Image Synthesis with Linear Diffusion Transformer
6
+
7
+ <div align="center">
8
+ <a href="https://nvlabs.github.io/Sana/"><img src="https://img.shields.io/static/v1?label=Project&message=Github&color=blue&logo=github-pages"></a> &ensp;
9
+ <a href="https://hanlab.mit.edu/projects/sana/"><img src="https://img.shields.io/static/v1?label=Page&message=MIT&color=darkred&logo=github-pages"></a> &ensp;
10
+ <a href="https://arxiv.org/abs/2410.10629"><img src="https://img.shields.io/static/v1?label=Arxiv&message=Sana&color=red&logo=arxiv"></a> &ensp;
11
+ <a href="https://nv-sana.mit.edu/"><img src="https://img.shields.io/static/v1?label=Demo&message=MIT&color=yellow"></a> &ensp;
12
+ <a href="https://discord.gg/rde6eaE5Ta"><img src="https://img.shields.io/static/v1?label=Discuss&message=Discord&color=purple&logo=discord"></a> &ensp;
13
+ </div>
14
+
15
+ <p align="center" border-raduis="10px">
16
+ <img src="asset/Sana.jpg" width="90%" alt="teaser_page1"/>
17
+ </p>
18
+
19
+ ## 💡 Introduction
20
+
21
+ We introduce Sana, a text-to-image framework that can efficiently generate images up to 4096 × 4096 resolution.
22
+ Sana can synthesize high-resolution, high-quality images with strong text-image alignment at a remarkably fast speed, deployable on laptop GPU.
23
+ Core designs include:
24
+
25
+ (1) [**DC-AE**](https://hanlab.mit.edu/projects/dc-ae): unlike traditional AEs, which compress images only 8×, we trained an AE that can compress images 32×, effectively reducing the number of latent tokens. \
26
+ (2) **Linear DiT**: we replace all vanilla attention in DiT with linear attention, which is more efficient at high resolutions without sacrificing quality. \
27
+ (3) **Decoder-only text encoder**: we replaced T5 with modern decoder-only small LLM as the text encoder and designed complex human instruction with in-context learning to enhance the image-text alignment. \
28
+ (4) **Efficient training and sampling**: we propose **Flow-DPM-Solver** to reduce sampling steps, with efficient caption labeling and selection to accelerate convergence.
29
+
30
+ As a result, Sana-0.6B is very competitive with modern giant diffusion model (e.g. Flux-12B), being 20 times smaller and 100+ times faster in measured throughput. Moreover, Sana-0.6B can be deployed on a 16GB laptop GPU, taking less than 1 second to generate a 1024 × 1024 resolution image. Sana enables content creation at low cost.
31
+
32
+ <p align="center" border-raduis="10px">
33
+ <img src="asset/model-incremental.jpg" width="90%" alt="teaser_page2"/>
34
+ </p>
35
+
36
+ ## 🔥🔥 News
37
+
38
+ - (🔥 New) \[2024/11\] 1.6B [Sana models](https://huggingface.co/collections/Efficient-Large-Model/sana-673efba2a57ed99843f11f9e) are released.
39
+ - (🔥 New) \[2024/11\] Training & Inference & Metrics code are released.
40
+ - (🔥 New) \[2024/11\] Working on [`diffusers`](https://github.com/huggingface/diffusers/pull/9982).
41
+ - \[2024/10\] [Demo](https://nv-sana.mit.edu/) is released.
42
+ - \[2024/10\] [DC-AE Code](https://github.com/mit-han-lab/efficientvit/blob/master/applications/dc_ae/README.md) and [weights](https://huggingface.co/collections/mit-han-lab/dc-ae-670085b9400ad7197bb1009b) are released!
43
+ - \[2024/10\] [Paper](https://arxiv.org/abs/2410.10629) is on Arxiv!
44
+
45
+ ## Performance
46
+
47
+ | Methods (1024x1024) | Throughput (samples/s) | Latency (s) | Params (B) | Speedup | FID 👆 | CLIP 👆 | GenEval 👆 | DPG 👆 |
48
+ |------------------------------|------------------------|-------------|------------|-----------|-------------|--------------|-------------|-------------|
49
+ | FLUX-dev | 0.04 | 23.0 | 12.0 | 1.0× | 10.15 | 27.47 | _0.67_ | _84.0_ |
50
+ | **Sana-0.6B** | 1.7 | 0.9 | 0.6 | **39.5×** | <u>5.81</u> | 28.36 | 0.64 | 83.6 |
51
+ | **Sana-1.6B** | 1.0 | 1.2 | 1.6 | **23.3×** | **5.76** | <u>28.67</u> | <u>0.66</u> | **84.8** |
52
+
53
+ <details>
54
+ <summary><h3>Click to show all</h3></summary>
55
+
56
+ | Methods | Throughput (samples/s) | Latency (s) | Params (B) | Speedup | FID 👆 | CLIP 👆 | GenEval 👆 | DPG 👆 |
57
+ |------------------------------|------------------------|-------------|------------|-----------|-------------|--------------|-------------|-------------|
58
+ | _**512 × 512 resolution**_ | | | | | | | | |
59
+ | PixArt-α | 1.5 | 1.2 | 0.6 | 1.0× | 6.14 | 27.55 | 0.48 | 71.6 |
60
+ | PixArt-Σ | 1.5 | 1.2 | 0.6 | 1.0× | _6.34_ | _27.62_ | <u>0.52</u> | _79.5_ |
61
+ | **Sana-0.6B** | 6.7 | 0.8 | 0.6 | 5.0× | <u>5.67</u> | <u>27.92</u> | _0.64_ | <u>84.3</u> |
62
+ | **Sana-1.6B** | 3.8 | 0.6 | 1.6 | 2.5× | **5.16** | **28.19** | **0.66** | **85.5** |
63
+ | _**1024 × 1024 resolution**_ | | | | | | | | |
64
+ | LUMINA-Next | 0.12 | 9.1 | 2.0 | 2.8× | 7.58 | 26.84 | 0.46 | 74.6 |
65
+ | SDXL | 0.15 | 6.5 | 2.6 | 3.5× | 6.63 | _29.03_ | 0.55 | 74.7 |
66
+ | PlayGroundv2.5 | 0.21 | 5.3 | 2.6 | 4.9× | _6.09_ | **29.13** | 0.56 | 75.5 |
67
+ | Hunyuan-DiT | 0.05 | 18.2 | 1.5 | 1.2× | 6.54 | 28.19 | 0.63 | 78.9 |
68
+ | PixArt-Σ | 0.4 | 2.7 | 0.6 | 9.3× | 6.15 | 28.26 | 0.54 | 80.5 |
69
+ | DALLE3 | - | - | - | - | - | - | _0.67_ | 83.5 |
70
+ | SD3-medium | 0.28 | 4.4 | 2.0 | 6.5× | 11.92 | 27.83 | 0.62 | <u>84.1</u> |
71
+ | FLUX-dev | 0.04 | 23.0 | 12.0 | 1.0× | 10.15 | 27.47 | _0.67_ | _84.0_ |
72
+ | FLUX-schnell | 0.5 | 2.1 | 12.0 | 11.6× | 7.94 | 28.14 | **0.71** | **84.8** |
73
+ | **Sana-0.6B** | 1.7 | 0.9 | 0.6 | **39.5×** | <u>5.81</u> | 28.36 | 0.64 | 83.6 |
74
+ | **Sana-1.6B** | 1.0 | 1.2 | 1.6 | **23.3×** | **5.76** | <u>28.67</u> | <u>0.66</u> | **84.8** |
75
+
76
+ </details>
77
+
78
+ ## Contents
79
+
80
+ - [Env](#-1-dependencies-and-installation)
81
+ - [Demo](#-3-how-to-inference)
82
+ - [Training](#-2-how-to-train)
83
+ - [Testing](#-4-how-to-inference--test-metrics-fid-clip-score-geneval-dpg-bench-etc)
84
+ - [TODO](#to-do-list)
85
+ - [Citation](#bibtex)
86
+
87
+ # 🔧 1. Dependencies and Installation
88
+
89
+ - Python >= 3.10.0 (Recommend to use [Anaconda](https://www.anaconda.com/download/#linux) or [Miniconda](https://docs.conda.io/en/latest/miniconda.html))
90
+ - [PyTorch >= 2.0.1+cu12.1](https://pytorch.org/)
91
+
92
+ ```bash
93
+ git clone https://github.com/NVlabs/Sana.git
94
+ cd Sana
95
+
96
+ ./environment_setup.sh sana
97
+ # or you can install each components step by step following environment_setup.sh
98
+ ```
99
+
100
+ # 💻 2. How to Play with Sana (Inference)
101
+
102
+ ## 💰Hardware requirement
103
+
104
+ - 9GB VRAM is required for 0.6B model and 12GB VRAM for 1.6B model. Our later quantization version will require less than 8GB for inference.
105
+ - All the tests are done on A100 GPUs. Different GPU version may be different.
106
+
107
+ ## 🔛 Quick start with [Gradio](https://www.gradio.app/guides/quickstart)
108
+
109
+ ```bash
110
+ # official online demo
111
+ DEMO_PORT=15432 \
112
+ python app/app_sana.py \
113
+ --config=configs/sana_config/1024ms/Sana_1600M_img1024.yaml \
114
+ --model_path=hf://Efficient-Large-Model/Sana_1600M_1024px/checkpoints/Sana_1600M_1024px.pth
115
+ ```
116
+
117
+ ```python
118
+ import torch
119
+ from app.sana_pipeline import SanaPipeline
120
+ from torchvision.utils import save_image
121
+
122
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
123
+ generator = torch.Generator(device=device).manual_seed(42)
124
+
125
+ sana = SanaPipeline("configs/sana_config/1024ms/Sana_1600M_img1024.yaml")
126
+ sana.from_pretrained("hf://Efficient-Large-Model/Sana_1600M_1024px/checkpoints/Sana_1600M_1024px.pth")
127
+ prompt = 'a cyberpunk cat with a neon sign that says "Sana"'
128
+
129
+ image = sana(
130
+ prompt=prompt,
131
+ height=1024,
132
+ width=1024,
133
+ guidance_scale=5.0,
134
+ pag_guidance_scale=2.0,
135
+ num_inference_steps=18,
136
+ generator=generator,
137
+ )
138
+ save_image(image, 'output/sana.png', nrow=1, normalize=True, value_range=(-1, 1))
139
+ ```
140
+
141
+ ## 🔛 Run inference with TXT or JSON files
142
+
143
+ ```bash
144
+ # Run samples in a txt file
145
+ python scripts/inference.py \
146
+ --config=configs/sana_config/1024ms/Sana_1600M_img1024.yaml \
147
+ --model_path=hf://Efficient-Large-Model/Sana_1600M_1024px/checkpoints/Sana_1600M_1024px.pth
148
+ --txt_file=asset/samples_mini.txt
149
+
150
+ # Run samples in a json file
151
+ python scripts/inference.py \
152
+ --config=configs/sana_config/1024ms/Sana_1600M_img1024.yaml \
153
+ --model_path=hf://Efficient-Large-Model/Sana_1600M_1024px/checkpoints/Sana_1600M_1024px.pth
154
+ --json_file=asset/samples_mini.json
155
+ ```
156
+
157
+ where each line of [`asset/samples_mini.txt`](asset/samples_mini.txt) contains a prompt to generate
158
+
159
+ # 🔥 3. How to Train Sana
160
+
161
+ ## 💰Hardware requirement
162
+
163
+ - 32GB VRAM is required for both 0.6B and 1.6B model's training
164
+
165
+ We provide a training example here and you can also select your desired config file from [config files dir](configs/sana_config) based on your data structure.
166
+
167
+ To launch Sana training, you will first need to prepare data in the following formats
168
+
169
+ ```bash
170
+ asset/example_data
171
+ ├── AAA.txt
172
+ ├── AAA.png
173
+ ├── BCC.txt
174
+ ├── BCC.png
175
+ ├── ......
176
+ ├── CCC.txt
177
+ └── CCC.png
178
+ ```
179
+
180
+ Then Sana's training can be launched via
181
+
182
+ ```bash
183
+ # Example of training Sana 0.6B with 512x512 resolution
184
+ bash train_scripts/train.sh \
185
+ configs/sana_config/512ms/Sana_600M_img512.yaml \
186
+ --data.data_dir="[asset/example_data]" \
187
+ --data.type=SanaImgDataset \
188
+ --model.multi_scale=false \
189
+ --train.train_batch_size=32
190
+
191
+ # Example of training Sana 1.6B with 1024x1024 resolution
192
+ bash train_scripts/train.sh \
193
+ configs/sana_config/1024ms/Sana_1600M_img1024.yaml \
194
+ --data.data_dir="[asset/example_data]" \
195
+ --data.type=SanaImgDataset \
196
+ --model.multi_scale=false \
197
+ --train.train_batch_size=8
198
+ ```
199
+
200
+ # 💻 4. Metric toolkit
201
+
202
+ Refer to [Toolkit Manual](asset/docs/metrics_toolkit.md).
203
+
204
+ # 💪To-Do List
205
+
206
+ We will try our best to release
207
+
208
+ - \[x\] Training code
209
+ - \[x\] Inference code
210
+ - \[+\] Model zoo
211
+ - \[ \] working on Diffusers(https://github.com/huggingface/diffusers/pull/9982)
212
+ - \[ \] ComfyUI
213
+ - \[ \] Laptop development
214
+
215
+ # 🤗Acknowledgements
216
+
217
+ - Thanks to [PixArt-α](https://github.com/PixArt-alpha/PixArt-alpha), [PixArt-Σ](https://github.com/PixArt-alpha/PixArt-sigma) and [Efficient-ViT](https://github.com/mit-han-lab/efficientvit) for their wonderful work and codebase!
218
+
219
+ # 📖BibTeX
220
+
221
+ ```
222
+ @misc{xie2024sana,
223
+ title={Sana: Efficient High-Resolution Image Synthesis with Linear Diffusion Transformer},
224
+ author={Enze Xie and Junsong Chen and Junyu Chen and Han Cai and Haotian Tang and Yujun Lin and Zhekai Zhang and Muyang Li and Ligeng Zhu and Yao Lu and Song Han},
225
+ year={2024},
226
+ eprint={2410.10629},
227
+ archivePrefix={arXiv},
228
+ primaryClass={cs.CV},
229
+ url={https://arxiv.org/abs/2410.10629},
230
+ }
231
+ ```
app/app_sana.py ADDED
@@ -0,0 +1,488 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ #
16
+ # SPDX-License-Identifier: Apache-2.0
17
+ from __future__ import annotations
18
+
19
+ import argparse
20
+ import os
21
+ import random
22
+ import time
23
+ import uuid
24
+ from datetime import datetime
25
+
26
+ import gradio as gr
27
+ import numpy as np
28
+ import spaces
29
+ import torch
30
+ from PIL import Image
31
+ from torchvision.utils import make_grid, save_image
32
+ from transformers import AutoModelForCausalLM, AutoTokenizer
33
+
34
+ from app import safety_check
35
+ from app.sana_pipeline import SanaPipeline
36
+
37
+ MAX_SEED = np.iinfo(np.int32).max
38
+ CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES", "1") == "1"
39
+ MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096"))
40
+ USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
41
+ ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
42
+ DEMO_PORT = int(os.getenv("DEMO_PORT", "15432"))
43
+ os.environ["GRADIO_EXAMPLES_CACHE"] = "./.gradio/cache"
44
+
45
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
46
+
47
+ style_list = [
48
+ {
49
+ "name": "(No style)",
50
+ "prompt": "{prompt}",
51
+ "negative_prompt": "",
52
+ },
53
+ {
54
+ "name": "Cinematic",
55
+ "prompt": "cinematic still {prompt} . emotional, harmonious, vignette, highly detailed, high budget, bokeh, "
56
+ "cinemascope, moody, epic, gorgeous, film grain, grainy",
57
+ "negative_prompt": "anime, cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, mutated, ugly, disfigured",
58
+ },
59
+ {
60
+ "name": "Photographic",
61
+ "prompt": "cinematic photo {prompt} . 35mm photograph, film, bokeh, professional, 4k, highly detailed",
62
+ "negative_prompt": "drawing, painting, crayon, sketch, graphite, impressionist, noisy, blurry, soft, deformed, ugly",
63
+ },
64
+ {
65
+ "name": "Anime",
66
+ "prompt": "anime artwork {prompt} . anime style, key visual, vibrant, studio anime, highly detailed",
67
+ "negative_prompt": "photo, deformed, black and white, realism, disfigured, low contrast",
68
+ },
69
+ {
70
+ "name": "Manga",
71
+ "prompt": "manga style {prompt} . vibrant, high-energy, detailed, iconic, Japanese comic style",
72
+ "negative_prompt": "ugly, deformed, noisy, blurry, low contrast, realism, photorealistic, Western comic style",
73
+ },
74
+ {
75
+ "name": "Digital Art",
76
+ "prompt": "concept art {prompt} . digital artwork, illustrative, painterly, matte painting, highly detailed",
77
+ "negative_prompt": "photo, photorealistic, realism, ugly",
78
+ },
79
+ {
80
+ "name": "Pixel art",
81
+ "prompt": "pixel-art {prompt} . low-res, blocky, pixel art style, 8-bit graphics",
82
+ "negative_prompt": "sloppy, messy, blurry, noisy, highly detailed, ultra textured, photo, realistic",
83
+ },
84
+ {
85
+ "name": "Fantasy art",
86
+ "prompt": "ethereal fantasy concept art of {prompt} . magnificent, celestial, ethereal, painterly, epic, "
87
+ "majestic, magical, fantasy art, cover art, dreamy",
88
+ "negative_prompt": "photographic, realistic, realism, 35mm film, dslr, cropped, frame, text, deformed, "
89
+ "glitch, noise, noisy, off-center, deformed, cross-eyed, closed eyes, bad anatomy, ugly, "
90
+ "disfigured, sloppy, duplicate, mutated, black and white",
91
+ },
92
+ {
93
+ "name": "Neonpunk",
94
+ "prompt": "neonpunk style {prompt} . cyberpunk, vaporwave, neon, vibes, vibrant, stunningly beautiful, crisp, "
95
+ "detailed, sleek, ultramodern, magenta highlights, dark purple shadows, high contrast, cinematic, "
96
+ "ultra detailed, intricate, professional",
97
+ "negative_prompt": "painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured",
98
+ },
99
+ {
100
+ "name": "3D Model",
101
+ "prompt": "professional 3d model {prompt} . octane render, highly detailed, volumetric, dramatic lighting",
102
+ "negative_prompt": "ugly, deformed, noisy, low poly, blurry, painting",
103
+ },
104
+ ]
105
+
106
+ styles = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in style_list}
107
+ STYLE_NAMES = list(styles.keys())
108
+ DEFAULT_STYLE_NAME = "(No style)"
109
+ SCHEDULE_NAME = ["Flow_DPM_Solver"]
110
+ DEFAULT_SCHEDULE_NAME = "Flow_DPM_Solver"
111
+ NUM_IMAGES_PER_PROMPT = 1
112
+ TEST_TIMES = 0
113
+ INFER_SPEED = 0
114
+ FILENAME = f"output/port{DEMO_PORT}_inference_count.txt"
115
+
116
+
117
+ def read_inference_count():
118
+ global TEST_TIMES
119
+ try:
120
+ with open(FILENAME) as f:
121
+ count = int(f.read().strip())
122
+ except FileNotFoundError:
123
+ count = 0
124
+ TEST_TIMES = count
125
+
126
+ return count
127
+
128
+
129
+ def write_inference_count(count):
130
+ with open(FILENAME, "w") as f:
131
+ f.write(str(count))
132
+
133
+
134
+ def run_inference(num_imgs=1):
135
+ TEST_TIMES = read_inference_count()
136
+ TEST_TIMES += int(num_imgs)
137
+ write_inference_count(TEST_TIMES)
138
+
139
+ return (
140
+ f"<span style='font-size: 16px; font-weight: bold;'>Total inference runs: </span><span style='font-size: "
141
+ f"16px; color:red; font-weight: bold;'>{TEST_TIMES}</span>"
142
+ )
143
+
144
+
145
+ def update_inference_count():
146
+ count = read_inference_count()
147
+ return (
148
+ f"<span style='font-size: 16px; font-weight: bold;'>Total inference runs: </span><span style='font-size: "
149
+ f"16px; color:red; font-weight: bold;'>{count}</span>"
150
+ )
151
+
152
+
153
+ def apply_style(style_name: str, positive: str, negative: str = "") -> tuple[str, str]:
154
+ p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
155
+ if not negative:
156
+ negative = ""
157
+ return p.replace("{prompt}", positive), n + negative
158
+
159
+
160
+ def get_args():
161
+ parser = argparse.ArgumentParser()
162
+ parser.add_argument("--config", type=str, help="config")
163
+ parser.add_argument(
164
+ "--model_path",
165
+ nargs="?",
166
+ default="hf://Efficient-Large-Model/Sana_1600M_1024px/checkpoints/Sana_1600M_1024px.pth",
167
+ type=str,
168
+ help="Path to the model file (positional)",
169
+ )
170
+ parser.add_argument("--output", default="./", type=str)
171
+ parser.add_argument("--bs", default=1, type=int)
172
+ parser.add_argument("--image_size", default=1024, type=int)
173
+ parser.add_argument("--cfg_scale", default=5.0, type=float)
174
+ parser.add_argument("--pag_scale", default=2.0, type=float)
175
+ parser.add_argument("--seed", default=42, type=int)
176
+ parser.add_argument("--step", default=-1, type=int)
177
+ parser.add_argument("--custom_image_size", default=None, type=int)
178
+ parser.add_argument(
179
+ "--shield_model_path",
180
+ type=str,
181
+ help="The path to shield model, we employ ShieldGemma-2B by default.",
182
+ default="google/shieldgemma-2b",
183
+ )
184
+
185
+ return parser.parse_known_args()[0]
186
+
187
+
188
+ args = get_args()
189
+
190
+ if torch.cuda.is_available():
191
+ weight_dtype = torch.float16
192
+ model_path = args.model_path
193
+ pipe = SanaPipeline(args.config)
194
+ pipe.from_pretrained(model_path)
195
+ pipe.register_progress_bar(gr.Progress())
196
+
197
+ # safety checker
198
+ safety_checker_tokenizer = AutoTokenizer.from_pretrained(args.shield_model_path)
199
+ safety_checker_model = AutoModelForCausalLM.from_pretrained(
200
+ args.shield_model_path,
201
+ device_map="auto",
202
+ torch_dtype=torch.bfloat16,
203
+ ).to(device)
204
+
205
+
206
+ def save_image_sana(img, seed="", save_img=False):
207
+ unique_name = f"{str(uuid.uuid4())}_{seed}.png"
208
+ save_path = os.path.join(f"output/online_demo_img/{datetime.now().date()}")
209
+ os.umask(0o000) # file permission: 666; dir permission: 777
210
+ os.makedirs(save_path, exist_ok=True)
211
+ unique_name = os.path.join(save_path, unique_name)
212
+ if save_img:
213
+ save_image(img, unique_name, nrow=1, normalize=True, value_range=(-1, 1))
214
+
215
+ return unique_name
216
+
217
+
218
+ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
219
+ if randomize_seed:
220
+ seed = random.randint(0, MAX_SEED)
221
+ return seed
222
+
223
+
224
+ @torch.no_grad()
225
+ @torch.inference_mode()
226
+ @spaces.GPU(enable_queue=True)
227
+ def generate(
228
+ prompt: str = None,
229
+ negative_prompt: str = "",
230
+ style: str = DEFAULT_STYLE_NAME,
231
+ use_negative_prompt: bool = False,
232
+ num_imgs: int = 1,
233
+ seed: int = 0,
234
+ height: int = 1024,
235
+ width: int = 1024,
236
+ flow_dpms_guidance_scale: float = 5.0,
237
+ flow_dpms_pag_guidance_scale: float = 2.0,
238
+ flow_dpms_inference_steps: int = 20,
239
+ randomize_seed: bool = False,
240
+ ):
241
+ global TEST_TIMES
242
+ global INFER_SPEED
243
+ # seed = 823753551
244
+ seed = int(randomize_seed_fn(seed, randomize_seed))
245
+ generator = torch.Generator(device=device).manual_seed(seed)
246
+ print(f"PORT: {DEMO_PORT}, model_path: {model_path}, time_times: {TEST_TIMES}")
247
+ if safety_check.is_dangerous(safety_checker_tokenizer, safety_checker_model, prompt, threshold=0.2):
248
+ prompt = "A red heart."
249
+
250
+ print(prompt)
251
+
252
+ num_inference_steps = flow_dpms_inference_steps
253
+ guidance_scale = flow_dpms_guidance_scale
254
+ pag_guidance_scale = flow_dpms_pag_guidance_scale
255
+
256
+ if not use_negative_prompt:
257
+ negative_prompt = None # type: ignore
258
+ prompt, negative_prompt = apply_style(style, prompt, negative_prompt)
259
+
260
+ pipe.progress_fn(0, desc="Sana Start")
261
+
262
+ time_start = time.time()
263
+ images = pipe(
264
+ prompt=prompt,
265
+ height=height,
266
+ width=width,
267
+ negative_prompt=negative_prompt,
268
+ guidance_scale=guidance_scale,
269
+ pag_guidance_scale=pag_guidance_scale,
270
+ num_inference_steps=num_inference_steps,
271
+ num_images_per_prompt=num_imgs,
272
+ generator=generator,
273
+ )
274
+
275
+ pipe.progress_fn(1.0, desc="Sana End")
276
+ INFER_SPEED = (time.time() - time_start) / num_imgs
277
+
278
+ save_img = False
279
+ if save_img:
280
+ img = [save_image_sana(img, seed, save_img=save_image) for img in images]
281
+ print(img)
282
+ else:
283
+ if num_imgs > 1:
284
+ nrow = 2
285
+ else:
286
+ nrow = 1
287
+ img = make_grid(images, nrow=nrow, normalize=True, value_range=(-1, 1))
288
+ img = img.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
289
+ img = [Image.fromarray(img.astype(np.uint8))]
290
+
291
+ torch.cuda.empty_cache()
292
+
293
+ return (
294
+ img,
295
+ seed,
296
+ f"<span style='font-size: 16px; font-weight: bold;'>Inference Speed: {INFER_SPEED:.3f} s/Img</span>",
297
+ )
298
+
299
+
300
+ TEST_TIMES = read_inference_count()
301
+ model_size = "1.6" if "D20" in args.model_path else "0.6"
302
+ title = f"""
303
+ <div style='display: flex; align-items: center; justify-content: center; text-align: center;'>
304
+ <img src="https://raw.githubusercontent.com/NVlabs/Sana/refs/heads/main/asset/logo.png" width="50%" alt="logo"/>
305
+ </div>
306
+ """
307
+ DESCRIPTION = f"""
308
+ <p><span style="font-size: 36px; font-weight: bold;">Sana-{model_size}B</span><span style="font-size: 20px; font-weight: bold;">{args.image_size}px</span></p>
309
+ <p style="font-size: 16px; font-weight: bold;">Sana: Efficient High-Resolution Image Synthesis with Linear Diffusion Transformer</p>
310
+ <p><span style="font-size: 16px;"><a href="https://arxiv.org/abs/2410.10629">[Paper]</a></span> <span style="font-size: 16px;"><a href="https://github.com/NVlabs/Sana">[Github(coming soon)]</a></span> <span style="font-size: 16px;"><a href="https://nvlabs.github.io/Sana">[Project]</a></span</p>
311
+ <p style="font-size: 16px; font-weight: bold;">Powered by <a href="https://hanlab.mit.edu/projects/dc-ae">DC-AE</a> with 32x latent space</p>, running on A6000 node.
312
+ <p style="font-size: 16px; font-weight: bold;">Unsafe word will give you a 'Red Heart' in the image instead.</p>
313
+ """
314
+ if model_size == "0.6":
315
+ DESCRIPTION += "\n<p>0.6B model's text rendering ability is limited.</p>"
316
+ if not torch.cuda.is_available():
317
+ DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
318
+
319
+ examples = [
320
+ 'a cyberpunk cat with a neon sign that says "Sana"',
321
+ "A very detailed and realistic full body photo set of a tall, slim, and athletic Shiba Inu in a white oversized straight t-shirt, white shorts, and short white shoes.",
322
+ "Pirate ship trapped in a cosmic maelstrom nebula, rendered in cosmic beach whirlpool engine, volumetric lighting, spectacular, ambient lights, light pollution, cinematic atmosphere, art nouveau style, illustration art artwork by SenseiJaye, intricate detail.",
323
+ "portrait photo of a girl, photograph, highly detailed face, depth of field",
324
+ 'make me a logo that says "So Fast" with a really cool flying dragon shape with lightning sparks all over the sides and all of it contains Indonesian language',
325
+ "🐶 Wearing 🕶 flying on the 🌈",
326
+ "👧 with 🌹 in the ❄️",
327
+ "an old rusted robot wearing pants and a jacket riding skis in a supermarket.",
328
+ "professional portrait photo of an anthropomorphic cat wearing fancy gentleman hat and jacket walking in autumn forest.",
329
+ "Astronaut in a jungle, cold color palette, muted colors, detailed",
330
+ "a stunning and luxurious bedroom carved into a rocky mountainside seamlessly blending nature with modern design with a plush earth-toned bed textured stone walls circular fireplace massive uniquely shaped window framing snow-capped mountains dense forests",
331
+ ]
332
+
333
+ css = """
334
+ .gradio-container{max-width: 640px !important}
335
+ h1{text-align:center}
336
+ """
337
+ with gr.Blocks(css=css) as demo:
338
+ gr.Markdown(title)
339
+ gr.Markdown(DESCRIPTION)
340
+ gr.DuplicateButton(
341
+ value="Duplicate Space for private use",
342
+ elem_id="duplicate-button",
343
+ visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1",
344
+ )
345
+ info_box = gr.Markdown(
346
+ value=f"<span style='font-size: 16px; font-weight: bold;'>Total inference runs: </span><span style='font-size: 16px; color:red; font-weight: bold;'>{read_inference_count()}</span>"
347
+ )
348
+ demo.load(fn=update_inference_count, outputs=info_box) # update the value when re-loading the page
349
+ # with gr.Row(equal_height=False):
350
+ with gr.Group():
351
+ with gr.Row():
352
+ prompt = gr.Text(
353
+ label="Prompt",
354
+ show_label=False,
355
+ max_lines=1,
356
+ placeholder="Enter your prompt",
357
+ container=False,
358
+ )
359
+ run_button = gr.Button("Run", scale=0)
360
+ result = gr.Gallery(label="Result", show_label=False, columns=NUM_IMAGES_PER_PROMPT, format="png")
361
+ speed_box = gr.Markdown(
362
+ value=f"<span style='font-size: 16px; font-weight: bold;'>Inference speed: {INFER_SPEED} s/Img</span>"
363
+ )
364
+ with gr.Accordion("Advanced options", open=False):
365
+ with gr.Group():
366
+ with gr.Row(visible=True):
367
+ height = gr.Slider(
368
+ label="Height",
369
+ minimum=256,
370
+ maximum=MAX_IMAGE_SIZE,
371
+ step=32,
372
+ value=1024,
373
+ )
374
+ width = gr.Slider(
375
+ label="Width",
376
+ minimum=256,
377
+ maximum=MAX_IMAGE_SIZE,
378
+ step=32,
379
+ value=1024,
380
+ )
381
+ with gr.Row():
382
+ flow_dpms_inference_steps = gr.Slider(
383
+ label="Sampling steps",
384
+ minimum=5,
385
+ maximum=40,
386
+ step=1,
387
+ value=18,
388
+ )
389
+ flow_dpms_guidance_scale = gr.Slider(
390
+ label="CFG Guidance scale",
391
+ minimum=1,
392
+ maximum=10,
393
+ step=0.1,
394
+ value=5.0,
395
+ )
396
+ flow_dpms_pag_guidance_scale = gr.Slider(
397
+ label="PAG Guidance scale",
398
+ minimum=1,
399
+ maximum=4,
400
+ step=0.5,
401
+ value=2.0,
402
+ )
403
+ with gr.Row():
404
+ use_negative_prompt = gr.Checkbox(label="Use negative prompt", value=False, visible=True)
405
+ negative_prompt = gr.Text(
406
+ label="Negative prompt",
407
+ max_lines=1,
408
+ placeholder="Enter a negative prompt",
409
+ visible=True,
410
+ )
411
+ style_selection = gr.Radio(
412
+ show_label=True,
413
+ container=True,
414
+ interactive=True,
415
+ choices=STYLE_NAMES,
416
+ value=DEFAULT_STYLE_NAME,
417
+ label="Image Style",
418
+ )
419
+ seed = gr.Slider(
420
+ label="Seed",
421
+ minimum=0,
422
+ maximum=MAX_SEED,
423
+ step=1,
424
+ value=0,
425
+ )
426
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
427
+ with gr.Row(visible=True):
428
+ schedule = gr.Radio(
429
+ show_label=True,
430
+ container=True,
431
+ interactive=True,
432
+ choices=SCHEDULE_NAME,
433
+ value=DEFAULT_SCHEDULE_NAME,
434
+ label="Sampler Schedule",
435
+ visible=True,
436
+ )
437
+ num_imgs = gr.Slider(
438
+ label="Num Images",
439
+ minimum=1,
440
+ maximum=6,
441
+ step=1,
442
+ value=1,
443
+ )
444
+
445
+ run_button.click(fn=run_inference, inputs=num_imgs, outputs=info_box)
446
+
447
+ gr.Examples(
448
+ examples=examples,
449
+ inputs=prompt,
450
+ outputs=[result, seed],
451
+ fn=generate,
452
+ cache_examples=CACHE_EXAMPLES,
453
+ )
454
+
455
+ use_negative_prompt.change(
456
+ fn=lambda x: gr.update(visible=x),
457
+ inputs=use_negative_prompt,
458
+ outputs=negative_prompt,
459
+ api_name=False,
460
+ )
461
+
462
+ gr.on(
463
+ triggers=[
464
+ prompt.submit,
465
+ negative_prompt.submit,
466
+ run_button.click,
467
+ ],
468
+ fn=generate,
469
+ inputs=[
470
+ prompt,
471
+ negative_prompt,
472
+ style_selection,
473
+ use_negative_prompt,
474
+ num_imgs,
475
+ seed,
476
+ height,
477
+ width,
478
+ flow_dpms_guidance_scale,
479
+ flow_dpms_pag_guidance_scale,
480
+ flow_dpms_inference_steps,
481
+ randomize_seed,
482
+ ],
483
+ outputs=[result, seed, speed_box],
484
+ api_name="run",
485
+ )
486
+
487
+ if __name__ == "__main__":
488
+ demo.queue(max_size=20).launch(server_name="0.0.0.0", server_port=DEMO_PORT, debug=True, share=True)
app/app_sana_multithread.py ADDED
@@ -0,0 +1,565 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ #
16
+ # SPDX-License-Identifier: Apache-2.0
17
+ from __future__ import annotations
18
+
19
+ import argparse
20
+ import os
21
+ import random
22
+ import uuid
23
+ from datetime import datetime
24
+
25
+ import gradio as gr
26
+ import numpy as np
27
+ import spaces
28
+ import torch
29
+ from diffusers import FluxPipeline
30
+ from PIL import Image
31
+ from torchvision.utils import make_grid, save_image
32
+ from transformers import AutoModelForCausalLM, AutoTokenizer
33
+
34
+ from app import safety_check
35
+ from app.sana_pipeline import SanaPipeline
36
+
37
+ MAX_SEED = np.iinfo(np.int32).max
38
+ CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES", "1") == "1"
39
+ MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096"))
40
+ USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
41
+ ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
42
+ DEMO_PORT = int(os.getenv("DEMO_PORT", "15432"))
43
+ os.environ["GRADIO_EXAMPLES_CACHE"] = "./.gradio/cache"
44
+
45
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
46
+
47
+ style_list = [
48
+ {
49
+ "name": "(No style)",
50
+ "prompt": "{prompt}",
51
+ "negative_prompt": "",
52
+ },
53
+ {
54
+ "name": "Cinematic",
55
+ "prompt": "cinematic still {prompt} . emotional, harmonious, vignette, highly detailed, high budget, bokeh, "
56
+ "cinemascope, moody, epic, gorgeous, film grain, grainy",
57
+ "negative_prompt": "anime, cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, mutated, ugly, disfigured",
58
+ },
59
+ {
60
+ "name": "Photographic",
61
+ "prompt": "cinematic photo {prompt} . 35mm photograph, film, bokeh, professional, 4k, highly detailed",
62
+ "negative_prompt": "drawing, painting, crayon, sketch, graphite, impressionist, noisy, blurry, soft, deformed, ugly",
63
+ },
64
+ {
65
+ "name": "Anime",
66
+ "prompt": "anime artwork {prompt} . anime style, key visual, vibrant, studio anime, highly detailed",
67
+ "negative_prompt": "photo, deformed, black and white, realism, disfigured, low contrast",
68
+ },
69
+ {
70
+ "name": "Manga",
71
+ "prompt": "manga style {prompt} . vibrant, high-energy, detailed, iconic, Japanese comic style",
72
+ "negative_prompt": "ugly, deformed, noisy, blurry, low contrast, realism, photorealistic, Western comic style",
73
+ },
74
+ {
75
+ "name": "Digital Art",
76
+ "prompt": "concept art {prompt} . digital artwork, illustrative, painterly, matte painting, highly detailed",
77
+ "negative_prompt": "photo, photorealistic, realism, ugly",
78
+ },
79
+ {
80
+ "name": "Pixel art",
81
+ "prompt": "pixel-art {prompt} . low-res, blocky, pixel art style, 8-bit graphics",
82
+ "negative_prompt": "sloppy, messy, blurry, noisy, highly detailed, ultra textured, photo, realistic",
83
+ },
84
+ {
85
+ "name": "Fantasy art",
86
+ "prompt": "ethereal fantasy concept art of {prompt} . magnificent, celestial, ethereal, painterly, epic, "
87
+ "majestic, magical, fantasy art, cover art, dreamy",
88
+ "negative_prompt": "photographic, realistic, realism, 35mm film, dslr, cropped, frame, text, deformed, "
89
+ "glitch, noise, noisy, off-center, deformed, cross-eyed, closed eyes, bad anatomy, ugly, "
90
+ "disfigured, sloppy, duplicate, mutated, black and white",
91
+ },
92
+ {
93
+ "name": "Neonpunk",
94
+ "prompt": "neonpunk style {prompt} . cyberpunk, vaporwave, neon, vibes, vibrant, stunningly beautiful, crisp, "
95
+ "detailed, sleek, ultramodern, magenta highlights, dark purple shadows, high contrast, cinematic, "
96
+ "ultra detailed, intricate, professional",
97
+ "negative_prompt": "painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured",
98
+ },
99
+ {
100
+ "name": "3D Model",
101
+ "prompt": "professional 3d model {prompt} . octane render, highly detailed, volumetric, dramatic lighting",
102
+ "negative_prompt": "ugly, deformed, noisy, low poly, blurry, painting",
103
+ },
104
+ ]
105
+
106
+ styles = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in style_list}
107
+ STYLE_NAMES = list(styles.keys())
108
+ DEFAULT_STYLE_NAME = "(No style)"
109
+ SCHEDULE_NAME = ["Flow_DPM_Solver"]
110
+ DEFAULT_SCHEDULE_NAME = "Flow_DPM_Solver"
111
+ NUM_IMAGES_PER_PROMPT = 1
112
+ TEST_TIMES = 0
113
+ FILENAME = f"output/port{DEMO_PORT}_inference_count.txt"
114
+
115
+
116
+ def set_env(seed=0):
117
+ torch.manual_seed(seed)
118
+ torch.set_grad_enabled(False)
119
+
120
+
121
+ def read_inference_count():
122
+ global TEST_TIMES
123
+ try:
124
+ with open(FILENAME) as f:
125
+ count = int(f.read().strip())
126
+ except FileNotFoundError:
127
+ count = 0
128
+ TEST_TIMES = count
129
+
130
+ return count
131
+
132
+
133
+ def write_inference_count(count):
134
+ with open(FILENAME, "w") as f:
135
+ f.write(str(count))
136
+
137
+
138
+ def run_inference(num_imgs=1):
139
+ TEST_TIMES = read_inference_count()
140
+ TEST_TIMES += int(num_imgs)
141
+ write_inference_count(TEST_TIMES)
142
+
143
+ return (
144
+ f"<span style='font-size: 16px; font-weight: bold;'>Total inference runs: </span><span style='font-size: "
145
+ f"16px; color:red; font-weight: bold;'>{TEST_TIMES}</span>"
146
+ )
147
+
148
+
149
+ def update_inference_count():
150
+ count = read_inference_count()
151
+ return (
152
+ f"<span style='font-size: 16px; font-weight: bold;'>Total inference runs: </span><span style='font-size: "
153
+ f"16px; color:red; font-weight: bold;'>{count}</span>"
154
+ )
155
+
156
+
157
+ def apply_style(style_name: str, positive: str, negative: str = "") -> tuple[str, str]:
158
+ p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
159
+ if not negative:
160
+ negative = ""
161
+ return p.replace("{prompt}", positive), n + negative
162
+
163
+
164
+ def get_args():
165
+ parser = argparse.ArgumentParser()
166
+ parser.add_argument("--config", type=str, help="config")
167
+ parser.add_argument(
168
+ "--model_path",
169
+ nargs="?",
170
+ default="output/Sana_D20/SANA.pth",
171
+ type=str,
172
+ help="Path to the model file (positional)",
173
+ )
174
+ parser.add_argument("--output", default="./", type=str)
175
+ parser.add_argument("--bs", default=1, type=int)
176
+ parser.add_argument("--image_size", default=1024, type=int)
177
+ parser.add_argument("--cfg_scale", default=5.0, type=float)
178
+ parser.add_argument("--pag_scale", default=2.0, type=float)
179
+ parser.add_argument("--seed", default=42, type=int)
180
+ parser.add_argument("--step", default=-1, type=int)
181
+ parser.add_argument("--custom_image_size", default=None, type=int)
182
+ parser.add_argument(
183
+ "--shield_model_path",
184
+ type=str,
185
+ help="The path to shield model, we employ ShieldGemma-2B by default.",
186
+ default="google/shieldgemma-2b",
187
+ )
188
+
189
+ return parser.parse_args()
190
+
191
+
192
+ args = get_args()
193
+
194
+ if torch.cuda.is_available():
195
+ weight_dtype = torch.float16
196
+ model_path = args.model_path
197
+ pipe = SanaPipeline(args.config)
198
+ pipe.from_pretrained(model_path)
199
+ pipe.register_progress_bar(gr.Progress())
200
+
201
+ repo_name = "black-forest-labs/FLUX.1-dev"
202
+ pipe2 = FluxPipeline.from_pretrained(repo_name, torch_dtype=torch.float16).to("cuda")
203
+
204
+ # safety checker
205
+ safety_checker_tokenizer = AutoTokenizer.from_pretrained(args.shield_model_path)
206
+ safety_checker_model = AutoModelForCausalLM.from_pretrained(
207
+ args.shield_model_path,
208
+ device_map="auto",
209
+ torch_dtype=torch.bfloat16,
210
+ ).to(device)
211
+
212
+ set_env(42)
213
+
214
+
215
+ def save_image_sana(img, seed="", save_img=False):
216
+ unique_name = f"{str(uuid.uuid4())}_{seed}.png"
217
+ save_path = os.path.join(f"output/online_demo_img/{datetime.now().date()}")
218
+ os.umask(0o000) # file permission: 666; dir permission: 777
219
+ os.makedirs(save_path, exist_ok=True)
220
+ unique_name = os.path.join(save_path, unique_name)
221
+ if save_img:
222
+ save_image(img, unique_name, nrow=1, normalize=True, value_range=(-1, 1))
223
+
224
+ return unique_name
225
+
226
+
227
+ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
228
+ if randomize_seed:
229
+ seed = random.randint(0, MAX_SEED)
230
+ return seed
231
+
232
+
233
+ @spaces.GPU(enable_queue=True)
234
+ async def generate_2(
235
+ prompt: str = None,
236
+ negative_prompt: str = "",
237
+ style: str = DEFAULT_STYLE_NAME,
238
+ use_negative_prompt: bool = False,
239
+ num_imgs: int = 1,
240
+ seed: int = 0,
241
+ height: int = 1024,
242
+ width: int = 1024,
243
+ flow_dpms_guidance_scale: float = 5.0,
244
+ flow_dpms_pag_guidance_scale: float = 2.0,
245
+ flow_dpms_inference_steps: int = 20,
246
+ randomize_seed: bool = False,
247
+ ):
248
+ seed = int(randomize_seed_fn(seed, randomize_seed))
249
+ generator = torch.Generator(device=device).manual_seed(seed)
250
+ print(f"PORT: {DEMO_PORT}, model_path: {model_path}")
251
+ if safety_check.is_dangerous(safety_checker_tokenizer, safety_checker_model, prompt):
252
+ prompt = "A red heart."
253
+
254
+ print(prompt)
255
+
256
+ if not use_negative_prompt:
257
+ negative_prompt = None # type: ignore
258
+ prompt, negative_prompt = apply_style(style, prompt, negative_prompt)
259
+
260
+ with torch.no_grad():
261
+ images = pipe2(
262
+ prompt=prompt,
263
+ height=height,
264
+ width=width,
265
+ guidance_scale=3.5,
266
+ num_inference_steps=50,
267
+ num_images_per_prompt=num_imgs,
268
+ max_sequence_length=256,
269
+ generator=generator,
270
+ ).images
271
+
272
+ save_img = False
273
+ img = images
274
+ if save_img:
275
+ img = [save_image_sana(img, seed, save_img=save_image) for img in images]
276
+ print(img)
277
+ torch.cuda.empty_cache()
278
+
279
+ return img
280
+
281
+
282
+ @spaces.GPU(enable_queue=True)
283
+ async def generate(
284
+ prompt: str = None,
285
+ negative_prompt: str = "",
286
+ style: str = DEFAULT_STYLE_NAME,
287
+ use_negative_prompt: bool = False,
288
+ num_imgs: int = 1,
289
+ seed: int = 0,
290
+ height: int = 1024,
291
+ width: int = 1024,
292
+ flow_dpms_guidance_scale: float = 5.0,
293
+ flow_dpms_pag_guidance_scale: float = 2.0,
294
+ flow_dpms_inference_steps: int = 20,
295
+ randomize_seed: bool = False,
296
+ ):
297
+ global TEST_TIMES
298
+ # seed = 823753551
299
+ seed = int(randomize_seed_fn(seed, randomize_seed))
300
+ generator = torch.Generator(device=device).manual_seed(seed)
301
+ print(f"PORT: {DEMO_PORT}, model_path: {model_path}, time_times: {TEST_TIMES}")
302
+ if safety_check.is_dangerous(safety_checker_tokenizer, safety_checker_model, prompt):
303
+ prompt = "A red heart."
304
+
305
+ print(prompt)
306
+
307
+ num_inference_steps = flow_dpms_inference_steps
308
+ guidance_scale = flow_dpms_guidance_scale
309
+ pag_guidance_scale = flow_dpms_pag_guidance_scale
310
+
311
+ if not use_negative_prompt:
312
+ negative_prompt = None # type: ignore
313
+ prompt, negative_prompt = apply_style(style, prompt, negative_prompt)
314
+
315
+ pipe.progress_fn(0, desc="Sana Start")
316
+
317
+ with torch.no_grad():
318
+ images = pipe(
319
+ prompt=prompt,
320
+ height=height,
321
+ width=width,
322
+ negative_prompt=negative_prompt,
323
+ guidance_scale=guidance_scale,
324
+ pag_guidance_scale=pag_guidance_scale,
325
+ num_inference_steps=num_inference_steps,
326
+ num_images_per_prompt=num_imgs,
327
+ generator=generator,
328
+ )
329
+
330
+ pipe.progress_fn(1.0, desc="Sana End")
331
+
332
+ save_img = False
333
+ if save_img:
334
+ img = [save_image_sana(img, seed, save_img=save_image) for img in images]
335
+ print(img)
336
+ else:
337
+ if num_imgs > 1:
338
+ nrow = 2
339
+ else:
340
+ nrow = 1
341
+ img = make_grid(images, nrow=nrow, normalize=True, value_range=(-1, 1))
342
+ img = img.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
343
+ img = [Image.fromarray(img.astype(np.uint8))]
344
+
345
+ torch.cuda.empty_cache()
346
+
347
+ return img
348
+
349
+
350
+ TEST_TIMES = read_inference_count()
351
+ model_size = "1.6" if "D20" in args.model_path else "0.6"
352
+ title = f"""
353
+ <div style='display: flex; align-items: center; justify-content: center; text-align: center;'>
354
+ <img src="https://raw.githubusercontent.com/NVlabs/Sana/refs/heads/main/asset/logo.png" width="50%" alt="logo"/>
355
+ </div>
356
+ """
357
+ DESCRIPTION = f"""
358
+ <p><span style="font-size: 36px; font-weight: bold;">Sana-{model_size}B</span><span style="font-size: 20px; font-weight: bold;">{args.image_size}px</span></p>
359
+ <p style="font-size: 16px; font-weight: bold;">Sana: Efficient High-Resolution Image Synthesis with Linear Diffusion Transformer</p>
360
+ <p><span style="font-size: 16px;"><a href="https://arxiv.org/abs/2410.10629">[Paper]</a></span> <span style="font-size: 16px;"><a href="https://github.com/NVlabs/Sana">[Github(coming soon)]</a></span> <span style="font-size: 16px;"><a href="https://nvlabs.github.io/Sana">[Project]</a></span</p>
361
+ <p style="font-size: 16px; font-weight: bold;">Powered by <a href="https://hanlab.mit.edu/projects/dc-ae">DC-AE</a> with 32x latent space</p>
362
+ <p style="font-size: 16px; font-weight: bold;">Unsafe word will give you a 'Red Heart' in the image instead.</p>
363
+ """
364
+ if model_size == "0.6":
365
+ DESCRIPTION += "\n<p>0.6B model's text rendering ability is limited.</p>"
366
+ if not torch.cuda.is_available():
367
+ DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
368
+
369
+ examples = [
370
+ 'a cyberpunk cat with a neon sign that says "Sana"',
371
+ "A very detailed and realistic full body photo set of a tall, slim, and athletic Shiba Inu in a white oversized straight t-shirt, white shorts, and short white shoes.",
372
+ "Pirate ship trapped in a cosmic maelstrom nebula, rendered in cosmic beach whirlpool engine, volumetric lighting, spectacular, ambient lights, light pollution, cinematic atmosphere, art nouveau style, illustration art artwork by SenseiJaye, intricate detail.",
373
+ "portrait photo of a girl, photograph, highly detailed face, depth of field",
374
+ 'make me a logo that says "So Fast" with a really cool flying dragon shape with lightning sparks all over the sides and all of it contains Indonesian language',
375
+ "🐶 Wearing 🕶 flying on the 🌈",
376
+ # "👧 with 🌹 in the ❄️",
377
+ # "an old rusted robot wearing pants and a jacket riding skis in a supermarket.",
378
+ # "professional portrait photo of an anthropomorphic cat wearing fancy gentleman hat and jacket walking in autumn forest.",
379
+ # "Astronaut in a jungle, cold color palette, muted colors, detailed",
380
+ # "a stunning and luxurious bedroom carved into a rocky mountainside seamlessly blending nature with modern design with a plush earth-toned bed textured stone walls circular fireplace massive uniquely shaped window framing snow-capped mountains dense forests",
381
+ ]
382
+
383
+ css = """
384
+ .gradio-container{max-width: 1024px !important}
385
+ h1{text-align:center}
386
+ """
387
+ with gr.Blocks(css=css) as demo:
388
+ gr.Markdown(title)
389
+ gr.Markdown(DESCRIPTION)
390
+ gr.DuplicateButton(
391
+ value="Duplicate Space for private use",
392
+ elem_id="duplicate-button",
393
+ visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1",
394
+ )
395
+ info_box = gr.Markdown(
396
+ value=f"<span style='font-size: 16px; font-weight: bold;'>Total inference runs: </span><span style='font-size: 16px; color:red; font-weight: bold;'>{read_inference_count()}</span>"
397
+ )
398
+ demo.load(fn=update_inference_count, outputs=info_box) # update the value when re-loading the page
399
+ # with gr.Row(equal_height=False):
400
+ with gr.Group():
401
+ with gr.Row():
402
+ prompt = gr.Text(
403
+ label="Prompt",
404
+ show_label=False,
405
+ max_lines=1,
406
+ placeholder="Enter your prompt",
407
+ container=False,
408
+ )
409
+ run_button = gr.Button("Run-sana", scale=0)
410
+ run_button2 = gr.Button("Run-flux", scale=0)
411
+
412
+ with gr.Row():
413
+ result = gr.Gallery(label="Result from Sana", show_label=True, columns=NUM_IMAGES_PER_PROMPT, format="webp")
414
+ result_2 = gr.Gallery(
415
+ label="Result from FLUX", show_label=True, columns=NUM_IMAGES_PER_PROMPT, format="webp"
416
+ )
417
+
418
+ with gr.Accordion("Advanced options", open=False):
419
+ with gr.Group():
420
+ with gr.Row(visible=True):
421
+ height = gr.Slider(
422
+ label="Height",
423
+ minimum=256,
424
+ maximum=MAX_IMAGE_SIZE,
425
+ step=32,
426
+ value=1024,
427
+ )
428
+ width = gr.Slider(
429
+ label="Width",
430
+ minimum=256,
431
+ maximum=MAX_IMAGE_SIZE,
432
+ step=32,
433
+ value=1024,
434
+ )
435
+ with gr.Row():
436
+ flow_dpms_inference_steps = gr.Slider(
437
+ label="Sampling steps",
438
+ minimum=5,
439
+ maximum=40,
440
+ step=1,
441
+ value=18,
442
+ )
443
+ flow_dpms_guidance_scale = gr.Slider(
444
+ label="CFG Guidance scale",
445
+ minimum=1,
446
+ maximum=10,
447
+ step=0.1,
448
+ value=5.0,
449
+ )
450
+ flow_dpms_pag_guidance_scale = gr.Slider(
451
+ label="PAG Guidance scale",
452
+ minimum=1,
453
+ maximum=4,
454
+ step=0.5,
455
+ value=2.0,
456
+ )
457
+ with gr.Row():
458
+ use_negative_prompt = gr.Checkbox(label="Use negative prompt", value=False, visible=True)
459
+ negative_prompt = gr.Text(
460
+ label="Negative prompt",
461
+ max_lines=1,
462
+ placeholder="Enter a negative prompt",
463
+ visible=True,
464
+ )
465
+ style_selection = gr.Radio(
466
+ show_label=True,
467
+ container=True,
468
+ interactive=True,
469
+ choices=STYLE_NAMES,
470
+ value=DEFAULT_STYLE_NAME,
471
+ label="Image Style",
472
+ )
473
+ seed = gr.Slider(
474
+ label="Seed",
475
+ minimum=0,
476
+ maximum=MAX_SEED,
477
+ step=1,
478
+ value=0,
479
+ )
480
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
481
+ with gr.Row(visible=True):
482
+ schedule = gr.Radio(
483
+ show_label=True,
484
+ container=True,
485
+ interactive=True,
486
+ choices=SCHEDULE_NAME,
487
+ value=DEFAULT_SCHEDULE_NAME,
488
+ label="Sampler Schedule",
489
+ visible=True,
490
+ )
491
+ num_imgs = gr.Slider(
492
+ label="Num Images",
493
+ minimum=1,
494
+ maximum=6,
495
+ step=1,
496
+ value=1,
497
+ )
498
+
499
+ run_button.click(fn=run_inference, inputs=num_imgs, outputs=info_box)
500
+
501
+ gr.Examples(
502
+ examples=examples,
503
+ inputs=prompt,
504
+ outputs=[result],
505
+ fn=generate,
506
+ cache_examples=CACHE_EXAMPLES,
507
+ )
508
+ gr.Examples(
509
+ examples=examples,
510
+ inputs=prompt,
511
+ outputs=[result_2],
512
+ fn=generate_2,
513
+ cache_examples=CACHE_EXAMPLES,
514
+ )
515
+
516
+ use_negative_prompt.change(
517
+ fn=lambda x: gr.update(visible=x),
518
+ inputs=use_negative_prompt,
519
+ outputs=negative_prompt,
520
+ api_name=False,
521
+ )
522
+
523
+ run_button.click(
524
+ fn=generate,
525
+ inputs=[
526
+ prompt,
527
+ negative_prompt,
528
+ style_selection,
529
+ use_negative_prompt,
530
+ num_imgs,
531
+ seed,
532
+ height,
533
+ width,
534
+ flow_dpms_guidance_scale,
535
+ flow_dpms_pag_guidance_scale,
536
+ flow_dpms_inference_steps,
537
+ randomize_seed,
538
+ ],
539
+ outputs=[result],
540
+ queue=True,
541
+ )
542
+
543
+ run_button2.click(
544
+ fn=generate_2,
545
+ inputs=[
546
+ prompt,
547
+ negative_prompt,
548
+ style_selection,
549
+ use_negative_prompt,
550
+ num_imgs,
551
+ seed,
552
+ height,
553
+ width,
554
+ flow_dpms_guidance_scale,
555
+ flow_dpms_pag_guidance_scale,
556
+ flow_dpms_inference_steps,
557
+ randomize_seed,
558
+ ],
559
+ outputs=[result_2],
560
+ queue=True,
561
+ )
562
+
563
+
564
+ if __name__ == "__main__":
565
+ demo.queue(max_size=20).launch(server_name="0.0.0.0", server_port=DEMO_PORT, debug=True, share=True)
app/safety_check.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+ # This file is copy from https://github.com/mit-han-lab/hart/blob/main/hart/utils/safety_check.py
17
+ import json
18
+ from functools import lru_cache
19
+ from typing import Dict, Sequence
20
+
21
+ import torch
22
+ import transformers
23
+ from torch.nn.functional import softmax
24
+ from transformers import AutoModelForCausalLM, AutoTokenizer
25
+
26
+ __all__ = ["is_dangerous"]
27
+
28
+ __rank, __local_rank, __world_size, __device = (
29
+ 0,
30
+ 0,
31
+ 1,
32
+ "cuda" if torch.cuda.is_available() else "cpu",
33
+ )
34
+
35
+
36
+ safety_policy = """
37
+ The prompt must not seek to generate harmful, abusive, hateful, sexually explicit, or otherwise inappropriate content targeting individuals or protected groups.
38
+ """
39
+
40
+
41
+ def is_dangerous(tokenizer, model, user_prompt, threshold=0.5):
42
+ prompt = f"""You are a policy expert trying to help determine whether a user
43
+ prompt is in violation of the defined safety policies.
44
+
45
+ <start_of_turn>
46
+ Human Question: {user_prompt.strip()}
47
+ <end_of_turn>
48
+
49
+ Our safety principle is defined in the below:
50
+
51
+ {safety_policy.strip()}
52
+
53
+ Does the human question violate the above principle? Your answer must start
54
+ with 'Yes' or 'No'. And then walk through step by step to be sure we answer
55
+ correctly.
56
+ """
57
+
58
+ inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
59
+ with torch.no_grad():
60
+ logits = model(**inputs).logits
61
+
62
+ # Extract the logits for the Yes and No tokens
63
+ vocab = tokenizer.get_vocab()
64
+ selected_logits = logits[0, -1, [vocab["Yes"], vocab["No"]]]
65
+
66
+ # Convert these logits to a probability with softmax
67
+ probabilities = softmax(selected_logits, dim=0)
68
+
69
+ # Return probability of 'Yes'
70
+ score = probabilities[0].item()
71
+
72
+ return score > threshold
app/sana_pipeline.py ADDED
@@ -0,0 +1,324 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+ import argparse
17
+ import warnings
18
+ from dataclasses import dataclass, field
19
+ from typing import Optional, Tuple
20
+
21
+ import pyrallis
22
+ import torch
23
+ import torch.nn as nn
24
+
25
+ warnings.filterwarnings("ignore") # ignore warning
26
+
27
+
28
+ from diffusion import DPMS, FlowEuler
29
+ from diffusion.data.datasets.utils import ASPECT_RATIO_512_TEST, ASPECT_RATIO_1024_TEST, ASPECT_RATIO_2048_TEST
30
+ from diffusion.model.builder import build_model, get_tokenizer_and_text_encoder, get_vae, vae_decode
31
+ from diffusion.model.utils import prepare_prompt_ar, resize_and_crop_tensor
32
+ from diffusion.utils.config import SanaConfig
33
+ from diffusion.utils.logger import get_root_logger
34
+
35
+ # from diffusion.utils.misc import read_config
36
+ from tools.download import find_model
37
+
38
+
39
+ def guidance_type_select(default_guidance_type, pag_scale, attn_type):
40
+ guidance_type = default_guidance_type
41
+ if not (pag_scale > 1.0 and attn_type == "linear"):
42
+ guidance_type = "classifier-free"
43
+ return guidance_type
44
+
45
+
46
+ def classify_height_width_bin(height: int, width: int, ratios: dict) -> Tuple[int, int]:
47
+ """Returns binned height and width."""
48
+ ar = float(height / width)
49
+ closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - ar))
50
+ default_hw = ratios[closest_ratio]
51
+ return int(default_hw[0]), int(default_hw[1])
52
+
53
+
54
+ @dataclass
55
+ class SanaInference(SanaConfig):
56
+ config: Optional[str] = "configs/sana_config/1024ms/Sana_1600M_img1024.yaml" # config
57
+ model_path: str = field(
58
+ default="output/Sana_D20/SANA.pth", metadata={"help": "Path to the model file (positional)"}
59
+ )
60
+ output: str = "./output"
61
+ bs: int = 1
62
+ image_size: int = 1024
63
+ cfg_scale: float = 5.0
64
+ pag_scale: float = 2.0
65
+ seed: int = 42
66
+ step: int = -1
67
+ custom_image_size: Optional[int] = None
68
+ shield_model_path: str = field(
69
+ default="google/shieldgemma-2b",
70
+ metadata={"help": "The path to shield model, we employ ShieldGemma-2B by default."},
71
+ )
72
+
73
+
74
+ class SanaPipeline(nn.Module):
75
+ def __init__(
76
+ self,
77
+ config: Optional[str] = "configs/sana_config/1024ms/Sana_1600M_img1024.yaml",
78
+ ):
79
+ super().__init__()
80
+ config = pyrallis.load(SanaInference, open(config))
81
+ self.args = self.config = config
82
+
83
+ # set some hyper-parameters
84
+ self.image_size = self.config.model.image_size
85
+
86
+ self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
87
+ logger = get_root_logger()
88
+ self.logger = logger
89
+ self.progress_fn = lambda progress, desc: None
90
+
91
+ self.latent_size = self.image_size // config.vae.vae_downsample_rate
92
+ self.max_sequence_length = config.text_encoder.model_max_length
93
+ self.flow_shift = config.scheduler.flow_shift
94
+ guidance_type = "classifier-free_PAG"
95
+
96
+ if config.model.mixed_precision == "fp16":
97
+ weight_dtype = torch.float16
98
+ elif config.model.mixed_precision == "bf16":
99
+ weight_dtype = torch.bfloat16
100
+ elif config.model.mixed_precision == "fp32":
101
+ weight_dtype = torch.float32
102
+ else:
103
+ raise ValueError(f"weigh precision {config.model.mixed_precision} is not defined")
104
+ self.weight_dtype = weight_dtype
105
+
106
+ self.base_ratios = eval(f"ASPECT_RATIO_{self.image_size}_TEST")
107
+ self.vis_sampler = self.config.scheduler.vis_sampler
108
+ logger.info(f"Sampler {self.vis_sampler}, flow_shift: {self.flow_shift}")
109
+ self.guidance_type = guidance_type_select(guidance_type, self.args.pag_scale, config.model.attn_type)
110
+ logger.info(f"Inference with {self.weight_dtype}, PAG guidance layer: {self.config.model.pag_applied_layers}")
111
+
112
+ # 1. build vae and text encoder
113
+ self.vae = self.build_vae(config.vae)
114
+ self.tokenizer, self.text_encoder = self.build_text_encoder(config.text_encoder)
115
+
116
+ # 2. build Sana model
117
+ self.model = self.build_sana_model(config).to(self.device)
118
+
119
+ # 3. pre-compute null embedding
120
+ with torch.no_grad():
121
+ null_caption_token = self.tokenizer(
122
+ "", max_length=self.max_sequence_length, padding="max_length", truncation=True, return_tensors="pt"
123
+ ).to(self.device)
124
+ self.null_caption_embs = self.text_encoder(null_caption_token.input_ids, null_caption_token.attention_mask)[
125
+ 0
126
+ ]
127
+
128
+ def build_vae(self, config):
129
+ vae = get_vae(config.vae_type, config.vae_pretrained, self.device).to(self.weight_dtype)
130
+ return vae
131
+
132
+ def build_text_encoder(self, config):
133
+ tokenizer, text_encoder = get_tokenizer_and_text_encoder(name=config.text_encoder_name, device=self.device)
134
+ return tokenizer, text_encoder
135
+
136
+ def build_sana_model(self, config):
137
+ # model setting
138
+ pred_sigma = getattr(config.scheduler, "pred_sigma", True)
139
+ learn_sigma = getattr(config.scheduler, "learn_sigma", True) and pred_sigma
140
+ model_kwargs = {
141
+ "input_size": self.latent_size,
142
+ "pe_interpolation": config.model.pe_interpolation,
143
+ "config": config,
144
+ "model_max_length": config.text_encoder.model_max_length,
145
+ "qk_norm": config.model.qk_norm,
146
+ "micro_condition": config.model.micro_condition,
147
+ "caption_channels": self.text_encoder.config.hidden_size,
148
+ "y_norm": config.text_encoder.y_norm,
149
+ "attn_type": config.model.attn_type,
150
+ "ffn_type": config.model.ffn_type,
151
+ "mlp_ratio": config.model.mlp_ratio,
152
+ "mlp_acts": list(config.model.mlp_acts),
153
+ "in_channels": config.vae.vae_latent_dim,
154
+ "y_norm_scale_factor": config.text_encoder.y_norm_scale_factor,
155
+ "use_pe": config.model.use_pe,
156
+ "pred_sigma": pred_sigma,
157
+ "learn_sigma": learn_sigma,
158
+ "use_fp32_attention": config.model.get("fp32_attention", False) and config.model.mixed_precision != "bf16",
159
+ }
160
+ model = build_model(config.model.model, **model_kwargs)
161
+ model = model.to(self.weight_dtype)
162
+
163
+ self.logger.info(f"use_fp32_attention: {model.fp32_attention}")
164
+ self.logger.info(
165
+ f"{model.__class__.__name__}:{config.model.model},"
166
+ f"Model Parameters: {sum(p.numel() for p in model.parameters()):,}"
167
+ )
168
+ return model
169
+
170
+ def from_pretrained(self, model_path):
171
+ state_dict = find_model(model_path)
172
+ state_dict = state_dict.get("state_dict", state_dict)
173
+ if "pos_embed" in state_dict:
174
+ del state_dict["pos_embed"]
175
+ missing, unexpected = self.model.load_state_dict(state_dict, strict=False)
176
+ self.model.eval().to(self.weight_dtype)
177
+
178
+ self.logger.info("Generating sample from ckpt: %s" % model_path)
179
+ self.logger.warning(f"Missing keys: {missing}")
180
+ self.logger.warning(f"Unexpected keys: {unexpected}")
181
+
182
+ def register_progress_bar(self, progress_fn=None):
183
+ self.progress_fn = progress_fn if progress_fn is not None else self.progress_fn
184
+
185
+ @torch.inference_mode()
186
+ def forward(
187
+ self,
188
+ prompt=None,
189
+ height=1024,
190
+ width=1024,
191
+ negative_prompt="",
192
+ num_inference_steps=20,
193
+ guidance_scale=5,
194
+ pag_guidance_scale=2.5,
195
+ num_images_per_prompt=1,
196
+ generator=torch.Generator().manual_seed(42),
197
+ latents=None,
198
+ ):
199
+ self.ori_height, self.ori_width = height, width
200
+ self.height, self.width = classify_height_width_bin(height, width, ratios=self.base_ratios)
201
+ self.latent_size_h, self.latent_size_w = (
202
+ self.height // self.config.vae.vae_downsample_rate,
203
+ self.width // self.config.vae.vae_downsample_rate,
204
+ )
205
+ self.guidance_type = guidance_type_select(self.guidance_type, pag_guidance_scale, self.config.model.attn_type)
206
+
207
+ # 1. pre-compute negative embedding
208
+ if negative_prompt != "":
209
+ null_caption_token = self.tokenizer(
210
+ negative_prompt,
211
+ max_length=self.max_sequence_length,
212
+ padding="max_length",
213
+ truncation=True,
214
+ return_tensors="pt",
215
+ ).to(self.device)
216
+ self.null_caption_embs = self.text_encoder(null_caption_token.input_ids, null_caption_token.attention_mask)[
217
+ 0
218
+ ]
219
+
220
+ if prompt is None:
221
+ prompt = [""]
222
+ prompts = prompt if isinstance(prompt, list) else [prompt]
223
+ samples = []
224
+
225
+ for prompt in prompts:
226
+ # data prepare
227
+ prompts, hw, ar = (
228
+ [],
229
+ torch.tensor([[self.image_size, self.image_size]], dtype=torch.float, device=self.device).repeat(
230
+ num_images_per_prompt, 1
231
+ ),
232
+ torch.tensor([[1.0]], device=self.device).repeat(num_images_per_prompt, 1),
233
+ )
234
+ for _ in range(num_images_per_prompt):
235
+ with torch.no_grad():
236
+ prompts.append(
237
+ prepare_prompt_ar(prompt, self.base_ratios, device=self.device, show=False)[0].strip()
238
+ )
239
+
240
+ # prepare text feature
241
+ if not self.config.text_encoder.chi_prompt:
242
+ max_length_all = self.config.text_encoder.model_max_length
243
+ prompts_all = prompts
244
+ else:
245
+ chi_prompt = "\n".join(self.config.text_encoder.chi_prompt)
246
+ prompts_all = [chi_prompt + prompt for prompt in prompts]
247
+ num_chi_prompt_tokens = len(self.tokenizer.encode(chi_prompt))
248
+ max_length_all = (
249
+ num_chi_prompt_tokens + self.config.text_encoder.model_max_length - 2
250
+ ) # magic number 2: [bos], [_]
251
+
252
+ caption_token = self.tokenizer(
253
+ prompts_all,
254
+ max_length=max_length_all,
255
+ padding="max_length",
256
+ truncation=True,
257
+ return_tensors="pt",
258
+ ).to(device=self.device)
259
+ select_index = [0] + list(range(-self.config.text_encoder.model_max_length + 1, 0))
260
+ caption_embs = self.text_encoder(caption_token.input_ids, caption_token.attention_mask)[0][:, None][
261
+ :, :, select_index
262
+ ].to(self.weight_dtype)
263
+ emb_masks = caption_token.attention_mask[:, select_index]
264
+ null_y = self.null_caption_embs.repeat(len(prompts), 1, 1)[:, None].to(self.weight_dtype)
265
+
266
+ n = len(prompts)
267
+ if latents is None:
268
+ z = torch.randn(
269
+ n,
270
+ self.config.vae.vae_latent_dim,
271
+ self.latent_size_h,
272
+ self.latent_size_w,
273
+ generator=generator,
274
+ device=self.device,
275
+ dtype=self.weight_dtype,
276
+ )
277
+ else:
278
+ z = latents.to(self.weight_dtype).to(self.device)
279
+ model_kwargs = dict(data_info={"img_hw": hw, "aspect_ratio": ar}, mask=emb_masks)
280
+ if self.vis_sampler == "flow_euler":
281
+ flow_solver = FlowEuler(
282
+ self.model,
283
+ condition=caption_embs,
284
+ uncondition=null_y,
285
+ cfg_scale=guidance_scale,
286
+ model_kwargs=model_kwargs,
287
+ )
288
+ sample = flow_solver.sample(
289
+ z,
290
+ steps=num_inference_steps,
291
+ )
292
+ elif self.vis_sampler == "flow_dpm-solver":
293
+ scheduler = DPMS(
294
+ self.model,
295
+ condition=caption_embs,
296
+ uncondition=null_y,
297
+ guidance_type=self.guidance_type,
298
+ cfg_scale=guidance_scale,
299
+ pag_scale=pag_guidance_scale,
300
+ pag_applied_layers=self.config.model.pag_applied_layers,
301
+ model_type="flow",
302
+ model_kwargs=model_kwargs,
303
+ schedule="FLOW",
304
+ )
305
+ scheduler.register_progress_bar(self.progress_fn)
306
+ sample = scheduler.sample(
307
+ z,
308
+ steps=num_inference_steps,
309
+ order=2,
310
+ skip_type="time_uniform_flow",
311
+ method="multistep",
312
+ flow_shift=self.flow_shift,
313
+ )
314
+
315
+ sample = sample.to(self.weight_dtype)
316
+ with torch.no_grad():
317
+ sample = vae_decode(self.config.vae.vae_type, self.vae, sample)
318
+
319
+ sample = resize_and_crop_tensor(sample, self.ori_width, self.ori_height)
320
+ samples.append(sample)
321
+
322
+ return sample
323
+
324
+ return samples
asset/Sana.jpg ADDED

Git LFS Details

  • SHA256: 1a10d77cfe5a1a703c2cb801d0f3fe9fa32a05c60dfff22b0bc7a479980df61c
  • Pointer size: 132 Bytes
  • Size of remote file: 1.16 MB
asset/docs/metrics_toolkit.md ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 💻 How to Inference & Test Metrics (FID, CLIP Score, GenEval, DPG-Bench, etc...)
2
+
3
+ This ToolKit will automatically inference your model and log the metrics results onto wandb as chart for better illustration. We curerntly support:
4
+
5
+ - \[x\] [FID](https://github.com/mseitzer/pytorch-fid) & [CLIP-Score](https://github.com/openai/CLIP)
6
+ - \[x\] [GenEval](https://github.com/djghosh13/geneval)
7
+ - \[x\] [DPG-Bench](https://github.com/TencentQQGYLab/ELLA)
8
+ - \[x\] [ImageReward](https://github.com/THUDM/ImageReward/tree/main)
9
+
10
+ ### 0. Install corresponding env for GenEval and DPG-Bench
11
+
12
+ Make sure you can activate the following envs:
13
+
14
+ - `conda activate geneval`([GenEval](https://github.com/djghosh13/geneval))
15
+ - `conda activate dpg`([DGB-Bench](https://github.com/TencentQQGYLab/ELLA))
16
+
17
+ ### 0.1 Prepare data.
18
+
19
+ Metirc FID & CLIP-Score on [MJHQ-30K](https://huggingface.co/datasets/playgroundai/MJHQ-30K)
20
+
21
+ ```python
22
+ from huggingface_hub import hf_hub_download
23
+
24
+ hf_hub_download(
25
+ repo_id="playgroundai/MJHQ-30K",
26
+ filename="mjhq30k_imgs.zip",
27
+ local_dir="data/test/PG-eval-data/MJHQ-30K/",
28
+ repo_type="dataset"
29
+ )
30
+ ```
31
+
32
+ Unzip mjhq30k_imgs.zip into its per-category folder structure.
33
+
34
+ ```
35
+ data/test/PG-eval-data/MJHQ-30K/imgs/
36
+ ├── animals
37
+ ├── art
38
+ ├── fashion
39
+ ├── food
40
+ ├── indoor
41
+ ├── landscape
42
+ ├── logo
43
+ ├── people
44
+ ├── plants
45
+ └── vehicles
46
+ ```
47
+
48
+ ### 0.2 Prepare checkpoints
49
+
50
+ ```bash
51
+ huggingface-cli download Efficient-Large-Model/Sana_1600M_1024px --repo-type model --local-dir ./output/Sana_1600M_1024px --local-dir-use-symlinks False
52
+ ```
53
+
54
+ ### 1. directly \[Inference and Metric\] a .pth file
55
+
56
+ ```bash
57
+ # We provide four scripts for evaluating metrics:
58
+ fid_clipscore_launch=scripts/bash_run_inference_metric.sh
59
+ geneval_launch=scripts/bash_run_inference_metric_geneval.sh
60
+ dpg_launch=scripts/bash_run_inference_metric_dpg.sh
61
+ image_reward_launch=scripts/bash_run_inference_metric_imagereward.sh
62
+
63
+ # Use following format to metric your models:
64
+ # bash $correspoinding_metric_launch $your_config_file_path $your_relative_pth_file_path
65
+
66
+ # example
67
+ bash $geneval_launch \
68
+ configs/sana_config/1024ms/Sana_1600M_img1024.yaml \
69
+ output/Sana_1600M_1024px/checkpoints/Sana_1600M_1024px.pth
70
+ ```
71
+
72
+ ### 2. \[Inference and Metric\] a list of .pth files using a txt file
73
+
74
+ You can also write all your pth files of a job in one txt file, eg. [model_paths.txt](../model_paths.txt)
75
+
76
+ ```bash
77
+ # Use following format to metric your models, gathering in a txt file:
78
+ # bash $correspoinding_metric_launch $your_config_file_path $your_txt_file_path_containing_pth_path
79
+
80
+ # We suggest follow the file tree structure in our project for robust experiment
81
+ # example
82
+ bash scripts/bash_run_inference_metric.sh \
83
+ configs/sana_config/1024ms/Sana_1600M_img1024.yaml \
84
+ asset/model_paths.txt
85
+ ```
86
+
87
+ ### 3. You will get the following data tree.
88
+
89
+ ```
90
+ output
91
+ ├──your_job_name/ (everything will be saved here)
92
+ │ ├──config.yaml
93
+ │ ├──train_log.log
94
+
95
+ │ ├──checkpoints (all checkpoints)
96
+ │ │ ├──epoch_1_step_6666.pth
97
+ │ │ ├──epoch_1_step_8888.pth
98
+ │ │ ├──......
99
+
100
+ │ ├──vis (all visualization result dirs)
101
+ │ │ ├──visualization_file_name
102
+ │ │ │ ├──xxxxxxx.jpg
103
+ │ │ │ ├──......
104
+ │ │ ├──visualization_file_name2
105
+ │ │ │ ├──xxxxxxx.jpg
106
+ │ │ │ ├──......
107
+ │ ├──......
108
+
109
+ │ ├──metrics (all metrics testing related files)
110
+ │ │ ├──model_paths.txt Optional(👈)(relative path of testing ckpts)
111
+ │ │ │ ├──output/your_job_name/checkpoings/epoch_1_step_6666.pth
112
+ │ │ │ ├──output/your_job_name/checkpoings/epoch_1_step_8888.pth
113
+ │ │ ├──fid_img_paths.txt Optional(👈)(name of testing img_dir in vis)
114
+ │ │ │ ├──visualization_file_name
115
+ │ │ │ ├──visualization_file_name2
116
+ │ │ ├──cached_img_paths.txt Optional(👈)
117
+ │ │ ├──......
118
+ ```
asset/example_data/00000000.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ a cyberpunk cat with a neon sign that says "Sana".
asset/examples.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+
17
+ examples = [
18
+ [
19
+ "A small cactus with a happy face in the Sahara desert.",
20
+ "flow_dpm-solver",
21
+ 20,
22
+ 5.0,
23
+ 2.5,
24
+ ],
25
+ [
26
+ "An extreme close-up of an gray-haired man with a beard in his 60s, he is deep in thought pondering the history"
27
+ "of the universe as he sits at a cafe in Paris, his eyes focus on people offscreen as they walk as he sits "
28
+ "mostly motionless, he is dressed in a wool coat suit coat with a button-down shirt, he wears a brown beret "
29
+ "and glasses and has a very professorial appearance, and the end he offers a subtle closed-mouth smile "
30
+ "as if he found the answer to the mystery of life, the lighting is very cinematic with the golden light and "
31
+ "the Parisian streets and city in the background, depth of field, cinematic 35mm film.",
32
+ "flow_dpm-solver",
33
+ 20,
34
+ 5.0,
35
+ 2.5,
36
+ ],
37
+ [
38
+ "An illustration of a human heart made of translucent glass, standing on a pedestal amidst a stormy sea. "
39
+ "Rays of sunlight pierce the clouds, illuminating the heart, revealing a tiny universe within. "
40
+ "The quote 'Find the universe within you' is etched in bold letters across the horizon."
41
+ "blue and pink, brilliantly illuminated in the background.",
42
+ "flow_dpm-solver",
43
+ 20,
44
+ 5.0,
45
+ 2.5,
46
+ ],
47
+ [
48
+ "A transparent sculpture of a duck made out of glass. The sculpture is in front of a painting of a landscape.",
49
+ "flow_dpm-solver",
50
+ 20,
51
+ 5.0,
52
+ 2.5,
53
+ ],
54
+ [
55
+ "A litter of golden retriever puppies playing in the snow. Their heads pop out of the snow, covered in.",
56
+ "flow_dpm-solver",
57
+ 20,
58
+ 5.0,
59
+ 2.5,
60
+ ],
61
+ [
62
+ "a kayak in the water, in the style of optical color mixing, aerial view, rainbowcore, "
63
+ "national geographic photo, 8k resolution, crayon art, interactive artwork",
64
+ "flow_dpm-solver",
65
+ 20,
66
+ 5.0,
67
+ 2.5,
68
+ ],
69
+ ]
asset/model-incremental.jpg ADDED
asset/model_paths.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ output/Sana_1600M_1024px/checkpoints/Sana_1600M_1024px.pth
2
+ output/Sana_1600M_1024px/checkpoints/Sana_1600M_1024px.pth
asset/samples.txt ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ A small cactus with a happy face in the Sahara desert.
2
+ Pirate ship trapped in a cosmic maelstrom nebula, rendered in cosmic beach whirlpool engine, volumetric lighting, spectacular, ambient lights, light pollution, cinematic atmosphere, art nouveau style, illustration art artwork by SenseiJaye, intricate detail.
3
+ beautiful lady, freckles, big smile, blue eyes, short ginger hair, dark makeup, wearing a floral blue vest top, soft light, dark grey background
4
+ stars, water, brilliantly, gorgeous large scale scene, a little girl, in the style of dreamy realism, light gold and amber, blue and pink, brilliantly illuminated in the background.
5
+ nature vs human nature, surreal, UHD, 8k, hyper details, rich colors, photograph.
6
+ Spectacular Tiny World in the Transparent Jar On the Table, interior of the Great Hall, Elaborate, Carved Architecture, Anatomy, Symetrical, Geometric and Parameteric Details, Precision Flat line Details, Pattern, Dark fantasy, Dark errie mood and ineffably mysterious mood, Technical design, Intricate Ultra Detail, Ornate Detail, Stylized and Futuristic and Biomorphic Details, Architectural Concept, Low contrast Details, Cinematic Lighting, 8k, by moebius, Fullshot, Epic, Fullshot, Octane render, Unreal ,Photorealistic, Hyperrealism
7
+ anthropomorphic profile of the white snow owl Crystal priestess , art deco painting, pretty and expressive eyes, ornate costume, mythical, ethereal, intricate, elaborate, hyperrealism, hyper detailed, 3D, 8K, Ultra Realistic, high octane, ultra resolution, amazing detail, perfection, In frame, photorealistic, cinematic lighting, visual clarity, shading , Lumen Reflections, Super-Resolution, gigapixel, color grading, retouch, enhanced, PBR, Blender, V-ray, Procreate, zBrush, Unreal Engine 5, cinematic, volumetric, dramatic, neon lighting, wide angle lens ,no digital painting blur
8
+ The parametric hotel lobby is a sleek and modern space with plenty of natural light. The lobby is spacious and open with a variety of seating options. The front desk is a sleek white counter with a parametric design. The walls are a light blue color with parametric patterns. The floor is a light wood color with a parametric design. There are plenty of plants and flowers throughout the space. The overall effect is a calm and relaxing space. occlusion, moody, sunset, concept art, octane rendering, 8k, highly detailed, concept art, highly detailed, beautiful scenery, cinematic, beautiful light, hyperreal, octane render, hdr, long exposure, 8K, realistic, fog, moody, fire and explosions, smoke, 50mm f2.8
9
+ Bright scene, aerial view, ancient city, fantasy, gorgeous light, mirror reflection, high detail, wide angle lens.
10
+ 8k uhd A man looks up at the starry sky, lonely and ethereal, Minimalism, Chaotic composition Op Art
11
+ A middle-aged woman of Asian descent, her dark hair streaked with silver, appears fractured and splintered, intricately embedded within a sea of broken porcelain. The porcelain glistens with splatter paint patterns in a harmonious blend of glossy and matte blues, greens, oranges, and reds, capturing her dance in a surreal juxtaposition of movement and stillness. Her skin tone, a light hue like the porcelain, adds an almost mystical quality to her form.
12
+ A 4k dslr image of a lemur wearing a red magician hat and a blue coat performing magic tricks with cards in a garden.
13
+ A alpaca made of colorful building blocks, cyberpunk
14
+ A baby painter trying to draw very simple picture, white background
15
+ A boy and a girl fall in love
16
+ A dog that has been meditating all the time
17
+ A man is sitting in a chair with his chin resting on his hand. The chair, along with the man's feet, are submerged in the sea. Strikingly, the man's back is on fire.
18
+ A painter study hard to learn how to draw with many concepts in the air, white background
19
+ A painter with low quality, white background, pixel art
20
+ A person standing on the desert, desert waves, gossip illustration, half red, half blue, abstract image of sand, clear style, trendy illustration, outdoor, top view, clear style, precision art, ultra high definition image
21
+ A silhouette of a grand piano overlooking a dusky cityscape viewed from a top-floor penthouse, rendered in the bold and vivid sytle of a vintage travel poster.
22
+ A sureal parallel world where mankind avoid extinction by preserving nature, epic trees, water streams, various flowers, intricate details, rich colors, rich vegetation, cinematic, symmetrical, beautiful lighting, V-Ray render, sun rays, magical lights, photography
23
+ A woman is shopping for fresh produce at the farmer's market.
24
+ A worker that looks like a mixture of cow and horse is working hard to type code
25
+ A young man dressed in ancient Chinese clothing, Asian people, White robe, Handsome, Hand gestures forming a spell, Martial arts and fairy-like vibe, Carrying a legendary-level giant sword on the back, Game character, Surrounded by runes, Cyberpunk style, neon lights, best quality, masterpiece, cg, hdr, high-definition, extremely detailed, photorealistic, epic, character design, detailed face, superhero, hero, detailed UHD, real-time, vfx, 3D rendering, 8k
26
+ An alien octopus floats through a protal reading a newspaper
27
+ An epressive oil painting of a basketbal player dunking, depicted as an explosion of a nebula
28
+ art collection style and fashion shoot, in the style of made of glass, dark blue and light pink, paul rand, solarpunk, camille vivier, beth didonato hair, barbiecore, hyper-realistic
29
+ artistic
30
+ beautiful secen
31
+ Crocodile in a sweater
32
+ Design a letter A, 3D stereoscopic Ice material Interior light blue Conceptual product design Futuristic Blind box toy Handcrafted Exquisite 3D effect Full body display Ultra-high precision Ultra-detailed Perfect lighting OC Renderer Blender 8k Ultra-sharp Ultra-noise reduction
33
+ Floating,colossal,futuristic statue in the sky, awe-inspiring and serenein the style of Stuart Lippincott:2with detailed composition and subtle geometric elements.This sanctuary-ike atmosphere features crisp clarity and soft amber tones.In contrasttiny human figures surround the statueThe pieceincorporates flowing draperiesreminiscent of Shwedoff and Philip McKay's stylesemphasizing thejuxtaposition between the powerful presence of the statue and thevulnerability of the minuscule human figuresshwedoff
34
+ knolling of a drawing tools for painter
35
+ Leonardo da Vinci's Last Supper content, Van Goph's Starry Night Style
36
+ Luffy from ONEPIECE, handsome face, fantasy
37
+ photography shot through an outdoor window of a coffee shop with neon sign lighting, window glares and reflections, depth of field, {little girl with red hair sitting at a table, portrait, kodak portra 800,105 mm f1.8
38
+ poster of a mechanical cat, techical Schematics viewed from front and side view on light white blueprint paper, illustartion drafting style, illustation, typography, conceptual art, dark fantasy steampunk, cinematic, dark fantasy
39
+ The girl in the car is filled with goldfish and flowers, goldfish can fly, Kawaguchi Renko's art, natural posture, holiday dadcore, youthful energy and pressure, body stretching, goldfish simulation movies in the sky, super details, and dreamy high photography. Colorful. Covered by water and goldfish, indoor scene, close-up shot in XT4 movie
40
+ The image features a woman wearing a red shirt with an icon. She appears to be posing for the camera, and her outfit includes a pair of jeans. The woman seems to be in a good mood, as she is smiling. The background of the image is blurry, focusing more on the woman and her attire.
41
+ The towel was on top of the hard counter.
42
+ A vast landscape made entirely of various meats spreads out before the viewer. tender, succulent hills of roast beef, chicken drumstick trees, bacon rivers, and ham boulders create a surreal, yet appetizing scene. the sky is adorned with pepperoni sun and salami clouds.
43
+ I want to supplement vitamin c, please help me paint related food.
44
+ A vibrant yellow banana-shaped couch sits in a cozy living room, its curve cradling a pile of colorful cushions. on the wooden floor, a patterned rug adds a touch of eclectic charm, and a potted plant sits in the corner, reaching towards the sunlight filtering through the window.
45
+ A transparent sculpture of a duck made out of glass. The sculpture is in front of a painting of a landscape.
46
+ A blue jay standing on a large basket of rainbow macarons.
47
+ A bucket bag made of blue suede. The bag is decorated with intricate golden paisley patterns. The handle of the bag is made of rubies and pearls.
48
+ An alien octopus floats through a portal reading a newspaper.
49
+ bird's eye view of a city.
50
+ beautiful scene
51
+ A 2D animation of a folk music band composed of anthropomorphic autumn leaves, each playing traditional bluegrass instruments, amidst a rustic forest setting dappled with the soft light of a harvest moon.
52
+ In front of a deep black backdrop, a figure of middle years, her Tongan skin rich and glowing, is captured mid-twirl, her curly hair flowing like a storm behind her. Her attire resembles a whirlwind of marble and porcelain fragments. Illuminated by the gleam of scattered porcelain shards, creating a dreamlike atmosphere, the dancer manages to appear fragmented, yet maintains a harmonious and fluid form.
53
+ Digital illustration of a beach scene crafted from yarn. The sandy beach is depicted with beige yarn, waves are made of blue and white yarn crashing onto the shore. A yarn sun sets on the horizon, casting a warm glow. Yarn palm trees sway gently, and little yarn seashells dot the shoreline.
54
+ Illustration of a chic chair with a design reminiscent of a pumpkin’s form, with deep orange cushioning, in a stylish loft setting.
55
+ A detailed oil painting of an old sea captain, steering his ship through a storm. Saltwater is splashing against his weathered face, determination in his eyes. Twirling malevolent clouds are seen above and stern waves threaten to submerge the ship while seagulls dive and twirl through the chaotic landscape. Thunder and lights embark in the distance, illuminating the scene with an eerie green glow.
56
+ An illustration of a human heart made of translucent glass, standing on a pedestal amidst a stormy sea. Rays of sunlight pierce the clouds, illuminating the heart, revealing a tiny universe within. The quote 'Find the universe within you' is etched in bold letters across the horizon.
57
+ A modern architectural building with large glass windows, situated on a cliff overlooking a serene ocean at sunset
58
+ photo of an ancient shipwreck nestled on the ocean floor. Marine plants have claimed the wooden structure, and fish swim in and out of its hollow spaces. Sunken treasures and old cannons are scattered around, providing a glimpse into the past
59
+ A 3D render of a coffee mug placed on a window sill during a stormy day. The storm outside the window is reflected in the coffee, with miniature lightning bolts and turbulent waves seen inside the mug. The room is dimly lit, adding to the dramatic atmosphere.A minimap diorama of a cafe adorned with indoor plants. Wooden beams crisscross above, and a cold brew station stands out with tiny bottles and glasses.
60
+ An antique botanical illustration drawn with fine lines and a touch of watercolour whimsy, depicting a strange lily crossed with a Venus flytrap, its petals poised as if ready to snap shut on any unsuspecting insects.An illustration inspired by old-world botanical sketches blends a cactus with lilac blooms into a Möbius strip, using detailed lines and subtle watercolor touches to capture nature's diverse beauty and mathematical intrigue.
61
+ An ink sketch style illustration of a small hedgehog holding a piece of watermelon with its tiny paws, taking little bites with its eyes closed in delight.Photo of a lychee-inspired spherical chair, with a bumpy white exterior and plush interior, set against a tropical wallpaper.
62
+ 3d digital art of an adorable ghost, glowing within, holding a heart shaped pumpkin, Halloween, super cute, spooky haunted house background
63
+ professional portrait photo of an anthropomorphic cat wearing fancy gentleman hat and jacket walking in autumn forest.
64
+ an astronaut sitting in a diner, eating fries, cinematic, analog film
65
+ Chinese architecture, ancient style,mountain, bird, lotus, pond, big tree, 4K Unity, octane rendering.
66
+ Ethereal fantasy concept art of thunder god with hammer. magnificent, celestial, ethereal, painterly, epic, majestic, magical, fantasy art, cover art, dreamy.
67
+ A Japanese girl walking along a path, surrounding by blooming oriental cherry, pink petal slowly falling down to the ground
68
+ A Ukiyoe style painting, an astronaut riding a unicorn, In the background there is an ancient Japanese architecture
69
+ Steampunk makeup, in the style of vray tracing, colorful impasto, uhd image, indonesian art, fine feather details with bright red and yellow and green and pink and orange colours, intricate patterns and details, dark cyan and amber makeup. Rich colourful plumes. Victorian style.
70
+ A cute teddy bear in front of a plain white wall, warm and brown fur, soft and fluffy
71
+ The beautiful scenery of Seattle, painting by Al Capp.
72
+ Photo of a rhino dressed suit and tie sitting at a table in a bar with a bar stools, award winning photography, Elke vogelsang.
73
+ An astronaut riding a horse on the moon, oil painting by Van Gogh.
74
+ A deep forest clearing with a mirrored pond reflecting a galaxy-filled night sky
75
+ Realistic oil painting of a stunning model merged in multicolor splash made of finely torn paper, eye contact, walking with class in a street.
76
+ a chinese model is sitting on a train, magazine cover, clothes made of plastic, photorealistic,futuristic style, gray and green light, movie lighting, 32K HD
77
+ a handsome 24 years old boy in the middle with sky color background wearing eye glasses, it's super detailed with anime style, it's a portrait with delicated eyes and nice looking face
78
+ a kayak in the water, in the style of optical color mixing, aerial view, rainbowcore, national geographic photo, 8k resolution, crayon art, interactive artwork
79
+ 3D rendering miniature scene design, Many tall buildings, A winding urban road runs through the middle,a lot of cars on the road, transparent material pipeline transports Materials, ,there are many people around, in thestyle of light orange and yellow, graphic design- inspired illustrations, classic still-life, beeple, josan gon-zalez, manga-influenced, miniature dioramas, in thestyle of playful and whimsical designs, graphic de-sign-inspired illustrations, minimalism, hyperrealismlomo lca, e-commerce C4D style, e-commerce posterUl, UX, octane render, blender
80
+ Close-up photos of models, hazy light and shadow, laser metal hair accessories, soft and beautiful, light gold pupils, white eyelashes, low saturation, real skin details, clear pores and fine lines, light reflection and refraction, ultra-clear, cinematography, award-winning works
81
+ A cute orange kitten sliding down an aqua slide. happy excited. 16mm lens in front. we see his excitement and scared in the eye. vibrant colors. water splashing on the lens
82
+ Several giant wooly mammoths approach treading through a snowy meadow, their long wooly fur lightly blows in the wind as they walk, snow covered trees and dramatic snow capped mountains in the distance, mid afternoon light with wispy clouds and a sun high in the distance creates a warm glow, the low camera view is stunning capturing the large furry mammal with beautiful photography, depth of field.
83
+ A gorgeously rendered papercraft world of a coral reef, rife with colorful fish and sea creatures.
84
+ An extreme close-up of an gray-haired man with a beard in his 60s, he is deep in thought pondering the history of the universe as he sits at a cafe in Paris, his eyes focus on people offscreen as they walk as he sits mostly motionless, he is dressed in a wool coat suit coat with a button-down shirt , he wears a brown beret and glasses and has a very professorial appearance, and the end he offers a subtle closed-mouth smile as if he found the answer to the mystery of life, the lighting is very cinematic with the golden light and the Parisian streets and city in the background, depth of field, cinematic 35mm film.
85
+ A litter of golden retriever puppies playing in the snow. Their heads pop out of the snow, covered in.
86
+ A New Zealand female business owner stands and is happy that his business is growing by having good VoIP and broadband supplied by Voyager Internet. This business owner is dressed semi casual and is standing with a funky office space in the background. The image is light and bright and is well lit. This image needs to be shot like a professional photo shoot using a Canon R6 with high quality 25mm lens. This image has a shallow depth of field
87
+ The parametric hotel lobby is a sleek and modern space with plenty of natural light. The lobby is spacious and open with a variety of seating options. The front desk is a sleek white counter with a parametric design. The walls are a light blue color with parametric patterns. The floor is a light wood color with a parametric design. There are plenty of plants and flowers throughout the space. The overall effect is a calm and relaxing space. occlusion, moody, sunset, concept art, octane rendering, 8k, highly detailed, concept art, highly detailed, beautiful scenery, cinematic, beautiful light, hyperreal, octane render, hdr, long exposure, 8K, realistic, fog, moody, fire and explosions, smoke, 50mm f2.8
88
+ Editorial photoshoot of a old woman, high fashion 2000s fashion
89
+ Mural Painted of Prince in Purple Rain on side of 5 story brick building next to zen garden vacant lot in the urban center district, rgb
90
+ Cozy Scandinavian living room, there is a cat sleeping on the couch, depth of field
91
+ Street style centered straight shot photo shot on Afga Vista 400, lense 50mm, of a two women,skin to skin touch face, emotion, hughing, natural blond hair, natural features, ultra detailed, skin texture, Rembrandt light, soft shadows
92
+ Frog, in forest, colorful, no watermark, no signature, in forest, 8k
93
+ selfie of a woman and her lion cub on the plains
94
+ A fisherman fixing his net sitting on a beautiful tropical beach at sunset with bending palm trees fishing gear and a small boat on shore
95
+ Coast, decorative painting, horizon, modern, fashionable, full of abstract feeling, full of imagination, the picture reveals the sense of time passing, there is a feeling of the end of the world
96
+ A close up of a branch of a tree and a golden bug on the top a leaf, shutterstock contest winner,ecological art, depth of field, shallow depth of field, macro photography
97
+ Outdoor style fashion photo, full – body shot of a man with short brown hair, happy and smiling, he is standing on his hipster bicycle wearing a light blue long sleeved blouse with closed buttons and dark blue jeans trousers, in the background the exterior of an Aldi store, fully lit background, natural afternoon lighting
98
+ beautiful woman sniper, wearing soviet army uniform, one eye on sniper lens, in snow ground
99
+ A very attractive and natural woman, sitting on a yoka mat, breathing, eye closed, no make up, intense satisfaction, she looks like she is intensely relaxed, yoga class, sunrise, 35mm
100
+ a close up of a helmet on a person, digital art, inspired by Han Gan, cloisonnism, female, victorian armor, ultramarine, best of behance, anton fadeev 8 k, fined detail, sci-fi character, elegant armor, fantasy art behance
101
+ a melting apple
102
+ yellow FIAT 500 Cinquecento 1957 driving through liechtenstein castle with a lot of banknotes scattered behind ,filled with wads of cash , car color yellow, license plate R-33
103
+ tented resort in the desert, rocky and sandy terrain, 5 star hotel, beautiful landscape, landscape photography, depth of view, Fujifilm GFX 100 –uplight
104
+ Full body shot, a French woman, Photography, French Streets background, backlighting, rim light, Fujifilm.
105
+ Modern luxury contemporary luxury home interiors house, in the style of mimicking ruined materials, ray tracing, haunting houses, and stone, capture the essence of nature, gray and bronze, dynamic outdoor shots.
106
+ Over the shoulder game perspective, game screen of Diablo 4, Inside the gorgeous palace is the wet ground, The necromancer knelt before the king, and a horde of skeletons he summoned stood at his side, cinematic light.
107
+ Color photo of a corgi made of transparent glass, standing on the riverside in Yosemite National Park.
108
+ Happy dreamy owl monster sitting on a tree branch, colorful glittering particles, forest background, detailed feathers.
109
+ Game-Art - An island with different geographical properties and multiple small cities floating in space
110
+ Photorealistic closeup video of two pirate ships battling each other as they sail inside a cup of coffee.
111
+ A car made out of vegetables.
112
+ A serene lakeside during autumn with trees displaying a palette of fiery colors.
113
+ A realistic landscape shot of the Northern Lights dancing over a snowy mountain range in Iceland.
114
+ A deep forest clearing with a mirrored pond reflecting a galaxy-filled night sky.
115
+ Drone view of waves crashing against the rugged cliffs along Big Sur’s Garay Point beach. The crashing blue waters create white-tipped waves, while the golden light of the setting sun illuminates the rocky shore.
116
+ A curvy timber house near a sea, designed by Zaha Hadid, represent the image of a cold, modern architecture, at night, white lighting, highly detailed.
117
+ Eiffel Tower was Made up of more than 2 million translucent straws to look like a cloud, with the bell tower at the top of the building, Michel installed huge foam-making machines in the forest to blow huge amounts of unpredictable wet clouds in the building's classic architecture.
118
+ Close-up photos of models, hazy light and shadow, laser metal hair accessories, soft and beautiful, light gold pupils, white eyelashes, low saturation, real skin details, clear pores and fine lines, light reflection and refraction, ultra-clear, cinematography, award-winning works.
119
+ smiling cartoon dog sits at a table, coffee mug on hand, as a room goes up in flames. "Help" the dog is yelling.
120
+ A stylish woman walks down a Tokyo street filled with warm glowing neon and animated city signage. She wears a black leather jacket, a long red dress, and black boots, and carries a black purse. She wears sunglasses and red lipstick. She walks confidently and casually. The street is damp and reflective, creating a mirror effect of the colorful lights. Many pedestrians walk about.
121
+ A close-up photo of a person. The subject is a woman. She wore a blue coat with a gray dress underneath. She has blue eyes and blond hair and wears a pair of earrings. Behind are blurred city buildings and streets.
122
+ 👧 with 🌹 in the ❄️
123
+ 🐶 Wearing 🕶 flying on the 🌈
124
+ a cyberpunk cat with a neon sign that says "MIT"
125
+ a black and white picture of a woman looking through the window, in the style of Duffy Sheridan, Anna Razumovskaya, smooth and shiny, wavy, Patrick Demarchelier, album covers, lush and detailed.
asset/samples_mini.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ A cyberpunk cat with a neon sign that says 'Sana'.
2
+ A small cactus with a happy face in the Sahara desert.
3
+ The towel was on top of the hard counter.
4
+ A vast landscape made entirely of various meats spreads out before the viewer. tender, succulent hills of roast beef, chicken drumstick trees, bacon rivers, and ham boulders create a surreal, yet appetizing scene. the sky is adorned with pepperoni sun and salami clouds.
5
+ I want to supplement vitamin c, please help me paint related food.
6
+ A transparent sculpture of a duck made out of glass. The sculpture is in front of a painting of a landscape.
7
+ an old rusted robot wearing pants and a jacket riding skis in a supermarket.
8
+ professional portrait photo of an anthropomorphic cat wearing fancy gentleman hat and jacket walking in autumn forest.
9
+ Astronaut in a jungle, cold color palette, muted colors, detailed
10
+ a stunning and luxurious bedroom carved into a rocky mountainside seamlessly blending nature with modern design with a plush earth-toned bed textured stone walls circular fireplace massive uniquely shaped window framing snow-capped mountains dense forests.
configs/sana_app_config/Sana_1600M_app.yaml ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data:
2
+ data_dir: []
3
+ image_size: 1024
4
+ caption_proportion:
5
+ prompt: 1
6
+ external_caption_suffixes: []
7
+ external_clipscore_suffixes: []
8
+ clip_thr_temperature: 0.1
9
+ clip_thr: 25.0
10
+ load_text_feat: false
11
+ load_vae_feat: false
12
+ transform: default_train
13
+ type: SanaWebDatasetMS
14
+ data:
15
+ sort_dataset: false
16
+ # model config
17
+ model:
18
+ model: SanaMS_1600M_P1_D20
19
+ image_size: 1024
20
+ mixed_precision: fp16 # ['fp16', 'fp32', 'bf16']
21
+ fp32_attention: true
22
+ load_from:
23
+ resume_from:
24
+ aspect_ratio_type: ASPECT_RATIO_1024
25
+ multi_scale: true
26
+ #pe_interpolation: 1.
27
+ attn_type: linear
28
+ ffn_type: glumbconv
29
+ mlp_acts:
30
+ - silu
31
+ - silu
32
+ -
33
+ mlp_ratio: 2.5
34
+ use_pe: false
35
+ qk_norm: false
36
+ class_dropout_prob: 0.1
37
+ # CFG & PAG settings
38
+ pag_applied_layers:
39
+ - 8
40
+ # VAE setting
41
+ vae:
42
+ vae_type: dc-ae
43
+ vae_pretrained: mit-han-lab/dc-ae-f32c32-sana-1.0
44
+ scale_factor: 0.41407
45
+ vae_latent_dim: 32
46
+ vae_downsample_rate: 32
47
+ sample_posterior: true
48
+ # text encoder
49
+ text_encoder:
50
+ text_encoder_name: gemma-2-2b-it
51
+ y_norm: true
52
+ y_norm_scale_factor: 0.01
53
+ model_max_length: 300
54
+ # CHI
55
+ chi_prompt:
56
+ - 'Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:'
57
+ - '- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.'
58
+ - '- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.'
59
+ - 'Here are examples of how to transform or refine prompts:'
60
+ - '- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.'
61
+ - '- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.'
62
+ - 'Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:'
63
+ - 'User Prompt: '
64
+ # Sana schedule Flow
65
+ scheduler:
66
+ predict_v: true
67
+ noise_schedule: linear_flow
68
+ pred_sigma: false
69
+ flow_shift: 3.0
70
+ # logit-normal timestep
71
+ weighting_scheme: logit_normal
72
+ logit_mean: 0.0
73
+ logit_std: 1.0
74
+ vis_sampler: flow_dpm-solver
75
+ # training setting
76
+ train:
77
+ num_workers: 10
78
+ seed: 1
79
+ train_batch_size: 64
80
+ num_epochs: 100
81
+ gradient_accumulation_steps: 1
82
+ grad_checkpointing: true
83
+ gradient_clip: 0.1
84
+ optimizer:
85
+ betas:
86
+ - 0.9
87
+ - 0.999
88
+ - 0.9999
89
+ eps:
90
+ - 1.0e-30
91
+ - 1.0e-16
92
+ lr: 0.0001
93
+ type: CAMEWrapper
94
+ weight_decay: 0.0
95
+ lr_schedule: constant
96
+ lr_schedule_args:
97
+ num_warmup_steps: 2000
98
+ local_save_vis: true # if save log image locally
99
+ visualize: true
100
+ eval_sampling_steps: 500
101
+ log_interval: 20
102
+ save_model_epochs: 5
103
+ save_model_steps: 500
104
+ work_dir: output/debug
105
+ online_metric: false
106
+ eval_metric_step: 2000
107
+ online_metric_dir: metric_helper
configs/sana_app_config/Sana_600M_app.yaml ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data:
2
+ data_dir: []
3
+ image_size: 1024
4
+ caption_proportion:
5
+ prompt: 1
6
+ external_caption_suffixes: []
7
+ external_clipscore_suffixes: []
8
+ clip_thr_temperature: 0.1
9
+ clip_thr: 25.0
10
+ load_text_feat: false
11
+ load_vae_feat: true
12
+ transform: default_train
13
+ type: SanaWebDatasetMS
14
+ sort_dataset: false
15
+ # model config
16
+ model:
17
+ model: SanaMS_600M_P1_D28
18
+ image_size: 1024
19
+ mixed_precision: fp16 # ['fp16', 'fp32', 'bf16']
20
+ fp32_attention: true
21
+ load_from:
22
+ resume_from:
23
+ aspect_ratio_type: ASPECT_RATIO_1024
24
+ multi_scale: true
25
+ attn_type: linear
26
+ ffn_type: glumbconv
27
+ mlp_acts:
28
+ - silu
29
+ - silu
30
+ -
31
+ mlp_ratio: 2.5
32
+ use_pe: false
33
+ qk_norm: false
34
+ class_dropout_prob: 0.1
35
+ # CFG & PAG settings
36
+ pag_applied_layers:
37
+ - 14
38
+ # VAE setting
39
+ vae:
40
+ vae_type: dc-ae
41
+ vae_pretrained: mit-han-lab/dc-ae-f32c32-sana-1.0
42
+ scale_factor: 0.41407
43
+ vae_latent_dim: 32
44
+ vae_downsample_rate: 32
45
+ sample_posterior: true
46
+ # text encoder
47
+ text_encoder:
48
+ text_encoder_name: gemma-2-2b-it
49
+ y_norm: true
50
+ y_norm_scale_factor: 0.01
51
+ model_max_length: 300
52
+ # CHI
53
+ chi_prompt:
54
+ - 'Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:'
55
+ - '- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.'
56
+ - '- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.'
57
+ - 'Here are examples of how to transform or refine prompts:'
58
+ - '- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.'
59
+ - '- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.'
60
+ - 'Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:'
61
+ - 'User Prompt: '
62
+ # Sana schedule Flow
63
+ scheduler:
64
+ predict_v: true
65
+ noise_schedule: linear_flow
66
+ pred_sigma: false
67
+ flow_shift: 4.0
68
+ # logit-normal timestep
69
+ weighting_scheme: logit_normal
70
+ logit_mean: 0.0
71
+ logit_std: 1.0
72
+ vis_sampler: flow_dpm-solver
73
+ # training setting
74
+ train:
75
+ num_workers: 10
76
+ seed: 1
77
+ train_batch_size: 64
78
+ num_epochs: 100
79
+ gradient_accumulation_steps: 1
80
+ grad_checkpointing: true
81
+ gradient_clip: 0.1
82
+ optimizer:
83
+ betas:
84
+ - 0.9
85
+ - 0.999
86
+ - 0.9999
87
+ eps:
88
+ - 1.0e-30
89
+ - 1.0e-16
90
+ lr: 0.0001
91
+ type: CAMEWrapper
92
+ weight_decay: 0.0
93
+ lr_schedule: constant
94
+ lr_schedule_args:
95
+ num_warmup_steps: 2000
96
+ local_save_vis: true # if save log image locally
97
+ visualize: true
98
+ eval_sampling_steps: 500
99
+ log_interval: 20
100
+ save_model_epochs: 5
101
+ save_model_steps: 500
102
+ work_dir: output/debug
103
+ online_metric: false
104
+ eval_metric_step: 2000
105
+ online_metric_dir: metric_helper
configs/sana_base.yaml ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # data settings
2
+ data:
3
+ data_dir: []
4
+ caption_proportion:
5
+ prompt: 1
6
+ external_caption_suffixes: []
7
+ external_clipscore_suffixes: []
8
+ clip_thr_temperature: 1.0
9
+ clip_thr: 0.0
10
+ sort_dataset: false
11
+ load_text_feat: false
12
+ load_vae_feat: false
13
+ transform: default_train
14
+ type: SanaWebDatasetMS
15
+ image_size: 512
16
+ hq_only: false
17
+ valid_num: 0
18
+ # model settings
19
+ model:
20
+ model: SanaMS_600M_P1_D28
21
+ image_size: 512
22
+ mixed_precision: fp16 # ['fp16', 'fp32', 'bf16']
23
+ fp32_attention: true
24
+ load_from:
25
+ resume_from:
26
+ checkpoint:
27
+ load_ema: false
28
+ resume_lr_scheduler: true
29
+ resume_optimizer: true
30
+ aspect_ratio_type: ASPECT_RATIO_1024
31
+ multi_scale: true
32
+ pe_interpolation: 1.0
33
+ micro_condition: false
34
+ attn_type: linear # 'flash', 'linear', 'vanilla', 'triton_linear'
35
+ cross_norm: false
36
+ autocast_linear_attn: false
37
+ ffn_type: glumbconv
38
+ mlp_acts:
39
+ - silu
40
+ - silu
41
+ -
42
+ mlp_ratio: 2.5
43
+ use_pe: false
44
+ qk_norm: false
45
+ class_dropout_prob: 0.0
46
+ linear_head_dim: 32
47
+ # CFG & PAG settings
48
+ cfg_scale: 4
49
+ guidance_type: classifier-free
50
+ pag_applied_layers: [14]
51
+ # text encoder settings
52
+ text_encoder:
53
+ text_encoder_name: gemma-2-2b-it
54
+ caption_channels: 2304
55
+ y_norm: false
56
+ y_norm_scale_factor: 1.0
57
+ model_max_length: 300
58
+ chi_prompt: []
59
+ # VAE settings
60
+ vae:
61
+ vae_type: dc-ae
62
+ vae_pretrained: mit-han-lab/dc-ae-f32c32-sana-1.0
63
+ scale_factor: 0.41407
64
+ vae_latent_dim: 32
65
+ vae_downsample_rate: 32
66
+ sample_posterior: true
67
+ # Scheduler settings
68
+ scheduler:
69
+ train_sampling_steps: 1000
70
+ predict_v: True
71
+ noise_schedule: linear_flow
72
+ pred_sigma: false
73
+ flow_shift: 1.0
74
+ weighting_scheme: logit_normal
75
+ logit_mean: 0.0
76
+ logit_std: 1.0
77
+ vis_sampler: flow_dpm-solver
78
+ # training settings
79
+ train:
80
+ num_workers: 4
81
+ seed: 43
82
+ train_batch_size: 32
83
+ num_epochs: 100
84
+ gradient_accumulation_steps: 1
85
+ grad_checkpointing: false
86
+ gradient_clip: 1.0
87
+ gc_step: 1
88
+ # optimizer settings
89
+ optimizer:
90
+ eps: 1.0e-10
91
+ lr: 0.0001
92
+ type: AdamW
93
+ weight_decay: 0.03
94
+ lr_schedule: constant
95
+ lr_schedule_args:
96
+ num_warmup_steps: 500
97
+ auto_lr:
98
+ rule: sqrt
99
+ ema_rate: 0.9999
100
+ eval_batch_size: 16
101
+ use_fsdp: false
102
+ use_flash_attn: false
103
+ eval_sampling_steps: 250
104
+ lora_rank: 4
105
+ log_interval: 50
106
+ mask_type: 'null'
107
+ mask_loss_coef: 0.0
108
+ load_mask_index: false
109
+ snr_loss: false
110
+ real_prompt_ratio: 1.0
111
+ debug_nan: false
112
+ # checkpoint settings
113
+ save_image_epochs: 1
114
+ save_model_epochs: 1
115
+ save_model_steps: 1000000
116
+ # visualization settings
117
+ visualize: false
118
+ null_embed_root: output/pretrained_models/
119
+ valid_prompt_embed_root: output/tmp_embed/
120
+ validation_prompts:
121
+ - dog
122
+ - portrait photo of a girl, photograph, highly detailed face, depth of field
123
+ - Self-portrait oil painting, a beautiful cyborg with golden hair, 8k
124
+ - Astronaut in a jungle, cold color palette, muted colors, detailed, 8k
125
+ - A photo of beautiful mountain with realistic sunset and blue lake, highly detailed, masterpiece
126
+ local_save_vis: false
127
+ deterministic_validation: true
128
+ online_metric: false
129
+ eval_metric_step: 5000
130
+ online_metric_dir: metric_helper
131
+ # work dir settings
132
+ work_dir: /cache/exps/
133
+ skip_step: 0
134
+ # LCM settings
135
+ loss_type: huber
136
+ huber_c: 0.001
137
+ num_ddim_timesteps: 50
138
+ w_max: 15.0
139
+ w_min: 3.0
140
+ ema_decay: 0.95
configs/sana_config/1024ms/Sana_1600M_img1024.yaml ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data:
2
+ data_dir: [data/data_public/dir1]
3
+ image_size: 1024
4
+ caption_proportion:
5
+ prompt: 1
6
+ external_caption_suffixes: ['', _InternVL2-26B, _VILA1-5-13B]
7
+ external_clipscore_suffixes:
8
+ - _InternVL2-26B_clip_score
9
+ - _VILA1-5-13B_clip_score
10
+ - _prompt_clip_score
11
+ clip_thr_temperature: 0.1
12
+ clip_thr: 25.0
13
+ load_text_feat: false
14
+ load_vae_feat: false
15
+ transform: default_train
16
+ type: SanaWebDatasetMS
17
+ sort_dataset: false
18
+ # model config
19
+ model:
20
+ model: SanaMS_1600M_P1_D20
21
+ image_size: 1024
22
+ mixed_precision: fp16 # ['fp16', 'fp32', 'bf16']
23
+ fp32_attention: true
24
+ load_from:
25
+ resume_from:
26
+ aspect_ratio_type: ASPECT_RATIO_1024
27
+ multi_scale: true
28
+ #pe_interpolation: 1.
29
+ attn_type: linear
30
+ ffn_type: glumbconv
31
+ mlp_acts:
32
+ - silu
33
+ - silu
34
+ -
35
+ mlp_ratio: 2.5
36
+ use_pe: false
37
+ qk_norm: false
38
+ class_dropout_prob: 0.1
39
+ # PAG
40
+ pag_applied_layers:
41
+ - 8
42
+ # VAE setting
43
+ vae:
44
+ vae_type: dc-ae
45
+ vae_pretrained: mit-han-lab/dc-ae-f32c32-sana-1.0
46
+ scale_factor: 0.41407
47
+ vae_latent_dim: 32
48
+ vae_downsample_rate: 32
49
+ sample_posterior: true
50
+ # text encoder
51
+ text_encoder:
52
+ text_encoder_name: gemma-2-2b-it
53
+ y_norm: true
54
+ y_norm_scale_factor: 0.01
55
+ model_max_length: 300
56
+ # CHI
57
+ chi_prompt:
58
+ - 'Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:'
59
+ - '- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.'
60
+ - '- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.'
61
+ - 'Here are examples of how to transform or refine prompts:'
62
+ - '- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.'
63
+ - '- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.'
64
+ - 'Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:'
65
+ - 'User Prompt: '
66
+ # Sana schedule Flow
67
+ scheduler:
68
+ predict_v: true
69
+ noise_schedule: linear_flow
70
+ pred_sigma: false
71
+ flow_shift: 3.0
72
+ # logit-normal timestep
73
+ weighting_scheme: logit_normal
74
+ logit_mean: 0.0
75
+ logit_std: 1.0
76
+ vis_sampler: flow_dpm-solver
77
+ # training setting
78
+ train:
79
+ num_workers: 10
80
+ seed: 1
81
+ train_batch_size: 64
82
+ num_epochs: 100
83
+ gradient_accumulation_steps: 1
84
+ grad_checkpointing: true
85
+ gradient_clip: 0.1
86
+ optimizer:
87
+ betas:
88
+ - 0.9
89
+ - 0.999
90
+ - 0.9999
91
+ eps:
92
+ - 1.0e-30
93
+ - 1.0e-16
94
+ lr: 0.0001
95
+ type: CAMEWrapper
96
+ weight_decay: 0.0
97
+ lr_schedule: constant
98
+ lr_schedule_args:
99
+ num_warmup_steps: 2000
100
+ local_save_vis: true # if save log image locally
101
+ visualize: true
102
+ eval_sampling_steps: 500
103
+ log_interval: 20
104
+ save_model_epochs: 5
105
+ save_model_steps: 500
106
+ work_dir: output/debug
107
+ online_metric: false
108
+ eval_metric_step: 2000
109
+ online_metric_dir: metric_helper
configs/sana_config/1024ms/Sana_600M_img1024.yaml ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data:
2
+ data_dir: [data/data_public/dir1]
3
+ image_size: 1024
4
+ caption_proportion:
5
+ prompt: 1
6
+ external_caption_suffixes: ['', _InternVL2-26B, _VILA1-5-13B]
7
+ external_clipscore_suffixes:
8
+ - _InternVL2-26B_clip_score
9
+ - _VILA1-5-13B_clip_score
10
+ - _prompt_clip_score
11
+ clip_thr_temperature: 0.1
12
+ clip_thr: 25.0
13
+ load_text_feat: false
14
+ load_vae_feat: false
15
+ transform: default_train
16
+ type: SanaWebDatasetMS
17
+ sort_dataset: false
18
+ # model config
19
+ model:
20
+ model: SanaMS_600M_P1_D28
21
+ image_size: 1024
22
+ mixed_precision: fp16
23
+ fp32_attention: true
24
+ load_from:
25
+ resume_from:
26
+ aspect_ratio_type: ASPECT_RATIO_1024
27
+ multi_scale: true
28
+ attn_type: linear
29
+ ffn_type: glumbconv
30
+ mlp_acts:
31
+ - silu
32
+ - silu
33
+ -
34
+ mlp_ratio: 2.5
35
+ use_pe: false
36
+ qk_norm: false
37
+ class_dropout_prob: 0.1
38
+ # VAE setting
39
+ vae:
40
+ vae_type: dc-ae
41
+ vae_pretrained: mit-han-lab/dc-ae-f32c32-sana-1.0
42
+ scale_factor: 0.41407
43
+ vae_latent_dim: 32
44
+ vae_downsample_rate: 32
45
+ sample_posterior: true
46
+ # text encoder
47
+ text_encoder:
48
+ text_encoder_name: gemma-2-2b-it
49
+ y_norm: true
50
+ y_norm_scale_factor: 0.01
51
+ model_max_length: 300
52
+ # CHI
53
+ chi_prompt:
54
+ - 'Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:'
55
+ - '- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.'
56
+ - '- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.'
57
+ - 'Here are examples of how to transform or refine prompts:'
58
+ - '- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.'
59
+ - '- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.'
60
+ - 'Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:'
61
+ - 'User Prompt: '
62
+ # Sana schedule Flow
63
+ scheduler:
64
+ predict_v: true
65
+ noise_schedule: linear_flow
66
+ pred_sigma: false
67
+ flow_shift: 4.0
68
+ # logit-normal timestep
69
+ weighting_scheme: logit_normal
70
+ logit_mean: 0.0
71
+ logit_std: 1.0
72
+ vis_sampler: flow_dpm-solver
73
+ # training setting
74
+ train:
75
+ num_workers: 10
76
+ seed: 1
77
+ train_batch_size: 64
78
+ num_epochs: 100
79
+ gradient_accumulation_steps: 1
80
+ grad_checkpointing: true
81
+ gradient_clip: 0.1
82
+ optimizer:
83
+ betas:
84
+ - 0.9
85
+ - 0.999
86
+ - 0.9999
87
+ eps:
88
+ - 1.0e-30
89
+ - 1.0e-16
90
+ lr: 0.0001
91
+ type: CAMEWrapper
92
+ weight_decay: 0.0
93
+ lr_schedule: constant
94
+ lr_schedule_args:
95
+ num_warmup_steps: 2000
96
+ local_save_vis: true # if save log image locally
97
+ visualize: true
98
+ eval_sampling_steps: 500
99
+ log_interval: 20
100
+ save_model_epochs: 5
101
+ save_model_steps: 500
102
+ work_dir: output/debug
103
+ online_metric: false
104
+ eval_metric_step: 2000
105
+ online_metric_dir: metric_helper
configs/sana_config/512ms/Sana_1600M_img512.yaml ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data:
2
+ data_dir: [data/data_public/dir1]
3
+ image_size: 512
4
+ caption_proportion:
5
+ prompt: 1
6
+ external_caption_suffixes: ['', _InternVL2-26B, _VILA1-5-13B]
7
+ external_clipscore_suffixes:
8
+ - _InternVL2-26B_clip_score
9
+ - _VILA1-5-13B_clip_score
10
+ - _prompt_clip_score
11
+ clip_thr_temperature: 0.1
12
+ clip_thr: 25.0
13
+ load_text_feat: false
14
+ load_vae_feat: false
15
+ transform: default_train
16
+ type: SanaWebDatasetMS
17
+ sort_dataset: false
18
+ # model config
19
+ model:
20
+ model: SanaMS_1600M_P1_D20
21
+ image_size: 512
22
+ mixed_precision: fp16 # ['fp16', 'fp32', 'bf16']
23
+ fp32_attention: true
24
+ load_from:
25
+ resume_from:
26
+ aspect_ratio_type: ASPECT_RATIO_512
27
+ multi_scale: true
28
+ attn_type: linear
29
+ ffn_type: glumbconv
30
+ mlp_acts:
31
+ - silu
32
+ - silu
33
+ -
34
+ mlp_ratio: 2.5
35
+ use_pe: false
36
+ qk_norm: false
37
+ class_dropout_prob: 0.1
38
+ # PAG
39
+ pag_applied_layers:
40
+ - 8
41
+ # VAE setting
42
+ vae:
43
+ vae_type: dc-ae
44
+ vae_pretrained: mit-han-lab/dc-ae-f32c32-sana-1.0
45
+ scale_factor: 0.41407
46
+ vae_latent_dim: 32
47
+ vae_downsample_rate: 32
48
+ sample_posterior: true
49
+ # text encoder
50
+ text_encoder:
51
+ text_encoder_name: gemma-2-2b-it
52
+ y_norm: true
53
+ y_norm_scale_factor: 0.01
54
+ model_max_length: 300
55
+ # CHI
56
+ chi_prompt:
57
+ - 'Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:'
58
+ - '- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.'
59
+ - '- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.'
60
+ - 'Here are examples of how to transform or refine prompts:'
61
+ - '- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.'
62
+ - '- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.'
63
+ - 'Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:'
64
+ - 'User Prompt: '
65
+ # Sana schedule Flow
66
+ scheduler:
67
+ predict_v: true
68
+ noise_schedule: linear_flow
69
+ pred_sigma: false
70
+ flow_shift: 1.0
71
+ # logit-normal timestep
72
+ weighting_scheme: logit_normal
73
+ logit_mean: 0.0
74
+ logit_std: 1.0
75
+ vis_sampler: flow_dpm-solver
76
+ # training setting
77
+ train:
78
+ num_workers: 10
79
+ seed: 1
80
+ train_batch_size: 64
81
+ num_epochs: 100
82
+ gradient_accumulation_steps: 1
83
+ grad_checkpointing: true
84
+ gradient_clip: 0.1
85
+ optimizer:
86
+ betas:
87
+ - 0.9
88
+ - 0.999
89
+ - 0.9999
90
+ eps:
91
+ - 1.0e-30
92
+ - 1.0e-16
93
+ lr: 0.0001
94
+ type: CAMEWrapper
95
+ weight_decay: 0.0
96
+ lr_schedule: constant
97
+ lr_schedule_args:
98
+ num_warmup_steps: 2000
99
+ local_save_vis: true # if save log image locally
100
+ visualize: true
101
+ eval_sampling_steps: 500
102
+ log_interval: 20
103
+ save_model_epochs: 5
104
+ save_model_steps: 500
105
+ work_dir: output/debug
106
+ online_metric: false
107
+ eval_metric_step: 2000
108
+ online_metric_dir: metric_helper
configs/sana_config/512ms/Sana_600M_img512.yaml ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data:
2
+ data_dir: [data/data_public/dir1]
3
+ image_size: 512
4
+ caption_proportion:
5
+ prompt: 1
6
+ external_caption_suffixes: ['', _InternVL2-26B, _VILA1-5-13B]
7
+ external_clipscore_suffixes:
8
+ - _InternVL2-26B_clip_score
9
+ - _VILA1-5-13B_clip_score
10
+ - _prompt_clip_score
11
+ clip_thr_temperature: 0.1
12
+ clip_thr: 25.0
13
+ load_text_feat: false
14
+ load_vae_feat: false
15
+ transform: default_train
16
+ type: SanaWebDatasetMS
17
+ sort_dataset: false
18
+ # model config
19
+ model:
20
+ model: SanaMS_600M_P1_D28
21
+ image_size: 512
22
+ mixed_precision: fp16
23
+ fp32_attention: true
24
+ load_from:
25
+ resume_from:
26
+ aspect_ratio_type: ASPECT_RATIO_512
27
+ multi_scale: true
28
+ #pe_interpolation: 1.
29
+ attn_type: linear
30
+ linear_head_dim: 32
31
+ ffn_type: glumbconv
32
+ mlp_acts:
33
+ - silu
34
+ - silu
35
+ - null
36
+ mlp_ratio: 2.5
37
+ use_pe: false
38
+ qk_norm: false
39
+ class_dropout_prob: 0.1
40
+ # VAE setting
41
+ vae:
42
+ vae_type: dc-ae
43
+ vae_pretrained: mit-han-lab/dc-ae-f32c32-sana-1.0
44
+ scale_factor: 0.41407
45
+ vae_latent_dim: 32
46
+ vae_downsample_rate: 32
47
+ sample_posterior: true
48
+ # text encoder
49
+ text_encoder:
50
+ text_encoder_name: gemma-2-2b-it
51
+ y_norm: true
52
+ y_norm_scale_factor: 0.01
53
+ model_max_length: 300
54
+ # CHI
55
+ chi_prompt:
56
+ - 'Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:'
57
+ - '- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.'
58
+ - '- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.'
59
+ - 'Here are examples of how to transform or refine prompts:'
60
+ - '- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.'
61
+ - '- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.'
62
+ - 'Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:'
63
+ - 'User Prompt: '
64
+ # Sana schedule Flow
65
+ scheduler:
66
+ predict_v: true
67
+ noise_schedule: linear_flow
68
+ pred_sigma: false
69
+ flow_shift: 1.0
70
+ # logit-normal timestep
71
+ weighting_scheme: logit_normal
72
+ logit_mean: 0.0
73
+ logit_std: 1.0
74
+ vis_sampler: flow_dpm-solver
75
+ # training setting
76
+ train:
77
+ num_workers: 10
78
+ seed: 1
79
+ train_batch_size: 128
80
+ num_epochs: 100
81
+ gradient_accumulation_steps: 1
82
+ grad_checkpointing: true
83
+ gradient_clip: 0.1
84
+ optimizer:
85
+ betas:
86
+ - 0.9
87
+ - 0.999
88
+ - 0.9999
89
+ eps:
90
+ - 1.0e-30
91
+ - 1.0e-16
92
+ lr: 0.0001
93
+ type: CAMEWrapper
94
+ weight_decay: 0.0
95
+ lr_schedule: constant
96
+ lr_schedule_args:
97
+ num_warmup_steps: 2000
98
+ local_save_vis: true # if save log image locally
99
+ visualize: true
100
+ eval_sampling_steps: 500
101
+ log_interval: 20
102
+ save_model_epochs: 5
103
+ save_model_steps: 500
104
+ work_dir: output/debug
105
+ online_metric: false
106
+ eval_metric_step: 2000
107
+ online_metric_dir: metric_helper
configs/sana_config/512ms/ci_Sana_600M_img512.yaml ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data:
2
+ data_dir: [data/data_public/vaef32c32_v2_512/dir1]
3
+ image_size: 512
4
+ caption_proportion:
5
+ prompt: 1
6
+ external_caption_suffixes: ['', _InternVL2-26B, _VILA1-5-13B]
7
+ external_clipscore_suffixes:
8
+ - _InternVL2-26B_clip_score
9
+ - _VILA1-5-13B_clip_score
10
+ - _prompt_clip_score
11
+ clip_thr_temperature: 0.1
12
+ clip_thr: 25.0
13
+ load_text_feat: false
14
+ load_vae_feat: false
15
+ transform: default_train
16
+ type: SanaWebDatasetMS
17
+ sort_dataset: false
18
+ # model config
19
+ model:
20
+ model: SanaMS_600M_P1_D28
21
+ image_size: 512
22
+ mixed_precision: fp16
23
+ fp32_attention: true
24
+ load_from:
25
+ resume_from:
26
+ aspect_ratio_type: ASPECT_RATIO_512
27
+ multi_scale: true
28
+ #pe_interpolation: 1.
29
+ attn_type: linear
30
+ linear_head_dim: 32
31
+ ffn_type: glumbconv
32
+ mlp_acts:
33
+ - silu
34
+ - silu
35
+ - null
36
+ mlp_ratio: 2.5
37
+ use_pe: false
38
+ qk_norm: false
39
+ class_dropout_prob: 0.1
40
+ # VAE setting
41
+ vae:
42
+ vae_type: dc-ae
43
+ vae_pretrained: mit-han-lab/dc-ae-f32c32-sana-1.0
44
+ scale_factor: 0.41407
45
+ vae_latent_dim: 32
46
+ vae_downsample_rate: 32
47
+ sample_posterior: true
48
+ # text encoder
49
+ text_encoder:
50
+ text_encoder_name: gemma-2-2b-it
51
+ y_norm: true
52
+ y_norm_scale_factor: 0.01
53
+ model_max_length: 300
54
+ # CHI
55
+ chi_prompt:
56
+ - 'Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:'
57
+ - '- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.'
58
+ - '- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.'
59
+ - 'Here are examples of how to transform or refine prompts:'
60
+ - '- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.'
61
+ - '- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.'
62
+ - 'Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:'
63
+ - 'User Prompt: '
64
+ # Sana schedule Flow
65
+ scheduler:
66
+ predict_v: true
67
+ noise_schedule: linear_flow
68
+ pred_sigma: false
69
+ flow_shift: 1.0
70
+ # logit-normal timestep
71
+ weighting_scheme: logit_normal
72
+ logit_mean: 0.0
73
+ logit_std: 1.0
74
+ vis_sampler: flow_dpm-solver
75
+ # training setting
76
+ train:
77
+ num_workers: 10
78
+ seed: 1
79
+ train_batch_size: 64
80
+ num_epochs: 1
81
+ gradient_accumulation_steps: 1
82
+ grad_checkpointing: true
83
+ gradient_clip: 0.1
84
+ optimizer:
85
+ betas:
86
+ - 0.9
87
+ - 0.999
88
+ - 0.9999
89
+ eps:
90
+ - 1.0e-30
91
+ - 1.0e-16
92
+ lr: 0.0001
93
+ type: CAMEWrapper
94
+ weight_decay: 0.0
95
+ lr_schedule: constant
96
+ lr_schedule_args:
97
+ num_warmup_steps: 2000
98
+ local_save_vis: true # if save log image locally
99
+ visualize: true
100
+ eval_sampling_steps: 500
101
+ log_interval: 20
102
+ save_model_epochs: 5
103
+ save_model_steps: 500
104
+ work_dir: output/debug
105
+ online_metric: false
106
+ eval_metric_step: 2000
107
+ online_metric_dir: metric_helper
configs/sana_config/512ms/sample_dataset.yaml ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data:
2
+ data_dir: [asset/example_data]
3
+ image_size: 512
4
+ caption_proportion:
5
+ prompt: 1
6
+ external_caption_suffixes: ['', _InternVL2-26B, _VILA1-5-13B] # json fils
7
+ external_clipscore_suffixes: # json files
8
+ - _InternVL2-26B_clip_score
9
+ - _VILA1-5-13B_clip_score
10
+ - _prompt_clip_score
11
+ clip_thr_temperature: 0.1
12
+ clip_thr: 25.0
13
+ load_text_feat: false
14
+ load_vae_feat: false
15
+ transform: default_train
16
+ type: SanaImgDataset
17
+ sort_dataset: false
18
+ # model config
19
+ model:
20
+ model: SanaMS_600M_P1_D28
21
+ image_size: 512
22
+ mixed_precision: fp16
23
+ fp32_attention: true
24
+ load_from:
25
+ resume_from:
26
+ aspect_ratio_type: ASPECT_RATIO_512
27
+ multi_scale: false
28
+ #pe_interpolation: 1.
29
+ attn_type: linear
30
+ linear_head_dim: 32
31
+ ffn_type: glumbconv
32
+ mlp_acts:
33
+ - silu
34
+ - silu
35
+ - null
36
+ mlp_ratio: 2.5
37
+ use_pe: false
38
+ qk_norm: false
39
+ class_dropout_prob: 0.1
40
+ # VAE setting
41
+ vae:
42
+ vae_type: dc-ae
43
+ vae_pretrained: mit-han-lab/dc-ae-f32c32-sana-1.0
44
+ scale_factor: 0.41407
45
+ vae_latent_dim: 32
46
+ vae_downsample_rate: 32
47
+ sample_posterior: true
48
+ # text encoder
49
+ text_encoder:
50
+ text_encoder_name: gemma-2-2b-it
51
+ y_norm: true
52
+ y_norm_scale_factor: 0.01
53
+ model_max_length: 300
54
+ # CHI
55
+ chi_prompt:
56
+ - 'Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:'
57
+ - '- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.'
58
+ - '- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.'
59
+ - 'Here are examples of how to transform or refine prompts:'
60
+ - '- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.'
61
+ - '- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.'
62
+ - 'Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:'
63
+ - 'User Prompt: '
64
+ # Sana schedule Flow
65
+ scheduler:
66
+ predict_v: true
67
+ noise_schedule: linear_flow
68
+ pred_sigma: false
69
+ flow_shift: 1.0
70
+ # logit-normal timestep
71
+ weighting_scheme: logit_normal
72
+ logit_mean: 0.0
73
+ logit_std: 1.0
74
+ vis_sampler: flow_dpm-solver
75
+ # training setting
76
+ train:
77
+ num_workers: 10
78
+ seed: 1
79
+ train_batch_size: 128
80
+ num_epochs: 100
81
+ gradient_accumulation_steps: 1
82
+ grad_checkpointing: true
83
+ gradient_clip: 0.1
84
+ optimizer:
85
+ betas:
86
+ - 0.9
87
+ - 0.999
88
+ - 0.9999
89
+ eps:
90
+ - 1.0e-30
91
+ - 1.0e-16
92
+ lr: 0.0001
93
+ type: CAMEWrapper
94
+ weight_decay: 0.0
95
+ lr_schedule: constant
96
+ lr_schedule_args:
97
+ num_warmup_steps: 2000
98
+ local_save_vis: true # if save log image locally
99
+ visualize: true
100
+ eval_sampling_steps: 500
101
+ log_interval: 20
102
+ save_model_epochs: 5
103
+ save_model_steps: 500
104
+ work_dir: output/debug
105
+ online_metric: false
106
+ eval_metric_step: 2000
107
+ online_metric_dir: metric_helper
diffusion/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from OpenAI's diffusion repos
2
+ # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3
+ # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4
+ # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5
+
6
+ from .dpm_solver import DPMS
7
+ from .flow_euler_sampler import FlowEuler
8
+ from .iddpm import Scheduler
9
+ from .sa_sampler import SASolverSampler
diffusion/data/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .datasets import *
2
+ from .transforms import get_transform
diffusion/data/builder.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+
17
+ import os
18
+ import time
19
+
20
+ from mmcv import Registry, build_from_cfg
21
+ from termcolor import colored
22
+ from torch.utils.data import DataLoader
23
+
24
+ from diffusion.data.transforms import get_transform
25
+ from diffusion.utils.logger import get_root_logger
26
+
27
+ DATASETS = Registry("datasets")
28
+
29
+ DATA_ROOT = "data"
30
+
31
+
32
+ def set_data_root(data_root):
33
+ global DATA_ROOT
34
+ DATA_ROOT = data_root
35
+
36
+
37
+ def get_data_path(data_dir):
38
+ if os.path.isabs(data_dir):
39
+ return data_dir
40
+ global DATA_ROOT
41
+ return os.path.join(DATA_ROOT, data_dir)
42
+
43
+
44
+ def get_data_root_and_path(data_dir):
45
+ if os.path.isabs(data_dir):
46
+ return data_dir
47
+ global DATA_ROOT
48
+ return DATA_ROOT, os.path.join(DATA_ROOT, data_dir)
49
+
50
+
51
+ def build_dataset(cfg, resolution=224, **kwargs):
52
+ logger = get_root_logger()
53
+
54
+ dataset_type = cfg.get("type")
55
+ logger.info(f"Constructing dataset {dataset_type}...")
56
+ t = time.time()
57
+ transform = cfg.pop("transform", "default_train")
58
+ transform = get_transform(transform, resolution)
59
+ dataset = build_from_cfg(cfg, DATASETS, default_args=dict(transform=transform, resolution=resolution, **kwargs))
60
+ logger.info(
61
+ f"{colored(f'Dataset {dataset_type} constructed: ', 'green', attrs=['bold'])}"
62
+ f"time: {(time.time() - t):.2f} s, length (use/ori): {len(dataset)}/{dataset.ori_imgs_nums}"
63
+ )
64
+ return dataset
65
+
66
+
67
+ def build_dataloader(dataset, batch_size=256, num_workers=4, shuffle=True, **kwargs):
68
+ if "batch_sampler" in kwargs:
69
+ dataloader = DataLoader(
70
+ dataset, batch_sampler=kwargs["batch_sampler"], num_workers=num_workers, pin_memory=True
71
+ )
72
+ else:
73
+ dataloader = DataLoader(
74
+ dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=True, **kwargs
75
+ )
76
+ return dataloader
diffusion/data/datasets/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .sana_data import SanaImgDataset, SanaWebDataset
2
+ from .sana_data_multi_scale import DummyDatasetMS, SanaWebDatasetMS
3
+ from .utils import *
diffusion/data/datasets/sana_data.py ADDED
@@ -0,0 +1,467 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+
17
+ # This file is modified from https://github.com/PixArt-alpha/PixArt-sigma
18
+ import getpass
19
+ import json
20
+ import os
21
+ import os.path as osp
22
+ import random
23
+
24
+ import numpy as np
25
+ import torch
26
+ import torch.distributed as dist
27
+ from PIL import Image
28
+ from termcolor import colored
29
+ from torch.utils.data import Dataset
30
+
31
+ from diffusion.data.builder import DATASETS, get_data_path
32
+ from diffusion.data.wids import ShardListDataset, ShardListDatasetMulti, lru_json_load
33
+ from diffusion.utils.logger import get_root_logger
34
+
35
+
36
+ @DATASETS.register_module()
37
+ class SanaImgDataset(torch.utils.data.Dataset):
38
+ def __init__(
39
+ self,
40
+ data_dir="",
41
+ transform=None,
42
+ resolution=256,
43
+ load_vae_feat=False,
44
+ load_text_feat=False,
45
+ max_length=300,
46
+ config=None,
47
+ caption_proportion=None,
48
+ external_caption_suffixes=None,
49
+ external_clipscore_suffixes=None,
50
+ clip_thr=0.0,
51
+ clip_thr_temperature=1.0,
52
+ img_extension=".png",
53
+ **kwargs,
54
+ ):
55
+ if external_caption_suffixes is None:
56
+ external_caption_suffixes = []
57
+ if external_clipscore_suffixes is None:
58
+ external_clipscore_suffixes = []
59
+
60
+ self.logger = (
61
+ get_root_logger() if config is None else get_root_logger(osp.join(config.work_dir, "train_log.log"))
62
+ )
63
+ self.transform = transform if not load_vae_feat else None
64
+ self.load_vae_feat = load_vae_feat
65
+ self.load_text_feat = load_text_feat
66
+ self.resolution = resolution
67
+ self.max_length = max_length
68
+ self.caption_proportion = caption_proportion if caption_proportion is not None else {"prompt": 1.0}
69
+ self.external_caption_suffixes = external_caption_suffixes
70
+ self.external_clipscore_suffixes = external_clipscore_suffixes
71
+ self.clip_thr = clip_thr
72
+ self.clip_thr_temperature = clip_thr_temperature
73
+ self.default_prompt = "prompt"
74
+ self.img_extension = img_extension
75
+
76
+ self.data_dirs = data_dir if isinstance(data_dir, list) else [data_dir]
77
+ # self.meta_datas = [osp.join(data_dir, "meta_data.json") for data_dir in self.data_dirs]
78
+ self.dataset = []
79
+ for data_dir in self.data_dirs:
80
+ meta_data = json.load(open(osp.join(data_dir, "meta_data.json")))
81
+ self.dataset.extend([osp.join(data_dir, i) for i in meta_data["img_names"]])
82
+
83
+ self.dataset = self.dataset * 2000
84
+ self.logger.info(colored("Dataset is repeat 2000 times for toy dataset", "red", attrs=["bold"]))
85
+ self.ori_imgs_nums = len(self)
86
+ self.logger.info(f"Dataset samples: {len(self.dataset)}")
87
+
88
+ self.logger.info(f"Loading external caption json from: original_filename{external_caption_suffixes}.json")
89
+ self.logger.info(f"Loading external clipscore json from: original_filename{external_clipscore_suffixes}.json")
90
+ self.logger.info(f"external caption clipscore threshold: {clip_thr}, temperature: {clip_thr_temperature}")
91
+ self.logger.info(f"T5 max token length: {self.max_length}")
92
+
93
+ def getdata(self, idx):
94
+ data = self.dataset[idx]
95
+ self.key = data.split("/")[-1]
96
+ # info = json.load(open(f"{data}.json"))[self.key]
97
+ info = {}
98
+ with open(f"{data}.txt") as f:
99
+ info[self.default_prompt] = f.readlines()[0].strip()
100
+
101
+ # external json file
102
+ for suffix in self.external_caption_suffixes:
103
+ caption_json_path = f"{data}{suffix}.json"
104
+ if os.path.exists(caption_json_path):
105
+ try:
106
+ caption_json = lru_json_load(caption_json_path)
107
+ except:
108
+ caption_json = {}
109
+ if self.key in caption_json:
110
+ info.update(caption_json[self.key])
111
+
112
+ caption_type, caption_clipscore = self.weighted_sample_clipscore(data, info)
113
+ caption_type = caption_type if caption_type in info else self.default_prompt
114
+ txt_fea = "" if info[caption_type] is None else info[caption_type]
115
+
116
+ data_info = {
117
+ "img_hw": torch.tensor([self.resolution, self.resolution], dtype=torch.float32),
118
+ "aspect_ratio": torch.tensor(1.0),
119
+ }
120
+
121
+ if self.load_vae_feat:
122
+ assert ValueError("Load VAE is not supported now")
123
+ else:
124
+ img = f"{data}{self.img_extension}"
125
+ img = Image.open(img)
126
+ if self.transform:
127
+ img = self.transform(img)
128
+
129
+ attention_mask = torch.ones(1, 1, self.max_length, dtype=torch.int16) # 1x1xT
130
+ if self.load_text_feat:
131
+ npz_path = f"{self.key}.npz"
132
+ txt_info = np.load(npz_path)
133
+ txt_fea = torch.from_numpy(txt_info["caption_feature"]) # 1xTx4096
134
+ if "attention_mask" in txt_info:
135
+ attention_mask = torch.from_numpy(txt_info["attention_mask"])[None]
136
+ # make sure the feature length are the same
137
+ if txt_fea.shape[1] != self.max_length:
138
+ txt_fea = torch.cat([txt_fea, txt_fea[:, -1:].repeat(1, self.max_length - txt_fea.shape[1], 1)], dim=1)
139
+ attention_mask = torch.cat(
140
+ [attention_mask, torch.zeros(1, 1, self.max_length - attention_mask.shape[-1])], dim=-1
141
+ )
142
+
143
+ return (
144
+ img,
145
+ txt_fea,
146
+ attention_mask.to(torch.int16),
147
+ data_info,
148
+ idx,
149
+ caption_type,
150
+ "",
151
+ str(caption_clipscore),
152
+ )
153
+
154
+ def __getitem__(self, idx):
155
+ for _ in range(10):
156
+ try:
157
+ data = self.getdata(idx)
158
+ return data
159
+ except Exception as e:
160
+ print(f"Error details: {str(e)}")
161
+ idx = idx + 1
162
+ raise RuntimeError("Too many bad data.")
163
+
164
+ def __len__(self):
165
+ return len(self.dataset)
166
+
167
+ def weighted_sample_fix_prob(self):
168
+ labels = list(self.caption_proportion.keys())
169
+ weights = list(self.caption_proportion.values())
170
+ sampled_label = random.choices(labels, weights=weights, k=1)[0]
171
+ return sampled_label
172
+
173
+ def weighted_sample_clipscore(self, data, info):
174
+ labels = []
175
+ weights = []
176
+ fallback_label = None
177
+ max_clip_score = float("-inf")
178
+
179
+ for suffix in self.external_clipscore_suffixes:
180
+ clipscore_json_path = f"{data}{suffix}.json"
181
+
182
+ if os.path.exists(clipscore_json_path):
183
+ try:
184
+ clipscore_json = lru_json_load(clipscore_json_path)
185
+ except:
186
+ clipscore_json = {}
187
+ if self.key in clipscore_json:
188
+ clip_scores = clipscore_json[self.key]
189
+
190
+ for caption_type, clip_score in clip_scores.items():
191
+ clip_score = float(clip_score)
192
+ if caption_type in info:
193
+ if clip_score >= self.clip_thr:
194
+ labels.append(caption_type)
195
+ weights.append(clip_score)
196
+
197
+ if clip_score > max_clip_score:
198
+ max_clip_score = clip_score
199
+ fallback_label = caption_type
200
+
201
+ if not labels and fallback_label:
202
+ return fallback_label, max_clip_score
203
+
204
+ if not labels:
205
+ return self.default_prompt, 0.0
206
+
207
+ adjusted_weights = np.array(weights) ** (1.0 / max(self.clip_thr_temperature, 0.01))
208
+ normalized_weights = adjusted_weights / np.sum(adjusted_weights)
209
+ sampled_label = random.choices(labels, weights=normalized_weights, k=1)[0]
210
+ # sampled_label = random.choices(labels, weights=[1]*len(weights), k=1)[0]
211
+ index = labels.index(sampled_label)
212
+ original_weight = weights[index]
213
+
214
+ return sampled_label, original_weight
215
+
216
+
217
+ @DATASETS.register_module()
218
+ class SanaWebDataset(torch.utils.data.Dataset):
219
+ def __init__(
220
+ self,
221
+ data_dir="",
222
+ meta_path=None,
223
+ cache_dir="/cache/data/sana-webds-meta",
224
+ max_shards_to_load=None,
225
+ transform=None,
226
+ resolution=256,
227
+ load_vae_feat=False,
228
+ load_text_feat=False,
229
+ max_length=300,
230
+ config=None,
231
+ caption_proportion=None,
232
+ sort_dataset=False,
233
+ num_replicas=None,
234
+ external_caption_suffixes=None,
235
+ external_clipscore_suffixes=None,
236
+ clip_thr=0.0,
237
+ clip_thr_temperature=1.0,
238
+ **kwargs,
239
+ ):
240
+ if external_caption_suffixes is None:
241
+ external_caption_suffixes = []
242
+ if external_clipscore_suffixes is None:
243
+ external_clipscore_suffixes = []
244
+
245
+ self.logger = (
246
+ get_root_logger() if config is None else get_root_logger(osp.join(config.work_dir, "train_log.log"))
247
+ )
248
+ self.transform = transform if not load_vae_feat else None
249
+ self.load_vae_feat = load_vae_feat
250
+ self.load_text_feat = load_text_feat
251
+ self.resolution = resolution
252
+ self.max_length = max_length
253
+ self.caption_proportion = caption_proportion if caption_proportion is not None else {"prompt": 1.0}
254
+ self.external_caption_suffixes = external_caption_suffixes
255
+ self.external_clipscore_suffixes = external_clipscore_suffixes
256
+ self.clip_thr = clip_thr
257
+ self.clip_thr_temperature = clip_thr_temperature
258
+ self.default_prompt = "prompt"
259
+
260
+ data_dirs = data_dir if isinstance(data_dir, list) else [data_dir]
261
+ meta_paths = meta_path if isinstance(meta_path, list) else [meta_path] * len(data_dirs)
262
+ self.meta_paths = []
263
+ for data_path, meta_path in zip(data_dirs, meta_paths):
264
+ self.data_path = osp.expanduser(data_path)
265
+ self.meta_path = osp.expanduser(meta_path) if meta_path is not None else None
266
+
267
+ _local_meta_path = osp.join(self.data_path, "wids-meta.json")
268
+ if meta_path is None and osp.exists(_local_meta_path):
269
+ self.logger.info(f"loading from {_local_meta_path}")
270
+ self.meta_path = meta_path = _local_meta_path
271
+
272
+ if meta_path is None:
273
+ self.meta_path = osp.join(
274
+ osp.expanduser(cache_dir),
275
+ self.data_path.replace("/", "--") + f".max_shards:{max_shards_to_load}" + ".wdsmeta.json",
276
+ )
277
+
278
+ assert osp.exists(self.meta_path), f"meta path not found in [{self.meta_path}] or [{_local_meta_path}]"
279
+ self.logger.info(f"[SimplyInternal] Loading meta information {self.meta_path}")
280
+ self.meta_paths.append(self.meta_path)
281
+
282
+ self._initialize_dataset(num_replicas, sort_dataset)
283
+
284
+ self.logger.info(f"Loading external caption json from: original_filename{external_caption_suffixes}.json")
285
+ self.logger.info(f"Loading external clipscore json from: original_filename{external_clipscore_suffixes}.json")
286
+ self.logger.info(f"external caption clipscore threshold: {clip_thr}, temperature: {clip_thr_temperature}")
287
+ self.logger.info(f"T5 max token length: {self.max_length}")
288
+ self.logger.warning(f"Sort the dataset: {sort_dataset}")
289
+
290
+ def _initialize_dataset(self, num_replicas, sort_dataset):
291
+ # uuid = abs(hash(self.meta_path)) % (10 ** 8)
292
+ import hashlib
293
+
294
+ uuid = hashlib.sha256(self.meta_path.encode()).hexdigest()[:8]
295
+ if len(self.meta_paths) > 0:
296
+ self.dataset = ShardListDatasetMulti(
297
+ self.meta_paths,
298
+ cache_dir=osp.expanduser(f"~/.cache/_wids_cache/{getpass.getuser()}-{uuid}"),
299
+ sort_data_inseq=sort_dataset,
300
+ num_replicas=num_replicas or dist.get_world_size(),
301
+ )
302
+ else:
303
+ # TODO: tmp to ensure there is no bug
304
+ self.dataset = ShardListDataset(
305
+ self.meta_path,
306
+ cache_dir=osp.expanduser(f"~/.cache/_wids_cache/{getpass.getuser()}-{uuid}"),
307
+ )
308
+ self.ori_imgs_nums = len(self)
309
+ self.logger.info(f"{self.dataset.data_info}")
310
+
311
+ def getdata(self, idx):
312
+ data = self.dataset[idx]
313
+ info = data[".json"]
314
+ self.key = data["__key__"]
315
+ dataindex_info = {
316
+ "index": data["__index__"],
317
+ "shard": "/".join(data["__shard__"].rsplit("/", 2)[-2:]),
318
+ "shardindex": data["__shardindex__"],
319
+ }
320
+
321
+ # external json file
322
+ for suffix in self.external_caption_suffixes:
323
+ caption_json_path = data["__shard__"].replace(".tar", f"{suffix}.json")
324
+ if os.path.exists(caption_json_path):
325
+ try:
326
+ caption_json = lru_json_load(caption_json_path)
327
+ except:
328
+ caption_json = {}
329
+ if self.key in caption_json:
330
+ info.update(caption_json[self.key])
331
+
332
+ caption_type, caption_clipscore = self.weighted_sample_clipscore(data, info)
333
+ caption_type = caption_type if caption_type in info else self.default_prompt
334
+ txt_fea = "" if info[caption_type] is None else info[caption_type]
335
+
336
+ data_info = {
337
+ "img_hw": torch.tensor([self.resolution, self.resolution], dtype=torch.float32),
338
+ "aspect_ratio": torch.tensor(1.0),
339
+ }
340
+
341
+ if self.load_vae_feat:
342
+ img = data[".npy"]
343
+ else:
344
+ img = data[".png"] if ".png" in data else data[".jpg"]
345
+ if self.transform:
346
+ img = self.transform(img)
347
+
348
+ attention_mask = torch.ones(1, 1, self.max_length, dtype=torch.int16) # 1x1xT
349
+ if self.load_text_feat:
350
+ npz_path = f"{self.key}.npz"
351
+ txt_info = np.load(npz_path)
352
+ txt_fea = torch.from_numpy(txt_info["caption_feature"]) # 1xTx4096
353
+ if "attention_mask" in txt_info:
354
+ attention_mask = torch.from_numpy(txt_info["attention_mask"])[None]
355
+ # make sure the feature length are the same
356
+ if txt_fea.shape[1] != self.max_length:
357
+ txt_fea = torch.cat([txt_fea, txt_fea[:, -1:].repeat(1, self.max_length - txt_fea.shape[1], 1)], dim=1)
358
+ attention_mask = torch.cat(
359
+ [attention_mask, torch.zeros(1, 1, self.max_length - attention_mask.shape[-1])], dim=-1
360
+ )
361
+
362
+ return (
363
+ img,
364
+ txt_fea,
365
+ attention_mask.to(torch.int16),
366
+ data_info,
367
+ idx,
368
+ caption_type,
369
+ dataindex_info,
370
+ str(caption_clipscore),
371
+ )
372
+
373
+ def __getitem__(self, idx):
374
+ for _ in range(10):
375
+ try:
376
+ data = self.getdata(idx)
377
+ return data
378
+ except Exception as e:
379
+ print(f"Error details: {str(e)}")
380
+ idx = idx + 1
381
+ raise RuntimeError("Too many bad data.")
382
+
383
+ def __len__(self):
384
+ return len(self.dataset)
385
+
386
+ def weighted_sample_fix_prob(self):
387
+ labels = list(self.caption_proportion.keys())
388
+ weights = list(self.caption_proportion.values())
389
+ sampled_label = random.choices(labels, weights=weights, k=1)[0]
390
+ return sampled_label
391
+
392
+ def weighted_sample_clipscore(self, data, info):
393
+ labels = []
394
+ weights = []
395
+ fallback_label = None
396
+ max_clip_score = float("-inf")
397
+
398
+ for suffix in self.external_clipscore_suffixes:
399
+ clipscore_json_path = data["__shard__"].replace(".tar", f"{suffix}.json")
400
+
401
+ if os.path.exists(clipscore_json_path):
402
+ try:
403
+ clipscore_json = lru_json_load(clipscore_json_path)
404
+ except:
405
+ clipscore_json = {}
406
+ if self.key in clipscore_json:
407
+ clip_scores = clipscore_json[self.key]
408
+
409
+ for caption_type, clip_score in clip_scores.items():
410
+ clip_score = float(clip_score)
411
+ if caption_type in info:
412
+ if clip_score >= self.clip_thr:
413
+ labels.append(caption_type)
414
+ weights.append(clip_score)
415
+
416
+ if clip_score > max_clip_score:
417
+ max_clip_score = clip_score
418
+ fallback_label = caption_type
419
+
420
+ if not labels and fallback_label:
421
+ return fallback_label, max_clip_score
422
+
423
+ if not labels:
424
+ return self.default_prompt, 0.0
425
+
426
+ adjusted_weights = np.array(weights) ** (1.0 / max(self.clip_thr_temperature, 0.01))
427
+ normalized_weights = adjusted_weights / np.sum(adjusted_weights)
428
+ sampled_label = random.choices(labels, weights=normalized_weights, k=1)[0]
429
+ # sampled_label = random.choices(labels, weights=[1]*len(weights), k=1)[0]
430
+ index = labels.index(sampled_label)
431
+ original_weight = weights[index]
432
+
433
+ return sampled_label, original_weight
434
+
435
+ def get_data_info(self, idx):
436
+ try:
437
+ data = self.dataset[idx]
438
+ info = data[".json"]
439
+ key = data["__key__"]
440
+ version = info.get("version", "others")
441
+ return {"height": info["height"], "width": info["width"], "version": version, "key": key}
442
+ except Exception as e:
443
+ print(f"Error details: {str(e)}")
444
+ return None
445
+
446
+
447
+ if __name__ == "__main__":
448
+ from torch.utils.data import DataLoader
449
+
450
+ from diffusion.data.transforms import get_transform
451
+
452
+ image_size = 1024 # 256
453
+ transform = get_transform("default_train", image_size)
454
+ train_dataset = SanaWebDataset(
455
+ data_dir="debug_data_train/vaef32c32/debug_data",
456
+ resolution=image_size,
457
+ transform=transform,
458
+ max_length=300,
459
+ load_vae_feat=True,
460
+ num_replicas=1,
461
+ )
462
+ dataloader = DataLoader(train_dataset, batch_size=32, shuffle=False, num_workers=4)
463
+
464
+ for data in dataloader:
465
+ img, txt_fea, attention_mask, data_info = data
466
+ print(txt_fea)
467
+ break
diffusion/data/datasets/sana_data_multi_scale.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+
17
+ # This file is modified from https://github.com/PixArt-alpha/PixArt-sigma
18
+ import os
19
+ import random
20
+
21
+ import numpy as np
22
+ import torch
23
+ from torchvision import transforms as T
24
+ from torchvision.transforms.functional import InterpolationMode
25
+ from tqdm import tqdm
26
+
27
+ from diffusion.data.builder import DATASETS
28
+ from diffusion.data.datasets.sana_data import SanaWebDataset
29
+ from diffusion.data.datasets.utils import *
30
+ from diffusion.data.wids import lru_json_load
31
+
32
+
33
+ def get_closest_ratio(height: float, width: float, ratios: dict):
34
+ aspect_ratio = height / width
35
+ closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - aspect_ratio))
36
+ return ratios[closest_ratio], float(closest_ratio)
37
+
38
+
39
+ @DATASETS.register_module()
40
+ class SanaWebDatasetMS(SanaWebDataset):
41
+ def __init__(
42
+ self,
43
+ data_dir="",
44
+ meta_path=None,
45
+ cache_dir="/cache/data/sana-webds-meta",
46
+ max_shards_to_load=None,
47
+ transform=None,
48
+ resolution=256,
49
+ sample_subset=None,
50
+ load_vae_feat=False,
51
+ load_text_feat=False,
52
+ input_size=32,
53
+ patch_size=2,
54
+ max_length=300,
55
+ config=None,
56
+ caption_proportion=None,
57
+ sort_dataset=False,
58
+ num_replicas=None,
59
+ external_caption_suffixes=None,
60
+ external_clipscore_suffixes=None,
61
+ clip_thr=0.0,
62
+ clip_thr_temperature=1.0,
63
+ vae_downsample_rate=32,
64
+ **kwargs,
65
+ ):
66
+ super().__init__(
67
+ data_dir=data_dir,
68
+ meta_path=meta_path,
69
+ cache_dir=cache_dir,
70
+ max_shards_to_load=max_shards_to_load,
71
+ transform=transform,
72
+ resolution=resolution,
73
+ sample_subset=sample_subset,
74
+ load_vae_feat=load_vae_feat,
75
+ load_text_feat=load_text_feat,
76
+ input_size=input_size,
77
+ patch_size=patch_size,
78
+ max_length=max_length,
79
+ config=config,
80
+ caption_proportion=caption_proportion,
81
+ sort_dataset=sort_dataset,
82
+ num_replicas=num_replicas,
83
+ external_caption_suffixes=external_caption_suffixes,
84
+ external_clipscore_suffixes=external_clipscore_suffixes,
85
+ clip_thr=clip_thr,
86
+ clip_thr_temperature=clip_thr_temperature,
87
+ vae_downsample_rate=32,
88
+ **kwargs,
89
+ )
90
+ self.base_size = int(kwargs["aspect_ratio_type"].split("_")[-1])
91
+ self.aspect_ratio = eval(kwargs.pop("aspect_ratio_type")) # base aspect ratio
92
+ self.ratio_index = {}
93
+ self.ratio_nums = {}
94
+ self.interpolate_model = InterpolationMode.BICUBIC
95
+ self.interpolate_model = (
96
+ InterpolationMode.BICUBIC
97
+ if self.aspect_ratio not in [ASPECT_RATIO_2048, ASPECT_RATIO_2880]
98
+ else InterpolationMode.LANCZOS
99
+ )
100
+
101
+ for k, v in self.aspect_ratio.items():
102
+ self.ratio_index[float(k)] = []
103
+ self.ratio_nums[float(k)] = 0
104
+
105
+ self.vae_downsample_rate = vae_downsample_rate
106
+
107
+ def __getitem__(self, idx):
108
+ for _ in range(10):
109
+ try:
110
+ data = self.getdata(idx)
111
+ return data
112
+ except Exception as e:
113
+ print(f"Error details: {str(e)}")
114
+ idx = random.choice(self.ratio_index[self.closest_ratio])
115
+ raise RuntimeError("Too many bad data.")
116
+
117
+ def getdata(self, idx):
118
+ data = self.dataset[idx]
119
+ info = data[".json"]
120
+ self.key = data["__key__"]
121
+ dataindex_info = {
122
+ "index": data["__index__"],
123
+ "shard": "/".join(data["__shard__"].rsplit("/", 2)[-2:]),
124
+ "shardindex": data["__shardindex__"],
125
+ }
126
+
127
+ # external json file
128
+ for suffix in self.external_caption_suffixes:
129
+ caption_json_path = data["__shard__"].replace(".tar", f"{suffix}.json")
130
+ if os.path.exists(caption_json_path):
131
+ try:
132
+ caption_json = lru_json_load(caption_json_path)
133
+ except:
134
+ caption_json = {}
135
+ if self.key in caption_json:
136
+ info.update(caption_json[self.key])
137
+
138
+ data_info = {}
139
+ ori_h, ori_w = info["height"], info["width"]
140
+
141
+ # Calculate the closest aspect ratio and resize & crop image[w, h]
142
+ closest_size, closest_ratio = get_closest_ratio(ori_h, ori_w, self.aspect_ratio)
143
+ closest_size = list(map(lambda x: int(x), closest_size))
144
+ self.closest_ratio = closest_ratio
145
+
146
+ data_info["img_hw"] = torch.tensor([ori_h, ori_w], dtype=torch.float32)
147
+ data_info["aspect_ratio"] = closest_ratio
148
+
149
+ caption_type, caption_clipscore = self.weighted_sample_clipscore(data, info)
150
+ caption_type = caption_type if caption_type in info else self.default_prompt
151
+ txt_fea = "" if info[caption_type] is None else info[caption_type]
152
+
153
+ if self.load_vae_feat:
154
+ img = data[".npy"]
155
+ if len(img.shape) == 4 and img.shape[0] == 1:
156
+ img = img[0]
157
+ h, w = (img.shape[1], img.shape[2])
158
+ assert h == int(closest_size[0] // self.vae_downsample_rate) and w == int(
159
+ closest_size[1] // self.vae_downsample_rate
160
+ ), f"h: {h}, w: {w}, ori_hw: {closest_size}, data_info: {dataindex_info}"
161
+ else:
162
+ img = data[".png"] if ".png" in data else data[".jpg"]
163
+ if closest_size[0] / ori_h > closest_size[1] / ori_w:
164
+ resize_size = closest_size[0], int(ori_w * closest_size[0] / ori_h)
165
+ else:
166
+ resize_size = int(ori_h * closest_size[1] / ori_w), closest_size[1]
167
+ self.transform = T.Compose(
168
+ [
169
+ T.Lambda(lambda img: img.convert("RGB")),
170
+ T.Resize(resize_size, interpolation=self.interpolate_model), # Image.BICUBIC
171
+ T.CenterCrop(closest_size),
172
+ T.ToTensor(),
173
+ T.Normalize([0.5], [0.5]),
174
+ ]
175
+ )
176
+ if idx not in self.ratio_index[closest_ratio]:
177
+ self.ratio_index[closest_ratio].append(idx)
178
+
179
+ if self.transform:
180
+ img = self.transform(img)
181
+
182
+ attention_mask = torch.ones(1, 1, self.max_length, dtype=torch.int16) # 1x1xT
183
+ if self.load_text_feat:
184
+ npz_path = f"{self.key}.npz"
185
+ txt_info = np.load(npz_path)
186
+ txt_fea = torch.from_numpy(txt_info["caption_feature"]) # 1xTx4096
187
+ if "attention_mask" in txt_info:
188
+ attention_mask = torch.from_numpy(txt_info["attention_mask"])[None]
189
+ # make sure the feature length are the same
190
+ if txt_fea.shape[1] != self.max_length:
191
+ txt_fea = torch.cat([txt_fea, txt_fea[:, -1:].repeat(1, self.max_length - txt_fea.shape[1], 1)], dim=1)
192
+ attention_mask = torch.cat(
193
+ [attention_mask, torch.zeros(1, 1, self.max_length - attention_mask.shape[-1])], dim=-1
194
+ )
195
+
196
+ return (
197
+ img,
198
+ txt_fea,
199
+ attention_mask.to(torch.int16),
200
+ data_info,
201
+ idx,
202
+ caption_type,
203
+ dataindex_info,
204
+ str(caption_clipscore),
205
+ )
206
+
207
+ def __len__(self):
208
+ return len(self.dataset)
209
+
210
+
211
+ @DATASETS.register_module()
212
+ class DummyDatasetMS(SanaWebDatasetMS):
213
+ def __init__(self, **kwargs):
214
+ self.base_size = int(kwargs["aspect_ratio_type"].split("_")[-1])
215
+ self.aspect_ratio = eval(kwargs.pop("aspect_ratio_type")) # base aspect ratio
216
+ self.ratio_index = {}
217
+ self.ratio_nums = {}
218
+ self.interpolate_model = InterpolationMode.BICUBIC
219
+ self.interpolate_model = (
220
+ InterpolationMode.BICUBIC
221
+ if self.aspect_ratio not in [ASPECT_RATIO_2048, ASPECT_RATIO_2880]
222
+ else InterpolationMode.LANCZOS
223
+ )
224
+
225
+ for k, v in self.aspect_ratio.items():
226
+ self.ratio_index[float(k)] = []
227
+ self.ratio_nums[float(k)] = 0
228
+
229
+ self.ori_imgs_nums = 1_000_000
230
+ self.height = 384
231
+ self.width = 672
232
+
233
+ def __getitem__(self, idx):
234
+ img = torch.randn((3, self.height, self.width))
235
+ txt_fea = "The image depicts a young woman standing in the middle of a street, leaning against a silver car. She is dressed in a stylish outfit consisting of a blue blouse and black pants. Her hair is long and dark, and she is looking directly at the camera with a confident expression. The street is lined with colorful buildings, and the trees have autumn leaves, suggesting the season is fall. The lighting is warm, with sunlight casting long shadows on the street. There are a few people in the background, and the overall atmosphere is vibrant and lively."
236
+ attention_mask = torch.ones(1, 1, 300, dtype=torch.int16) # 1x1xT
237
+ data_info = {"img_hw": torch.tensor([816.0, 1456.0]), "aspect_ratio": 0.57}
238
+ idx = 2500
239
+ caption_type = self.default_prompt
240
+ dataindex_info = {"index": 2500, "shard": "data_for_test_after_change/00000000.tar", "shardindex": 2500}
241
+ return img, txt_fea, attention_mask, data_info, idx, caption_type, dataindex_info
242
+
243
+ def __len__(self):
244
+ return self.ori_imgs_nums
245
+
246
+ def get_data_info(self, idx):
247
+ return {"height": self.height, "width": self.width, "version": "1.0", "key": "dummpy_key"}
248
+
249
+
250
+ if __name__ == "__main__":
251
+ from torch.utils.data import DataLoader
252
+
253
+ from diffusion.data.datasets.utils import ASPECT_RATIO_1024
254
+ from diffusion.data.transforms import get_transform
255
+
256
+ image_size = 256
257
+ transform = get_transform("default_train", image_size)
258
+ data_dir = ["data/debug_data_train/debug_data"]
259
+ for data_path in data_dir:
260
+ train_dataset = SanaWebDatasetMS(data_dir=data_path, resolution=image_size, transform=transform, max_length=300)
261
+ dataloader = DataLoader(train_dataset, batch_size=1, shuffle=False, num_workers=4)
262
+
263
+ for data in tqdm(dataloader):
264
+ break
265
+ print(dataloader.dataset.index_info)
diffusion/data/datasets/utils.py ADDED
@@ -0,0 +1,506 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+
17
+ # This file is modified from https://github.com/PixArt-alpha/PixArt-sigma
18
+ ASPECT_RATIO_4096 = {
19
+ "0.25": [2048.0, 8192.0],
20
+ "0.26": [2048.0, 7936.0],
21
+ "0.27": [2048.0, 7680.0],
22
+ "0.28": [2048.0, 7424.0],
23
+ "0.32": [2304.0, 7168.0],
24
+ "0.33": [2304.0, 6912.0],
25
+ "0.35": [2304.0, 6656.0],
26
+ "0.4": [2560.0, 6400.0],
27
+ "0.42": [2560.0, 6144.0],
28
+ "0.48": [2816.0, 5888.0],
29
+ "0.5": [2816.0, 5632.0],
30
+ "0.52": [2816.0, 5376.0],
31
+ "0.57": [3072.0, 5376.0],
32
+ "0.6": [3072.0, 5120.0],
33
+ "0.68": [3328.0, 4864.0],
34
+ "0.72": [3328.0, 4608.0],
35
+ "0.78": [3584.0, 4608.0],
36
+ "0.82": [3584.0, 4352.0],
37
+ "0.88": [3840.0, 4352.0],
38
+ "0.94": [3840.0, 4096.0],
39
+ "1.0": [4096.0, 4096.0],
40
+ "1.07": [4096.0, 3840.0],
41
+ "1.13": [4352.0, 3840.0],
42
+ "1.21": [4352.0, 3584.0],
43
+ "1.29": [4608.0, 3584.0],
44
+ "1.38": [4608.0, 3328.0],
45
+ "1.46": [4864.0, 3328.0],
46
+ "1.67": [5120.0, 3072.0],
47
+ "1.75": [5376.0, 3072.0],
48
+ "2.0": [5632.0, 2816.0],
49
+ "2.09": [5888.0, 2816.0],
50
+ "2.4": [6144.0, 2560.0],
51
+ "2.5": [6400.0, 2560.0],
52
+ "2.89": [6656.0, 2304.0],
53
+ "3.0": [6912.0, 2304.0],
54
+ "3.11": [7168.0, 2304.0],
55
+ "3.62": [7424.0, 2048.0],
56
+ "3.75": [7680.0, 2048.0],
57
+ "3.88": [7936.0, 2048.0],
58
+ "4.0": [8192.0, 2048.0],
59
+ }
60
+
61
+ ASPECT_RATIO_2880 = {
62
+ "0.25": [1408.0, 5760.0],
63
+ "0.26": [1408.0, 5568.0],
64
+ "0.27": [1408.0, 5376.0],
65
+ "0.28": [1408.0, 5184.0],
66
+ "0.32": [1600.0, 4992.0],
67
+ "0.33": [1600.0, 4800.0],
68
+ "0.34": [1600.0, 4672.0],
69
+ "0.4": [1792.0, 4480.0],
70
+ "0.42": [1792.0, 4288.0],
71
+ "0.47": [1920.0, 4096.0],
72
+ "0.49": [1920.0, 3904.0],
73
+ "0.51": [1920.0, 3776.0],
74
+ "0.55": [2112.0, 3840.0],
75
+ "0.59": [2112.0, 3584.0],
76
+ "0.68": [2304.0, 3392.0],
77
+ "0.72": [2304.0, 3200.0],
78
+ "0.78": [2496.0, 3200.0],
79
+ "0.83": [2496.0, 3008.0],
80
+ "0.89": [2688.0, 3008.0],
81
+ "0.93": [2688.0, 2880.0],
82
+ "1.0": [2880.0, 2880.0],
83
+ "1.07": [2880.0, 2688.0],
84
+ "1.12": [3008.0, 2688.0],
85
+ "1.21": [3008.0, 2496.0],
86
+ "1.28": [3200.0, 2496.0],
87
+ "1.39": [3200.0, 2304.0],
88
+ "1.47": [3392.0, 2304.0],
89
+ "1.7": [3584.0, 2112.0],
90
+ "1.82": [3840.0, 2112.0],
91
+ "2.03": [3904.0, 1920.0],
92
+ "2.13": [4096.0, 1920.0],
93
+ "2.39": [4288.0, 1792.0],
94
+ "2.5": [4480.0, 1792.0],
95
+ "2.92": [4672.0, 1600.0],
96
+ "3.0": [4800.0, 1600.0],
97
+ "3.12": [4992.0, 1600.0],
98
+ "3.68": [5184.0, 1408.0],
99
+ "3.82": [5376.0, 1408.0],
100
+ "3.95": [5568.0, 1408.0],
101
+ "4.0": [5760.0, 1408.0],
102
+ }
103
+
104
+ ASPECT_RATIO_2048 = {
105
+ "0.25": [1024.0, 4096.0],
106
+ "0.26": [1024.0, 3968.0],
107
+ "0.27": [1024.0, 3840.0],
108
+ "0.28": [1024.0, 3712.0],
109
+ "0.32": [1152.0, 3584.0],
110
+ "0.33": [1152.0, 3456.0],
111
+ "0.35": [1152.0, 3328.0],
112
+ "0.4": [1280.0, 3200.0],
113
+ "0.42": [1280.0, 3072.0],
114
+ "0.48": [1408.0, 2944.0],
115
+ "0.5": [1408.0, 2816.0],
116
+ "0.52": [1408.0, 2688.0],
117
+ "0.57": [1536.0, 2688.0],
118
+ "0.6": [1536.0, 2560.0],
119
+ "0.68": [1664.0, 2432.0],
120
+ "0.72": [1664.0, 2304.0],
121
+ "0.78": [1792.0, 2304.0],
122
+ "0.82": [1792.0, 2176.0],
123
+ "0.88": [1920.0, 2176.0],
124
+ "0.94": [1920.0, 2048.0],
125
+ "1.0": [2048.0, 2048.0],
126
+ "1.07": [2048.0, 1920.0],
127
+ "1.13": [2176.0, 1920.0],
128
+ "1.21": [2176.0, 1792.0],
129
+ "1.29": [2304.0, 1792.0],
130
+ "1.38": [2304.0, 1664.0],
131
+ "1.46": [2432.0, 1664.0],
132
+ "1.67": [2560.0, 1536.0],
133
+ "1.75": [2688.0, 1536.0],
134
+ "2.0": [2816.0, 1408.0],
135
+ "2.09": [2944.0, 1408.0],
136
+ "2.4": [3072.0, 1280.0],
137
+ "2.5": [3200.0, 1280.0],
138
+ "2.89": [3328.0, 1152.0],
139
+ "3.0": [3456.0, 1152.0],
140
+ "3.11": [3584.0, 1152.0],
141
+ "3.62": [3712.0, 1024.0],
142
+ "3.75": [3840.0, 1024.0],
143
+ "3.88": [3968.0, 1024.0],
144
+ "4.0": [4096.0, 1024.0],
145
+ }
146
+
147
+ ASPECT_RATIO_1024 = {
148
+ "0.25": [512.0, 2048.0],
149
+ "0.26": [512.0, 1984.0],
150
+ "0.27": [512.0, 1920.0],
151
+ "0.28": [512.0, 1856.0],
152
+ "0.32": [576.0, 1792.0],
153
+ "0.33": [576.0, 1728.0],
154
+ "0.35": [576.0, 1664.0],
155
+ "0.4": [640.0, 1600.0],
156
+ "0.42": [640.0, 1536.0],
157
+ "0.48": [704.0, 1472.0],
158
+ "0.5": [704.0, 1408.0],
159
+ "0.52": [704.0, 1344.0],
160
+ "0.57": [768.0, 1344.0],
161
+ "0.6": [768.0, 1280.0],
162
+ "0.68": [832.0, 1216.0],
163
+ "0.72": [832.0, 1152.0],
164
+ "0.78": [896.0, 1152.0],
165
+ "0.82": [896.0, 1088.0],
166
+ "0.88": [960.0, 1088.0],
167
+ "0.94": [960.0, 1024.0],
168
+ "1.0": [1024.0, 1024.0],
169
+ "1.07": [1024.0, 960.0],
170
+ "1.13": [1088.0, 960.0],
171
+ "1.21": [1088.0, 896.0],
172
+ "1.29": [1152.0, 896.0],
173
+ "1.38": [1152.0, 832.0],
174
+ "1.46": [1216.0, 832.0],
175
+ "1.67": [1280.0, 768.0],
176
+ "1.75": [1344.0, 768.0],
177
+ "2.0": [1408.0, 704.0],
178
+ "2.09": [1472.0, 704.0],
179
+ "2.4": [1536.0, 640.0],
180
+ "2.5": [1600.0, 640.0],
181
+ "2.89": [1664.0, 576.0],
182
+ "3.0": [1728.0, 576.0],
183
+ "3.11": [1792.0, 576.0],
184
+ "3.62": [1856.0, 512.0],
185
+ "3.75": [1920.0, 512.0],
186
+ "3.88": [1984.0, 512.0],
187
+ "4.0": [2048.0, 512.0],
188
+ }
189
+
190
+ ASPECT_RATIO_512 = {
191
+ "0.25": [256.0, 1024.0],
192
+ "0.26": [256.0, 992.0],
193
+ "0.27": [256.0, 960.0],
194
+ "0.28": [256.0, 928.0],
195
+ "0.32": [288.0, 896.0],
196
+ "0.33": [288.0, 864.0],
197
+ "0.35": [288.0, 832.0],
198
+ "0.4": [320.0, 800.0],
199
+ "0.42": [320.0, 768.0],
200
+ "0.48": [352.0, 736.0],
201
+ "0.5": [352.0, 704.0],
202
+ "0.52": [352.0, 672.0],
203
+ "0.57": [384.0, 672.0],
204
+ "0.6": [384.0, 640.0],
205
+ "0.68": [416.0, 608.0],
206
+ "0.72": [416.0, 576.0],
207
+ "0.78": [448.0, 576.0],
208
+ "0.82": [448.0, 544.0],
209
+ "0.88": [480.0, 544.0],
210
+ "0.94": [480.0, 512.0],
211
+ "1.0": [512.0, 512.0],
212
+ "1.07": [512.0, 480.0],
213
+ "1.13": [544.0, 480.0],
214
+ "1.21": [544.0, 448.0],
215
+ "1.29": [576.0, 448.0],
216
+ "1.38": [576.0, 416.0],
217
+ "1.46": [608.0, 416.0],
218
+ "1.67": [640.0, 384.0],
219
+ "1.75": [672.0, 384.0],
220
+ "2.0": [704.0, 352.0],
221
+ "2.09": [736.0, 352.0],
222
+ "2.4": [768.0, 320.0],
223
+ "2.5": [800.0, 320.0],
224
+ "2.89": [832.0, 288.0],
225
+ "3.0": [864.0, 288.0],
226
+ "3.11": [896.0, 288.0],
227
+ "3.62": [928.0, 256.0],
228
+ "3.75": [960.0, 256.0],
229
+ "3.88": [992.0, 256.0],
230
+ "4.0": [1024.0, 256.0],
231
+ }
232
+
233
+ ASPECT_RATIO_256 = {
234
+ "0.25": [128.0, 512.0],
235
+ "0.26": [128.0, 496.0],
236
+ "0.27": [128.0, 480.0],
237
+ "0.28": [128.0, 464.0],
238
+ "0.32": [144.0, 448.0],
239
+ "0.33": [144.0, 432.0],
240
+ "0.35": [144.0, 416.0],
241
+ "0.4": [160.0, 400.0],
242
+ "0.42": [160.0, 384.0],
243
+ "0.48": [176.0, 368.0],
244
+ "0.5": [176.0, 352.0],
245
+ "0.52": [176.0, 336.0],
246
+ "0.57": [192.0, 336.0],
247
+ "0.6": [192.0, 320.0],
248
+ "0.68": [208.0, 304.0],
249
+ "0.72": [208.0, 288.0],
250
+ "0.78": [224.0, 288.0],
251
+ "0.82": [224.0, 272.0],
252
+ "0.88": [240.0, 272.0],
253
+ "0.94": [240.0, 256.0],
254
+ "1.0": [256.0, 256.0],
255
+ "1.07": [256.0, 240.0],
256
+ "1.13": [272.0, 240.0],
257
+ "1.21": [272.0, 224.0],
258
+ "1.29": [288.0, 224.0],
259
+ "1.38": [288.0, 208.0],
260
+ "1.46": [304.0, 208.0],
261
+ "1.67": [320.0, 192.0],
262
+ "1.75": [336.0, 192.0],
263
+ "2.0": [352.0, 176.0],
264
+ "2.09": [368.0, 176.0],
265
+ "2.4": [384.0, 160.0],
266
+ "2.5": [400.0, 160.0],
267
+ "2.89": [416.0, 144.0],
268
+ "3.0": [432.0, 144.0],
269
+ "3.11": [448.0, 144.0],
270
+ "3.62": [464.0, 128.0],
271
+ "3.75": [480.0, 128.0],
272
+ "3.88": [496.0, 128.0],
273
+ "4.0": [512.0, 128.0],
274
+ }
275
+
276
+ ASPECT_RATIO_256_TEST = {
277
+ "0.25": [128.0, 512.0],
278
+ "0.28": [128.0, 464.0],
279
+ "0.32": [144.0, 448.0],
280
+ "0.33": [144.0, 432.0],
281
+ "0.35": [144.0, 416.0],
282
+ "0.4": [160.0, 400.0],
283
+ "0.42": [160.0, 384.0],
284
+ "0.48": [176.0, 368.0],
285
+ "0.5": [176.0, 352.0],
286
+ "0.52": [176.0, 336.0],
287
+ "0.57": [192.0, 336.0],
288
+ "0.6": [192.0, 320.0],
289
+ "0.68": [208.0, 304.0],
290
+ "0.72": [208.0, 288.0],
291
+ "0.78": [224.0, 288.0],
292
+ "0.82": [224.0, 272.0],
293
+ "0.88": [240.0, 272.0],
294
+ "0.94": [240.0, 256.0],
295
+ "1.0": [256.0, 256.0],
296
+ "1.07": [256.0, 240.0],
297
+ "1.13": [272.0, 240.0],
298
+ "1.21": [272.0, 224.0],
299
+ "1.29": [288.0, 224.0],
300
+ "1.38": [288.0, 208.0],
301
+ "1.46": [304.0, 208.0],
302
+ "1.67": [320.0, 192.0],
303
+ "1.75": [336.0, 192.0],
304
+ "2.0": [352.0, 176.0],
305
+ "2.09": [368.0, 176.0],
306
+ "2.4": [384.0, 160.0],
307
+ "2.5": [400.0, 160.0],
308
+ "3.0": [432.0, 144.0],
309
+ "4.0": [512.0, 128.0],
310
+ }
311
+
312
+ ASPECT_RATIO_512_TEST = {
313
+ "0.25": [256.0, 1024.0],
314
+ "0.28": [256.0, 928.0],
315
+ "0.32": [288.0, 896.0],
316
+ "0.33": [288.0, 864.0],
317
+ "0.35": [288.0, 832.0],
318
+ "0.4": [320.0, 800.0],
319
+ "0.42": [320.0, 768.0],
320
+ "0.48": [352.0, 736.0],
321
+ "0.5": [352.0, 704.0],
322
+ "0.52": [352.0, 672.0],
323
+ "0.57": [384.0, 672.0],
324
+ "0.6": [384.0, 640.0],
325
+ "0.68": [416.0, 608.0],
326
+ "0.72": [416.0, 576.0],
327
+ "0.78": [448.0, 576.0],
328
+ "0.82": [448.0, 544.0],
329
+ "0.88": [480.0, 544.0],
330
+ "0.94": [480.0, 512.0],
331
+ "1.0": [512.0, 512.0],
332
+ "1.07": [512.0, 480.0],
333
+ "1.13": [544.0, 480.0],
334
+ "1.21": [544.0, 448.0],
335
+ "1.29": [576.0, 448.0],
336
+ "1.38": [576.0, 416.0],
337
+ "1.46": [608.0, 416.0],
338
+ "1.67": [640.0, 384.0],
339
+ "1.75": [672.0, 384.0],
340
+ "2.0": [704.0, 352.0],
341
+ "2.09": [736.0, 352.0],
342
+ "2.4": [768.0, 320.0],
343
+ "2.5": [800.0, 320.0],
344
+ "3.0": [864.0, 288.0],
345
+ "4.0": [1024.0, 256.0],
346
+ }
347
+
348
+ ASPECT_RATIO_1024_TEST = {
349
+ "0.25": [512.0, 2048.0],
350
+ "0.28": [512.0, 1856.0],
351
+ "0.32": [576.0, 1792.0],
352
+ "0.33": [576.0, 1728.0],
353
+ "0.35": [576.0, 1664.0],
354
+ "0.4": [640.0, 1600.0],
355
+ "0.42": [640.0, 1536.0],
356
+ "0.48": [704.0, 1472.0],
357
+ "0.5": [704.0, 1408.0],
358
+ "0.52": [704.0, 1344.0],
359
+ "0.57": [768.0, 1344.0],
360
+ "0.6": [768.0, 1280.0],
361
+ "0.68": [832.0, 1216.0],
362
+ "0.72": [832.0, 1152.0],
363
+ "0.78": [896.0, 1152.0],
364
+ "0.82": [896.0, 1088.0],
365
+ "0.88": [960.0, 1088.0],
366
+ "0.94": [960.0, 1024.0],
367
+ "1.0": [1024.0, 1024.0],
368
+ "1.07": [1024.0, 960.0],
369
+ "1.13": [1088.0, 960.0],
370
+ "1.21": [1088.0, 896.0],
371
+ "1.29": [1152.0, 896.0],
372
+ "1.38": [1152.0, 832.0],
373
+ "1.46": [1216.0, 832.0],
374
+ "1.67": [1280.0, 768.0],
375
+ "1.75": [1344.0, 768.0],
376
+ "2.0": [1408.0, 704.0],
377
+ "2.09": [1472.0, 704.0],
378
+ "2.4": [1536.0, 640.0],
379
+ "2.5": [1600.0, 640.0],
380
+ "3.0": [1728.0, 576.0],
381
+ "4.0": [2048.0, 512.0],
382
+ }
383
+
384
+ ASPECT_RATIO_2048_TEST = {
385
+ "0.25": [1024.0, 4096.0],
386
+ "0.26": [1024.0, 3968.0],
387
+ "0.32": [1152.0, 3584.0],
388
+ "0.33": [1152.0, 3456.0],
389
+ "0.35": [1152.0, 3328.0],
390
+ "0.4": [1280.0, 3200.0],
391
+ "0.42": [1280.0, 3072.0],
392
+ "0.48": [1408.0, 2944.0],
393
+ "0.5": [1408.0, 2816.0],
394
+ "0.52": [1408.0, 2688.0],
395
+ "0.57": [1536.0, 2688.0],
396
+ "0.6": [1536.0, 2560.0],
397
+ "0.68": [1664.0, 2432.0],
398
+ "0.72": [1664.0, 2304.0],
399
+ "0.78": [1792.0, 2304.0],
400
+ "0.82": [1792.0, 2176.0],
401
+ "0.88": [1920.0, 2176.0],
402
+ "0.94": [1920.0, 2048.0],
403
+ "1.0": [2048.0, 2048.0],
404
+ "1.07": [2048.0, 1920.0],
405
+ "1.13": [2176.0, 1920.0],
406
+ "1.21": [2176.0, 1792.0],
407
+ "1.29": [2304.0, 1792.0],
408
+ "1.38": [2304.0, 1664.0],
409
+ "1.46": [2432.0, 1664.0],
410
+ "1.67": [2560.0, 1536.0],
411
+ "1.75": [2688.0, 1536.0],
412
+ "2.0": [2816.0, 1408.0],
413
+ "2.09": [2944.0, 1408.0],
414
+ "2.4": [3072.0, 1280.0],
415
+ "2.5": [3200.0, 1280.0],
416
+ "3.0": [3456.0, 1152.0],
417
+ "4.0": [4096.0, 1024.0],
418
+ }
419
+
420
+ ASPECT_RATIO_2880_TEST = {
421
+ "0.25": [2048.0, 8192.0],
422
+ "0.26": [2048.0, 7936.0],
423
+ "0.32": [2304.0, 7168.0],
424
+ "0.33": [2304.0, 6912.0],
425
+ "0.35": [2304.0, 6656.0],
426
+ "0.4": [2560.0, 6400.0],
427
+ "0.42": [2560.0, 6144.0],
428
+ "0.48": [2816.0, 5888.0],
429
+ "0.5": [2816.0, 5632.0],
430
+ "0.52": [2816.0, 5376.0],
431
+ "0.57": [3072.0, 5376.0],
432
+ "0.6": [3072.0, 5120.0],
433
+ "0.68": [3328.0, 4864.0],
434
+ "0.72": [3328.0, 4608.0],
435
+ "0.78": [3584.0, 4608.0],
436
+ "0.82": [3584.0, 4352.0],
437
+ "0.88": [3840.0, 4352.0],
438
+ "0.94": [3840.0, 4096.0],
439
+ "1.0": [4096.0, 4096.0],
440
+ "1.07": [4096.0, 3840.0],
441
+ "1.13": [4352.0, 3840.0],
442
+ "1.21": [4352.0, 3584.0],
443
+ "1.29": [4608.0, 3584.0],
444
+ "1.38": [4608.0, 3328.0],
445
+ "1.46": [4864.0, 3328.0],
446
+ "1.67": [5120.0, 3072.0],
447
+ "1.75": [5376.0, 3072.0],
448
+ "2.0": [5632.0, 2816.0],
449
+ "2.09": [5888.0, 2816.0],
450
+ "2.4": [6144.0, 2560.0],
451
+ "2.5": [6400.0, 2560.0],
452
+ "3.0": [6912.0, 2304.0],
453
+ "4.0": [8192.0, 2048.0],
454
+ }
455
+
456
+ ASPECT_RATIO_4096_TEST = {
457
+ "0.25": [2048.0, 8192.0],
458
+ "0.26": [2048.0, 7936.0],
459
+ "0.27": [2048.0, 7680.0],
460
+ "0.28": [2048.0, 7424.0],
461
+ "0.32": [2304.0, 7168.0],
462
+ "0.33": [2304.0, 6912.0],
463
+ "0.35": [2304.0, 6656.0],
464
+ "0.4": [2560.0, 6400.0],
465
+ "0.42": [2560.0, 6144.0],
466
+ "0.48": [2816.0, 5888.0],
467
+ "0.5": [2816.0, 5632.0],
468
+ "0.52": [2816.0, 5376.0],
469
+ "0.57": [3072.0, 5376.0],
470
+ "0.6": [3072.0, 5120.0],
471
+ "0.68": [3328.0, 4864.0],
472
+ "0.72": [3328.0, 4608.0],
473
+ "0.78": [3584.0, 4608.0],
474
+ "0.82": [3584.0, 4352.0],
475
+ "0.88": [3840.0, 4352.0],
476
+ "0.94": [3840.0, 4096.0],
477
+ "1.0": [4096.0, 4096.0],
478
+ "1.07": [4096.0, 3840.0],
479
+ "1.13": [4352.0, 3840.0],
480
+ "1.21": [4352.0, 3584.0],
481
+ "1.29": [4608.0, 3584.0],
482
+ "1.38": [4608.0, 3328.0],
483
+ "1.46": [4864.0, 3328.0],
484
+ "1.67": [5120.0, 3072.0],
485
+ "1.75": [5376.0, 3072.0],
486
+ "2.0": [5632.0, 2816.0],
487
+ "2.09": [5888.0, 2816.0],
488
+ "2.4": [6144.0, 2560.0],
489
+ "2.5": [6400.0, 2560.0],
490
+ "2.89": [6656.0, 2304.0],
491
+ "3.0": [6912.0, 2304.0],
492
+ "3.11": [7168.0, 2304.0],
493
+ "3.62": [7424.0, 2048.0],
494
+ "3.75": [7680.0, 2048.0],
495
+ "3.88": [7936.0, 2048.0],
496
+ "4.0": [8192.0, 2048.0],
497
+ }
498
+
499
+ ASPECT_RATIO_1280_TEST = {"1.0": [1280.0, 1280.0]}
500
+ ASPECT_RATIO_1536_TEST = {"1.0": [1536.0, 1536.0]}
501
+ ASPECT_RATIO_768_TEST = {"1.0": [768.0, 768.0]}
502
+
503
+
504
+ def get_chunks(lst, n):
505
+ for i in range(0, len(lst), n):
506
+ yield lst[i : i + n]
diffusion/data/transforms.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+
17
+ import torchvision.transforms as T
18
+
19
+ TRANSFORMS = dict()
20
+
21
+
22
+ def register_transform(transform):
23
+ name = transform.__name__
24
+ if name in TRANSFORMS:
25
+ raise RuntimeError(f"Transform {name} has already registered.")
26
+ TRANSFORMS.update({name: transform})
27
+
28
+
29
+ def get_transform(type, resolution):
30
+ transform = TRANSFORMS[type](resolution)
31
+ transform = T.Compose(transform)
32
+ transform.image_size = resolution
33
+ return transform
34
+
35
+
36
+ @register_transform
37
+ def default_train(n_px):
38
+ transform = [
39
+ T.Lambda(lambda img: img.convert("RGB")),
40
+ T.Resize(n_px), # Image.BICUBIC
41
+ T.CenterCrop(n_px),
42
+ # T.RandomHorizontalFlip(),
43
+ T.ToTensor(),
44
+ T.Normalize([0.5], [0.5]),
45
+ ]
46
+ return transform
diffusion/data/wids/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2017-2019 NVIDIA CORPORATION. All rights reserved.
2
+ # This file is part of the WebDataset library.
3
+ # See the LICENSE file for licensing terms (BSD-style).
4
+ #
5
+ # flake8: noqa
6
+
7
+ from .wids import (
8
+ ChunkedSampler,
9
+ DistributedChunkedSampler,
10
+ DistributedLocalSampler,
11
+ DistributedRangedSampler,
12
+ ShardedSampler,
13
+ ShardListDataset,
14
+ ShardListDatasetMulti,
15
+ lru_json_load,
16
+ )
diffusion/data/wids/wids.py ADDED
@@ -0,0 +1,1051 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+
17
+ # This file is modified from https://github.com/NVlabs/VILA/tree/main/llava/wids
18
+ import base64
19
+ import gzip
20
+ import hashlib
21
+ import io
22
+ import json
23
+ import math
24
+ import os
25
+ import os.path as osp
26
+ import random
27
+ import re
28
+ import sqlite3
29
+ import sys
30
+ import tempfile
31
+ import uuid
32
+ import warnings
33
+ from functools import lru_cache, partial
34
+ from typing import Any, BinaryIO, Dict, Optional, TypeVar, Union
35
+ from urllib.parse import quote, urlparse
36
+
37
+ import numpy as np
38
+ import torch
39
+ import torch.distributed as dist
40
+ from torch.utils.data.distributed import DistributedSampler
41
+
42
+ from .wids_dl import download_and_open
43
+ from .wids_lru import LRUCache
44
+ from .wids_mmtar import MMIndexedTar
45
+ from .wids_specs import load_dsdesc_and_resolve, urldir
46
+ from .wids_tar import TarFileReader, find_index_file
47
+
48
+ try:
49
+ from torch.utils.data import Dataset, Sampler
50
+ except ImportError:
51
+
52
+ class Dataset:
53
+ pass
54
+
55
+ class Sampler:
56
+ pass
57
+
58
+
59
+ T = TypeVar("T")
60
+
61
+ T_co = TypeVar("T_co", covariant=True)
62
+
63
+
64
+ def compute_file_md5sum(fname: Union[str, BinaryIO], chunksize: int = 1000000) -> str:
65
+ """Compute the md5sum of a file in chunks.
66
+
67
+ Parameters
68
+ ----------
69
+ fname : Union[str, BinaryIO]
70
+ Filename or file object
71
+ chunksize : int, optional
72
+ Chunk size in bytes, by default 1000000
73
+
74
+ Returns
75
+ -------
76
+ str
77
+ MD5 sum of the file
78
+
79
+ Examples
80
+ --------
81
+ >>> compute_file_md5sum("test.txt")
82
+ 'd41d8cd98f00b204e9800998ecf8427e'
83
+ """
84
+ md5 = hashlib.md5()
85
+ if isinstance(fname, str):
86
+ with open(fname, "rb") as f:
87
+ for chunk in iter(lambda: f.read(chunksize), b""):
88
+ md5.update(chunk)
89
+ else:
90
+ fname.seek(0)
91
+ for chunk in iter(lambda: fname.read(chunksize), b""):
92
+ md5.update(chunk)
93
+ return md5.hexdigest()
94
+
95
+
96
+ def compute_file_md5sum(fname: Union[str, BinaryIO], chunksize: int = 1000000) -> str:
97
+ """Compute the md5sum of a file in chunks."""
98
+ md5 = hashlib.md5()
99
+ if isinstance(fname, str):
100
+ with open(fname, "rb") as f:
101
+ for chunk in iter(lambda: f.read(chunksize), b""):
102
+ md5.update(chunk)
103
+ else:
104
+ fname.seek(0)
105
+ for chunk in iter(lambda: fname.read(chunksize), b""):
106
+ md5.update(chunk)
107
+ return md5.hexdigest()
108
+
109
+
110
+ def compute_num_samples(fname):
111
+ ds = IndexedTarSamples(fname)
112
+ return len(ds)
113
+
114
+
115
+ def splitname(fname):
116
+ """Returns the basename and extension of a filename"""
117
+ assert "." in fname, "Filename must have an extension"
118
+ # basename, extension = re.match(r"^((?:.*/)?.*?)(\..*)$", fname).groups()
119
+ basename, extension = os.path.splitext(fname)
120
+ return basename, extension
121
+
122
+
123
+ # NOTE(ligeng): change to ordered mapping to more flexbile dict
124
+ # TODO(ligeng): submit a PR to fix the mapping issue.
125
+ def group_by_key(names):
126
+ """Group the file names by key.
127
+
128
+ Args:
129
+ names: A list of file names.
130
+
131
+ Returns:
132
+ A list of lists of indices, where each sublist contains indices of files
133
+ with the same key.
134
+ """
135
+ groups = []
136
+ kmaps = {}
137
+ for i, fname in enumerate(names):
138
+ # Ignore files that are not in a subdirectory.
139
+ if "." not in fname:
140
+ print(f"Warning: Ignoring file {fname} (no '.')")
141
+ continue
142
+ if fname == ".":
143
+ print(f"Warning: Ignoring the '.' file.")
144
+ continue
145
+ key, ext = splitname(fname)
146
+ if key not in kmaps:
147
+ kmaps[key] = []
148
+ kmaps[key].append(i)
149
+ for k, v in kmaps.items():
150
+ groups.append(v)
151
+ return groups
152
+
153
+
154
+ def default_decoder(sample: Dict[str, Any], format: Optional[Union[bool, str]] = True):
155
+ """A default decoder for webdataset.
156
+
157
+ This handles common file extensions: .txt, .cls, .cls2,
158
+ .jpg, .png, .json, .npy, .mp, .pt, .pth, .pickle, .pkl.
159
+ These are the most common extensions used in webdataset.
160
+ For other extensions, users can provide their own decoder.
161
+
162
+ Args:
163
+ sample: sample, modified in place
164
+ """
165
+ sample = dict(sample)
166
+ for key, stream in sample.items():
167
+ extensions = key.split(".")
168
+ if len(extensions) < 1:
169
+ continue
170
+ extension = extensions[-1]
171
+ if extension in ["gz"]:
172
+ decompressed = gzip.decompress(stream.read())
173
+ stream = io.BytesIO(decompressed)
174
+ if len(extensions) < 2:
175
+ sample[key] = stream
176
+ continue
177
+ extension = extensions[-2]
178
+ if key.startswith("__"):
179
+ continue
180
+ elif extension in ["txt", "text"]:
181
+ value = stream.read()
182
+ sample[key] = value.decode("utf-8")
183
+ elif extension in ["cls", "cls2"]:
184
+ value = stream.read()
185
+ sample[key] = int(value.decode("utf-8"))
186
+ elif extension in ["jpg", "png", "ppm", "pgm", "pbm", "pnm"]:
187
+ if format == "PIL":
188
+ import PIL.Image
189
+
190
+ sample[key] = PIL.Image.open(stream)
191
+ elif format == "numpy":
192
+ import numpy as np
193
+
194
+ sample[key] = np.asarray(PIL.Image.open(stream))
195
+ else:
196
+ raise ValueError(f"Unknown format: {format}")
197
+ elif extension == "json":
198
+ import json
199
+
200
+ value = stream.read()
201
+ sample[key] = json.loads(value)
202
+ elif extension == "npy":
203
+ import numpy as np
204
+
205
+ sample[key] = np.load(stream)
206
+ elif extension == "mp":
207
+ import msgpack
208
+
209
+ value = stream.read()
210
+ sample[key] = msgpack.unpackb(value, raw=False)
211
+ elif extension in ["pt", "pth"]:
212
+ import torch
213
+
214
+ sample[key] = torch.load(stream)
215
+ elif extension in ["pickle", "pkl"]:
216
+ import pickle
217
+
218
+ sample[key] = pickle.load(stream)
219
+ elif extension == "mp4":
220
+ # Write stream to a temporary file
221
+ # with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as tmpfile:
222
+ # tmpfile.write(stream.read())
223
+ # tmpfile_path = tmpfile.name
224
+
225
+ # sample[key] = tmpfile_path
226
+ sample[key] = io.BytesIO(stream.read())
227
+ return sample
228
+
229
+
230
+ def update_dict_with_extend(original_dict, update_dict):
231
+ for key, value in update_dict.items():
232
+ if key in original_dict and isinstance(original_dict[key], list) and isinstance(value, list):
233
+ original_dict[key].extend(value)
234
+ else:
235
+ original_dict[key] = value
236
+
237
+
238
+ open_itfs = {}
239
+
240
+
241
+ class IndexedTarSamples:
242
+ """A class that accesses samples in a tar file. The tar file must follow
243
+ WebDataset conventions. The tar file is indexed when the IndexedTarSamples
244
+ object is created. The samples are accessed by index using the __getitem__
245
+ method. The __getitem__ method returns a dictionary containing the files
246
+ for the sample. The key for each file is the extension of the file name.
247
+ The key "__key__" is reserved for the key of the sample (the basename of
248
+ each file without the extension). For example, if the tar file contains
249
+ the files "sample1.jpg" and "sample1.txt", then the sample with key
250
+ "sample1" will be returned as the dictionary {"jpg": ..., "txt": ...}.
251
+ """
252
+
253
+ def __init__(
254
+ self,
255
+ *,
256
+ path=None,
257
+ stream=None,
258
+ md5sum=None,
259
+ expected_size=None,
260
+ use_mmap=True,
261
+ index_file=find_index_file,
262
+ ):
263
+ assert path is not None or stream is not None
264
+
265
+ # Create TarFileReader object to read from tar_file
266
+ self.path = path
267
+ stream = self.stream = stream or open(path, "rb")
268
+
269
+ # verify the MD5 sum
270
+ if md5sum is not None:
271
+ stream.seek(0)
272
+ got = compute_file_md5sum(stream)
273
+ assert got == md5sum, f"MD5 sum mismatch: expected {md5sum}, got {got}"
274
+ stream.seek(0)
275
+
276
+ # use either the mmap or the stream based implementation
277
+ # NOTE(ligeng): https://stackoverflow.com/questions/11072705/twitter-trends-api-unicodedecodeerror-utf8-codec-cant-decode-byte-0x8b-in-po
278
+ # import gzip
279
+ # print("convert to gzip IO stream")
280
+ # stream = gzip.GzipFile(fileobj=stream)
281
+
282
+ if use_mmap:
283
+ self.reader = MMIndexedTar(stream)
284
+ else:
285
+ self.reader = TarFileReader(stream, index_file=index_file)
286
+
287
+ # Get list of all files in stream
288
+ all_files = self.reader.names()
289
+
290
+ # Group files by key into samples
291
+ self.samples = group_by_key(all_files)
292
+ # print("DEBUG:", list(all_files)[:20])
293
+ # print("DEBUG:", self.samples[:20])
294
+
295
+ # check that the number of samples is correct
296
+ if expected_size is not None:
297
+ assert len(self) == expected_size, f"Expected {expected_size} samples, got {len(self)}"
298
+
299
+ self.uuid = str(uuid.uuid4())
300
+
301
+ def close(self):
302
+ self.reader.close()
303
+ if not self.stream.closed:
304
+ self.stream.close()
305
+
306
+ def __len__(self):
307
+ return len(self.samples)
308
+
309
+ def __getitem__(self, idx):
310
+ # Get indexes of files for the sample at index idx
311
+ try:
312
+ indexes = self.samples[idx]
313
+ except IndexError as e:
314
+ print(f"[wids-debug] curr idx: {idx}, total sample length: {len(self.samples)} {e}")
315
+ raise e
316
+ sample = {}
317
+ key = None
318
+ for i in indexes:
319
+ # Get filename and data for the file at index i
320
+ fname, data = self.reader.get_file(i)
321
+ # Split filename into key and extension
322
+ k, ext = splitname(fname)
323
+ # Make sure all files in sample have same key
324
+ key = key or k
325
+ assert key == k
326
+ sample[ext] = data
327
+ # Add key to sample
328
+ sample["__key__"] = key
329
+ return sample
330
+
331
+ def __str__(self):
332
+ return f"<IndexedTarSamples-{id(self)} {self.path}>"
333
+
334
+ def __repr__(self):
335
+ return str(self)
336
+
337
+
338
+ def hash_localname(dldir="/tmp/_wids_cache"):
339
+ os.makedirs(dldir, exist_ok=True)
340
+
341
+ connection = sqlite3.connect(os.path.join(dldir, "cache.db"))
342
+ cursor = connection.cursor()
343
+ cursor.execute("CREATE TABLE IF NOT EXISTS cache (url TEXT PRIMARY KEY, path TEXT, checksum TEXT)")
344
+ connection.commit()
345
+
346
+ def f(shard):
347
+ """Given a URL, return a local name for the shard."""
348
+ if shard.startswith("pipe:"):
349
+ # uuencode the entire URL string
350
+ hex32 = base64.urlsafe_b64encode(hashlib.sha256(shard.encode()).digest())[:32].decode()
351
+ return os.path.join(dldir, "pipe__" + hex32)
352
+ else:
353
+ # we hash the host and directory components into a 16 character string
354
+ dirname = urldir(shard)
355
+ hex16 = base64.urlsafe_b64encode(hashlib.sha256(dirname.encode()).digest())[:16].decode()
356
+ # the cache name is the concatenation of the hex16 string and the file name component of the URL
357
+ cachename = "data__" + hex16 + "__" + os.path.basename(urlparse(shard).path)
358
+ checksum = None
359
+ cursor.execute(
360
+ "INSERT OR REPLACE INTO cache VALUES (?, ?, ?)",
361
+ (shard, cachename, checksum),
362
+ )
363
+ connection.commit()
364
+ return os.path.join(dldir, cachename)
365
+
366
+ return f
367
+
368
+
369
+ def cache_localname(cachedir):
370
+ os.makedirs(cachedir, exist_ok=True)
371
+
372
+ def f(shard):
373
+ """Given a URL, return a local name for the shard."""
374
+ path = urlparse(shard).path
375
+ fname = os.path.basename(path)
376
+ return os.path.join(cachedir, fname)
377
+
378
+ return f
379
+
380
+
381
+ def default_localname(dldir="/tmp/_wids_cache"):
382
+ os.makedirs(dldir, exist_ok=True)
383
+
384
+ def f(shard):
385
+ """Given a URL, return a local name for the shard."""
386
+ cachename = quote(shard, safe="+-")
387
+ return os.path.join(dldir, cachename)
388
+
389
+ return f
390
+
391
+
392
+ class LRUShards:
393
+ """A class that manages a cache of shards. The cache is a LRU cache that
394
+ stores the local names of the shards as keys and the downloaded paths as
395
+ values. The shards are downloaded to a directory specified by dldir.
396
+ The local name of a shard is computed by the localname function, which
397
+ takes the shard URL as an argument. If keep is True, the downloaded files
398
+ are not deleted when they are no longer needed.
399
+ """
400
+
401
+ def __init__(self, lru_size, keep=False, localname=default_localname()):
402
+ self.localname = localname
403
+ # the cache contains the local name as the key and the downloaded path as the value
404
+ self.lru = LRUCache(lru_size, release_handler=self.release_handler)
405
+ # keep statistics
406
+ self.reset_stats()
407
+
408
+ def reset_stats(self):
409
+ self.accesses = 0
410
+ self.misses = 0
411
+
412
+ def __len__(self):
413
+ return len(self.lru)
414
+
415
+ def release_handler(self, key, value):
416
+ value.close()
417
+
418
+ def clear(self):
419
+ self.lru.clear()
420
+
421
+ def get_shard(self, url):
422
+ assert isinstance(url, str)
423
+ self.accesses += 1
424
+ if url not in self.lru:
425
+ local = self.localname(url)
426
+ with download_and_open(url, local) as stream:
427
+ itf = IndexedTarSamples(path=local, stream=stream)
428
+ self.lru[url] = itf
429
+ self.misses += 1
430
+ self.last_missed = True
431
+ else:
432
+ self.last_missed = False
433
+ return self.lru[url]
434
+
435
+
436
+ def interpret_transformations(transformations):
437
+ """Interpret the transformations argument.
438
+
439
+ This takes care of transformations specified as string shortcuts
440
+ and returns a list of callables.
441
+ """
442
+ if not isinstance(transformations, list):
443
+ transformations = [transformations]
444
+
445
+ result = []
446
+
447
+ for transformation in transformations:
448
+ if transformation == "PIL":
449
+ transformation = partial(default_decoder, format="PIL")
450
+ elif transformation == "numpy":
451
+ transformation = partial(default_decoder, format="numpy")
452
+ else:
453
+ assert callable(transformation)
454
+ result.append(transformation)
455
+
456
+ return result
457
+
458
+
459
+ def hash_dataset_name(input_string):
460
+ """Compute a hash of the input string and return the first 16 characters of the hash."""
461
+ # Compute SHA256 hash of the input string
462
+ hash_object = hashlib.sha256(input_string.encode())
463
+ hash_digest = hash_object.digest()
464
+
465
+ # Encode the hash in base64
466
+ base64_encoded_hash = base64.urlsafe_b64encode(hash_digest)
467
+
468
+ # Return the first 16 characters of the base64-encoded hash
469
+ return base64_encoded_hash[:16].decode("ascii")
470
+
471
+
472
+ @lru_cache(maxsize=16)
473
+ def lru_json_load(fpath):
474
+ with open(fpath) as fp:
475
+ return json.load(fp)
476
+
477
+
478
+ class ShardListDataset(Dataset[T]):
479
+ """An indexable dataset based on a list of shards.
480
+
481
+ The dataset is either given as a list of shards with optional options and name,
482
+ or as a URL pointing to a JSON descriptor file.
483
+
484
+ Datasets can reference other datasets via `source_url`.
485
+
486
+ Shard references within a dataset are resolve relative to an explicitly
487
+ given `base` property, or relative to the URL from which the dataset
488
+ descriptor was loaded.
489
+ """
490
+
491
+ def __init__(
492
+ self,
493
+ shards,
494
+ *,
495
+ cache_size=int(1e12),
496
+ cache_dir=None,
497
+ lru_size=10,
498
+ dataset_name=None,
499
+ localname=None,
500
+ transformations="PIL",
501
+ keep=False,
502
+ base=None,
503
+ options=None,
504
+ ):
505
+ """Create a ShardListDataset.
506
+
507
+ Args:
508
+ shards: a list of (filename, length) pairs or a URL pointing to a JSON descriptor file
509
+ cache_size: the number of shards to keep in the cache
510
+ lru_size: the number of shards to keep in the LRU cache
511
+ localname: a function that maps URLs to local filenames
512
+
513
+ Note that there are two caches: an on-disk directory, and an in-memory LRU cache.
514
+ """
515
+ if options is None:
516
+ options = {}
517
+ super().__init__()
518
+ # shards is a list of (filename, length) pairs. We'll need to
519
+ # keep track of the lengths and cumulative lengths to know how
520
+ # to map indices to shards and indices within shards.
521
+ if isinstance(shards, (str, io.IOBase)):
522
+ if base is None and isinstance(shards, str):
523
+ shards = osp.expanduser(shards)
524
+ base = urldir(shards)
525
+ self.base = base
526
+ self.spec = load_dsdesc_and_resolve(shards, options=options, base=base)
527
+ self.shards = self.spec.get("shardlist", [])
528
+ self.dataset_name = self.spec.get("name") or hash_dataset_name(str(shards))
529
+ else:
530
+ raise NotImplementedError("Only support taking path/url to JSON descriptor file.")
531
+ self.base = None
532
+ self.spec = options
533
+ self.shards = shards
534
+ self.dataset_name = dataset_name or hash_dataset_name(str(shards))
535
+
536
+ self.lengths = [shard["nsamples"] for shard in self.shards]
537
+ self.cum_lengths = np.cumsum(self.lengths)
538
+ self.total_length = self.cum_lengths[-1]
539
+
540
+ if cache_dir is not None:
541
+ # when a cache dir is explicitly given, we download files into
542
+ # that directory without any changes
543
+ self.cache_dir = cache_dir
544
+ self.localname = cache_localname(cache_dir)
545
+ elif localname is not None:
546
+ # when a localname function is given, we use that
547
+ self.cache_dir = None
548
+ self.localname = localname
549
+ else:
550
+ import getpass
551
+
552
+ # when no cache dir or localname are given, use the cache from the environment
553
+ self.cache_dir = os.environ.get("WIDS_CACHE", f"~/.cache/_wids_cache")
554
+ self.cache_dir = osp.expanduser(self.cache_dir)
555
+ self.localname = default_localname(self.cache_dir)
556
+
557
+ self.data_info = (
558
+ f"[WebShardedList] {str(shards)}, base: {self.base,}, name: {self.spec.get('name')}, "
559
+ f"nfiles: {str(len(self.shards))}"
560
+ )
561
+ if True or int(os.environ.get("WIDS_VERBOSE", 0)):
562
+ nbytes = sum(shard.get("filesize", 0) for shard in self.shards)
563
+ nsamples = sum(shard["nsamples"] for shard in self.shards)
564
+ self.data_info += f"nbytes: {str(nbytes)}, samples: {str(nsamples),}, cache: {self.cache_dir} "
565
+ # print(
566
+ # "[WebShardedList]",
567
+ # str(shards),
568
+ # "base:",
569
+ # self.base,
570
+ # "name:",
571
+ # self.spec.get("name"),
572
+ # "nfiles:",
573
+ # len(self.shards),
574
+ # "nbytes:",
575
+ # nbytes,
576
+ # "samples:",
577
+ # nsamples,
578
+ # "cache:",
579
+ # self.cache_dir,
580
+ # file=sys.stderr,
581
+ # )
582
+ self.transformations = interpret_transformations(transformations)
583
+
584
+ if lru_size > 200:
585
+ warnings.warn("LRU size is very large; consider reducing it to avoid running out of file descriptors")
586
+ self.cache = LRUShards(lru_size, localname=self.localname, keep=keep)
587
+
588
+ def add_transform(self, transform):
589
+ """Add a transformation to the dataset."""
590
+ self.transformations.append(transform)
591
+ return self
592
+
593
+ def __len__(self):
594
+ """Return the total number of samples in the dataset."""
595
+ return self.total_length
596
+
597
+ def get_stats(self):
598
+ """Return the number of cache accesses and misses."""
599
+ return self.cache.accesses, self.cache.misses
600
+
601
+ def check_cache_misses(self):
602
+ """Check if the cache miss rate is too high."""
603
+ accesses, misses = self.get_stats()
604
+ if accesses > 100 and misses / accesses > 0.3:
605
+ # output a warning only once
606
+ self.check_cache_misses = lambda: None
607
+ print(f"Warning: ShardListDataset has a cache miss rate of {misses * 100.0 / accesses:.1%}%")
608
+
609
+ def get_shard(self, index):
610
+ """Get the shard and index within the shard corresponding to the given index."""
611
+ # Find the shard corresponding to the given index.
612
+ shard_idx = np.searchsorted(self.cum_lengths, index, side="right")
613
+
614
+ # Figure out which index within the shard corresponds to the
615
+ # given index.
616
+ if shard_idx == 0:
617
+ inner_idx = index
618
+ else:
619
+ inner_idx = index - self.cum_lengths[shard_idx - 1]
620
+
621
+ # Get the shard and return the corresponding element.
622
+ desc = self.shards[shard_idx]
623
+ url = desc["url"]
624
+ if url.startswith(("https://", "http://", "gs://", "/", "~")):
625
+ # absolute path or url path
626
+ url = url
627
+ else:
628
+ # concat relative path
629
+ if self.base is None and "base_path" not in self.spec:
630
+ raise FileNotFoundError("passing a relative path in shardlist but no base found.")
631
+ base_path = self.spec["base_path"] if "base_path" in self.spec else self.base
632
+ url = osp.abspath(osp.join(osp.expanduser(base_path), url))
633
+
634
+ desc["url"] = url
635
+ try:
636
+ shard = self.cache.get_shard(url)
637
+ except UnicodeDecodeError as e:
638
+ print("UnicodeDecodeError:", desc)
639
+ raise e
640
+ return shard, inner_idx, desc
641
+
642
+ def __getitem__(self, index):
643
+ """Return the sample corresponding to the given index."""
644
+ shard, inner_idx, desc = self.get_shard(index)
645
+ sample = shard[inner_idx]
646
+
647
+ # Check if we're missing the cache too often.
648
+ self.check_cache_misses()
649
+
650
+ sample["__dataset__"] = desc.get("dataset")
651
+ sample["__index__"] = index
652
+ sample["__shard__"] = desc["url"]
653
+ sample["__shardindex__"] = inner_idx
654
+
655
+ # Apply transformations
656
+ for transform in self.transformations:
657
+ sample = transform(sample)
658
+
659
+ return sample
660
+
661
+ def close(self):
662
+ """Close the dataset."""
663
+ self.cache.clear()
664
+
665
+
666
+ class ShardListDatasetMulti(ShardListDataset):
667
+ """An indexable dataset based on a list of shards.
668
+
669
+ The dataset is either given as a list of shards with optional options and name,
670
+ or as a URL pointing to a JSON descriptor file.
671
+
672
+ Datasets can reference other datasets via `source_url`.
673
+
674
+ Shard references within a dataset are resolve relative to an explicitly
675
+ given `base` property, or relative to the URL from which the dataset
676
+ descriptor was loaded.
677
+ """
678
+
679
+ def __init__(
680
+ self,
681
+ shards,
682
+ *,
683
+ cache_size=int(1e12),
684
+ cache_dir=None,
685
+ lru_size=10,
686
+ dataset_name=None,
687
+ localname=None,
688
+ transformations="PIL",
689
+ keep=False,
690
+ base=None,
691
+ options=None,
692
+ sort_data_inseq=False,
693
+ num_replicas=None,
694
+ ):
695
+ """Create a ShardListDataset.
696
+
697
+ Args:
698
+ shards: a list of (filename, length) pairs or a URL pointing to a JSON descriptor file
699
+ cache_size: the number of shards to keep in the cache
700
+ lru_size: the number of shards to keep in the LRU cache
701
+ localname: a function that maps URLs to local filenames
702
+
703
+ Note that there are two caches: an on-disk directory, and an in-memory LRU cache.
704
+ """
705
+ if options is None:
706
+ options = {}
707
+ # shards is a list of (filename, length) pairs. We'll need to
708
+ # keep track of the lengths and cumulative lengths to know how
709
+ # to map indices to shards and indices within shards.
710
+ shards_lists = shards if isinstance(shards, list) else [shards]
711
+ bases = base if isinstance(base, list) else [base] * len(shards_lists)
712
+ self.spec = {}
713
+ self.shards = []
714
+ self.num_per_dir = {}
715
+ for base, shards in zip(bases, shards_lists):
716
+ if isinstance(shards, (str, io.IOBase)):
717
+ if base is None and isinstance(shards, str):
718
+ shards = osp.expanduser(shards)
719
+ base = urldir(shards)
720
+ self.base = base
721
+ _spec = load_dsdesc_and_resolve(shards, options=options, base=base)
722
+ update_dict_with_extend(self.spec, _spec)
723
+ self.num_per_dir[os.path.basename(os.path.dirname(shards))] = sum(
724
+ [shard["nsamples"] for shard in _spec.get("shardlist", [])]
725
+ )
726
+ else:
727
+ raise NotImplementedError("Only support taking path/url to JSON descriptor file.")
728
+ self.base = None
729
+ self.spec = options
730
+ self.shards = shards
731
+ self.dataset_name = dataset_name or hash_dataset_name(str(shards))
732
+
733
+ if sort_data_inseq and len(self.spec.get("shardlist", [])) > 0:
734
+ num_replicas = num_replicas or dist.get_world_size()
735
+ self.spec["shardlist"] = split_and_recombine(self.spec["shardlist"], num_replicas)
736
+
737
+ self.shards.extend(self.spec.get("shardlist", []))
738
+ self.dataset_name = self.spec.get("name") or hash_dataset_name(str(shards))
739
+
740
+ self.lengths = [shard["nsamples"] for shard in self.shards]
741
+ self.cum_lengths = np.cumsum(self.lengths)
742
+ self.total_length = self.cum_lengths[-1]
743
+
744
+ if cache_dir is not None:
745
+ # when a cache dir is explicitly given, we download files into
746
+ # that directory without any changes
747
+ self.cache_dir = cache_dir
748
+ self.localname = cache_localname(cache_dir)
749
+ elif localname is not None:
750
+ # when a localname function is given, we use that
751
+ self.cache_dir = None
752
+ self.localname = localname
753
+ else:
754
+ import getpass
755
+
756
+ # when no cache dir or localname are given, use the cache from the environment
757
+ self.cache_dir = os.environ.get("WIDS_CACHE", f"~/.cache/_wids_cache")
758
+ self.cache_dir = osp.expanduser(self.cache_dir)
759
+ self.localname = default_localname(self.cache_dir)
760
+
761
+ self.data_info = (
762
+ f"[WebShardedList] {str(shards)}, base: {self.base,}, name: {self.spec.get('name')}, "
763
+ f"nfiles: {str(len(self.shards))}"
764
+ )
765
+ if True or int(os.environ.get("WIDS_VERBOSE", 0)):
766
+ nbytes = sum(shard.get("filesize", 0) for shard in self.shards)
767
+ nsamples = sum(shard["nsamples"] for shard in self.shards)
768
+ self.data_info += f"nbytes: {str(nbytes)}, samples: {str(nsamples),}, cache: {self.cache_dir} "
769
+ self.transformations = interpret_transformations(transformations)
770
+
771
+ if lru_size > 200:
772
+ warnings.warn("LRU size is very large; consider reducing it to avoid running out of file descriptors")
773
+ self.cache = LRUShards(lru_size, localname=self.localname, keep=keep)
774
+
775
+
776
+ def split_and_recombine(lst, n):
777
+ from collections import OrderedDict
778
+
779
+ def extract_prefix(i):
780
+ return i["url"].split("/")[-2]
781
+
782
+ unique_parts = list(OrderedDict((extract_prefix(item), None) for item in lst).keys())
783
+ split_dict = {part: [] for part in unique_parts}
784
+
785
+ for part in unique_parts:
786
+ part_list = [item for item in lst if extract_prefix(item) == part]
787
+ chunk_size = max(1, len(part_list) // n) # 确保 chunk_size 至少为 1
788
+ chunks = [part_list[i * chunk_size : (i + 1) * chunk_size] for i in range(n)]
789
+
790
+ # 处理最后一个 chunk,如果数量不均匀,将剩余的元素添加到最后一个 chunk
791
+ if len(part_list) % n != 0:
792
+ chunks[-1].extend(part_list[n * chunk_size :])
793
+
794
+ split_dict[part] = chunks
795
+
796
+ recombined_list = []
797
+ for i in range(n):
798
+ for part in unique_parts:
799
+ recombined_list.extend(split_dict[part][i])
800
+
801
+ return recombined_list
802
+
803
+
804
+ def lengths_to_ranges(lengths):
805
+ """Convert a list of lengths to a list of ranges."""
806
+ ranges = []
807
+ start = 0
808
+ for length in lengths:
809
+ ranges.append((start, start + length))
810
+ start += length
811
+ return ranges
812
+
813
+
814
+ def intersect_range(a, b):
815
+ """Return the intersection of the two half-open integer intervals."""
816
+ result = max(a[0], b[0]), min(a[1], b[1])
817
+ if result[0] >= result[1]:
818
+ return None
819
+ return result
820
+
821
+
822
+ def intersect_ranges(rangelist, r):
823
+ """Return the intersection of the half-open integer interval r with the list of half-open integer intervals."""
824
+ result = []
825
+ for a in rangelist:
826
+ x = intersect_range(a, r)
827
+ if x is not None:
828
+ result.append(x)
829
+ return result
830
+
831
+
832
+ def iterate_ranges(ranges, rng, indexshuffle=True, shardshuffle=True):
833
+ """Iterate over the ranges in a random order."""
834
+ shard_indexes = list(range(len(ranges)))
835
+ if shardshuffle:
836
+ rng.shuffle(shard_indexes)
837
+ for i in shard_indexes:
838
+ lo, hi = ranges[i]
839
+ sample_indexes = list(range(lo, hi))
840
+ if indexshuffle:
841
+ rng.shuffle(sample_indexes)
842
+ yield from sample_indexes
843
+
844
+
845
+ class ShardListSampler(Sampler):
846
+ """A sampler that samples consistent with a ShardListDataset.
847
+
848
+ This sampler is used to sample from a ShardListDataset in a way that
849
+ preserves locality.
850
+
851
+ This returns a permutation of the indexes by shard, then a permutation of
852
+ indexes within each shard. This ensures that the data is accessed in a
853
+ way that preserves locality.
854
+
855
+ Note that how this ends up splitting data between multiple workers ends up
856
+ on the details of the DataLoader. Generally, it will likely load samples from the
857
+ same shard in each worker.
858
+
859
+ Other more sophisticated shard-aware samplers are possible and will likely
860
+ be added.
861
+ """
862
+
863
+ def __init__(self, dataset, *, lengths=None, seed=0, shufflefirst=False):
864
+ if lengths is None:
865
+ lengths = list(dataset.lengths)
866
+ self.ranges = lengths_to_ranges(lengths)
867
+ self.seed = seed
868
+ self.shufflefirst = shufflefirst
869
+ self.epoch = 0
870
+
871
+ def __iter__(self):
872
+ self.rng = random.Random(self.seed + 1289738273 * self.epoch)
873
+ shardshuffle = self.shufflefirst or self.epoch > 0
874
+ yield from iterate_ranges(self.ranges, self.rng, shardshuffle=shardshuffle)
875
+ self.epoch += 1
876
+
877
+
878
+ ShardedSampler = ShardListSampler
879
+
880
+
881
+ class ChunkedSampler(Sampler):
882
+ """A sampler that samples in chunks and then shuffles the samples within each chunk.
883
+
884
+ This preserves locality of reference while still shuffling the data.
885
+ """
886
+
887
+ def __init__(
888
+ self,
889
+ dataset,
890
+ *,
891
+ num_samples=None,
892
+ chunksize=2000,
893
+ seed=0,
894
+ shuffle=False,
895
+ shufflefirst=False,
896
+ ):
897
+ if isinstance(num_samples, int):
898
+ lo, hi = 0, num_samples
899
+ elif num_samples is None:
900
+ lo, hi = 0, len(dataset)
901
+ else:
902
+ lo, hi = num_samples
903
+ self.ranges = [(i, min(i + chunksize, hi)) for i in range(lo, hi, chunksize)]
904
+ self.seed = seed
905
+ self.shuffle = shuffle
906
+ self.shufflefirst = shufflefirst
907
+ self.epoch = 0
908
+
909
+ def set_epoch(self, epoch):
910
+ self.epoch = epoch
911
+
912
+ def __iter__(self):
913
+ self.rng = random.Random(self.seed + 1289738273 * self.epoch)
914
+ shardshuffle = self.shufflefirst or self.epoch > 0
915
+ yield from iterate_ranges(
916
+ self.ranges,
917
+ self.rng,
918
+ indexshuffle=self.shuffle,
919
+ shardshuffle=(self.shuffle and shardshuffle),
920
+ )
921
+ self.epoch += 1
922
+
923
+ def __len__(self):
924
+ return len(self.ranges)
925
+
926
+
927
+ def DistributedChunkedSampler(
928
+ dataset: Dataset,
929
+ *,
930
+ num_replicas: Optional[int] = None,
931
+ num_samples: Optional[int] = None,
932
+ rank: Optional[int] = None,
933
+ shuffle: bool = True,
934
+ shufflefirst: bool = False,
935
+ seed: int = 0,
936
+ drop_last: bool = None,
937
+ chunksize: int = 1000000,
938
+ ) -> ChunkedSampler:
939
+ """Return a ChunkedSampler for the current worker in distributed training.
940
+
941
+ Reverts to a simple ChunkedSampler if not running in distributed mode.
942
+
943
+ Since the split among workers takes place before the chunk shuffle,
944
+ workers end up with a fixed set of shards they need to download. The
945
+ more workers, the fewer shards are used by each worker.
946
+ """
947
+ if drop_last is not None:
948
+ warnings.warn("DistributedChunkedSampler does not support drop_last, thus it will be ignored")
949
+ if not dist.is_initialized():
950
+ warnings.warn("DistributedChunkedSampler is called without distributed initialized; assuming single process")
951
+ num_replicas = 1
952
+ rank = 0
953
+ else:
954
+ num_replicas = num_replicas or dist.get_world_size()
955
+ rank = rank or dist.get_rank()
956
+ assert rank >= 0 and rank < num_replicas
957
+
958
+ num_samples = num_samples or len(dataset)
959
+ worker_chunk = (num_samples + num_replicas - 1) // num_replicas
960
+ worker_start = rank * worker_chunk
961
+ worker_end = min(worker_start + worker_chunk, num_samples)
962
+ return ChunkedSampler(
963
+ dataset,
964
+ num_samples=(worker_start, worker_end),
965
+ chunksize=chunksize,
966
+ seed=seed,
967
+ shuffle=shuffle,
968
+ shufflefirst=shufflefirst,
969
+ )
970
+
971
+
972
+ class DistributedRangedSampler(Sampler):
973
+ """A sampler that samples in chunks and then shuffles the samples within each chunk.
974
+
975
+ This preserves locality of reference while still shuffling the data.
976
+ """
977
+
978
+ def __init__(
979
+ self,
980
+ dataset: Dataset,
981
+ num_replicas: Optional[int] = None,
982
+ num_samples: Optional[int] = None,
983
+ rank: Optional[int] = None,
984
+ drop_last: bool = None,
985
+ ):
986
+ if drop_last is not None:
987
+ warnings.warn("DistributedChunkedSampler does not support drop_last, thus it will be ignored")
988
+ if not dist.is_initialized():
989
+ warnings.warn(
990
+ "DistributedChunkedSampler is called without distributed initialized; assuming single process"
991
+ )
992
+ num_replicas = 1
993
+ rank = 0
994
+ else:
995
+ num_replicas = num_replicas or dist.get_world_size()
996
+ rank = rank or dist.get_rank()
997
+ assert rank >= 0 and rank < num_replicas
998
+ num_samples = num_samples or len(dataset)
999
+ self.worker_chunk = num_samples // num_replicas
1000
+ self.worker_start = rank * self.worker_chunk
1001
+ self.worker_end = min((rank + 1) * self.worker_chunk, num_samples)
1002
+ self.ranges = range(self.worker_start, self.worker_end)
1003
+ self.epoch = 0
1004
+ self.step_start = 0
1005
+
1006
+ def set_epoch(self, epoch):
1007
+ self.epoch = epoch
1008
+
1009
+ def __len__(self):
1010
+ return len(self.ranges)
1011
+
1012
+ def set_start(self, start):
1013
+ self.step_start = start
1014
+
1015
+ def __iter__(self):
1016
+ yield from self.ranges[self.step_start :]
1017
+ self.epoch += 1
1018
+
1019
+
1020
+ class DistributedLocalSampler(DistributedSampler):
1021
+ def __iter__(self):
1022
+ if self.shuffle:
1023
+ # deterministically shuffle based on epoch and seed
1024
+ g = torch.Generator()
1025
+ g.manual_seed(self.seed + self.epoch)
1026
+ indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type]
1027
+ else:
1028
+ indices = list(range(len(self.dataset))) # type: ignore[arg-type]
1029
+
1030
+ if not self.drop_last:
1031
+ # add extra samples to make it evenly divisible
1032
+ padding_size = self.total_size - len(indices)
1033
+ if padding_size <= len(indices):
1034
+ indices += indices[:padding_size]
1035
+ else:
1036
+ indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size]
1037
+ else:
1038
+ # remove tail of data to make it evenly divisible.
1039
+ indices = indices[: self.total_size]
1040
+ assert len(indices) == self.total_size
1041
+
1042
+ # subsample
1043
+ # indices = indices[self.rank:self.total_size:self.num_replicas]
1044
+ chunk_size = self.total_size // self.num_replicas
1045
+ begin_idx = chunk_size * self.rank
1046
+ stop_idx = chunk_size * (self.rank + 1)
1047
+ indices = indices[begin_idx:stop_idx]
1048
+
1049
+ # print("[SamplerIndices: ]", indices)
1050
+ assert len(indices) == self.num_samples
1051
+ return iter(indices)
diffusion/data/wids/wids_dl.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+
17
+ # This file is copied from https://github.com/NVlabs/VILA/tree/main/llava/wids
18
+ import fcntl
19
+ import os
20
+ import shutil
21
+ import sys
22
+ import time
23
+ from collections import deque
24
+ from datetime import datetime
25
+ from urllib.parse import urlparse
26
+
27
+ recent_downloads = deque(maxlen=1000)
28
+
29
+ open_objects = {}
30
+ max_open_objects = 100
31
+
32
+
33
+ class ULockFile:
34
+ """A simple locking class. We don't need any of the third
35
+ party libraries since we rely on POSIX semantics for linking
36
+ below anyway."""
37
+
38
+ def __init__(self, path):
39
+ self.lockfile_path = path
40
+ self.lockfile = None
41
+
42
+ def __enter__(self):
43
+ self.lockfile = open(self.lockfile_path, "w")
44
+ fcntl.flock(self.lockfile.fileno(), fcntl.LOCK_EX)
45
+ return self
46
+
47
+ def __exit__(self, exc_type, exc_val, exc_tb):
48
+ fcntl.flock(self.lockfile.fileno(), fcntl.LOCK_UN)
49
+ self.lockfile.close()
50
+ self.lockfile = None
51
+ try:
52
+ os.unlink(self.lockfile_path)
53
+ except FileNotFoundError:
54
+ pass
55
+
56
+
57
+ def pipe_download(remote, local):
58
+ """Perform a download for a pipe: url."""
59
+ assert remote.startswith("pipe:")
60
+ cmd = remote[5:]
61
+ cmd = cmd.format(local=local)
62
+ assert os.system(cmd) == 0, "Command failed: %s" % cmd
63
+
64
+
65
+ def copy_file(remote, local):
66
+ remote = urlparse(remote)
67
+ assert remote.scheme in ["file", ""]
68
+ # use absolute path
69
+ remote = os.path.abspath(remote.path)
70
+ local = urlparse(local)
71
+ assert local.scheme in ["file", ""]
72
+ local = os.path.abspath(local.path)
73
+ if remote == local:
74
+ return
75
+ # check if the local file exists
76
+ shutil.copyfile(remote, local)
77
+
78
+
79
+ verbose_cmd = int(os.environ.get("WIDS_VERBOSE_CMD", "0"))
80
+
81
+
82
+ def vcmd(flag, verbose_flag=""):
83
+ return verbose_flag if verbose_cmd else flag
84
+
85
+
86
+ default_cmds = {
87
+ "posixpath": copy_file,
88
+ "file": copy_file,
89
+ "pipe": pipe_download,
90
+ "http": "curl " + vcmd("-s") + " -L {url} -o {local}",
91
+ "https": "curl " + vcmd("-s") + " -L {url} -o {local}",
92
+ "ftp": "curl " + vcmd("-s") + " -L {url} -o {local}",
93
+ "ftps": "curl " + vcmd("-s") + " -L {url} -o {local}",
94
+ "gs": "gsutil " + vcmd("-q") + " cp {url} {local}",
95
+ "s3": "aws s3 cp {url} {local}",
96
+ }
97
+
98
+
99
+ # TODO(ligeng): change HTTPS download to python requests library
100
+
101
+
102
+ def download_file_no_log(remote, local, handlers=default_cmds):
103
+ """Download a file from a remote url to a local path.
104
+ The remote url can be a pipe: url, in which case the remainder of
105
+ the url is treated as a command template that is executed to perform the download.
106
+ """
107
+
108
+ if remote.startswith("pipe:"):
109
+ schema = "pipe"
110
+ else:
111
+ schema = urlparse(remote).scheme
112
+ if schema is None or schema == "":
113
+ schema = "posixpath"
114
+ # get the handler
115
+ handler = handlers.get(schema)
116
+ if handler is None:
117
+ raise ValueError("Unknown schema: %s" % schema)
118
+ # call the handler
119
+ if callable(handler):
120
+ handler(remote, local)
121
+ else:
122
+ assert isinstance(handler, str)
123
+ cmd = handler.format(url=remote, local=local)
124
+ assert os.system(cmd) == 0, "Command failed: %s" % cmd
125
+ return local
126
+
127
+
128
+ def download_file(remote, local, handlers=default_cmds, verbose=False):
129
+ start = time.time()
130
+ try:
131
+ return download_file_no_log(remote, local, handlers=handlers)
132
+ finally:
133
+ recent_downloads.append((remote, local, time.time(), time.time() - start))
134
+ if verbose:
135
+ print(
136
+ "downloaded",
137
+ remote,
138
+ "to",
139
+ local,
140
+ "in",
141
+ time.time() - start,
142
+ "seconds",
143
+ file=sys.stderr,
144
+ )
145
+
146
+
147
+ def download_and_open(remote, local, mode="rb", handlers=default_cmds, verbose=False):
148
+ with ULockFile(local + ".lock"):
149
+ if os.path.exists(remote):
150
+ # print("enter1", remote, local, mode)
151
+ result = open(remote, mode)
152
+ else:
153
+ # print("enter2", remote, local, mode)
154
+ if not os.path.exists(local):
155
+ if verbose:
156
+ print("downloading", remote, "to", local, file=sys.stderr)
157
+ download_file(remote, local, handlers=handlers)
158
+ else:
159
+ if verbose:
160
+ print("using cached", local, file=sys.stderr)
161
+ result = open(local, mode)
162
+
163
+ # input()
164
+
165
+ if open_objects is not None:
166
+ for k, v in list(open_objects.items()):
167
+ if v.closed:
168
+ del open_objects[k]
169
+ if len(open_objects) > max_open_objects:
170
+ raise RuntimeError("Too many open objects")
171
+ current_time = datetime.now().strftime("%Y%m%d%H%M%S")
172
+ key = tuple(str(x) for x in [remote, local, mode, current_time])
173
+ open_objects[key] = result
174
+ return result
diffusion/data/wids/wids_lru.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+
17
+ # This file is copied from https://github.com/NVlabs/VILA/tree/main/llava/wids
18
+ from collections import OrderedDict
19
+
20
+
21
+ class LRUCache:
22
+ def __init__(self, capacity: int, release_handler=None):
23
+ """Initialize a new LRU cache with the given capacity."""
24
+ self.capacity = capacity
25
+ self.cache = OrderedDict()
26
+ self.release_handler = release_handler
27
+
28
+ def __getitem__(self, key):
29
+ """Return the value associated with the given key, or None."""
30
+ if key not in self.cache:
31
+ return None
32
+ self.cache.move_to_end(key)
33
+ return self.cache[key]
34
+
35
+ def __setitem__(self, key, value):
36
+ """Associate the given value with the given key."""
37
+ if key in self.cache:
38
+ self.cache.move_to_end(key)
39
+ self.cache[key] = value
40
+ if len(self.cache) > self.capacity:
41
+ key, value = self.cache.popitem(last=False)
42
+ if self.release_handler is not None:
43
+ self.release_handler(key, value)
44
+
45
+ def __delitem__(self, key):
46
+ """Remove the given key from the cache."""
47
+ if key in self.cache:
48
+ if self.release_handler is not None:
49
+ value = self.cache[key]
50
+ self.release_handler(key, value)
51
+ del self.cache[key]
52
+
53
+ def __len__(self):
54
+ """Return the number of entries in the cache."""
55
+ return len(self.cache)
56
+
57
+ def __contains__(self, key):
58
+ """Return whether the cache contains the given key."""
59
+ return key in self.cache
60
+
61
+ def items(self):
62
+ """Return an iterator over the keys of the cache."""
63
+ return self.cache.items()
64
+
65
+ def keys(self):
66
+ """Return an iterator over the keys of the cache."""
67
+ return self.cache.keys()
68
+
69
+ def values(self):
70
+ """Return an iterator over the values of the cache."""
71
+ return self.cache.values()
72
+
73
+ def clear(self):
74
+ for key in list(self.keys()):
75
+ value = self.cache[key]
76
+ if self.release_handler is not None:
77
+ self.release_handler(key, value)
78
+ del self[key]
79
+
80
+ def __del__(self):
81
+ self.clear()
diffusion/data/wids/wids_mmtar.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+
17
+ # This file is copied from https://github.com/NVlabs/VILA/tree/main/llava/wids
18
+ import collections
19
+ import fcntl
20
+ import io
21
+ import mmap
22
+ import os
23
+ import struct
24
+
25
+ TarHeader = collections.namedtuple(
26
+ "TarHeader",
27
+ [
28
+ "name",
29
+ "mode",
30
+ "uid",
31
+ "gid",
32
+ "size",
33
+ "mtime",
34
+ "chksum",
35
+ "typeflag",
36
+ "linkname",
37
+ "magic",
38
+ "version",
39
+ "uname",
40
+ "gname",
41
+ "devmajor",
42
+ "devminor",
43
+ "prefix",
44
+ ],
45
+ )
46
+
47
+
48
+ def parse_tar_header(header_bytes):
49
+ header = struct.unpack("!100s8s8s8s12s12s8s1s100s6s2s32s32s8s8s155s", header_bytes)
50
+ return TarHeader(*header)
51
+
52
+
53
+ def next_header(offset, header):
54
+ block_size = 512
55
+ size = header.size.decode("utf-8").strip("\x00")
56
+ if size == "":
57
+ return -1
58
+ size = int(size, 8)
59
+ # compute the file size rounded up to the next block size if it is a partial block
60
+ padded_file_size = (size + block_size - 1) // block_size * block_size
61
+ return offset + block_size + padded_file_size
62
+
63
+
64
+ # TODO(ligeng): support gzip stream
65
+ class MMIndexedTar:
66
+ def __init__(self, fname, index_file=None, verbose=True, cleanup_callback=None):
67
+ self.verbose = verbose
68
+ self.cleanup_callback = cleanup_callback
69
+ if isinstance(fname, str):
70
+ self.stream = open(fname, "rb")
71
+ self.fname = fname
72
+ elif isinstance(fname, io.IOBase):
73
+ self.stream = fname
74
+ self.fname = None
75
+ self.mmapped_file = mmap.mmap(self.stream.fileno(), 0, access=mmap.ACCESS_READ)
76
+ if cleanup_callback:
77
+ cleanup_callback(fname, self.stream.fileno(), "start")
78
+ self._build_index()
79
+
80
+ def close(self, dispose=False):
81
+ if self.cleanup_callback:
82
+ self.cleanup_callback(self.fname, self.stream.fileno(), "end")
83
+ self.mmapped_file.close()
84
+ self.stream.close()
85
+
86
+ def _build_index(self):
87
+ self.by_name = {}
88
+ self.by_index = []
89
+ offset = 0
90
+ while offset >= 0 and offset < len(self.mmapped_file):
91
+ header = parse_tar_header(self.mmapped_file[offset : offset + 500])
92
+ name = header.name.decode("utf-8").strip("\x00")
93
+ typeflag = header.typeflag.decode("utf-8").strip("\x00")
94
+ if name != "" and name != "././@PaxHeader" and typeflag in ["0", ""]:
95
+ try:
96
+ size = int(header.size.decode("utf-8")[:-1], 8)
97
+ except ValueError as exn:
98
+ print(header)
99
+ raise exn
100
+ self.by_name[name] = offset
101
+ self.by_index.append((name, offset, size))
102
+ offset = next_header(offset, header)
103
+
104
+ def names(self):
105
+ return self.by_name.keys()
106
+
107
+ def get_at_offset(self, offset):
108
+ header = parse_tar_header(self.mmapped_file[offset : offset + 500])
109
+ name = header.name.decode("utf-8").strip("\x00")
110
+ start = offset + 512
111
+ end = start + int(header.size.decode("utf-8")[:-1], 8)
112
+ return name, self.mmapped_file[start:end]
113
+
114
+ def get_at_index(self, index):
115
+ name, offset, size = self.by_index[index]
116
+ return self.get_at_offset(offset)
117
+
118
+ def get_by_name(self, name):
119
+ offset = self.by_name[name]
120
+ return self.get_at_offset(offset)
121
+
122
+ def __iter__(self):
123
+ for name, offset, size in self.by_index:
124
+ yield name, self.mmapped_file[offset + 512 : offset + 512 + size]
125
+
126
+ def __getitem__(self, key):
127
+ if isinstance(key, int):
128
+ return self.get_at_index(key)
129
+ else:
130
+ return self.get_by_name(key)
131
+
132
+ def __len__(self):
133
+ return len(self.by_index)
134
+
135
+ def get_file(self, i):
136
+ fname, data = self.get_at_index(i)
137
+ return fname, io.BytesIO(data)
138
+
139
+
140
+ def keep_while_reading(fname, fd, phase, delay=0.0):
141
+ """This is a possible cleanup callback for cleanup_callback of MIndexedTar.
142
+
143
+ It assumes that as long as there are some readers for a file,
144
+ more readers may be trying to open it.
145
+
146
+ Note that on Linux, unlinking the file doesn't matter after
147
+ it has been mmapped. The contents will only be deleted when
148
+ all readers close the file. The unlinking merely makes the file
149
+ unavailable to new readers, since the downloader checks first
150
+ whether the file exists.
151
+ """
152
+ assert delay == 0.0, "delay not implemented"
153
+ if fd < 0 or fname is None:
154
+ return
155
+ if phase == "start":
156
+ fcntl.flock(fd, fcntl.LOCK_SH)
157
+ elif phase == "end":
158
+ try:
159
+ fcntl.flock(fd, fcntl.LOCK_EX | fcntl.LOCK_NB)
160
+ os.unlink(fname)
161
+ except FileNotFoundError:
162
+ # someone else deleted it already
163
+ pass
164
+ except BlockingIOError:
165
+ # we couldn't get an exclusive lock, so someone else is still reading
166
+ pass
167
+ else:
168
+ raise ValueError(f"Unknown phase {phase}")
diffusion/data/wids/wids_specs.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+
17
+ # This file is copied from https://github.com/NVlabs/VILA/tree/main/llava/wids
18
+ import io
19
+ import json
20
+ import os
21
+ import tempfile
22
+ from urllib.parse import urlparse, urlunparse
23
+
24
+ from .wids_dl import download_and_open
25
+
26
+
27
+ def urldir(url):
28
+ """Return the directory part of a url."""
29
+ parsed_url = urlparse(url)
30
+ path = parsed_url.path
31
+ directory = os.path.dirname(path)
32
+ return parsed_url._replace(path=directory).geturl()
33
+
34
+
35
+ def urlmerge(base, url):
36
+ """Merge a base URL and a relative URL.
37
+
38
+ The function fills in any missing part of the url from the base,
39
+ except for params, query, and fragment, which are taken only from the 'url'.
40
+ For the pathname component, it merges the paths like os.path.join:
41
+ an absolute path in 'url' overrides the base path, otherwise the paths are merged.
42
+
43
+ Parameters:
44
+ base (str): The base URL.
45
+ url (str): The URL to merge with the base.
46
+
47
+ Returns:
48
+ str: The merged URL.
49
+ """
50
+ # Parse the base and the relative URL
51
+ parsed_base = urlparse(base)
52
+ parsed_url = urlparse(url)
53
+
54
+ # Merge paths using os.path.join
55
+ # If the url path is absolute, it overrides the base path
56
+ if parsed_url.path.startswith("/"):
57
+ merged_path = parsed_url.path
58
+ else:
59
+ merged_path = os.path.normpath(os.path.join(parsed_base.path, parsed_url.path))
60
+
61
+ # Construct the merged URL
62
+ merged_url = urlunparse(
63
+ (
64
+ parsed_url.scheme or parsed_base.scheme,
65
+ parsed_url.netloc or parsed_base.netloc,
66
+ merged_path,
67
+ parsed_url.params, # Use params from the url only
68
+ parsed_url.query, # Use query from the url only
69
+ parsed_url.fragment, # Use fragment from the url only
70
+ )
71
+ )
72
+
73
+ return merged_url
74
+
75
+
76
+ def check_shards(l):
77
+ """Check that a list of shards is well-formed.
78
+
79
+ This checks that the list is a list of dictionaries, and that
80
+ each dictionary has a "url" and a "nsamples" key.
81
+ """
82
+ assert isinstance(l, list)
83
+ for shard in l:
84
+ assert isinstance(shard, dict)
85
+ assert "url" in shard
86
+ assert "nsamples" in shard
87
+ return l
88
+
89
+
90
+ def set_all(l, k, v):
91
+ """Set a key to a value in a list of dictionaries."""
92
+ if v is None:
93
+ return
94
+ for x in l:
95
+ if k not in x:
96
+ x[k] = v
97
+
98
+
99
+ def load_remote_dsdesc_raw(source):
100
+ """Load a remote or local dataset description in JSON format."""
101
+ if isinstance(source, str):
102
+ with tempfile.TemporaryDirectory() as tmpdir:
103
+ dlname = os.path.join(tmpdir, "dataset.json")
104
+ with download_and_open(source, dlname) as f:
105
+ dsdesc = json.load(f)
106
+ elif isinstance(source, io.IOBase):
107
+ dsdesc = json.load(source)
108
+ else:
109
+ # FIXME: use gopen
110
+ import requests
111
+
112
+ jsondata = requests.get(source).text
113
+ dsdesc = json.loads(jsondata)
114
+ return dsdesc
115
+
116
+
117
+ def rebase_shardlist(shardlist, base):
118
+ """Rebase the URLs in a shardlist."""
119
+ if base is None:
120
+ return shardlist
121
+ for shard in shardlist:
122
+ shard["url"] = urlmerge(base, shard["url"])
123
+ return shardlist
124
+
125
+
126
+ def resolve_dsdesc(dsdesc, *, options=None, base=None):
127
+ """Resolve a dataset description.
128
+
129
+ This rebases the shards as necessary and loads any remote references.
130
+
131
+ Dataset descriptions are JSON files. They must have the following format;
132
+
133
+ {
134
+ "wids_version": 1,
135
+ # optional immediate shardlist
136
+ "shardlist": [
137
+ {"url": "http://example.com/file.tar", "nsamples": 1000},
138
+ ...
139
+ ],
140
+ # sub-datasets
141
+ "datasets": [
142
+ {"source_url": "http://example.com/dataset.json"},
143
+ {"shardlist": [
144
+ {"url": "http://example.com/file.tar", "nsamples": 1000},
145
+ ...
146
+ ]}
147
+ ...
148
+ ]
149
+ }
150
+ """
151
+ if options is None:
152
+ options = {}
153
+ assert isinstance(dsdesc, dict)
154
+ dsdesc = dict(dsdesc, **options)
155
+ shardlist = rebase_shardlist(dsdesc.get("shardlist", []), base)
156
+ assert shardlist is not None
157
+ set_all(shardlist, "weight", dsdesc.get("weight"))
158
+ set_all(shardlist, "name", dsdesc.get("name"))
159
+ check_shards(shardlist)
160
+ assert "wids_version" in dsdesc, "No wids_version in dataset description"
161
+ assert dsdesc["wids_version"] == 1, "Unknown wids_version"
162
+ for component in dsdesc.get("datasets", []):
163
+ # we use the weight from the reference to the dataset,
164
+ # regardless of remote loading
165
+ weight = component.get("weight")
166
+ # follow any source_url dsdescs through remote loading
167
+ source_url = None
168
+ if "source_url" in component:
169
+ source_url = component["source_url"]
170
+ component = load_remote_dsdesc_raw(source_url)
171
+ assert "source_url" not in component, "double indirection in dataset description"
172
+ assert "shardlist" in component, "no shardlist in dataset description"
173
+ # if the component has a base, use it to rebase the shardlist
174
+ # otherwise use the base from the source_url, if any
175
+ subbase = component.get("base", urldir(source_url) if source_url else None)
176
+ if subbase is not None:
177
+ rebase_shardlist(component["shardlist"], subbase)
178
+ l = check_shards(component["shardlist"])
179
+ set_all(l, "weight", weight)
180
+ set_all(l, "source_url", source_url)
181
+ set_all(l, "dataset", component.get("name"))
182
+ shardlist.extend(l)
183
+ assert len(shardlist) > 0, "No shards found"
184
+ dsdesc["shardlist"] = shardlist
185
+ return dsdesc
186
+
187
+
188
+ def load_dsdesc_and_resolve(source, *, options=None, base=None):
189
+ if options is None:
190
+ options = {}
191
+ dsdesc = load_remote_dsdesc_raw(source)
192
+ return resolve_dsdesc(dsdesc, base=base, options=options)
diffusion/data/wids/wids_tar.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+
17
+ # This file is copied from https://github.com/NVlabs/VILA/tree/main/llava/wids
18
+ import io
19
+ import os
20
+ import os.path
21
+ import pickle
22
+ import re
23
+ import tarfile
24
+
25
+ import numpy as np
26
+
27
+
28
+ def find_index_file(file):
29
+ prefix, last_ext = os.path.splitext(file)
30
+ if re.match("._[0-9]+_$", last_ext):
31
+ return prefix + ".index"
32
+ else:
33
+ return file + ".index"
34
+
35
+
36
+ class TarFileReader:
37
+ def __init__(self, file, index_file=find_index_file, verbose=True):
38
+ self.verbose = verbose
39
+ if callable(index_file):
40
+ index_file = index_file(file)
41
+ self.index_file = index_file
42
+
43
+ # Open the tar file and keep it open
44
+ if isinstance(file, str):
45
+ self.tar_file = tarfile.open(file, "r")
46
+ else:
47
+ self.tar_file = tarfile.open(fileobj=file, mode="r")
48
+
49
+ # Create the index
50
+ self._create_tar_index()
51
+
52
+ def _create_tar_index(self):
53
+ if self.index_file is not None and os.path.exists(self.index_file):
54
+ if self.verbose:
55
+ print("Loading tar index from", self.index_file)
56
+ with open(self.index_file, "rb") as stream:
57
+ self.fnames, self.index = pickle.load(stream)
58
+ return
59
+ # Create an empty list for the index
60
+ self.fnames = []
61
+ self.index = []
62
+
63
+ if self.verbose:
64
+ print("Creating tar index for", self.tar_file.name, "at", self.index_file)
65
+ # Iterate over the members of the tar file
66
+ for member in self.tar_file:
67
+ # If the member is a file, add it to the index
68
+ if member.isfile():
69
+ # Get the file's offset
70
+ offset = self.tar_file.fileobj.tell()
71
+ self.fnames.append(member.name)
72
+ self.index.append([offset, member.size])
73
+ if self.verbose:
74
+ print("Done creating tar index for", self.tar_file.name, "at", self.index_file)
75
+ self.index = np.array(self.index)
76
+ if self.index_file is not None:
77
+ if os.path.exists(self.index_file + ".temp"):
78
+ os.unlink(self.index_file + ".temp")
79
+ with open(self.index_file + ".temp", "wb") as stream:
80
+ pickle.dump((self.fnames, self.index), stream)
81
+ os.rename(self.index_file + ".temp", self.index_file)
82
+
83
+ def names(self):
84
+ return self.fnames
85
+
86
+ def __len__(self):
87
+ return len(self.index)
88
+
89
+ def get_file(self, i):
90
+ name = self.fnames[i]
91
+ offset, size = self.index[i]
92
+ self.tar_file.fileobj.seek(offset)
93
+ file_bytes = self.tar_file.fileobj.read(size)
94
+ return name, io.BytesIO(file_bytes)
95
+
96
+ def close(self):
97
+ # Close the tar file
98
+ self.tar_file.close()
diffusion/dpm_solver.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+
17
+ import torch
18
+
19
+ from .model import gaussian_diffusion as gd
20
+ from .model.dpm_solver import DPM_Solver, NoiseScheduleFlow, NoiseScheduleVP, model_wrapper
21
+
22
+
23
+ def DPMS(
24
+ model,
25
+ condition,
26
+ uncondition,
27
+ cfg_scale,
28
+ pag_scale=1.0,
29
+ pag_applied_layers=None,
30
+ model_type="noise", # or "x_start" or "v" or "score", "flow"
31
+ noise_schedule="linear",
32
+ guidance_type="classifier-free",
33
+ model_kwargs=None,
34
+ diffusion_steps=1000,
35
+ schedule="VP",
36
+ interval_guidance=None,
37
+ ):
38
+ if pag_applied_layers is None:
39
+ pag_applied_layers = []
40
+ if model_kwargs is None:
41
+ model_kwargs = {}
42
+ if interval_guidance is None:
43
+ interval_guidance = [0, 1.0]
44
+ betas = torch.tensor(gd.get_named_beta_schedule(noise_schedule, diffusion_steps))
45
+
46
+ ## 1. Define the noise schedule.
47
+ if schedule == "VP":
48
+ noise_schedule = NoiseScheduleVP(schedule="discrete", betas=betas)
49
+ elif schedule == "FLOW":
50
+ noise_schedule = NoiseScheduleFlow(schedule="discrete_flow")
51
+
52
+ ## 2. Convert your discrete-time `model` to the continuous-time
53
+ ## noise prediction model. Here is an example for a diffusion model
54
+ ## `model` with the noise prediction type ("noise") .
55
+ model_fn = model_wrapper(
56
+ model,
57
+ noise_schedule,
58
+ model_type=model_type,
59
+ model_kwargs=model_kwargs,
60
+ guidance_type=guidance_type,
61
+ pag_scale=pag_scale,
62
+ pag_applied_layers=pag_applied_layers,
63
+ condition=condition,
64
+ unconditional_condition=uncondition,
65
+ guidance_scale=cfg_scale,
66
+ interval_guidance=interval_guidance,
67
+ )
68
+ ## 3. Define dpm-solver and sample by multistep DPM-Solver.
69
+ return DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++")
diffusion/flow_euler_sampler.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+
17
+ import os
18
+
19
+ import torch
20
+ from diffusers import FlowMatchEulerDiscreteScheduler
21
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
22
+ from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import retrieve_timesteps
23
+ from tqdm import tqdm
24
+
25
+
26
+ class FlowEuler:
27
+ def __init__(self, model_fn, condition, uncondition, cfg_scale, model_kwargs):
28
+ self.model = model_fn
29
+ self.condition = condition
30
+ self.uncondition = uncondition
31
+ self.cfg_scale = cfg_scale
32
+ self.model_kwargs = model_kwargs
33
+ # repo_id = "stabilityai/stable-diffusion-3-medium-diffusers"
34
+ self.scheduler = FlowMatchEulerDiscreteScheduler(shift=3.0)
35
+
36
+ def sample(self, latents, steps=28):
37
+ device = self.condition.device
38
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, steps, device, None)
39
+ do_classifier_free_guidance = True
40
+
41
+ prompt_embeds = self.condition
42
+ if do_classifier_free_guidance:
43
+ prompt_embeds = torch.cat([self.uncondition, self.condition], dim=0)
44
+
45
+ for i, t in tqdm(list(enumerate(timesteps)), disable=os.getenv("DPM_TQDM", "False") == "True"):
46
+
47
+ # expand the latents if we are doing classifier free guidance
48
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
49
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
50
+ timestep = t.expand(latent_model_input.shape[0])
51
+
52
+ noise_pred = self.model(
53
+ latent_model_input,
54
+ timestep,
55
+ prompt_embeds,
56
+ **self.model_kwargs,
57
+ )
58
+
59
+ if isinstance(noise_pred, Transformer2DModelOutput):
60
+ noise_pred = noise_pred[0]
61
+
62
+ # perform guidance
63
+ if do_classifier_free_guidance:
64
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
65
+ noise_pred = noise_pred_uncond + self.cfg_scale * (noise_pred_text - noise_pred_uncond)
66
+
67
+ # compute the previous noisy sample x_t -> x_t-1
68
+ latents_dtype = latents.dtype
69
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
70
+
71
+ if latents.dtype != latents_dtype:
72
+ latents = latents.to(latents_dtype)
73
+
74
+ return latents
diffusion/iddpm.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+
17
+ # Modified from OpenAI's diffusion repos
18
+ # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
19
+ # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
20
+ # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
21
+ from diffusion.model.respace import SpacedDiffusion, space_timesteps
22
+
23
+ from .model import gaussian_diffusion as gd
24
+
25
+
26
+ def Scheduler(
27
+ timestep_respacing,
28
+ noise_schedule="linear",
29
+ use_kl=False,
30
+ sigma_small=False,
31
+ predict_xstart=False,
32
+ predict_v=False,
33
+ learn_sigma=True,
34
+ pred_sigma=True,
35
+ rescale_learned_sigmas=False,
36
+ diffusion_steps=1000,
37
+ snr=False,
38
+ return_startx=False,
39
+ flow_shift=1.0,
40
+ ):
41
+ betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps)
42
+ if use_kl:
43
+ loss_type = gd.LossType.RESCALED_KL
44
+ elif rescale_learned_sigmas:
45
+ loss_type = gd.LossType.RESCALED_MSE
46
+ else:
47
+ loss_type = gd.LossType.MSE
48
+ if timestep_respacing is None or timestep_respacing == "":
49
+ timestep_respacing = [diffusion_steps]
50
+ if predict_xstart:
51
+ model_mean_type = gd.ModelMeanType.START_X
52
+ elif predict_v:
53
+ model_mean_type = gd.ModelMeanType.VELOCITY
54
+ else:
55
+ model_mean_type = gd.ModelMeanType.EPSILON
56
+ return SpacedDiffusion(
57
+ use_timesteps=space_timesteps(diffusion_steps, timestep_respacing),
58
+ betas=betas,
59
+ model_mean_type=model_mean_type,
60
+ model_var_type=(
61
+ (
62
+ (gd.ModelVarType.FIXED_LARGE if not sigma_small else gd.ModelVarType.FIXED_SMALL)
63
+ if not learn_sigma
64
+ else gd.ModelVarType.LEARNED_RANGE
65
+ )
66
+ if pred_sigma
67
+ else None
68
+ ),
69
+ loss_type=loss_type,
70
+ snr=snr,
71
+ return_startx=return_startx,
72
+ # rescale_timesteps=rescale_timesteps,
73
+ flow="flow" in noise_schedule,
74
+ flow_shift=flow_shift,
75
+ diffusion_steps=diffusion_steps,
76
+ )
diffusion/lcm_scheduler.py ADDED
@@ -0,0 +1,457 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Stanford University Team and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
16
+ # and https://github.com/hojonathanho/diffusion
17
+
18
+ import math
19
+ from dataclasses import dataclass
20
+ from typing import List, Optional, Tuple, Union
21
+
22
+ import numpy as np
23
+ import torch
24
+ from diffusers import ConfigMixin, SchedulerMixin
25
+ from diffusers.configuration_utils import register_to_config
26
+ from diffusers.utils import BaseOutput
27
+
28
+
29
+ @dataclass
30
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DDIM
31
+ class LCMSchedulerOutput(BaseOutput):
32
+ """
33
+ Output class for the scheduler's `step` function output.
34
+ Args:
35
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
36
+ Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
37
+ denoising loop.
38
+ pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
39
+ The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
40
+ `pred_original_sample` can be used to preview progress or for guidance.
41
+ """
42
+
43
+ prev_sample: torch.FloatTensor
44
+ denoised: Optional[torch.FloatTensor] = None
45
+
46
+
47
+ # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
48
+ def betas_for_alpha_bar(
49
+ num_diffusion_timesteps,
50
+ max_beta=0.999,
51
+ alpha_transform_type="cosine",
52
+ ):
53
+ """
54
+ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
55
+ (1-beta) over time from t = [0,1].
56
+ Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
57
+ to that part of the diffusion process.
58
+ Args:
59
+ num_diffusion_timesteps (`int`): the number of betas to produce.
60
+ max_beta (`float`): the maximum beta to use; use values lower than 1 to
61
+ prevent singularities.
62
+ alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
63
+ Choose from `cosine` or `exp`
64
+ Returns:
65
+ betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
66
+ """
67
+ if alpha_transform_type == "cosine":
68
+
69
+ def alpha_bar_fn(t):
70
+ return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
71
+
72
+ elif alpha_transform_type == "exp":
73
+
74
+ def alpha_bar_fn(t):
75
+ return math.exp(t * -12.0)
76
+
77
+ else:
78
+ raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
79
+
80
+ betas = []
81
+ for i in range(num_diffusion_timesteps):
82
+ t1 = i / num_diffusion_timesteps
83
+ t2 = (i + 1) / num_diffusion_timesteps
84
+ betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
85
+ return torch.tensor(betas, dtype=torch.float32)
86
+
87
+
88
+ def rescale_zero_terminal_snr(betas):
89
+ """
90
+ Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
91
+ Args:
92
+ betas (`torch.FloatTensor`):
93
+ the betas that the scheduler is being initialized with.
94
+ Returns:
95
+ `torch.FloatTensor`: rescaled betas with zero terminal SNR
96
+ """
97
+ # Convert betas to alphas_bar_sqrt
98
+ alphas = 1.0 - betas
99
+ alphas_cumprod = torch.cumprod(alphas, dim=0)
100
+ alphas_bar_sqrt = alphas_cumprod.sqrt()
101
+
102
+ # Store old values.
103
+ alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
104
+ alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
105
+
106
+ # Shift so the last timestep is zero.
107
+ alphas_bar_sqrt -= alphas_bar_sqrt_T
108
+
109
+ # Scale so the first timestep is back to the old value.
110
+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
111
+
112
+ # Convert alphas_bar_sqrt to betas
113
+ alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
114
+ alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
115
+ alphas = torch.cat([alphas_bar[0:1], alphas])
116
+ betas = 1 - alphas
117
+
118
+ return betas
119
+
120
+
121
+ class LCMScheduler(SchedulerMixin, ConfigMixin):
122
+ """
123
+ `LCMScheduler` extends the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs) with
124
+ non-Markovian guidance.
125
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
126
+ methods the library implements for all schedulers such as loading and saving.
127
+ Args:
128
+ num_train_timesteps (`int`, defaults to 1000):
129
+ The number of diffusion steps to train the model.
130
+ beta_start (`float`, defaults to 0.0001):
131
+ The starting `beta` value of inference.
132
+ beta_end (`float`, defaults to 0.02):
133
+ The final `beta` value.
134
+ beta_schedule (`str`, defaults to `"linear"`):
135
+ The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
136
+ `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
137
+ trained_betas (`np.ndarray`, *optional*):
138
+ Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
139
+ clip_sample (`bool`, defaults to `True`):
140
+ Clip the predicted sample for numerical stability.
141
+ clip_sample_range (`float`, defaults to 1.0):
142
+ The maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
143
+ set_alpha_to_one (`bool`, defaults to `True`):
144
+ Each diffusion step uses the alphas product value at that step and at the previous one. For the final step
145
+ there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
146
+ otherwise it uses the alpha value at step 0.
147
+ steps_offset (`int`, defaults to 0):
148
+ An offset added to the inference steps. You can use a combination of `offset=1` and
149
+ `set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable
150
+ Diffusion.
151
+ prediction_type (`str`, defaults to `epsilon`, *optional*):
152
+ Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
153
+ `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
154
+ Video](https://imagen.research.google/video/paper.pdf) paper).
155
+ thresholding (`bool`, defaults to `False`):
156
+ Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
157
+ as Stable Diffusion.
158
+ dynamic_thresholding_ratio (`float`, defaults to 0.995):
159
+ The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
160
+ sample_max_value (`float`, defaults to 1.0):
161
+ The threshold value for dynamic thresholding. Valid only when `thresholding=True`.
162
+ timestep_spacing (`str`, defaults to `"leading"`):
163
+ The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
164
+ Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
165
+ rescale_betas_zero_snr (`bool`, defaults to `False`):
166
+ Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
167
+ dark samples instead of limiting it to samples with medium brightness. Loosely related to
168
+ [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
169
+ """
170
+
171
+ # _compatibles = [e.name for e in KarrasDiffusionSchedulers]
172
+ order = 1
173
+
174
+ @register_to_config
175
+ def __init__(
176
+ self,
177
+ num_train_timesteps: int = 1000,
178
+ beta_start: float = 0.0001,
179
+ beta_end: float = 0.02,
180
+ beta_schedule: str = "linear",
181
+ trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
182
+ clip_sample: bool = True,
183
+ set_alpha_to_one: bool = True,
184
+ steps_offset: int = 0,
185
+ prediction_type: str = "epsilon",
186
+ thresholding: bool = False,
187
+ dynamic_thresholding_ratio: float = 0.995,
188
+ clip_sample_range: float = 1.0,
189
+ sample_max_value: float = 1.0,
190
+ timestep_spacing: str = "leading",
191
+ rescale_betas_zero_snr: bool = False,
192
+ ):
193
+ if trained_betas is not None:
194
+ self.betas = torch.tensor(trained_betas, dtype=torch.float32)
195
+ elif beta_schedule == "linear":
196
+ self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
197
+ elif beta_schedule == "scaled_linear":
198
+ # this schedule is very specific to the latent diffusion model.
199
+ self.betas = (
200
+ torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
201
+ )
202
+ elif beta_schedule == "squaredcos_cap_v2":
203
+ # Glide cosine schedule
204
+ self.betas = betas_for_alpha_bar(num_train_timesteps)
205
+ else:
206
+ raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
207
+
208
+ # Rescale for zero SNR
209
+ if rescale_betas_zero_snr:
210
+ self.betas = rescale_zero_terminal_snr(self.betas)
211
+
212
+ self.alphas = 1.0 - self.betas
213
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
214
+
215
+ # At every step in ddim, we are looking into the previous alphas_cumprod
216
+ # For the final step, there is no previous alphas_cumprod because we are already at 0
217
+ # `set_alpha_to_one` decides whether we set this parameter simply to one or
218
+ # whether we use the final alpha of the "non-previous" one.
219
+ self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
220
+
221
+ # standard deviation of the initial noise distribution
222
+ self.init_noise_sigma = 1.0
223
+
224
+ # setable values
225
+ self.num_inference_steps = None
226
+ self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))
227
+
228
+ def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
229
+ """
230
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
231
+ current timestep.
232
+ Args:
233
+ sample (`torch.FloatTensor`):
234
+ The input sample.
235
+ timestep (`int`, *optional*):
236
+ The current timestep in the diffusion chain.
237
+ Returns:
238
+ `torch.FloatTensor`:
239
+ A scaled input sample.
240
+ """
241
+ return sample
242
+
243
+ def _get_variance(self, timestep, prev_timestep):
244
+ alpha_prod_t = self.alphas_cumprod[timestep]
245
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
246
+ beta_prod_t = 1 - alpha_prod_t
247
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
248
+
249
+ variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
250
+
251
+ return variance
252
+
253
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
254
+ def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
255
+ """
256
+ "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
257
+ prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
258
+ s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
259
+ pixels from saturation at each step. We find that dynamic thresholding results in significantly better
260
+ photorealism as well as better image-text alignment, especially when using very large guidance weights."
261
+ https://arxiv.org/abs/2205.11487
262
+ """
263
+ dtype = sample.dtype
264
+ batch_size, channels, height, width = sample.shape
265
+
266
+ if dtype not in (torch.float32, torch.float64):
267
+ sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
268
+
269
+ # Flatten sample for doing quantile calculation along each image
270
+ sample = sample.reshape(batch_size, channels * height * width)
271
+
272
+ abs_sample = sample.abs() # "a certain percentile absolute pixel value"
273
+
274
+ s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
275
+ s = torch.clamp(
276
+ s, min=1, max=self.config.sample_max_value
277
+ ) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
278
+
279
+ s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
280
+ sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
281
+
282
+ sample = sample.reshape(batch_size, channels, height, width)
283
+ sample = sample.to(dtype)
284
+
285
+ return sample
286
+
287
+ def set_timesteps(self, num_inference_steps: int, lcm_origin_steps: int, device: Union[str, torch.device] = None):
288
+ """
289
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
290
+ Args:
291
+ num_inference_steps (`int`):
292
+ The number of diffusion steps used when generating samples with a pre-trained model.
293
+ """
294
+
295
+ if num_inference_steps > self.config.num_train_timesteps:
296
+ raise ValueError(
297
+ f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
298
+ f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
299
+ f" maximal {self.config.num_train_timesteps} timesteps."
300
+ )
301
+
302
+ self.num_inference_steps = num_inference_steps
303
+
304
+ # LCM Timesteps Setting: # Linear Spacing
305
+ c = self.config.num_train_timesteps // lcm_origin_steps
306
+ lcm_origin_timesteps = np.asarray(list(range(1, lcm_origin_steps + 1))) * c - 1 # LCM Training Steps Schedule
307
+ skipping_step = len(lcm_origin_timesteps) // num_inference_steps
308
+ timesteps = lcm_origin_timesteps[::-skipping_step][:num_inference_steps] # LCM Inference Steps Schedule
309
+
310
+ self.timesteps = torch.from_numpy(timesteps.copy()).to(device)
311
+
312
+ def get_scalings_for_boundary_condition_discrete(self, t):
313
+ self.sigma_data = 0.5 # Default: 0.5
314
+
315
+ # By dividing 0.1: This is almost a delta function at t=0.
316
+ c_skip = self.sigma_data**2 / ((t / 0.1) ** 2 + self.sigma_data**2)
317
+ c_out = (t / 0.1) / ((t / 0.1) ** 2 + self.sigma_data**2) ** 0.5
318
+ return c_skip, c_out
319
+
320
+ def step(
321
+ self,
322
+ model_output: torch.FloatTensor,
323
+ timeindex: int,
324
+ timestep: int,
325
+ sample: torch.FloatTensor,
326
+ eta: float = 0.0,
327
+ use_clipped_model_output: bool = False,
328
+ generator=None,
329
+ variance_noise: Optional[torch.FloatTensor] = None,
330
+ return_dict: bool = True,
331
+ ) -> Union[LCMSchedulerOutput, Tuple]:
332
+ """
333
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
334
+ process from the learned model outputs (most often the predicted noise).
335
+ Args:
336
+ model_output (`torch.FloatTensor`):
337
+ The direct output from learned diffusion model.
338
+ timestep (`float`):
339
+ The current discrete timestep in the diffusion chain.
340
+ sample (`torch.FloatTensor`):
341
+ A current instance of a sample created by the diffusion process.
342
+ eta (`float`):
343
+ The weight of noise for added noise in diffusion step.
344
+ use_clipped_model_output (`bool`, defaults to `False`):
345
+ If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary
346
+ because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no
347
+ clipping has happened, "corrected" `model_output` would coincide with the one provided as input and
348
+ `use_clipped_model_output` has no effect.
349
+ generator (`torch.Generator`, *optional*):
350
+ A random number generator.
351
+ variance_noise (`torch.FloatTensor`):
352
+ Alternative to generating noise with `generator` by directly providing the noise for the variance
353
+ itself. Useful for methods such as [`CycleDiffusion`].
354
+ return_dict (`bool`, *optional*, defaults to `True`):
355
+ Whether or not to return a [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] or `tuple`.
356
+ Returns:
357
+ [`~schedulers.scheduling_utils.LCMSchedulerOutput`] or `tuple`:
358
+ If return_dict is `True`, [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] is returned, otherwise a
359
+ tuple is returned where the first element is the sample tensor.
360
+ """
361
+ if self.num_inference_steps is None:
362
+ raise ValueError(
363
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
364
+ )
365
+
366
+ # 1. get previous step value
367
+ prev_timeindex = timeindex + 1
368
+ if prev_timeindex < len(self.timesteps):
369
+ prev_timestep = self.timesteps[prev_timeindex]
370
+ else:
371
+ prev_timestep = timestep
372
+
373
+ # 2. compute alphas, betas
374
+ alpha_prod_t = self.alphas_cumprod[timestep]
375
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
376
+
377
+ beta_prod_t = 1 - alpha_prod_t
378
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
379
+
380
+ # 3. Get scalings for boundary conditions
381
+ c_skip, c_out = self.get_scalings_for_boundary_condition_discrete(timestep)
382
+
383
+ # 4. Different Parameterization:
384
+ parameterization = self.config.prediction_type
385
+
386
+ if parameterization == "epsilon": # noise-prediction
387
+ pred_x0 = (sample - beta_prod_t.sqrt() * model_output) / alpha_prod_t.sqrt()
388
+
389
+ elif parameterization == "sample": # x-prediction
390
+ pred_x0 = model_output
391
+
392
+ elif parameterization == "v_prediction": # v-prediction
393
+ pred_x0 = alpha_prod_t.sqrt() * sample - beta_prod_t.sqrt() * model_output
394
+
395
+ # 4. Denoise model output using boundary conditions
396
+ denoised = c_out * pred_x0 + c_skip * sample
397
+
398
+ # 5. Sample z ~ N(0, I), For MultiStep Inference
399
+ # Noise is not used for one-step sampling.
400
+ if len(self.timesteps) > 1:
401
+ noise = torch.randn(model_output.shape).to(model_output.device)
402
+ prev_sample = alpha_prod_t_prev.sqrt() * denoised + beta_prod_t_prev.sqrt() * noise
403
+ else:
404
+ prev_sample = denoised
405
+
406
+ if not return_dict:
407
+ return (prev_sample, denoised)
408
+
409
+ return LCMSchedulerOutput(prev_sample=prev_sample, denoised=denoised)
410
+
411
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
412
+ def add_noise(
413
+ self,
414
+ original_samples: torch.FloatTensor,
415
+ noise: torch.FloatTensor,
416
+ timesteps: torch.IntTensor,
417
+ ) -> torch.FloatTensor:
418
+ # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
419
+ alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
420
+ timesteps = timesteps.to(original_samples.device)
421
+
422
+ sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
423
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
424
+ while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
425
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
426
+
427
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
428
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
429
+ while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
430
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
431
+
432
+ noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
433
+ return noisy_samples
434
+
435
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity
436
+ def get_velocity(
437
+ self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor
438
+ ) -> torch.FloatTensor:
439
+ # Make sure alphas_cumprod and timestep have same device and dtype as sample
440
+ alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype)
441
+ timesteps = timesteps.to(sample.device)
442
+
443
+ sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
444
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
445
+ while len(sqrt_alpha_prod.shape) < len(sample.shape):
446
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
447
+
448
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
449
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
450
+ while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
451
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
452
+
453
+ velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
454
+ return velocity
455
+
456
+ def __len__(self):
457
+ return self.config.num_train_timesteps
diffusion/model/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .nets import *