Spaces:
Sleeping
Sleeping
update demo
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +2 -0
- .gitignore +4 -0
- .idea/.gitignore +8 -0
- .idea/SEED.iml +12 -0
- .idea/inspectionProfiles/profiles_settings.xml +6 -0
- .idea/misc.xml +4 -0
- .idea/modules.xml +8 -0
- .idea/vcs.xml +6 -0
- .project-root +0 -0
- Dockerfile +3 -12
- License.txt +470 -0
- README-SEED-2.md +184 -0
- SEED-1.md +93 -0
- configs/llm/seed_llama_14b.yaml +5 -0
- configs/llm/seed_llama_14b_8bit.yaml +5 -0
- configs/llm/seed_llama_8b.yaml +5 -0
- configs/llm/seed_llama_8b_8bit.yaml +5 -0
- configs/tokenizer/seed_llama_tokenizer.yaml +4 -0
- configs/tokenizer/seed_llama_tokenizer_hf.yaml +6 -0
- configs/transform/clip_transform.yaml +4 -0
- gradio_demo/conversation.py +190 -0
- gradio_demo/seed_llama_flask.py +230 -0
- gradio_demo/seed_llama_gradio.py +497 -0
- gradio_demo/utils.py +82 -0
- images/cat.jpg +3 -0
- images/demo_example1.jpg +3 -0
- images/demo_example2.jpg +3 -0
- images/demo_example3.jpg +3 -0
- images/demo_example4.jpg +3 -0
- images/demo_example5.jpg +3 -0
- images/demo_example6.jpg +3 -0
- images/demo_example7.jpg +3 -0
- images/dogs_4.jpg +3 -0
- images/eagle.jpg +3 -0
- images/flower.png +3 -0
- images/spongebob.png +3 -0
- images/star.jpg +3 -0
- models/__init__.py +0 -0
- models/llama_xformer.py +906 -0
- models/model_tools.py +18 -0
- models/pipeline_stable_unclip_img2img.py +794 -0
- models/seed_llama_tokenizer.py +213 -0
- models/seed_qformer/blip2.py +186 -0
- models/seed_qformer/clip_vit.py +257 -0
- models/seed_qformer/eva_vit.py +486 -0
- models/seed_qformer/qformer_causual.py +1169 -0
- models/seed_qformer/qformer_quantizer.py +375 -0
- models/seed_qformer/utils.py +138 -0
- models/seed_qformer/vit.py +395 -0
- models/transforms.py +21 -0
.gitattributes
CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
*.jpg filter=lfs diff=lfs merge=lfs -text
|
37 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
pretrained/*
|
2 |
+
!pretrained/.gitkeep
|
3 |
+
**/__pycache__/**
|
4 |
+
log/
|
.idea/.gitignore
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Default ignored files
|
2 |
+
/shelf/
|
3 |
+
/workspace.xml
|
4 |
+
# Editor-based HTTP Client requests
|
5 |
+
/httpRequests/
|
6 |
+
# Datasource local storage ignored files
|
7 |
+
/dataSources/
|
8 |
+
/dataSources.local.xml
|
.idea/SEED.iml
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
2 |
+
<module type="PYTHON_MODULE" version="4">
|
3 |
+
<component name="NewModuleRootManager">
|
4 |
+
<content url="file://$MODULE_DIR$" />
|
5 |
+
<orderEntry type="inheritedJdk" />
|
6 |
+
<orderEntry type="sourceFolder" forTests="false" />
|
7 |
+
</component>
|
8 |
+
<component name="PyDocumentationSettings">
|
9 |
+
<option name="format" value="GOOGLE" />
|
10 |
+
<option name="myDocStringFormat" value="Google" />
|
11 |
+
</component>
|
12 |
+
</module>
|
.idea/inspectionProfiles/profiles_settings.xml
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<component name="InspectionProjectProfileManager">
|
2 |
+
<settings>
|
3 |
+
<option name="USE_PROJECT_PROFILE" value="false" />
|
4 |
+
<version value="1.0" />
|
5 |
+
</settings>
|
6 |
+
</component>
|
.idea/misc.xml
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
2 |
+
<project version="4">
|
3 |
+
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.9" project-jdk-type="Python SDK" />
|
4 |
+
</project>
|
.idea/modules.xml
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
2 |
+
<project version="4">
|
3 |
+
<component name="ProjectModuleManager">
|
4 |
+
<modules>
|
5 |
+
<module fileurl="file://$PROJECT_DIR$/.idea/SEED.iml" filepath="$PROJECT_DIR$/.idea/SEED.iml" />
|
6 |
+
</modules>
|
7 |
+
</component>
|
8 |
+
</project>
|
.idea/vcs.xml
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
2 |
+
<project version="4">
|
3 |
+
<component name="VcsDirectoryMappings">
|
4 |
+
<mapping directory="" vcs="Git" />
|
5 |
+
</component>
|
6 |
+
</project>
|
.project-root
ADDED
File without changes
|
Dockerfile
CHANGED
@@ -4,8 +4,7 @@ FROM python:3.11
|
|
4 |
# Set the working directory to /code
|
5 |
WORKDIR /code
|
6 |
|
7 |
-
RUN apt-get update && apt-get install -y
|
8 |
-
RUN git lfs install
|
9 |
|
10 |
# Copy the current directory contents into the container at /code
|
11 |
# COPY ./requirements.txt /code/requirements.txt
|
@@ -29,16 +28,8 @@ WORKDIR $HOME/app
|
|
29 |
# Copy the current directory contents into the container at $HOME/app setting the owner to the user
|
30 |
COPY --chown=user . $HOME/app
|
31 |
|
32 |
-
RUN git
|
33 |
-
|
34 |
-
RUN mv SEED/* . && rm -rf SEED
|
35 |
|
36 |
RUN pip install -r requirements.txt
|
37 |
|
38 |
-
|
39 |
-
|
40 |
-
# RUN mv SEED/* pretrained/ && rm -rf SEED
|
41 |
-
|
42 |
-
RUN chmod +x start.sh
|
43 |
-
|
44 |
-
CMD ["./start.sh"]
|
|
|
4 |
# Set the working directory to /code
|
5 |
WORKDIR /code
|
6 |
|
7 |
+
RUN apt-get update && apt-get install -y git git-lfs
|
|
|
8 |
|
9 |
# Copy the current directory contents into the container at /code
|
10 |
# COPY ./requirements.txt /code/requirements.txt
|
|
|
28 |
# Copy the current directory contents into the container at $HOME/app setting the owner to the user
|
29 |
COPY --chown=user . $HOME/app
|
30 |
|
31 |
+
RUN git lfs install
|
|
|
|
|
32 |
|
33 |
RUN pip install -r requirements.txt
|
34 |
|
35 |
+
CMD ["python", 'start.py']
|
|
|
|
|
|
|
|
|
|
|
|
License.txt
ADDED
@@ -0,0 +1,470 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
This license applies to the source codes that are open sourced in connection with the research AI Lab论文开源-SEED.
|
2 |
+
|
3 |
+
Copyright (C) 2023 THL A29 Limited, a Tencent company.
|
4 |
+
|
5 |
+
Apache License
|
6 |
+
Version 2.0, January 2004
|
7 |
+
http://www.apache.org/licenses/
|
8 |
+
|
9 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
10 |
+
|
11 |
+
1. Definitions.
|
12 |
+
|
13 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
14 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
15 |
+
|
16 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
17 |
+
the copyright owner that is granting the License.
|
18 |
+
|
19 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
20 |
+
other entities that control, are controlled by, or are under common
|
21 |
+
control with that entity. For the purposes of this definition,
|
22 |
+
"control" means (i) the power, direct or indirect, to cause the
|
23 |
+
direction or management of such entity, whether by contract or
|
24 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
25 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
26 |
+
|
27 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
28 |
+
exercising permissions granted by this License.
|
29 |
+
|
30 |
+
"Source" form shall mean the preferred form for making modifications,
|
31 |
+
including but not limited to software source code, documentation
|
32 |
+
source, and configuration files.
|
33 |
+
|
34 |
+
"Object" form shall mean any form resulting from mechanical
|
35 |
+
transformation or translation of a Source form, including but
|
36 |
+
not limited to compiled object code, generated documentation,
|
37 |
+
and conversions to other media types.
|
38 |
+
|
39 |
+
"Work" shall mean the work of authorship, whether in Source or
|
40 |
+
Object form, made available under the License, as indicated by a
|
41 |
+
copyright notice that is included in or attached to the work
|
42 |
+
(an example is provided in the Appendix below).
|
43 |
+
|
44 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
45 |
+
form, that is based on (or derived from) the Work and for which the
|
46 |
+
editorial revisions, annotations, elaborations, or other modifications
|
47 |
+
represent, as a whole, an original work of authorship. For the purposes
|
48 |
+
of this License, Derivative Works shall not include works that remain
|
49 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
50 |
+
the Work and Derivative Works thereof.
|
51 |
+
|
52 |
+
"Contribution" shall mean any work of authorship, including
|
53 |
+
the original version of the Work and any modifications or additions
|
54 |
+
to that Work or Derivative Works thereof, that is intentionally
|
55 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
56 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
57 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
58 |
+
means any form of electronic, verbal, or written communication sent
|
59 |
+
to the Licensor or its representatives, including but not limited to
|
60 |
+
communication on electronic mailing lists, source code control systems,
|
61 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
62 |
+
Licensor for the purpose of discussing and improving the Work, but
|
63 |
+
excluding communication that is conspicuously marked or otherwise
|
64 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
65 |
+
|
66 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
67 |
+
on behalf of whom a Contribution has been received by Licensor and
|
68 |
+
subsequently incorporated within the Work.
|
69 |
+
|
70 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
71 |
+
this License, each Contributor hereby grants to You a perpetual,
|
72 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
73 |
+
copyright license to reproduce, prepare Derivative Works of,
|
74 |
+
publicly display, publicly perform, sublicense, and distribute the
|
75 |
+
Work and such Derivative Works in Source or Object form.
|
76 |
+
|
77 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
78 |
+
this License, each Contributor hereby grants to You a perpetual,
|
79 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
80 |
+
(except as stated in this section) patent license to make, have made,
|
81 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
82 |
+
where such license applies only to those patent claims licensable
|
83 |
+
by such Contributor that are necessarily infringed by their
|
84 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
85 |
+
with the Work to which such Contribution(s) was submitted. If You
|
86 |
+
institute patent litigation against any entity (including a
|
87 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
88 |
+
or a Contribution incorporated within the Work constitutes direct
|
89 |
+
or contributory patent infringement, then any patent licenses
|
90 |
+
granted to You under this License for that Work shall terminate
|
91 |
+
as of the date such litigation is filed.
|
92 |
+
|
93 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
94 |
+
Work or Derivative Works thereof in any medium, with or without
|
95 |
+
modifications, and in Source or Object form, provided that You
|
96 |
+
meet the following conditions:
|
97 |
+
|
98 |
+
(a) You must give any other recipients of the Work or
|
99 |
+
Derivative Works a copy of this License; and
|
100 |
+
|
101 |
+
(b) You must cause any modified files to carry prominent notices
|
102 |
+
stating that You changed the files; and
|
103 |
+
|
104 |
+
(c) You must retain, in the Source form of any Derivative Works
|
105 |
+
that You distribute, all copyright, patent, trademark, and
|
106 |
+
attribution notices from the Source form of the Work,
|
107 |
+
excluding those notices that do not pertain to any part of
|
108 |
+
the Derivative Works; and
|
109 |
+
|
110 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
111 |
+
distribution, then any Derivative Works that You distribute must
|
112 |
+
include a readable copy of the attribution notices contained
|
113 |
+
within such NOTICE file, excluding those notices that do not
|
114 |
+
pertain to any part of the Derivative Works, in at least one
|
115 |
+
of the following places: within a NOTICE text file distributed
|
116 |
+
as part of the Derivative Works; within the Source form or
|
117 |
+
documentation, if provided along with the Derivative Works; or,
|
118 |
+
within a display generated by the Derivative Works, if and
|
119 |
+
wherever such third-party notices normally appear. The contents
|
120 |
+
of the NOTICE file are for informational purposes only and
|
121 |
+
do not modify the License. You may add Your own attribution
|
122 |
+
notices within Derivative Works that You distribute, alongside
|
123 |
+
or as an addendum to the NOTICE text from the Work, provided
|
124 |
+
that such additional attribution notices cannot be construed
|
125 |
+
as modifying the License.
|
126 |
+
|
127 |
+
You may add Your own copyright statement to Your modifications and
|
128 |
+
may provide additional or different license terms and conditions
|
129 |
+
for use, reproduction, or distribution of Your modifications, or
|
130 |
+
for any such Derivative Works as a whole, provided Your use,
|
131 |
+
reproduction, and distribution of the Work otherwise complies with
|
132 |
+
the conditions stated in this License.
|
133 |
+
|
134 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
135 |
+
any Contribution intentionally submitted for inclusion in the Work
|
136 |
+
by You to the Licensor shall be under the terms and conditions of
|
137 |
+
this License, without any additional terms or conditions.
|
138 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
139 |
+
the terms of any separate license agreement you may have executed
|
140 |
+
with Licensor regarding such Contributions.
|
141 |
+
|
142 |
+
6. Trademarks. This License does not grant permission to use the trade
|
143 |
+
names, trademarks, service marks, or product names of the Licensor,
|
144 |
+
except as required for reasonable and customary use in describing the
|
145 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
146 |
+
|
147 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
148 |
+
agreed to in writing, Licensor provides the Work (and each
|
149 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
150 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
151 |
+
implied, including, without limitation, any warranties or conditions
|
152 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
153 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
154 |
+
appropriateness of using or redistributing the Work and assume any
|
155 |
+
risks associated with Your exercise of permissions under this License.
|
156 |
+
|
157 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
158 |
+
whether in tort (including negligence), contract, or otherwise,
|
159 |
+
unless required by applicable law (such as deliberate and grossly
|
160 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
161 |
+
liable to You for damages, including any direct, indirect, special,
|
162 |
+
incidental, or consequential damages of any character arising as a
|
163 |
+
result of this License or out of the use or inability to use the
|
164 |
+
Work (including but not limited to damages for loss of goodwill,
|
165 |
+
work stoppage, computer failure or malfunction, or any and all
|
166 |
+
other commercial damages or losses), even if such Contributor
|
167 |
+
has been advised of the possibility of such damages.
|
168 |
+
|
169 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
170 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
171 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
172 |
+
or other liability obligations and/or rights consistent with this
|
173 |
+
License. However, in accepting such obligations, You may act only
|
174 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
175 |
+
of any other Contributor, and only if You agree to indemnify,
|
176 |
+
defend, and hold each Contributor harmless for any liability
|
177 |
+
incurred by, or claims asserted against, such Contributor by reason
|
178 |
+
of your accepting any such warranty or additional liability.
|
179 |
+
|
180 |
+
10. This code is provided for research purposes only and is
|
181 |
+
not to be used for any commercial purposes. By using this code,
|
182 |
+
you agree that it will be used solely for academic research, scholarly work,
|
183 |
+
and non-commercial activities. Any use of this code for commercial purposes,
|
184 |
+
including but not limited to, selling, distributing, or incorporating it into
|
185 |
+
commercial products or services, is strictly prohibited. Violation of this
|
186 |
+
clause may result in legal actions and penalties.
|
187 |
+
|
188 |
+
END OF TERMS AND CONDITIONS
|
189 |
+
|
190 |
+
APPENDIX: How to apply the Apache License to your work.
|
191 |
+
|
192 |
+
To apply the Apache License to your work, attach the following
|
193 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
194 |
+
replaced with your own identifying information. (Don't include
|
195 |
+
the brackets!) The text should be enclosed in the appropriate
|
196 |
+
comment syntax for the file format. We also recommend that a
|
197 |
+
file or class name and description of purpose be included on the
|
198 |
+
same "printed page" as the copyright notice for easier
|
199 |
+
identification within third-party archives.
|
200 |
+
|
201 |
+
Copyright [yyyy] [name of copyright owner]
|
202 |
+
|
203 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
204 |
+
you may not use this file except in compliance with the License.
|
205 |
+
You may obtain a copy of the License at
|
206 |
+
|
207 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
208 |
+
|
209 |
+
Unless required by applicable law or agreed to in writing, software
|
210 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
211 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
212 |
+
See the License for the specific language governing permissions and
|
213 |
+
limitations under the License.
|
214 |
+
|
215 |
+
|
216 |
+
Other dependencies and licenses (if such optional components are used):
|
217 |
+
|
218 |
+
|
219 |
+
Components under BSD 3-Clause License:
|
220 |
+
------------------------------------------------
|
221 |
+
1. numpy
|
222 |
+
Copyright (c) 2005-2022, NumPy Developers.
|
223 |
+
All rights reserved.
|
224 |
+
|
225 |
+
2. pytorch
|
226 |
+
Copyright (c) 2016- Facebook, Inc (Adam Paszke)
|
227 |
+
Copyright (c) 2014- Facebook, Inc (Soumith Chintala)
|
228 |
+
Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
|
229 |
+
Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu)
|
230 |
+
Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
|
231 |
+
Copyright (c) 2011-2013 NYU (Clement Farabet)
|
232 |
+
Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
|
233 |
+
Copyright (c) 2006 Idiap Research Institute (Samy Bengio)
|
234 |
+
Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
|
235 |
+
|
236 |
+
3. torchvision
|
237 |
+
Copyright (c) Soumith Chintala 2016,
|
238 |
+
All rights reserved.
|
239 |
+
|
240 |
+
Redistribution and use in source and binary forms, with or without
|
241 |
+
modification, are permitted provided that the following conditions are met:
|
242 |
+
|
243 |
+
* Redistributions of source code must retain the above copyright notice, this
|
244 |
+
list of conditions and the following disclaimer.
|
245 |
+
|
246 |
+
* Redistributions in binary form must reproduce the above copyright notice,
|
247 |
+
this list of conditions and the following disclaimer in the documentation
|
248 |
+
and/or other materials provided with the distribution.
|
249 |
+
|
250 |
+
* Neither the name of the copyright holder nor the names of its
|
251 |
+
contributors may be used to endorse or promote products derived from
|
252 |
+
this software without specific prior written permission.
|
253 |
+
|
254 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
255 |
+
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
256 |
+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
257 |
+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
258 |
+
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
259 |
+
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
260 |
+
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
261 |
+
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
262 |
+
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
263 |
+
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
264 |
+
|
265 |
+
Component under Apache v2 License:
|
266 |
+
-----------------------------------------------------
|
267 |
+
1. timm
|
268 |
+
Copyright 2019 Ross Wightman
|
269 |
+
|
270 |
+
Apache License
|
271 |
+
Version 2.0, January 2004
|
272 |
+
http://www.apache.org/licenses/
|
273 |
+
|
274 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
275 |
+
|
276 |
+
1. Definitions.
|
277 |
+
|
278 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
279 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
280 |
+
|
281 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
282 |
+
the copyright owner that is granting the License.
|
283 |
+
|
284 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
285 |
+
other entities that control, are controlled by, or are under common
|
286 |
+
control with that entity. For the purposes of this definition,
|
287 |
+
"control" means (i) the power, direct or indirect, to cause the
|
288 |
+
direction or management of such entity, whether by contract or
|
289 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
290 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
291 |
+
|
292 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
293 |
+
exercising permissions granted by this License.
|
294 |
+
|
295 |
+
"Source" form shall mean the preferred form for making modifications,
|
296 |
+
including but not limited to software source code, documentation
|
297 |
+
source, and configuration files.
|
298 |
+
|
299 |
+
"Object" form shall mean any form resulting from mechanical
|
300 |
+
transformation or translation of a Source form, including but
|
301 |
+
not limited to compiled object code, generated documentation,
|
302 |
+
and conversions to other media types.
|
303 |
+
|
304 |
+
"Work" shall mean the work of authorship, whether in Source or
|
305 |
+
Object form, made available under the License, as indicated by a
|
306 |
+
copyright notice that is included in or attached to the work
|
307 |
+
(an example is provided in the Appendix below).
|
308 |
+
|
309 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
310 |
+
form, that is based on (or derived from) the Work and for which the
|
311 |
+
editorial revisions, annotations, elaborations, or other modifications
|
312 |
+
represent, as a whole, an original work of authorship. For the purposes
|
313 |
+
of this License, Derivative Works shall not include works that remain
|
314 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
315 |
+
the Work and Derivative Works thereof.
|
316 |
+
|
317 |
+
"Contribution" shall mean any work of authorship, including
|
318 |
+
the original version of the Work and any modifications or additions
|
319 |
+
to that Work or Derivative Works thereof, that is intentionally
|
320 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
321 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
322 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
323 |
+
means any form of electronic, verbal, or written communication sent
|
324 |
+
to the Licensor or its representatives, including but not limited to
|
325 |
+
communication on electronic mailing lists, source code control systems,
|
326 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
327 |
+
Licensor for the purpose of discussing and improving the Work, but
|
328 |
+
excluding communication that is conspicuously marked or otherwise
|
329 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
330 |
+
|
331 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
332 |
+
on behalf of whom a Contribution has been received by Licensor and
|
333 |
+
subsequently incorporated within the Work.
|
334 |
+
|
335 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
336 |
+
this License, each Contributor hereby grants to You a perpetual,
|
337 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
338 |
+
copyright license to reproduce, prepare Derivative Works of,
|
339 |
+
publicly display, publicly perform, sublicense, and distribute the
|
340 |
+
Work and such Derivative Works in Source or Object form.
|
341 |
+
|
342 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
343 |
+
this License, each Contributor hereby grants to You a perpetual,
|
344 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
345 |
+
(except as stated in this section) patent license to make, have made,
|
346 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
347 |
+
where such license applies only to those patent claims licensable
|
348 |
+
by such Contributor that are necessarily infringed by their
|
349 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
350 |
+
with the Work to which such Contribution(s) was submitted. If You
|
351 |
+
institute patent litigation against any entity (including a
|
352 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
353 |
+
or a Contribution incorporated within the Work constitutes direct
|
354 |
+
or contributory patent infringement, then any patent licenses
|
355 |
+
granted to You under this License for that Work shall terminate
|
356 |
+
as of the date such litigation is filed.
|
357 |
+
|
358 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
359 |
+
Work or Derivative Works thereof in any medium, with or without
|
360 |
+
modifications, and in Source or Object form, provided that You
|
361 |
+
meet the following conditions:
|
362 |
+
|
363 |
+
(a) You must give any other recipients of the Work or
|
364 |
+
Derivative Works a copy of this License; and
|
365 |
+
|
366 |
+
(b) You must cause any modified files to carry prominent notices
|
367 |
+
stating that You changed the files; and
|
368 |
+
|
369 |
+
(c) You must retain, in the Source form of any Derivative Works
|
370 |
+
that You distribute, all copyright, patent, trademark, and
|
371 |
+
attribution notices from the Source form of the Work,
|
372 |
+
excluding those notices that do not pertain to any part of
|
373 |
+
the Derivative Works; and
|
374 |
+
|
375 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
376 |
+
distribution, then any Derivative Works that You distribute must
|
377 |
+
include a readable copy of the attribution notices contained
|
378 |
+
within such NOTICE file, excluding those notices that do not
|
379 |
+
pertain to any part of the Derivative Works, in at least one
|
380 |
+
of the following places: within a NOTICE text file distributed
|
381 |
+
as part of the Derivative Works; within the Source form or
|
382 |
+
documentation, if provided along with the Derivative Works; or,
|
383 |
+
within a display generated by the Derivative Works, if and
|
384 |
+
wherever such third-party notices normally appear. The contents
|
385 |
+
of the NOTICE file are for informational purposes only and
|
386 |
+
do not modify the License. You may add Your own attribution
|
387 |
+
notices within Derivative Works that You distribute, alongside
|
388 |
+
or as an addendum to the NOTICE text from the Work, provided
|
389 |
+
that such additional attribution notices cannot be construed
|
390 |
+
as modifying the License.
|
391 |
+
|
392 |
+
You may add Your own copyright statement to Your modifications and
|
393 |
+
may provide additional or different license terms and conditions
|
394 |
+
for use, reproduction, or distribution of Your modifications, or
|
395 |
+
for any such Derivative Works as a whole, provided Your use,
|
396 |
+
reproduction, and distribution of the Work otherwise complies with
|
397 |
+
the conditions stated in this License.
|
398 |
+
|
399 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
400 |
+
any Contribution intentionally submitted for inclusion in the Work
|
401 |
+
by You to the Licensor shall be under the terms and conditions of
|
402 |
+
this License, without any additional terms or conditions.
|
403 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
404 |
+
the terms of any separate license agreement you may have executed
|
405 |
+
with Licensor regarding such Contributions.
|
406 |
+
|
407 |
+
6. Trademarks. This License does not grant permission to use the trade
|
408 |
+
names, trademarks, service marks, or product names of the Licensor,
|
409 |
+
except as required for reasonable and customary use in describing the
|
410 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
411 |
+
|
412 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
413 |
+
agreed to in writing, Licensor provides the Work (and each
|
414 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
415 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
416 |
+
implied, including, without limitation, any warranties or conditions
|
417 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
418 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
419 |
+
appropriateness of using or redistributing the Work and assume any
|
420 |
+
risks associated with Your exercise of permissions under this License.
|
421 |
+
|
422 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
423 |
+
whether in tort (including negligence), contract, or otherwise,
|
424 |
+
unless required by applicable law (such as deliberate and grossly
|
425 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
426 |
+
liable to You for damages, including any direct, indirect, special,
|
427 |
+
incidental, or consequential damages of any character arising as a
|
428 |
+
result of this License or out of the use or inability to use the
|
429 |
+
Work (including but not limited to damages for loss of goodwill,
|
430 |
+
work stoppage, computer failure or malfunction, or any and all
|
431 |
+
other commercial damages or losses), even if such Contributor
|
432 |
+
has been advised of the possibility of such damages.
|
433 |
+
|
434 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
435 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
436 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
437 |
+
or other liability obligations and/or rights consistent with this
|
438 |
+
License. However, in accepting such obligations, You may act only
|
439 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
440 |
+
of any other Contributor, and only if You agree to indemnify,
|
441 |
+
defend, and hold each Contributor harmless for any liability
|
442 |
+
incurred by, or claims asserted against, such Contributor by reason
|
443 |
+
of your accepting any such warranty or additional liability.
|
444 |
+
|
445 |
+
END OF TERMS AND CONDITIONS
|
446 |
+
|
447 |
+
APPENDIX: How to apply the Apache License to your work.
|
448 |
+
|
449 |
+
To apply the Apache License to your work, attach the following
|
450 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
451 |
+
replaced with your own identifying information. (Don't include
|
452 |
+
the brackets!) The text should be enclosed in the appropriate
|
453 |
+
comment syntax for the file format. We also recommend that a
|
454 |
+
file or class name and description of purpose be included on the
|
455 |
+
same "printed page" as the copyright notice for easier
|
456 |
+
identification within third-party archives.
|
457 |
+
|
458 |
+
Copyright [yyyy] [name of copyright owner]
|
459 |
+
|
460 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
461 |
+
you may not use this file except in compliance with the License.
|
462 |
+
You may obtain a copy of the License at
|
463 |
+
|
464 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
465 |
+
|
466 |
+
Unless required by applicable law or agreed to in writing, software
|
467 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
468 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
469 |
+
See the License for the specific language governing permissions and
|
470 |
+
limitations under the License.
|
README-SEED-2.md
ADDED
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# :chestnut: SEED Multimodal
|
2 |
+
|
3 |
+
[![Project Homepage](https://img.shields.io/badge/Project-Homepage-green)](https://ailab-cvc.github.io/seed/)
|
4 |
+
[![arXiv](https://img.shields.io/badge/arXiv-2307.08041-b31b1b.svg)](https://arxiv.org/abs/2307.08041)
|
5 |
+
[![arXiv](https://img.shields.io/badge/arXiv-2310.01218-b31b1b.svg)](https://arxiv.org/abs/2310.01218)
|
6 |
+
[![Static Badge](https://img.shields.io/badge/Model-Huggingface-yellow)](https://huggingface.co/AILab-CVC/SEED/tree/main)
|
7 |
+
[![Demo](https://img.shields.io/badge/Gradio-Demo-orange)](https://10a4e7976e6fc2032c.gradio.live/)
|
8 |
+
|
9 |
+
|
10 |
+
**Powered by [CV Center, Tencent AI Lab](https://ailab-cvc.github.io), and [ARC Lab, Tencent PCG](https://github.com/TencentARC).**
|
11 |
+
|
12 |
+
![image](https://github.com/AILab-CVC/SEED/blob/main/paper_images/milestone.jpg)
|
13 |
+
|
14 |
+
The repository provides the official implementation of [SEED](https://ailab-cvc.github.io/seed/seed.html), [SEED-LLaMA](https://ailab-cvc.github.io/seed/seed_llama.html). For any inquiries, please email [seed-x@googlegroups.com](mailto:seed-x@googlegroups.com).
|
15 |
+
|
16 |
+
|
17 |
+
## News
|
18 |
+
|
19 |
+
**:beers: We are actively looking for self-motivated interns. Please feel free to reach out if you are interested. :beers:**
|
20 |
+
|
21 |
+
- [x] **2023-10-23** :hugs: We have optimized the memory overhead. Through 8bit quantization and dynamic loading, SEED-LLaMA 8b/14B can run on single **16GB/24GB** GPU.
|
22 |
+
- [x] **2023-10-23** :hugs: All model weights will be **downloaded automatically** when starting the demo.
|
23 |
+
- [x] **2023-10-20** :hugs: We release the [checkpoints](https://huggingface.co/AILab-CVC/SEED/tree/main) and code of the SEED-2 tokenizer, and SEED-LLaMA-8B/14B.
|
24 |
+
- [x] **2023-10-20** :space_invader: We release an online [gradio demo](https://10a4e7976e6fc2032c.gradio.live/), feel free to use it by yourself.
|
25 |
+
- [x] **2023-10-02** :paperclip: We release the technical report of SEED-LLaMA on [arXiv](https://arxiv.org/abs/2310.01218), which is empowered by the improved SEED-2 tokenizer.
|
26 |
+
- [x] **2023-07-29** :octocat: We release the checkpoint of the SEED tokenizer and its inference code. Check it out via [SEED-1](./SEED-1.md).
|
27 |
+
- [x] **2023-07-16** :paperclip: We release the technical report of SEED on [arXiv](https://arxiv.org/abs/2307.08041).
|
28 |
+
|
29 |
+
Stay tuned for the updates!
|
30 |
+
|
31 |
+
## Brief Introduction
|
32 |
+
|
33 |
+
It is recommended to check out our [papers](#citation) for technical details.
|
34 |
+
|
35 |
+
### :speech_balloon: What can SEED-LLaMA do?
|
36 |
+
|
37 |
+
![image](https://github.com/AILab-CVC/SEED/blob/main/paper_images/v2/teaser.jpg)
|
38 |
+
|
39 |
+
**SEED-LLaMA** is capable of both multimodal comprehension and generation, exhibiting compositional emergent abilities such as multi-turn in-context multimodal generation, acting like your AI assistant. [[Compare to SOTA]](https://ailab-cvc.github.io/seed/seed_llama_compare.html) [[More examples on X]](https://twitter.com/ge_yixiao/status/1710509538238157069?s=20)
|
40 |
+
|
41 |
+
<!-- We present **SEED-LLaMA** by large-scale pretraining and instruction tuning on the interleaved textual and visual data, which demonstrates impressive performance on a broad range of multimodal comprehension and generation tasks. More importantly, SEED-LLaMA has exhibited **compositional emergent abilities** such as multi-turn in-context multimodal generation, acting like your **AI assistant**. -->
|
42 |
+
|
43 |
+
### :bulb: How does SEED-LLaMA achieve it?
|
44 |
+
|
45 |
+
![image](https://github.com/AILab-CVC/SEED/blob/main/paper_images/seed_overview.jpg)
|
46 |
+
|
47 |
+
The core of SEED-LLaMA is the tailored **SEED** tokenizer, which properly quantized visual signals into discrete visual tokens, capturing necessary semantics while being produced under 1D causal dependence. [[SEED-2 vs. SEED-1]](https://ailab-cvc.github.io/seed/seed_llama.html)
|
48 |
+
|
49 |
+
<!-- ### Compositional Emergent Ability
|
50 |
+
**Multi-turn in-context image and text generation.**
|
51 |
+
![image](paper_images/v2/multi_turn1.jpg)
|
52 |
+
![image](paper_images/v2/multi_turn2.jpg)
|
53 |
+
|
54 |
+
**Compositional image generation.**
|
55 |
+
![image](paper_images/v2/results.jpg) -->
|
56 |
+
|
57 |
+
<!-- ### SEED Tokenizer v2
|
58 |
+
In SEED tokenizer v2, the generation embedding is aligned with the **image embedding** (1 token) of [unCLIP SD](https://huggingface.co/stabilityai/stable-diffusion-2-1-unclip), and can be decoded to realistic images with the unCLIP-SD-UNet. In SEED tokenizer v1, we train a visual tokenizer through aligning the **generation embeddings** with the text embeddings (77 tokens) of [SD](https://github.com/CompVis/stable-diffusion), and the generation embeddings can be decoded to images with the SD-UNet. The below figure shows the visual comparison of the reconstructed images between SEED tokenizer v2 (the third row) and SEED tokenizer v1 (the second row). We can observe that the images reconstructed by SEED tokenizer v2 can better preserve the visual information of the original images. The semantic representations of texts can not fully preserve the rich visual information of images.
|
59 |
+
![image](paper_images/v2/seed_comparison.jpg) -->
|
60 |
+
|
61 |
+
<!-- ### Pretraining
|
62 |
+
We perform multimodal autoregressive pretraining on interleaved visual and textual data for SEED-LLaMA. Visual inputs are pre-processed into discrete tokens to conserve computational resources. Given the multimodal discrete sequence, a unified next-word-prediction objective is employed. During inference, visual codes are decoded into a realistic image by SEED De-Tokenization.
|
63 |
+
![image](paper_images/v2/method_page.jpg) -->
|
64 |
+
|
65 |
+
## Usage
|
66 |
+
|
67 |
+
### Dependencies
|
68 |
+
- Python >= 3.8 (Recommend to use [Anaconda](https://www.anaconda.com/download/#linux))
|
69 |
+
- [PyTorch >= 1.11.0](https://pytorch.org/)
|
70 |
+
- NVIDIA GPU + [CUDA](https://developer.nvidia.com/cuda-downloads)
|
71 |
+
|
72 |
+
### Installation
|
73 |
+
Clone the repo and install dependent packages
|
74 |
+
|
75 |
+
```bash
|
76 |
+
git clone https://github.com/AILab-CVC/SEED.git
|
77 |
+
cd SEED
|
78 |
+
pip install -r requirements.txt
|
79 |
+
```
|
80 |
+
|
81 |
+
|
82 |
+
### Model Weights
|
83 |
+
We release the pretrained SEED Tokenizer and De-Tokenizer, pretrained and instruction tuned SEED-LLaMA-8B and SEED-LLaMA-14B in [SEED Hugging Face](https://huggingface.co/AILab-CVC/SEED).
|
84 |
+
|
85 |
+
- Check the SEED tokenizer weights in [AILab-CVC/seed-tokenizer-2](https://huggingface.co/AILab-CVC/seed-tokenizer-2)
|
86 |
+
- Check the SEED LLaMA(8B) weights in [AILab-CVC/seed-llama-8b-sft](https://huggingface.co/AILab-CVC/seed-llama-8b-sft)
|
87 |
+
- Check the SEED LLaMA(14B) weights in [AILab-CVC/seed-llama-14b-sft](https://huggingface.co/AILab-CVC/seed-llama-14b-sft)
|
88 |
+
|
89 |
+
<!-- Please download the checkpoints and save under the folder `./pretrained`.
|
90 |
+
|
91 |
+
```bash
|
92 |
+
cd pretrained # SEED/pretrained
|
93 |
+
git lfs install
|
94 |
+
git clone https://huggingface.co/AILab-CVC/SEED
|
95 |
+
mv SEED/* ./
|
96 |
+
``` -->
|
97 |
+
|
98 |
+
The model weights of unCLIP SD-UNet which are used to reconstruct the image will be downloaded automatically.
|
99 |
+
|
100 |
+
<!-- To reconstruct the image from the SEED visual codes using unCLIP SD-UNet, please download the pretrained [unCLIP SD](https://huggingface.co/stabilityai/stable-diffusion-2-1-unclip). -->
|
101 |
+
|
102 |
+
<!-- To reconstruct the image from the SEED visual codes using unCLIP SD-UNet, please download the pretrained [unCLIP SD](https://huggingface.co/stabilityai/stable-diffusion-2-1-unclip).
|
103 |
+
Rename the checkpoint directory to **"diffusion_model"** and create a soft link to the "pretrained/seed_tokenizer" directory.
|
104 |
+
|
105 |
+
```bash
|
106 |
+
# SEED/pretrained
|
107 |
+
git lfs install
|
108 |
+
git clone https://huggingface.co/stabilityai/stable-diffusion-2-1-unclip
|
109 |
+
mv stable-diffusion-2-1-unclip seed_tokenizer/diffusion_model
|
110 |
+
``` -->
|
111 |
+
|
112 |
+
|
113 |
+
### Inference for visual tokenization and de-tokenization
|
114 |
+
To discretize an image to 1D visual codes with causal dependency, and reconstruct the image from the visual codes using the off-the-shelf unCLIP SD-UNet:
|
115 |
+
|
116 |
+
```bash
|
117 |
+
cd .. # SEED/
|
118 |
+
python scripts/seed_tokenizer_inference.py
|
119 |
+
```
|
120 |
+
### Inference for SEED-LLaMA
|
121 |
+
Given that SEED-LLaMA-8B is based on Vicuna-7B and SEED-LLaMA-14B based on LLaMA2-Chat-13B, we use Vicuna-7B's ("USER:", "ASSISTANT:") and LLaMA2-Chat-13B's ([INST] [/INST]) prompts for respective instruction tuning.
|
122 |
+
|
123 |
+
```bash
|
124 |
+
# Inference for SEED-LLaMA-8B
|
125 |
+
python scripts/seed_llama_inference_8B.py
|
126 |
+
```
|
127 |
+
|
128 |
+
```bash
|
129 |
+
# Inference for SEED-LLaMA-14B
|
130 |
+
python scripts/seed_llama_inference_14B.py
|
131 |
+
```
|
132 |
+
|
133 |
+
|
134 |
+
### Launching Gradio Demo of SEED-LLaMA-14B Locally
|
135 |
+
1. Building the local demo of SEED-LLaMA-14B currently requires **single 24GB** GPU.
|
136 |
+
|
137 |
+
```bash
|
138 |
+
# SEED/
|
139 |
+
# in first terminal
|
140 |
+
bash scripts/start_backend_14b.sh
|
141 |
+
# in second terminal
|
142 |
+
bash scripts/start_frontend_14b.sh
|
143 |
+
```
|
144 |
+
|
145 |
+
2. Building the local demo of SEED-LLaMA-8B currently requires **single 16GB** GPU.
|
146 |
+
|
147 |
+
```bash
|
148 |
+
# SEED/
|
149 |
+
# in first terminal
|
150 |
+
bash scripts/start_backend_8b.sh
|
151 |
+
# in second terminal
|
152 |
+
bash scripts/start_frontend_8b.sh
|
153 |
+
```
|
154 |
+
|
155 |
+
Then the demo can be accessed through http://127.0.0.1:80
|
156 |
+
|
157 |
+
## Citation
|
158 |
+
If you find the work helpful, please consider citing:
|
159 |
+
```bash
|
160 |
+
@article{ge2023making,
|
161 |
+
title={Making LLaMA SEE and Draw with SEED Tokenizer},
|
162 |
+
author={Ge, Yuying and Zhao, Sijie and Zeng, Ziyun and Ge, Yixiao and Li, Chen and Wang, Xintao and Shan, Ying},
|
163 |
+
journal={arXiv preprint arXiv:2310.01218},
|
164 |
+
year={2023}
|
165 |
+
}
|
166 |
+
|
167 |
+
@article{ge2023planting,
|
168 |
+
title={Planting a seed of vision in large language model},
|
169 |
+
author={Ge, Yuying and Ge, Yixiao and Zeng, Ziyun and Wang, Xintao and Shan, Ying},
|
170 |
+
journal={arXiv preprint arXiv:2307.08041},
|
171 |
+
year={2023}
|
172 |
+
}
|
173 |
+
```
|
174 |
+
|
175 |
+
The project is still in progress.
|
176 |
+
|
177 |
+
## License
|
178 |
+
`SEED` is released under [Apache License Version 2.0](License.txt).
|
179 |
+
|
180 |
+
`SEED-LLaMA` is released under the original [License](https://ai.meta.com/resources/models-and-libraries/llama-downloads/) of [LLaMA2](https://huggingface.co/meta-llama/Llama-2-13b-chat-hf).
|
181 |
+
|
182 |
+
## Acknowledgement
|
183 |
+
We thank the great work from [unCLIP SD](https://huggingface.co/stabilityai/stable-diffusion-2-1-unclip) and [BLIP2](https://github.com/salesforce/LAVIS).
|
184 |
+
|
SEED-1.md
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SEED Tokenizer v1
|
2 |
+
[[arXiv]](https://arxiv.org/abs/2307.08041)
|
3 |
+
|
4 |
+
![image](paper_images/teaser.jpg)
|
5 |
+
## Abstract
|
6 |
+
We present SEED, an elaborate image tokenizer that empowers Large Language
|
7 |
+
Models (LLMs) with the emergent ability to **SEE** and **D**raw at the same time.
|
8 |
+
Research on image tokenizers has previously reached an impasse, as frameworks
|
9 |
+
employing quantized visual tokens have lost prominence due to subpar performance and convergence in multimodal comprehension (compared to BLIP-2, etc.)
|
10 |
+
or generation (compared to Stable Diffusion, etc.). Despite the limitations, we
|
11 |
+
remain confident in its natural capacity to unify visual and textual representations,
|
12 |
+
facilitating scalable multimodal training with LLM’s original recipe. In this study,
|
13 |
+
we identify two crucial principles for the architecture and training of SEED that
|
14 |
+
effectively ease subsequent alignment with LLMs. (1) Image tokens should be
|
15 |
+
independent of 2D physical patch positions and instead be produced with a 1D
|
16 |
+
causal dependency, exhibiting intrinsic interdependence that aligns with the left-to-right autoregressive prediction mechanism in LLMs. (2) Image tokens should
|
17 |
+
capture high-level semantics consistent with the degree of semantic abstraction in
|
18 |
+
words, and be optimized for both discriminativeness and reconstruction during the
|
19 |
+
tokenizer training phase. As a result, the off-the-shelf LLM is able to perform both
|
20 |
+
image-to-text and text-to-image generation by incorporating our SEED through
|
21 |
+
efficient LoRA tuning. Comprehensive multimodal pretraining and instruction
|
22 |
+
tuning, which may yield improved results, are reserved for future investigation.
|
23 |
+
This version of SEED was trained in 5.7 days using only 64 V100 GPUs and 5M
|
24 |
+
publicly available image-text pairs. Our preliminary study emphasizes the great
|
25 |
+
potential of discrete visual tokens in versatile multimodal LLMs and the importance
|
26 |
+
of proper image tokenizers in broader research.
|
27 |
+
|
28 |
+
## SEED Tokenizer for Image Reconstruction
|
29 |
+
![image](paper_images/reconstruction.jpg)
|
30 |
+
|
31 |
+
## SEED-OPT<sub>2.7B </sub> for Multimodal Comprehension
|
32 |
+
![image](paper_images/vqa.jpg)
|
33 |
+
|
34 |
+
## SEED-OPT<sub>2.7B </sub> for Multimodal Generation
|
35 |
+
![image](paper_images/generation.jpg)
|
36 |
+
|
37 |
+
## Dependencies and Installation
|
38 |
+
- Python >= 3.8 (Recommend to use [Anaconda](https://www.anaconda.com/download/#linux))
|
39 |
+
- [PyTorch >= 1.11.0](https://pytorch.org/)
|
40 |
+
- NVIDIA GPU + [CUDA](https://developer.nvidia.com/cuda-downloads)
|
41 |
+
### Installation
|
42 |
+
1. Clone repo
|
43 |
+
|
44 |
+
```bash
|
45 |
+
git clone https://github.com/AILab-CVC/SEED.git
|
46 |
+
cd SEED
|
47 |
+
```
|
48 |
+
|
49 |
+
2. Install dependent packages
|
50 |
+
|
51 |
+
```bash
|
52 |
+
sh install.sh
|
53 |
+
```
|
54 |
+
|
55 |
+
## Model Weights
|
56 |
+
We release the pre-trained SEED Visual Tokenizer in [google drive](https://drive.google.com/drive/folders/1xmVXuttQfBPBOe4ZR96Wu1X34uzPkxsS?usp=drive_link).
|
57 |
+
|
58 |
+
## Inference
|
59 |
+
To discretize an image to 1D vision codes with causal dependency, and reconstruct the image
|
60 |
+
from the vision codes using stable diffusion UNet,
|
61 |
+
|
62 |
+
1. Download the pre-trained SEED Visual Tokenizer and stable diffusion model in [google drive](https://drive.google.com/drive/folders/1xmVXuttQfBPBOe4ZR96Wu1X34uzPkxsS?usp=drive_link) and put them under the folder "pretrained".
|
63 |
+
2. run the inference code.
|
64 |
+
```bash
|
65 |
+
python demo_recon.py
|
66 |
+
```
|
67 |
+
|
68 |
+
## To Do
|
69 |
+
- [x] Release SEED Tokenizer
|
70 |
+
|
71 |
+
## License
|
72 |
+
SEED is released under Apache License Version 2.0.
|
73 |
+
|
74 |
+
## Acknowledgement
|
75 |
+
We utilize Stable Diffusion to decode images from our visual codes, and use its implementation and pre-trained model in https://github.com/CompVis/stable-diffusion.git.
|
76 |
+
|
77 |
+
Our code is based on the implementation of BLIP-2 in https://github.com/salesforce/LAVIS.git.
|
78 |
+
|
79 |
+
|
80 |
+
## Citation
|
81 |
+
If you find the work helpful, please consider citing:
|
82 |
+
```
|
83 |
+
@misc{ge2023planting,
|
84 |
+
title={Planting a SEED of Vision in Large Language Model},
|
85 |
+
author={Yuying Ge and Yixiao Ge and Ziyun Zeng and Xintao Wang and Ying Shan},
|
86 |
+
year={2023},
|
87 |
+
eprint={2307.08041},
|
88 |
+
archivePrefix={arXiv},
|
89 |
+
primaryClass={cs.CV}
|
90 |
+
}
|
91 |
+
```
|
92 |
+
|
93 |
+
The project is still in progress. Stay tuned for more updates!
|
configs/llm/seed_llama_14b.yaml
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_target_: models.model_tools.get_pretrained_llama_causal_model
|
2 |
+
pretrained_model_name_or_path: ${oc.env:PROJECT_ROOT}/pretrained/seed_llama_14b_sft
|
3 |
+
|
4 |
+
torch_dtype: fp16
|
5 |
+
low_cpu_mem_usage: True
|
configs/llm/seed_llama_14b_8bit.yaml
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_target_: transformers.LlamaForCausalLM.from_pretrained
|
2 |
+
pretrained_model_name_or_path: AILab-CVC/seed-llama-14b-sft
|
3 |
+
load_in_8bit: True
|
4 |
+
# device_map: auto
|
5 |
+
low_cpu_mem_usage: True
|
configs/llm/seed_llama_8b.yaml
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_target_: models.model_tools.get_pretrained_llama_causal_model
|
2 |
+
pretrained_model_name_or_path: ${oc.env:PROJECT_ROOT}/pretrained/seed_llama_8b_sft
|
3 |
+
|
4 |
+
torch_dtype: fp16
|
5 |
+
low_cpu_mem_usage: True
|
configs/llm/seed_llama_8b_8bit.yaml
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_target_: transformers.LlamaForCausalLM.from_pretrained
|
2 |
+
pretrained_model_name_or_path: AILab-CVC/seed-llama-8b-sft
|
3 |
+
load_in_8bit: True
|
4 |
+
# device_map: auto
|
5 |
+
low_cpu_mem_usage: True
|
configs/tokenizer/seed_llama_tokenizer.yaml
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_target_: models.seed_llama_tokenizer.SeedLlamaTokenizer.from_pretrained
|
2 |
+
pretrained_model_name_or_path: ${oc.env:PROJECT_ROOT}/pretrained/seed_tokenizer
|
3 |
+
fp16: True
|
4 |
+
load_diffusion: True
|
configs/tokenizer/seed_llama_tokenizer_hf.yaml
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_target_: models.seed_llama_tokenizer.SeedLlamaTokenizer.from_pretrained
|
2 |
+
pretrained_model_name_or_path: AILab-CVC/seed-tokenizer-2
|
3 |
+
fp16: True
|
4 |
+
load_diffusion: False
|
5 |
+
encoder_url: https://huggingface.co/AILab-CVC/seed-tokenizer-2/resolve/main/seed_quantizer.pt
|
6 |
+
diffusion_path: stabilityai/stable-diffusion-2-1-unclip
|
configs/transform/clip_transform.yaml
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_target_: models.transforms.get_transform
|
2 |
+
type: clip
|
3 |
+
image_size: 224
|
4 |
+
keep_ratio: False
|
gradio_demo/conversation.py
ADDED
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import dataclasses
|
2 |
+
from enum import auto, Enum
|
3 |
+
from typing import List, Tuple
|
4 |
+
|
5 |
+
import io
|
6 |
+
import base64
|
7 |
+
import os
|
8 |
+
from PIL import Image
|
9 |
+
import copy
|
10 |
+
|
11 |
+
IMG_FLAG = '<image>'
|
12 |
+
|
13 |
+
|
14 |
+
class SeparatorStyle(Enum):
|
15 |
+
"""Different separator style."""
|
16 |
+
SINGLE = auto()
|
17 |
+
TWO = auto()
|
18 |
+
MPT = auto()
|
19 |
+
PLAIN = auto()
|
20 |
+
LLAMA_2 = auto()
|
21 |
+
|
22 |
+
|
23 |
+
def decode_image(encoded_image: str) -> Image:
|
24 |
+
decoded_bytes = base64.b64decode(encoded_image.encode('utf-8'))
|
25 |
+
buffer = io.BytesIO(decoded_bytes)
|
26 |
+
image = Image.open(buffer)
|
27 |
+
return image
|
28 |
+
|
29 |
+
|
30 |
+
def encode_image(image: Image.Image, format: str = 'PNG') -> str:
|
31 |
+
with io.BytesIO() as buffer:
|
32 |
+
image.save(buffer, format=format)
|
33 |
+
encoded_image = base64.b64encode(buffer.getvalue()).decode('utf-8')
|
34 |
+
return encoded_image
|
35 |
+
|
36 |
+
|
37 |
+
@dataclasses.dataclass
|
38 |
+
class Conversation:
|
39 |
+
"""A class that keeps all conversation history."""
|
40 |
+
system: str
|
41 |
+
roles: List[str]
|
42 |
+
messages: List[dict] # multi-turn -> user & assistant -> {'images': [PIL.Image,], 'text': str}
|
43 |
+
offset: int
|
44 |
+
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
|
45 |
+
sep: str = "###"
|
46 |
+
sep2: str = None
|
47 |
+
version: str = "Unknown"
|
48 |
+
|
49 |
+
skip_next: bool = False
|
50 |
+
|
51 |
+
def get_prompt(self):
|
52 |
+
messages = copy.deepcopy(self.messages)
|
53 |
+
if self.sep_style == SeparatorStyle.SINGLE:
|
54 |
+
if self.system is None or self.system == '':
|
55 |
+
text = ''
|
56 |
+
else:
|
57 |
+
text = self.system + self.sep
|
58 |
+
images = []
|
59 |
+
for message in messages:
|
60 |
+
text += message['role'] + ": " + message['message']['text'] + self.sep
|
61 |
+
for image_path, image_ids in zip(message['message']['images'], message['message']['images_ids']):
|
62 |
+
if image_ids is not None:
|
63 |
+
images.append(image_ids)
|
64 |
+
else:
|
65 |
+
image = Image.open(image_path).resize((256, 256))
|
66 |
+
image_base64 = encode_image(image)
|
67 |
+
images.append(image_base64)
|
68 |
+
|
69 |
+
text += self.roles[1] + ":"
|
70 |
+
elif self.sep_style == SeparatorStyle.LLAMA_2:
|
71 |
+
b_token = "[INST] "
|
72 |
+
e_token = " [/INST]"
|
73 |
+
if self.system is None or self.system == '':
|
74 |
+
text = ''
|
75 |
+
else:
|
76 |
+
text = f"<<SYS>>\n{self.system}\n<</SYS>>\n\n"
|
77 |
+
images = []
|
78 |
+
for idx, message in enumerate(messages):
|
79 |
+
# text += message['role'] + ": " + message['message']['text'] + self.sep
|
80 |
+
if idx % 2 == 0:
|
81 |
+
text += b_token + message['message']['text'] + e_token + self.sep
|
82 |
+
else:
|
83 |
+
text += message['message']['text'] + self.sep
|
84 |
+
|
85 |
+
for image_path, image_ids in zip(message['message']['images'], message['message']['images_ids']):
|
86 |
+
if image_ids is not None:
|
87 |
+
images.append(image_ids)
|
88 |
+
else:
|
89 |
+
image = Image.open(image_path).resize((256, 256))
|
90 |
+
image_base64 = encode_image(image)
|
91 |
+
images.append(image_base64)
|
92 |
+
else:
|
93 |
+
raise NotImplementedError
|
94 |
+
|
95 |
+
return {'text': text, 'images': images}
|
96 |
+
|
97 |
+
def update_image_ids(self, images_ids):
|
98 |
+
image_count = 0
|
99 |
+
for message in self.messages:
|
100 |
+
for idx in range(len(message['message']['images_ids'])):
|
101 |
+
if message['message']["images_ids"][idx] is None:
|
102 |
+
message['message']["images_ids"][idx] = images_ids[image_count]
|
103 |
+
image_count += 1
|
104 |
+
|
105 |
+
assert len(images_ids) == image_count, print(len(images_ids), image_count)
|
106 |
+
|
107 |
+
def append_message(self, role, message):
|
108 |
+
self.messages.append([role, message])
|
109 |
+
|
110 |
+
def to_gradio_chatbot(self):
|
111 |
+
dialog = []
|
112 |
+
for i, single_turn in enumerate(self.messages[self.offset:]):
|
113 |
+
single_turn = single_turn['message']
|
114 |
+
text_list = single_turn['text'].split(IMG_FLAG)
|
115 |
+
assert len(text_list) == len(single_turn['images']) + 1, print(text_list, len(single_turn['images']))
|
116 |
+
message = ''
|
117 |
+
for image_idx in range(len(single_turn['images'])):
|
118 |
+
# image = single_turn['images'][image_idx]
|
119 |
+
# image_base64 = encode_image(image)
|
120 |
+
# image_str = f'<img src="data:image/png;base64,{image_base64}" alt="user upload image" />'
|
121 |
+
image_path = single_turn['images'][image_idx]
|
122 |
+
if image_path == '':
|
123 |
+
message += text_list[image_idx] + '<corrupt_image>'
|
124 |
+
else:
|
125 |
+
message += text_list[image_idx] + f'![](file={image_path})'
|
126 |
+
message += text_list[-1]
|
127 |
+
|
128 |
+
if i % 2 == 0:
|
129 |
+
dialog.append([message, None])
|
130 |
+
else:
|
131 |
+
dialog[-1][-1] = message
|
132 |
+
|
133 |
+
return dialog
|
134 |
+
|
135 |
+
def copy(self):
|
136 |
+
return Conversation(system=self.system,
|
137 |
+
roles=self.roles,
|
138 |
+
messages=copy.deepcopy(self.messages),
|
139 |
+
offset=self.offset,
|
140 |
+
sep_style=self.sep_style,
|
141 |
+
sep=self.sep,
|
142 |
+
sep2=self.sep2,
|
143 |
+
version=self.version)
|
144 |
+
|
145 |
+
def dict(self):
|
146 |
+
messages = copy.deepcopy(self.messages)
|
147 |
+
for message in messages:
|
148 |
+
if 'images_ids' in message:
|
149 |
+
message.pop('images_ids')
|
150 |
+
for i in range(len(message['message']['images'])):
|
151 |
+
message['message']['images'][i] = os.path.basename(message['message']['images'][i])
|
152 |
+
return {
|
153 |
+
"system": self.system,
|
154 |
+
"roles": self.roles,
|
155 |
+
"messages": messages,
|
156 |
+
"offset": self.offset,
|
157 |
+
"sep": self.sep,
|
158 |
+
"sep2": self.sep2,
|
159 |
+
}
|
160 |
+
|
161 |
+
|
162 |
+
conv_seed_vicuna = Conversation(
|
163 |
+
system="",
|
164 |
+
roles=("USER", "ASSISTANT"),
|
165 |
+
version="v2",
|
166 |
+
messages=[],
|
167 |
+
offset=0,
|
168 |
+
sep_style=SeparatorStyle.SINGLE,
|
169 |
+
sep='\n',
|
170 |
+
)
|
171 |
+
|
172 |
+
conv_seed_vicuna_system = Conversation(
|
173 |
+
system="A chat between a curious user and an artificial intelligence assistant. ",
|
174 |
+
roles=("USER", "ASSISTANT"),
|
175 |
+
version="v2",
|
176 |
+
messages=[],
|
177 |
+
offset=0,
|
178 |
+
sep_style=SeparatorStyle.SINGLE,
|
179 |
+
sep='\n',
|
180 |
+
)
|
181 |
+
|
182 |
+
conv_seed_llama2 = Conversation(
|
183 |
+
system="",
|
184 |
+
roles=("[INST]", "[/INST]"),
|
185 |
+
version="v2",
|
186 |
+
messages=[],
|
187 |
+
offset=0,
|
188 |
+
sep_style=SeparatorStyle.LLAMA_2,
|
189 |
+
sep='\n',
|
190 |
+
)
|
gradio_demo/seed_llama_flask.py
ADDED
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import hydra
|
2 |
+
|
3 |
+
import pyrootutils
|
4 |
+
import os
|
5 |
+
import torch
|
6 |
+
|
7 |
+
from omegaconf import OmegaConf
|
8 |
+
from flask import Flask, request
|
9 |
+
import json
|
10 |
+
from typing import Optional
|
11 |
+
import transformers
|
12 |
+
from dataclasses import dataclass, field
|
13 |
+
import io
|
14 |
+
import base64
|
15 |
+
from PIL import Image
|
16 |
+
import gc
|
17 |
+
|
18 |
+
pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
19 |
+
|
20 |
+
BOI_TOKEN = '<img>'
|
21 |
+
EOI_TOKEN = '</img>'
|
22 |
+
IMG_TOKEN = '<img_{:05d}>'
|
23 |
+
|
24 |
+
IMG_FLAG = '<image>'
|
25 |
+
NUM_IMG_TOKNES = 32
|
26 |
+
NUM_IMG_CODES = 8192
|
27 |
+
|
28 |
+
app = Flask(__name__)
|
29 |
+
|
30 |
+
|
31 |
+
def decode_image(encoded_image: str) -> Image:
|
32 |
+
decoded_bytes = base64.b64decode(encoded_image.encode('utf-8'))
|
33 |
+
buffer = io.BytesIO(decoded_bytes)
|
34 |
+
image = Image.open(buffer)
|
35 |
+
return image
|
36 |
+
|
37 |
+
|
38 |
+
def encode_image(image: Image.Image, format: str = 'PNG') -> str:
|
39 |
+
with io.BytesIO() as buffer:
|
40 |
+
image.save(buffer, format=format)
|
41 |
+
encoded_image = base64.b64encode(buffer.getvalue()).decode('utf-8')
|
42 |
+
return encoded_image
|
43 |
+
|
44 |
+
|
45 |
+
@dataclass
|
46 |
+
class Arguments:
|
47 |
+
image_transform: Optional[str] = field(default=None, metadata={"help": "config path of image transform"})
|
48 |
+
tokenizer: Optional[str] = field(default=None, metadata={"help": "config path of tokenizer used to initialize tokenizer"})
|
49 |
+
model: Optional[str] = field(default=None, metadata={"help": "config path of llm"})
|
50 |
+
port: Optional[str] = field(default=80, metadata={"help": "network port"})
|
51 |
+
llm_device: Optional[str] = field(default='cuda:0', metadata={"help": "llm device"})
|
52 |
+
tokenizer_device: Optional[str] = field(default='cuda:0', metadata={"help": "tokenizer device"})
|
53 |
+
offload_encoder: Optional[bool] = field(default=False, metadata={"help": "offload image tokenizer"})
|
54 |
+
offload_decoder: Optional[bool] = field(default=True, metadata={"help": "offload image tokenizer"})
|
55 |
+
|
56 |
+
|
57 |
+
parser = transformers.HfArgumentParser(Arguments)
|
58 |
+
args, = parser.parse_args_into_dataclasses()
|
59 |
+
|
60 |
+
|
61 |
+
class LLMService:
|
62 |
+
def __init__(self, args) -> None:
|
63 |
+
image_transform_cfg = OmegaConf.load(args.image_transform)
|
64 |
+
tokenizer_cfg = OmegaConf.load(args.tokenizer)
|
65 |
+
model_cfg = OmegaConf.load(args.model)
|
66 |
+
self.image_id_shift = 32000
|
67 |
+
|
68 |
+
self.image_transform = hydra.utils.instantiate(image_transform_cfg)
|
69 |
+
self.tokenizer = hydra.utils.instantiate(tokenizer_cfg, device=args.tokenizer_device, load_diffusion=True)
|
70 |
+
|
71 |
+
if args.offload_encoder:
|
72 |
+
self.tokenizer.image_tokenizer.model.visual_encoder.to('cpu')
|
73 |
+
if args.offload_decoder:
|
74 |
+
self.tokenizer.image_tokenizer.diffusion_model.to('cpu')
|
75 |
+
|
76 |
+
# model = hydra.utils.instantiate(model_cfg, torch_dtype=torch.float16)
|
77 |
+
# self.model = model.eval().to(args.llm_device)
|
78 |
+
model = hydra.utils.instantiate(model_cfg, device_map=args.llm_device).eval()
|
79 |
+
self.model = model
|
80 |
+
print(model.get_memory_footprint())
|
81 |
+
self.llm_device = args.llm_device
|
82 |
+
self.tokenizer_device = args.tokenizer_device
|
83 |
+
self.offload_encoder = args.offload_encoder
|
84 |
+
self.offload_decoder = args.offload_decoder
|
85 |
+
self.boi_token_id = self.tokenizer(BOI_TOKEN, add_special_tokens=False).input_ids[0]
|
86 |
+
self.eoi_token_id = self.tokenizer(EOI_TOKEN, add_special_tokens=False).input_ids[0]
|
87 |
+
print('Init Done...')
|
88 |
+
|
89 |
+
|
90 |
+
service = LLMService(args)
|
91 |
+
|
92 |
+
|
93 |
+
@app.route('/generate', methods=['GET', 'POST'])
|
94 |
+
def generate():
|
95 |
+
|
96 |
+
request_info = request.get_json()
|
97 |
+
|
98 |
+
text_list = request_info['text'].split(IMG_FLAG)
|
99 |
+
image_list = request_info['images']
|
100 |
+
temperature = request_info.get('temperature', 0.7)
|
101 |
+
num_beams = request_info.get('num_beams', 1)
|
102 |
+
max_new_tokens = request_info.get('max_new_tokens', 256)
|
103 |
+
top_p = request_info.get('top_p', 0.5)
|
104 |
+
force_boi = request_info.get('force_boi', False)
|
105 |
+
|
106 |
+
assert len(text_list) == len(image_list) + 1
|
107 |
+
|
108 |
+
if len(image_list) > 0:
|
109 |
+
images_tensor_list = []
|
110 |
+
images_tensor_indices = []
|
111 |
+
images_ids_list = []
|
112 |
+
images_ids_indices = []
|
113 |
+
for idx, image_item in enumerate(image_list):
|
114 |
+
if isinstance(image_item, str):
|
115 |
+
image = decode_image(image_item)
|
116 |
+
image_tensor = service.image_transform(image)
|
117 |
+
images_tensor_list.append(image_tensor)
|
118 |
+
images_tensor_indices.append(idx)
|
119 |
+
else:
|
120 |
+
images_ids_list.append(image_item)
|
121 |
+
images_ids_indices.append(idx)
|
122 |
+
|
123 |
+
if len(images_tensor_list) > 0:
|
124 |
+
images_tensor = torch.stack(images_tensor_list, dim=0).to(service.tokenizer_device)
|
125 |
+
if service.offload_encoder:
|
126 |
+
service.tokenizer.image_tokenizer.model.visual_encoder.to(service.tokenizer_device)
|
127 |
+
|
128 |
+
images_ids_1 = service.tokenizer.encode_image(image_torch=images_tensor).cpu()
|
129 |
+
if args.offload_encoder:
|
130 |
+
service.tokenizer.image_tokenizer.model.visual_encoder.to('cpu')
|
131 |
+
torch.cuda.empty_cache()
|
132 |
+
gc.collect()
|
133 |
+
num_image_ids = images_ids_1.shape[-1]
|
134 |
+
else:
|
135 |
+
num_image_ids = len(images_ids_list[-1])
|
136 |
+
images_ids_2 = torch.tensor(images_ids_list, dtype=torch.long)
|
137 |
+
|
138 |
+
images_ids = torch.zeros((len(image_list), num_image_ids), dtype=torch.long)
|
139 |
+
if len(images_tensor_indices) > 0:
|
140 |
+
images_ids[images_tensor_indices, :] = images_ids_1
|
141 |
+
if len(images_ids_indices) > 0:
|
142 |
+
images_ids[images_ids_indices, :] = images_ids_2
|
143 |
+
|
144 |
+
input_text = ''
|
145 |
+
for i in range(images_ids.shape[0]):
|
146 |
+
single_image_ids = images_ids[i].view(-1).tolist()
|
147 |
+
image_tokens = BOI_TOKEN + ''.join([IMG_TOKEN.format(int(item)) for item in single_image_ids]) + EOI_TOKEN
|
148 |
+
input_text += text_list[i] + image_tokens
|
149 |
+
|
150 |
+
input_text = service.tokenizer.bos_token + input_text + text_list[-1]
|
151 |
+
|
152 |
+
images_ids_list = images_ids.tolist()
|
153 |
+
else:
|
154 |
+
|
155 |
+
input_text = service.tokenizer.bos_token + ''.join(text_list)
|
156 |
+
images_ids_list = []
|
157 |
+
|
158 |
+
if force_boi:
|
159 |
+
input_text += BOI_TOKEN
|
160 |
+
|
161 |
+
print(input_text)
|
162 |
+
input_ids = service.tokenizer(input_text, add_special_tokens=False, return_tensors='pt').input_ids
|
163 |
+
input_ids = input_ids.to(service.llm_device)
|
164 |
+
generation_config = {
|
165 |
+
'temperature': temperature,
|
166 |
+
'num_beams': num_beams,
|
167 |
+
'max_new_tokens': max_new_tokens,
|
168 |
+
'top_p': top_p,
|
169 |
+
'do_sample': True
|
170 |
+
}
|
171 |
+
|
172 |
+
generate_ids = service.model.generate(input_ids=input_ids, **generation_config)
|
173 |
+
|
174 |
+
if force_boi:
|
175 |
+
generate_ids = generate_ids[0][input_ids.shape[1] - 1:]
|
176 |
+
else:
|
177 |
+
generate_ids = generate_ids[0][input_ids.shape[1]:]
|
178 |
+
print('generated_ids: ', generate_ids)
|
179 |
+
boi_indices = torch.where(generate_ids == service.boi_token_id)[0].tolist()
|
180 |
+
eoi_indices = torch.where(generate_ids == service.eoi_token_id)[0].tolist()
|
181 |
+
# assert len(boi_indices) == len(eoi_indices)
|
182 |
+
|
183 |
+
generated_image_base64_list = []
|
184 |
+
text_mask = torch.ones_like(generate_ids, dtype=torch.bool)
|
185 |
+
|
186 |
+
error_msg = []
|
187 |
+
if len(boi_indices) != len(eoi_indices):
|
188 |
+
error_msg.append(
|
189 |
+
f'Num of BOI (begain of image) tokens: {len(boi_indices)} is not equal to EOI(end of image tokens): {len(eoi_indices)}, some image Some images will fail to decode.'
|
190 |
+
)
|
191 |
+
|
192 |
+
num_images = min(len(boi_indices), len(eoi_indices))
|
193 |
+
for idx in range(num_images):
|
194 |
+
boi_index, eoi_index = boi_indices[idx], eoi_indices[idx]
|
195 |
+
# for boi_index, eoi_index in zip(boi_indices, eoi_indices):
|
196 |
+
image_ids = generate_ids[boi_index + 1:eoi_index].unsqueeze(0).to(service.tokenizer_device)
|
197 |
+
image_ids = image_ids - service.image_id_shift
|
198 |
+
if image_ids.shape[-1] != NUM_IMG_TOKNES:
|
199 |
+
error_msg.append(f'Len(image_ids) {image_ids.shape[-1]} is not equal to {NUM_IMG_TOKNES}')
|
200 |
+
image_base64 = ''
|
201 |
+
elif (image_ids < 0).any() or (image_ids >= NUM_IMG_CODES).any():
|
202 |
+
error_msg.append(f'Some image_id out of range: [0, {NUM_IMG_CODES})')
|
203 |
+
image_base64 = ''
|
204 |
+
else:
|
205 |
+
if service.offload_decoder:
|
206 |
+
service.tokenizer.image_tokenizer.diffusion_model.to(service.tokenizer_device)
|
207 |
+
image = service.tokenizer.decode_image(image_ids)[0]
|
208 |
+
if service.offload_decoder:
|
209 |
+
service.tokenizer.image_tokenizer.diffusion_model.to('cpu')
|
210 |
+
torch.cuda.empty_cache()
|
211 |
+
gc.collect()
|
212 |
+
image_base64 = encode_image(image)
|
213 |
+
|
214 |
+
generated_image_base64_list.append(image_base64)
|
215 |
+
text_mask[boi_index + 1:eoi_index] = False
|
216 |
+
images_ids_list.append(image_ids.view(-1).tolist())
|
217 |
+
generate_ids = generate_ids[text_mask]
|
218 |
+
|
219 |
+
# print('generate_ids: ', generate_ids)
|
220 |
+
# generate_text = service.tokenizer.decode(generate_ids, skip_special_tokens=True)
|
221 |
+
generate_text = service.tokenizer.decode(generate_ids, skip_special_tokens=False)
|
222 |
+
# print('generate_text before: ', generate_text)
|
223 |
+
generate_text = generate_text.replace(BOI_TOKEN + ' ' + EOI_TOKEN + ' ', IMG_FLAG)
|
224 |
+
generate_text = generate_text.replace(service.tokenizer.eos_token, '')
|
225 |
+
print('generate_text: ', generate_text)
|
226 |
+
return {'text': generate_text, 'images': generated_image_base64_list, 'images_ids': images_ids_list, 'error_msg': error_msg}
|
227 |
+
|
228 |
+
|
229 |
+
if __name__ == '__main__':
|
230 |
+
app.run(host='0.0.0.0', port=args.port)
|
gradio_demo/seed_llama_gradio.py
ADDED
@@ -0,0 +1,497 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import hydra
|
2 |
+
|
3 |
+
import pyrootutils
|
4 |
+
import os
|
5 |
+
import torch
|
6 |
+
|
7 |
+
import datetime
|
8 |
+
from omegaconf import OmegaConf
|
9 |
+
# from flask import Flask, request
|
10 |
+
import json
|
11 |
+
from typing import Optional
|
12 |
+
import transformers
|
13 |
+
from dataclasses import dataclass, field
|
14 |
+
import io
|
15 |
+
import base64
|
16 |
+
from PIL import Image
|
17 |
+
import gradio as gr
|
18 |
+
import random
|
19 |
+
import time
|
20 |
+
import hashlib
|
21 |
+
import requests
|
22 |
+
|
23 |
+
from utils import build_logger
|
24 |
+
from conversation import conv_seed_vicuna, conv_seed_llama2
|
25 |
+
# from conversation import conv_seed_llama
|
26 |
+
|
27 |
+
IMG_FLAG = '<image>'
|
28 |
+
|
29 |
+
# request_address = 'http://11.29.21.161:80/generate'
|
30 |
+
# request_address = 'http://0.0.0.0:7890/generate'
|
31 |
+
LOGDIR = 'log'
|
32 |
+
|
33 |
+
logger = build_logger("gradio_seed_llama", LOGDIR)
|
34 |
+
headers = {"User-Agent": "SEED LLaMA Client"}
|
35 |
+
|
36 |
+
no_change_btn = gr.Button.update()
|
37 |
+
enable_btn = gr.Button.update(interactive=True)
|
38 |
+
disable_btn = gr.Button.update(interactive=False)
|
39 |
+
|
40 |
+
@dataclass
|
41 |
+
class Arguments:
|
42 |
+
server_port: Optional[int] = field(default=7860, metadata={"help": "network port"})
|
43 |
+
server_name: Optional[str] = field(default='0.0.0.0', metadata={"help": "network address"})
|
44 |
+
request_address: Optional[str] = field(default='http://127.0.0.1:7890/generate', metadata={"help": "request address"})
|
45 |
+
model_type: Optional[str] = field(default='seed-llama-14b', metadata={"help": "choice: [seed-llama-8b, seed-llama-14b]"})
|
46 |
+
|
47 |
+
parser = transformers.HfArgumentParser(Arguments)
|
48 |
+
args, = parser.parse_args_into_dataclasses()
|
49 |
+
|
50 |
+
if args.model_type == 'seed-llama-8b':
|
51 |
+
conv_seed_llama = conv_seed_vicuna
|
52 |
+
elif args.model_type == 'seed-llama-14b':
|
53 |
+
conv_seed_llama = conv_seed_llama2
|
54 |
+
else:
|
55 |
+
raise ValueError
|
56 |
+
|
57 |
+
|
58 |
+
def decode_image(encoded_image: str) -> Image:
|
59 |
+
decoded_bytes = base64.b64decode(encoded_image.encode('utf-8'))
|
60 |
+
# with io.BytesIO(decoded_bytes) as buffer:
|
61 |
+
# image = Image.open(buffer)
|
62 |
+
# return image
|
63 |
+
buffer = io.BytesIO(decoded_bytes)
|
64 |
+
image = Image.open(buffer)
|
65 |
+
return image
|
66 |
+
|
67 |
+
|
68 |
+
def encode_image(image: Image.Image, format: str = 'PNG') -> str:
|
69 |
+
with io.BytesIO() as buffer:
|
70 |
+
image.save(buffer, format=format)
|
71 |
+
encoded_image = base64.b64encode(buffer.getvalue()).decode('utf-8')
|
72 |
+
return encoded_image
|
73 |
+
|
74 |
+
|
75 |
+
def get_conv_log_filename():
|
76 |
+
t = datetime.datetime.now()
|
77 |
+
name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
|
78 |
+
return name
|
79 |
+
|
80 |
+
|
81 |
+
def get_conv_image_dir():
|
82 |
+
name = os.path.join(LOGDIR, 'images')
|
83 |
+
os.makedirs(name, exist_ok=True)
|
84 |
+
return name
|
85 |
+
|
86 |
+
|
87 |
+
def get_image_name(image, image_dir=None):
|
88 |
+
buffer = io.BytesIO()
|
89 |
+
image.save(buffer, format='PNG')
|
90 |
+
image_bytes = buffer.getvalue()
|
91 |
+
md5 = hashlib.md5(image_bytes).hexdigest()
|
92 |
+
|
93 |
+
if image_dir is not None:
|
94 |
+
image_name = os.path.join(image_dir, md5 + '.png')
|
95 |
+
else:
|
96 |
+
image_name = md5 + '.png'
|
97 |
+
|
98 |
+
return image_name
|
99 |
+
|
100 |
+
|
101 |
+
def resize_image(image, max_size=512):
|
102 |
+
width, height = image.size
|
103 |
+
aspect_ratio = float(width) / float(height)
|
104 |
+
|
105 |
+
if width > height:
|
106 |
+
new_width = max_size
|
107 |
+
new_height = int(new_width / aspect_ratio)
|
108 |
+
else:
|
109 |
+
new_height = max_size
|
110 |
+
new_width = int(new_height * aspect_ratio)
|
111 |
+
|
112 |
+
resized_image = image.resize((new_width, new_height))
|
113 |
+
return resized_image
|
114 |
+
|
115 |
+
|
116 |
+
def center_crop_image(image, max_aspect_ratio=1.5):
|
117 |
+
width, height = image.size
|
118 |
+
aspect_ratio = max(width, height) / min(width, height)
|
119 |
+
|
120 |
+
if aspect_ratio >= max_aspect_ratio:
|
121 |
+
if width > height:
|
122 |
+
new_width = int(height * max_aspect_ratio)
|
123 |
+
left = (width - new_width) // 2
|
124 |
+
right = (width + new_width) // 2
|
125 |
+
top = 0
|
126 |
+
bottom = height
|
127 |
+
else:
|
128 |
+
new_height = int(width * max_aspect_ratio)
|
129 |
+
left = 0
|
130 |
+
right = width
|
131 |
+
top = (height - new_height) // 2
|
132 |
+
bottom = (height + new_height) // 2
|
133 |
+
|
134 |
+
cropped_image = image.crop((left, top, right, bottom))
|
135 |
+
return cropped_image
|
136 |
+
else:
|
137 |
+
return image
|
138 |
+
|
139 |
+
def vote_last_response(state, vote_type, request: gr.Request):
|
140 |
+
with open(get_conv_log_filename(), "a") as fout:
|
141 |
+
data = {
|
142 |
+
"tstamp": round(time.time(), 4),
|
143 |
+
"type": vote_type,
|
144 |
+
"state": state.dict(),
|
145 |
+
"ip": request.client.host,
|
146 |
+
}
|
147 |
+
fout.write(json.dumps(data) + "\n")
|
148 |
+
|
149 |
+
|
150 |
+
def upvote_last_response(state, request: gr.Request):
|
151 |
+
logger.info(f"upvote. ip: {request.client.host}")
|
152 |
+
vote_last_response(state, "upvote", request)
|
153 |
+
return (disable_btn, ) * 2
|
154 |
+
|
155 |
+
|
156 |
+
def downvote_last_response(state, request: gr.Request):
|
157 |
+
logger.info(f"downvote. ip: {request.client.host}")
|
158 |
+
vote_last_response(state, "downvote", request)
|
159 |
+
return (disable_btn, ) * 2
|
160 |
+
|
161 |
+
|
162 |
+
def regenerate(dialog_state, request: gr.Request):
|
163 |
+
logger.info(f"regenerate. ip: {request.client.host}")
|
164 |
+
if dialog_state.messages[-1]['role'] == dialog_state.roles[1]:
|
165 |
+
dialog_state.messages.pop()
|
166 |
+
return (
|
167 |
+
dialog_state,
|
168 |
+
dialog_state.to_gradio_chatbot(),
|
169 |
+
) + (disable_btn, ) * 4
|
170 |
+
|
171 |
+
|
172 |
+
def clear_history(request: gr.Request):
|
173 |
+
logger.info(f"clear_history. ip: {request.client.host}")
|
174 |
+
# state = None
|
175 |
+
# return (state, [], "") + (disable_btn, ) * 5
|
176 |
+
dialog_state = conv_seed_llama.copy()
|
177 |
+
input_state = init_input_state()
|
178 |
+
return (dialog_state, input_state, dialog_state.to_gradio_chatbot()) + (disable_btn, ) * 4
|
179 |
+
|
180 |
+
|
181 |
+
def init_input_state():
|
182 |
+
return {'images': [], 'text': '', 'images_ids': []}
|
183 |
+
|
184 |
+
|
185 |
+
def add_text(dialog_state, input_state, text, request: gr.Request):
|
186 |
+
logger.info(f"add_text. ip: {request.client.host}.")
|
187 |
+
# if len(input_state['text']) == 0:
|
188 |
+
if text is None or len(text) == 0:
|
189 |
+
# dialog_state.skip_next = True
|
190 |
+
return (dialog_state, input_state, "", dialog_state.to_gradio_chatbot()) + (no_change_btn, ) * 4
|
191 |
+
input_state['text'] += text
|
192 |
+
|
193 |
+
# dialog_state.skip_next = False
|
194 |
+
|
195 |
+
if len(dialog_state.messages) > 0 and dialog_state.messages[-1]['role'] == dialog_state.roles[0]:
|
196 |
+
dialog_state.messages[-1]['message'] = input_state
|
197 |
+
else:
|
198 |
+
dialog_state.messages.append({'role': dialog_state.roles[0], 'message': input_state})
|
199 |
+
print('add_text: ', dialog_state.to_gradio_chatbot())
|
200 |
+
|
201 |
+
return (dialog_state, input_state, "", dialog_state.to_gradio_chatbot()) + (disable_btn, ) * 4
|
202 |
+
|
203 |
+
|
204 |
+
def add_image(dialog_state, input_state, image, request: gr.Request):
|
205 |
+
logger.info(f"add_image. ip: {request.client.host}.")
|
206 |
+
if image is None:
|
207 |
+
return (dialog_state, input_state, None, dialog_state.to_gradio_chatbot()) + (no_change_btn, ) * 4
|
208 |
+
|
209 |
+
image = image.convert('RGB')
|
210 |
+
image = resize_image(image, max_size=512)
|
211 |
+
image = center_crop_image(image, max_aspect_ratio=1.3)
|
212 |
+
image_dir = get_conv_image_dir()
|
213 |
+
image_path = get_image_name(image=image, image_dir=image_dir)
|
214 |
+
if not os.path.exists(image_path):
|
215 |
+
image.save(image_path)
|
216 |
+
|
217 |
+
input_state['images'].append(image_path)
|
218 |
+
input_state['text'] += IMG_FLAG
|
219 |
+
input_state['images_ids'].append(None)
|
220 |
+
|
221 |
+
if len(dialog_state.messages) > 0 and dialog_state.messages[-1]['role'] == dialog_state.roles[0]:
|
222 |
+
dialog_state.messages[-1]['message'] = input_state
|
223 |
+
else:
|
224 |
+
dialog_state.messages.append({'role': dialog_state.roles[0], 'message': input_state})
|
225 |
+
|
226 |
+
print('add_image:', dialog_state)
|
227 |
+
|
228 |
+
return (dialog_state, input_state, None, dialog_state.to_gradio_chatbot()) + (disable_btn, ) * 4
|
229 |
+
|
230 |
+
|
231 |
+
def http_bot_test(dialog_state, input_state, temperature, top_p, max_new_tokens, num_beams, max_turns, force_image_gen, request: gr.Request):
|
232 |
+
logger.info(f"http_bot. ip: {request.client.host}")
|
233 |
+
output_state = {}
|
234 |
+
output_state['text'] = 'This is test for frontend!'
|
235 |
+
output_state['images'] = []
|
236 |
+
if len(dialog_state.messages) > 0 and len(dialog_state.messages[-1]['message']['images']) != 0:
|
237 |
+
image = random.choice(dialog_state.messages[-1]['message']['images'])
|
238 |
+
output_state['images'].append(image)
|
239 |
+
output_state['text'] += IMG_FLAG
|
240 |
+
|
241 |
+
dialog_state.messages.append({'role': dialog_state.roles[1], 'message': output_state})
|
242 |
+
input_state = init_input_state()
|
243 |
+
|
244 |
+
print('http_bot: ', dialog_state.to_gradio_chatbot())
|
245 |
+
|
246 |
+
return (dialog_state, input_state, dialog_state.to_gradio_chatbot()) + (enable_btn, ) * 4
|
247 |
+
|
248 |
+
|
249 |
+
def update_error_msg(chatbot, error_msg):
|
250 |
+
if len(error_msg) > 0:
|
251 |
+
info = '\n-------------\nSome errors occurred during response, please clear history and restart.\n' + '\n'.join(
|
252 |
+
error_msg)
|
253 |
+
chatbot[-1][-1] = chatbot[-1][-1] + info
|
254 |
+
|
255 |
+
return chatbot
|
256 |
+
|
257 |
+
|
258 |
+
def http_bot(dialog_state, input_state, temperature, top_p, max_new_tokens, num_beams, max_turns, force_image_gen, request: gr.Request):
|
259 |
+
logger.info(f"http_bot. ip: {request.client.host}")
|
260 |
+
print('input_state:', input_state)
|
261 |
+
|
262 |
+
if len(dialog_state.messages) == 0 or dialog_state.messages[-1]['role'] != dialog_state.roles[0] or len(
|
263 |
+
dialog_state.messages[-1]['message']['text'].strip(' ?.;!/')) == 0:
|
264 |
+
# if len(input_state['text']) == 0:
|
265 |
+
# dialog_state.skip_next = True
|
266 |
+
return (dialog_state, input_state, dialog_state.to_gradio_chatbot()) + (no_change_btn, ) * 4
|
267 |
+
|
268 |
+
if len(dialog_state.messages) > max_turns * 2:
|
269 |
+
output_state = init_input_state()
|
270 |
+
output_state['text'] = 'Error: History exceeds maximum rounds, please clear history and restart.'
|
271 |
+
dialog_state.messages.append({'role': dialog_state.roles[1], 'message': output_state})
|
272 |
+
input_state = init_input_state()
|
273 |
+
return (dialog_state, input_state, dialog_state.to_gradio_chatbot()) + (disable_btn, ) * 3 + (enable_btn, )
|
274 |
+
|
275 |
+
prompt = dialog_state.get_prompt()
|
276 |
+
payload = {
|
277 |
+
'text': prompt['text'],
|
278 |
+
'temperature': float(temperature),
|
279 |
+
'top_p': float(top_p),
|
280 |
+
'max_new_tokens': int(max_new_tokens),
|
281 |
+
'num_beams': int(num_beams),
|
282 |
+
'images': prompt['images'],
|
283 |
+
'force_boi': force_image_gen,
|
284 |
+
}
|
285 |
+
|
286 |
+
print(
|
287 |
+
'request: ', {
|
288 |
+
'text': prompt['text'],
|
289 |
+
'temperature': float(temperature),
|
290 |
+
'top_p': float(top_p),
|
291 |
+
'max_new_tokens': int(max_new_tokens),
|
292 |
+
'num_beams': int(num_beams)
|
293 |
+
})
|
294 |
+
print('request_address', args.request_address)
|
295 |
+
response = requests.request(method="POST", url=args.request_address, headers=headers, json=payload)
|
296 |
+
results = response.json()
|
297 |
+
print('response: ', {'text': results['text'], 'images_ids': results['images_ids'], 'error_msg': results['error_msg']})
|
298 |
+
|
299 |
+
output_state = init_input_state()
|
300 |
+
image_dir = get_conv_image_dir()
|
301 |
+
output_state['text'] = results['text']
|
302 |
+
|
303 |
+
for image_base64 in results['images']:
|
304 |
+
if image_base64 == '':
|
305 |
+
image_path = ''
|
306 |
+
else:
|
307 |
+
image = decode_image(image_base64)
|
308 |
+
image = image.convert('RGB')
|
309 |
+
image_path = get_image_name(image=image, image_dir=image_dir)
|
310 |
+
if not os.path.exists(image_path):
|
311 |
+
image.save(image_path)
|
312 |
+
output_state['images'].append(image_path)
|
313 |
+
output_state['images_ids'].append(None)
|
314 |
+
|
315 |
+
dialog_state.messages.append({'role': dialog_state.roles[1], 'message': output_state})
|
316 |
+
dialog_state.update_image_ids(results['images_ids'])
|
317 |
+
|
318 |
+
vote_last_response(dialog_state, 'common', request)
|
319 |
+
input_state = init_input_state()
|
320 |
+
chatbot = update_error_msg(dialog_state.to_gradio_chatbot(), results['error_msg'])
|
321 |
+
return (dialog_state, input_state, chatbot) + (enable_btn, ) * 4
|
322 |
+
|
323 |
+
|
324 |
+
def load_demo(request: gr.Request):
|
325 |
+
logger.info(f"load_demo. ip: {request.client.host}")
|
326 |
+
dialog_state = conv_seed_llama.copy()
|
327 |
+
input_state = init_input_state()
|
328 |
+
return dialog_state, input_state
|
329 |
+
|
330 |
+
|
331 |
+
title = ("""
|
332 |
+
# SEED-LLaMA
|
333 |
+
[[Project Page]](https://ailab-cvc.github.io/seed/seed_llama.html) [[Paper]](https://arxiv.org/pdf/2310.01218.pdf) [[Code]](https://github.com/AILab-CVC/SEED/tree/main)
|
334 |
+
|
335 |
+
## Tips:
|
336 |
+
* Check out the conversation examples (at the bottom) for inspiration.
|
337 |
+
|
338 |
+
* You can adjust "Max History Rounds" to try a conversation with up to five rounds. For more turns, you can download our checkpoints from GitHub and deploy them locally for inference.
|
339 |
+
|
340 |
+
* Our demo supports a mix of images and texts as input. You can freely upload an image or enter text, and then click on "Add Image/Text". You can repeat the former step multiple times, and click on "Submit" for model inference at last.
|
341 |
+
|
342 |
+
* If you are not satisfied with the output, especially the generated image, you may click on "Regenerate" for another chance.
|
343 |
+
|
344 |
+
* You can click "Force Image Generation" to compel the model to produce images when necessary. For example, our model might struggle to generate images when there is an excessive amount of text-only context.
|
345 |
+
* SEED-LLaMA was trained with English-only data. It may process with other languages due to the inherent capabilities from LLaMA, but might not stable.
|
346 |
+
""")
|
347 |
+
|
348 |
+
css = """
|
349 |
+
img {
|
350 |
+
font-family: 'Helvetica';
|
351 |
+
font-weight: 300;
|
352 |
+
line-height: 2;
|
353 |
+
text-align: center;
|
354 |
+
|
355 |
+
width: auto;
|
356 |
+
height: auto;
|
357 |
+
display: block;
|
358 |
+
position: relative;
|
359 |
+
}
|
360 |
+
|
361 |
+
img:before {
|
362 |
+
content: " ";
|
363 |
+
display: block;
|
364 |
+
|
365 |
+
position: absolute;
|
366 |
+
top: -10px;
|
367 |
+
left: 0;
|
368 |
+
height: calc(100% + 10px);
|
369 |
+
width: 100%;
|
370 |
+
background-color: rgb(230, 230, 230);
|
371 |
+
border: 2px dotted rgb(200, 200, 200);
|
372 |
+
border-radius: 5px;
|
373 |
+
}
|
374 |
+
|
375 |
+
img:after {
|
376 |
+
content: " ";
|
377 |
+
display: block;
|
378 |
+
font-size: 16px;
|
379 |
+
font-style: normal;
|
380 |
+
font-family: FontAwesome;
|
381 |
+
color: rgb(100, 100, 100);
|
382 |
+
|
383 |
+
position: absolute;
|
384 |
+
top: 5px;
|
385 |
+
left: 0;
|
386 |
+
width: 100%;
|
387 |
+
text-align: center;
|
388 |
+
}
|
389 |
+
|
390 |
+
"""
|
391 |
+
|
392 |
+
if __name__ == '__main__':
|
393 |
+
|
394 |
+
examples_mix = [
|
395 |
+
['images/cat.jpg', 'Add sunglasses to the animal.'],
|
396 |
+
['images/eagle.jpg', 'Transform this image into cartoon style'],
|
397 |
+
[None, 'Generate an image of dog on green grass.'],
|
398 |
+
[None, 'Draw a painting of sunflowers in Van Gogh style.'],
|
399 |
+
['images/dogs_4.jpg', 'How many dogs in the image?'],
|
400 |
+
['images/spongebob.png', 'Who are they?'],
|
401 |
+
['images/star.jpg', 'Do you know this painting?'],
|
402 |
+
]
|
403 |
+
|
404 |
+
examples_conv = [
|
405 |
+
['images/demo_example1.jpg'],
|
406 |
+
['images/demo_example2.jpg'],
|
407 |
+
['images/demo_example3.jpg'],
|
408 |
+
['images/demo_example7.jpg'],
|
409 |
+
['images/demo_example5.jpg'],
|
410 |
+
['images/demo_example6.jpg'],
|
411 |
+
]
|
412 |
+
|
413 |
+
with gr.Blocks(css=css) as demo:
|
414 |
+
gr.Markdown(title)
|
415 |
+
dialog_state = gr.State()
|
416 |
+
input_state = gr.State()
|
417 |
+
with gr.Row():
|
418 |
+
with gr.Column(scale=3):
|
419 |
+
with gr.Row():
|
420 |
+
image = gr.Image(type='pil', label='input_image')
|
421 |
+
with gr.Row():
|
422 |
+
text = gr.Textbox(lines=5,
|
423 |
+
show_label=False,
|
424 |
+
label='input_text',
|
425 |
+
elem_id='textbox',
|
426 |
+
placeholder="Enter text or add image, and press submit,").style(container=False)
|
427 |
+
with gr.Row():
|
428 |
+
add_image_btn = gr.Button("Add Image")
|
429 |
+
add_text_btn = gr.Button("Add Text")
|
430 |
+
|
431 |
+
submit_btn = gr.Button("Submit")
|
432 |
+
|
433 |
+
with gr.Row():
|
434 |
+
num_beams = gr.Slider(minimum=1, maximum=4, value=1, step=1, interactive=True, label="Num of Beams")
|
435 |
+
max_new_tokens = gr.Slider(minimum=64,
|
436 |
+
maximum=1024,
|
437 |
+
value=256,
|
438 |
+
step=64,
|
439 |
+
interactive=True,
|
440 |
+
label="Max New Tokens")
|
441 |
+
temperature = gr.Slider(minimum=0.0,
|
442 |
+
maximum=1.0,
|
443 |
+
value=1.0,
|
444 |
+
step=0.1,
|
445 |
+
interactive=True,
|
446 |
+
label="Temperature")
|
447 |
+
top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.5, step=0.1, interactive=True, label="Top P")
|
448 |
+
max_turns = gr.Slider(minimum=1, maximum=5, value=3, step=1, interactive=True, label="Max History Rounds")
|
449 |
+
force_img_gen = gr.Radio(choices=[True, False], value=False, label='Force Image Generation')
|
450 |
+
|
451 |
+
with gr.Column(scale=7):
|
452 |
+
chatbot = gr.Chatbot(elem_id='chatbot', label="SEED LLaMA").style(height=700)
|
453 |
+
with gr.Row():
|
454 |
+
upvote_btn = gr.Button(value="👍 Upvote", interactive=False)
|
455 |
+
downvote_btn = gr.Button(value="👎 Downvote", interactive=False)
|
456 |
+
regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
|
457 |
+
clear_btn = gr.Button(value="🗑️ Clear history", interactive=False)
|
458 |
+
|
459 |
+
# with gr.Row():
|
460 |
+
# gr.Examples(examples=examples_image, label='Image examples', inputs=[image])
|
461 |
+
with gr.Row():
|
462 |
+
# with gr.Column(scale=6):
|
463 |
+
gr.Examples(examples=examples_mix, label='Input examples', inputs=[image, text])
|
464 |
+
# with gr.Column(scale=0.4):
|
465 |
+
# gr.Examples(examples=examples_text, inputs=[text])
|
466 |
+
|
467 |
+
|
468 |
+
# with gr.Row():
|
469 |
+
# gr.Examples(examples=examples_2, inputs=[image])
|
470 |
+
|
471 |
+
with gr.Row():
|
472 |
+
# gr.Gallery(value=[Image.open(e[0]) for e in examples_conv], show_label=True, label="Example Conversations", elem_id="gallery",height=1400, object_fit='contain').style(grid=[3], height='auto')
|
473 |
+
gr.Gallery(value=[Image.open(e[0]) for e in examples_conv], show_label=True, label="Example Conversations", elem_id="gallery",height=1500, columns=[3], rows=[2])
|
474 |
+
|
475 |
+
# Register listeners
|
476 |
+
btn_list = [upvote_btn, downvote_btn, regenerate_btn, clear_btn]
|
477 |
+
upvote_btn.click(upvote_last_response, [dialog_state], [upvote_btn, downvote_btn])
|
478 |
+
downvote_btn.click(downvote_last_response, [dialog_state], [upvote_btn, downvote_btn])
|
479 |
+
regenerate_btn.click(regenerate, [dialog_state], [dialog_state, chatbot] + btn_list).then(
|
480 |
+
http_bot, [dialog_state, input_state, temperature, top_p, max_new_tokens, num_beams, max_turns, force_img_gen],
|
481 |
+
[dialog_state, input_state, chatbot] + btn_list)
|
482 |
+
add_image_btn.click(add_image, [dialog_state, input_state, image],
|
483 |
+
[dialog_state, input_state, image, chatbot] + btn_list)
|
484 |
+
|
485 |
+
add_text_btn.click(add_text, [dialog_state, input_state, text], [dialog_state, input_state, text, chatbot] + btn_list)
|
486 |
+
|
487 |
+
submit_btn.click(
|
488 |
+
add_image, [dialog_state, input_state, image], [dialog_state, input_state, image, chatbot] + btn_list).then(
|
489 |
+
add_text, [dialog_state, input_state, text],
|
490 |
+
[dialog_state, input_state, text, chatbot, upvote_btn, downvote_btn, regenerate_btn, clear_btn]).then(
|
491 |
+
http_bot, [dialog_state, input_state, temperature, top_p, max_new_tokens, num_beams, max_turns, force_img_gen],
|
492 |
+
[dialog_state, input_state, chatbot] + btn_list)
|
493 |
+
clear_btn.click(clear_history, None, [dialog_state, input_state, chatbot] + btn_list)
|
494 |
+
|
495 |
+
demo.load(load_demo, None, [dialog_state, input_state])
|
496 |
+
|
497 |
+
demo.launch(server_name=args.server_name, server_port=args.server_port, enable_queue=True)
|
gradio_demo/utils.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import datetime
|
2 |
+
import logging
|
3 |
+
import logging.handlers
|
4 |
+
import os
|
5 |
+
import sys
|
6 |
+
|
7 |
+
handler = None
|
8 |
+
|
9 |
+
|
10 |
+
def build_logger(logger_name, logger_dir):
|
11 |
+
global handler
|
12 |
+
|
13 |
+
formatter = logging.Formatter(
|
14 |
+
fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
|
15 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
16 |
+
)
|
17 |
+
|
18 |
+
# Set the format of root handlers
|
19 |
+
if not logging.getLogger().handlers:
|
20 |
+
logging.basicConfig(level=logging.INFO)
|
21 |
+
logging.getLogger().handlers[0].setFormatter(formatter)
|
22 |
+
|
23 |
+
# Redirect stdout and stderr to loggers
|
24 |
+
stdout_logger = logging.getLogger("stdout")
|
25 |
+
stdout_logger.setLevel(logging.INFO)
|
26 |
+
sl = StreamToLogger(stdout_logger, logging.INFO)
|
27 |
+
sys.stdout = sl
|
28 |
+
|
29 |
+
stderr_logger = logging.getLogger("stderr")
|
30 |
+
stderr_logger.setLevel(logging.ERROR)
|
31 |
+
sl = StreamToLogger(stderr_logger, logging.ERROR)
|
32 |
+
sys.stderr = sl
|
33 |
+
|
34 |
+
# Get logger
|
35 |
+
logger = logging.getLogger(logger_name)
|
36 |
+
logger.setLevel(logging.INFO)
|
37 |
+
|
38 |
+
# Add a file handler for all loggers
|
39 |
+
if handler is None:
|
40 |
+
os.makedirs(logger_dir, exist_ok=True)
|
41 |
+
filename = os.path.join(logger_dir, logger_name + '.log')
|
42 |
+
handler = logging.handlers.TimedRotatingFileHandler(filename, when='D', utc=True)
|
43 |
+
handler.setFormatter(formatter)
|
44 |
+
|
45 |
+
for name, item in logging.root.manager.loggerDict.items():
|
46 |
+
if isinstance(item, logging.Logger):
|
47 |
+
item.addHandler(handler)
|
48 |
+
|
49 |
+
return logger
|
50 |
+
|
51 |
+
|
52 |
+
class StreamToLogger(object):
|
53 |
+
"""
|
54 |
+
Fake file-like stream object that redirects writes to a logger instance.
|
55 |
+
"""
|
56 |
+
def __init__(self, logger, log_level=logging.INFO):
|
57 |
+
self.terminal = sys.stdout
|
58 |
+
self.logger = logger
|
59 |
+
self.log_level = log_level
|
60 |
+
self.linebuf = ''
|
61 |
+
|
62 |
+
def __getattr__(self, attr):
|
63 |
+
return getattr(self.terminal, attr)
|
64 |
+
|
65 |
+
def write(self, buf):
|
66 |
+
temp_linebuf = self.linebuf + buf
|
67 |
+
self.linebuf = ''
|
68 |
+
for line in temp_linebuf.splitlines(True):
|
69 |
+
# From the io.TextIOWrapper docs:
|
70 |
+
# On output, if newline is None, any '\n' characters written
|
71 |
+
# are translated to the system default line separator.
|
72 |
+
# By default sys.stdout.write() expects '\n' newlines and then
|
73 |
+
# translates them so this is still cross platform.
|
74 |
+
if line[-1] == '\n':
|
75 |
+
self.logger.log(self.log_level, line.rstrip())
|
76 |
+
else:
|
77 |
+
self.linebuf += line
|
78 |
+
|
79 |
+
def flush(self):
|
80 |
+
if self.linebuf != '':
|
81 |
+
self.logger.log(self.log_level, self.linebuf.rstrip())
|
82 |
+
self.linebuf = ''
|
images/cat.jpg
ADDED
Git LFS Details
|
images/demo_example1.jpg
ADDED
Git LFS Details
|
images/demo_example2.jpg
ADDED
Git LFS Details
|
images/demo_example3.jpg
ADDED
Git LFS Details
|
images/demo_example4.jpg
ADDED
Git LFS Details
|
images/demo_example5.jpg
ADDED
Git LFS Details
|
images/demo_example6.jpg
ADDED
Git LFS Details
|
images/demo_example7.jpg
ADDED
Git LFS Details
|
images/dogs_4.jpg
ADDED
Git LFS Details
|
images/eagle.jpg
ADDED
Git LFS Details
|
images/flower.png
ADDED
Git LFS Details
|
images/spongebob.png
ADDED
Git LFS Details
|
images/star.jpg
ADDED
Git LFS Details
|
models/__init__.py
ADDED
File without changes
|
models/llama_xformer.py
ADDED
@@ -0,0 +1,906 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
3 |
+
#
|
4 |
+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
5 |
+
# and OPT implementations in this library. It has been modified from its
|
6 |
+
# original forms to accommodate minor architectural differences compared
|
7 |
+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
8 |
+
#
|
9 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
10 |
+
# you may not use this file except in compliance with the License.
|
11 |
+
# You may obtain a copy of the License at
|
12 |
+
#
|
13 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
14 |
+
#
|
15 |
+
# Unless required by applicable law or agreed to in writing, software
|
16 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
17 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
18 |
+
# See the License for the specific language governing permissions and
|
19 |
+
# limitations under the License.
|
20 |
+
""" PyTorch LLaMA model."""
|
21 |
+
from typing import List, Optional, Tuple, Union
|
22 |
+
|
23 |
+
import torch
|
24 |
+
import torch.utils.checkpoint
|
25 |
+
from torch import nn
|
26 |
+
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
27 |
+
|
28 |
+
from transformers.activations import ACT2FN
|
29 |
+
from transformers.modeling_outputs import (
|
30 |
+
BaseModelOutputWithPast,
|
31 |
+
CausalLMOutputWithPast,
|
32 |
+
SequenceClassifierOutputWithPast,
|
33 |
+
)
|
34 |
+
from transformers.modeling_utils import PreTrainedModel
|
35 |
+
from transformers.utils import (
|
36 |
+
add_start_docstrings,
|
37 |
+
add_start_docstrings_to_model_forward,
|
38 |
+
logging,
|
39 |
+
replace_return_docstrings,
|
40 |
+
)
|
41 |
+
from transformers.models.llama.configuration_llama import LlamaConfig
|
42 |
+
import xformers.ops as xops
|
43 |
+
|
44 |
+
logger = logging.get_logger(__name__)
|
45 |
+
|
46 |
+
_CONFIG_FOR_DOC = "LlamaConfig"
|
47 |
+
|
48 |
+
|
49 |
+
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
|
50 |
+
def _make_causal_mask(
|
51 |
+
input_ids_shape: torch.Size,
|
52 |
+
dtype: torch.dtype,
|
53 |
+
device: torch.device,
|
54 |
+
past_key_values_length: int = 0,
|
55 |
+
):
|
56 |
+
"""
|
57 |
+
Make causal mask used for bi-directional self-attention.
|
58 |
+
"""
|
59 |
+
bsz, tgt_len = input_ids_shape
|
60 |
+
mask = torch.full(
|
61 |
+
(tgt_len, tgt_len),
|
62 |
+
torch.tensor(torch.finfo(dtype).min, device=device),
|
63 |
+
device=device,
|
64 |
+
)
|
65 |
+
mask_cond = torch.arange(mask.size(-1), device=device)
|
66 |
+
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
|
67 |
+
mask = mask.to(dtype)
|
68 |
+
|
69 |
+
if past_key_values_length > 0:
|
70 |
+
mask = torch.cat(
|
71 |
+
[
|
72 |
+
torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device),
|
73 |
+
mask,
|
74 |
+
],
|
75 |
+
dim=-1,
|
76 |
+
)
|
77 |
+
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
|
78 |
+
|
79 |
+
|
80 |
+
# Copied from transformers.models.bart.modeling_bart._expand_mask
|
81 |
+
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
|
82 |
+
"""
|
83 |
+
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
84 |
+
"""
|
85 |
+
bsz, src_len = mask.size()
|
86 |
+
tgt_len = tgt_len if tgt_len is not None else src_len
|
87 |
+
|
88 |
+
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
|
89 |
+
|
90 |
+
inverted_mask = 1.0 - expanded_mask
|
91 |
+
|
92 |
+
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
|
93 |
+
|
94 |
+
|
95 |
+
class LlamaRMSNorm(nn.Module):
|
96 |
+
|
97 |
+
def __init__(self, hidden_size, eps=1e-6):
|
98 |
+
"""
|
99 |
+
LlamaRMSNorm is equivalent to T5LayerNorm
|
100 |
+
"""
|
101 |
+
super().__init__()
|
102 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
103 |
+
self.variance_epsilon = eps
|
104 |
+
|
105 |
+
def forward(self, hidden_states):
|
106 |
+
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
107 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
108 |
+
|
109 |
+
# convert into half-precision if necessary
|
110 |
+
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
111 |
+
hidden_states = hidden_states.to(self.weight.dtype)
|
112 |
+
|
113 |
+
return self.weight * hidden_states
|
114 |
+
|
115 |
+
|
116 |
+
class LlamaRotaryEmbedding(torch.nn.Module):
|
117 |
+
|
118 |
+
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
119 |
+
super().__init__()
|
120 |
+
inv_freq = 1.0 / (base**(torch.arange(0, dim, 2).float().to(device) / dim))
|
121 |
+
self.register_buffer("inv_freq", inv_freq)
|
122 |
+
|
123 |
+
# Build here to make `torch.jit.trace` work.
|
124 |
+
self.max_seq_len_cached = max_position_embeddings
|
125 |
+
t = torch.arange(
|
126 |
+
self.max_seq_len_cached,
|
127 |
+
device=self.inv_freq.device,
|
128 |
+
dtype=self.inv_freq.dtype,
|
129 |
+
)
|
130 |
+
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
131 |
+
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
132 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
133 |
+
self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
|
134 |
+
self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
|
135 |
+
|
136 |
+
def forward(self, x, seq_len=None):
|
137 |
+
# x: [bs, num_attention_heads, seq_len, head_size]
|
138 |
+
# This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
|
139 |
+
if seq_len > self.max_seq_len_cached:
|
140 |
+
self.max_seq_len_cached = seq_len
|
141 |
+
t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
|
142 |
+
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
143 |
+
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
144 |
+
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
|
145 |
+
self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
|
146 |
+
self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
|
147 |
+
return (
|
148 |
+
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
|
149 |
+
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
|
150 |
+
)
|
151 |
+
|
152 |
+
|
153 |
+
def rotate_half(x):
|
154 |
+
"""Rotates half the hidden dims of the input."""
|
155 |
+
x1 = x[..., :x.shape[-1] // 2]
|
156 |
+
x2 = x[..., x.shape[-1] // 2:]
|
157 |
+
return torch.cat((-x2, x1), dim=-1)
|
158 |
+
|
159 |
+
|
160 |
+
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
|
161 |
+
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
|
162 |
+
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
|
163 |
+
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
|
164 |
+
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
|
165 |
+
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
|
166 |
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
167 |
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
168 |
+
return q_embed, k_embed
|
169 |
+
|
170 |
+
|
171 |
+
class LlamaMLP(nn.Module):
|
172 |
+
|
173 |
+
def __init__(
|
174 |
+
self,
|
175 |
+
hidden_size: int,
|
176 |
+
intermediate_size: int,
|
177 |
+
hidden_act: str,
|
178 |
+
):
|
179 |
+
super().__init__()
|
180 |
+
self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
|
181 |
+
self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
|
182 |
+
self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
|
183 |
+
self.act_fn = ACT2FN[hidden_act]
|
184 |
+
|
185 |
+
def forward(self, x):
|
186 |
+
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
187 |
+
|
188 |
+
|
189 |
+
class LlamaAttention(nn.Module):
|
190 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
191 |
+
|
192 |
+
def __init__(self, config: LlamaConfig):
|
193 |
+
super().__init__()
|
194 |
+
self.config = config
|
195 |
+
self.hidden_size = config.hidden_size
|
196 |
+
self.num_heads = config.num_attention_heads
|
197 |
+
self.head_dim = self.hidden_size // self.num_heads
|
198 |
+
self.max_position_embeddings = config.max_position_embeddings
|
199 |
+
|
200 |
+
if (self.head_dim * self.num_heads) != self.hidden_size:
|
201 |
+
raise ValueError(f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
202 |
+
f" and `num_heads`: {self.num_heads}).")
|
203 |
+
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
204 |
+
self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
205 |
+
self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
206 |
+
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
207 |
+
self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings)
|
208 |
+
|
209 |
+
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
210 |
+
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
211 |
+
|
212 |
+
def forward(
|
213 |
+
self,
|
214 |
+
hidden_states: torch.Tensor,
|
215 |
+
attention_mask: Optional[torch.Tensor] = None,
|
216 |
+
position_ids: Optional[torch.LongTensor] = None,
|
217 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
218 |
+
output_attentions: bool = False,
|
219 |
+
use_cache: bool = False,
|
220 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
221 |
+
bsz, q_len, _ = hidden_states.size()
|
222 |
+
|
223 |
+
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
224 |
+
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
225 |
+
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
226 |
+
|
227 |
+
kv_seq_len = key_states.shape[-2]
|
228 |
+
if past_key_value is not None:
|
229 |
+
kv_seq_len += past_key_value[0].shape[-2]
|
230 |
+
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
231 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
232 |
+
# [bsz, nh, t, hd]
|
233 |
+
|
234 |
+
if past_key_value is not None:
|
235 |
+
# reuse k, v, self_attention
|
236 |
+
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
237 |
+
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
238 |
+
|
239 |
+
past_key_value = (key_states, value_states) if use_cache else None
|
240 |
+
query_states = query_states.transpose(1, 2)
|
241 |
+
key_states = key_states.transpose(1, 2)
|
242 |
+
value_states = value_states.transpose(1, 2)
|
243 |
+
if self.training:
|
244 |
+
attn_output = xops.memory_efficient_attention(
|
245 |
+
query_states,
|
246 |
+
key_states,
|
247 |
+
value_states,
|
248 |
+
attn_bias=xops.LowerTriangularMask(),
|
249 |
+
)
|
250 |
+
else:
|
251 |
+
attn_output = xops.memory_efficient_attention(
|
252 |
+
query_states,
|
253 |
+
key_states,
|
254 |
+
value_states,
|
255 |
+
attn_bias=None if attention_mask.sum() == 0 else xops.LowerTriangularMask(),
|
256 |
+
)
|
257 |
+
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
258 |
+
attn_output = self.o_proj(attn_output)
|
259 |
+
|
260 |
+
if not output_attentions:
|
261 |
+
attn_weights = None
|
262 |
+
|
263 |
+
return attn_output, attn_weights, past_key_value
|
264 |
+
|
265 |
+
|
266 |
+
class LlamaDecoderLayer(nn.Module):
|
267 |
+
|
268 |
+
def __init__(self, config: LlamaConfig):
|
269 |
+
super().__init__()
|
270 |
+
self.hidden_size = config.hidden_size
|
271 |
+
self.self_attn = LlamaAttention(config=config)
|
272 |
+
self.mlp = LlamaMLP(
|
273 |
+
hidden_size=self.hidden_size,
|
274 |
+
intermediate_size=config.intermediate_size,
|
275 |
+
hidden_act=config.hidden_act,
|
276 |
+
)
|
277 |
+
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
278 |
+
self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
279 |
+
|
280 |
+
def forward(
|
281 |
+
self,
|
282 |
+
hidden_states: torch.Tensor,
|
283 |
+
attention_mask: Optional[torch.Tensor] = None,
|
284 |
+
position_ids: Optional[torch.LongTensor] = None,
|
285 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
286 |
+
output_attentions: Optional[bool] = False,
|
287 |
+
use_cache: Optional[bool] = False,
|
288 |
+
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
289 |
+
"""
|
290 |
+
Args:
|
291 |
+
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
292 |
+
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
|
293 |
+
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
294 |
+
output_attentions (`bool`, *optional*):
|
295 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
296 |
+
returned tensors for more detail.
|
297 |
+
use_cache (`bool`, *optional*):
|
298 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
299 |
+
(see `past_key_values`).
|
300 |
+
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
301 |
+
"""
|
302 |
+
|
303 |
+
residual = hidden_states
|
304 |
+
|
305 |
+
hidden_states = self.input_layernorm(hidden_states)
|
306 |
+
|
307 |
+
# Self Attention
|
308 |
+
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
309 |
+
hidden_states=hidden_states,
|
310 |
+
attention_mask=attention_mask,
|
311 |
+
position_ids=position_ids,
|
312 |
+
past_key_value=past_key_value,
|
313 |
+
output_attentions=output_attentions,
|
314 |
+
use_cache=use_cache,
|
315 |
+
)
|
316 |
+
hidden_states = residual + hidden_states
|
317 |
+
|
318 |
+
# Fully Connected
|
319 |
+
residual = hidden_states
|
320 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
321 |
+
hidden_states = self.mlp(hidden_states)
|
322 |
+
hidden_states = residual + hidden_states
|
323 |
+
|
324 |
+
outputs = (hidden_states, )
|
325 |
+
|
326 |
+
if output_attentions:
|
327 |
+
outputs += (self_attn_weights, )
|
328 |
+
|
329 |
+
if use_cache:
|
330 |
+
outputs += (present_key_value, )
|
331 |
+
|
332 |
+
return outputs
|
333 |
+
|
334 |
+
|
335 |
+
LLAMA_START_DOCSTRING = r"""
|
336 |
+
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
337 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
338 |
+
etc.)
|
339 |
+
|
340 |
+
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
341 |
+
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
342 |
+
and behavior.
|
343 |
+
|
344 |
+
Parameters:
|
345 |
+
config ([`LlamaConfig`]):
|
346 |
+
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
347 |
+
load the weights associated with the model, only the configuration. Check out the
|
348 |
+
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
349 |
+
"""
|
350 |
+
|
351 |
+
|
352 |
+
@add_start_docstrings(
|
353 |
+
"The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
|
354 |
+
LLAMA_START_DOCSTRING,
|
355 |
+
)
|
356 |
+
class LlamaPreTrainedModel(PreTrainedModel):
|
357 |
+
config_class = LlamaConfig
|
358 |
+
base_model_prefix = "model"
|
359 |
+
supports_gradient_checkpointing = True
|
360 |
+
_no_split_modules = ["LlamaDecoderLayer"]
|
361 |
+
_keys_to_ignore_on_load_unexpected = [r"decoder\.version"]
|
362 |
+
|
363 |
+
def _init_weights(self, module):
|
364 |
+
std = self.config.initializer_range
|
365 |
+
if isinstance(module, nn.Linear):
|
366 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
367 |
+
if module.bias is not None:
|
368 |
+
module.bias.data.zero_()
|
369 |
+
elif isinstance(module, nn.Embedding):
|
370 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
371 |
+
if module.padding_idx is not None:
|
372 |
+
module.weight.data[module.padding_idx].zero_()
|
373 |
+
|
374 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
375 |
+
if isinstance(module, LlamaModel):
|
376 |
+
module.gradient_checkpointing = value
|
377 |
+
|
378 |
+
|
379 |
+
LLAMA_INPUTS_DOCSTRING = r"""
|
380 |
+
Args:
|
381 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
382 |
+
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
383 |
+
it.
|
384 |
+
|
385 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
386 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
387 |
+
|
388 |
+
[What are input IDs?](../glossary#input-ids)
|
389 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
390 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
391 |
+
|
392 |
+
- 1 for tokens that are **not masked**,
|
393 |
+
- 0 for tokens that are **masked**.
|
394 |
+
|
395 |
+
[What are attention masks?](../glossary#attention-mask)
|
396 |
+
|
397 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
398 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
399 |
+
|
400 |
+
If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
|
401 |
+
`past_key_values`).
|
402 |
+
|
403 |
+
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
|
404 |
+
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
|
405 |
+
information on the default strategy.
|
406 |
+
|
407 |
+
- 1 indicates the head is **not masked**,
|
408 |
+
- 0 indicates the head is **masked**.
|
409 |
+
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
410 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
411 |
+
config.n_positions - 1]`.
|
412 |
+
|
413 |
+
[What are position IDs?](../glossary#position-ids)
|
414 |
+
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
415 |
+
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
416 |
+
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
|
417 |
+
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
|
418 |
+
|
419 |
+
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
420 |
+
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
|
421 |
+
|
422 |
+
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
|
423 |
+
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
|
424 |
+
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
|
425 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
426 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
427 |
+
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
428 |
+
model's internal embedding lookup matrix.
|
429 |
+
use_cache (`bool`, *optional*):
|
430 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
431 |
+
`past_key_values`).
|
432 |
+
output_attentions (`bool`, *optional*):
|
433 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
434 |
+
tensors for more detail.
|
435 |
+
output_hidden_states (`bool`, *optional*):
|
436 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
437 |
+
more detail.
|
438 |
+
return_dict (`bool`, *optional*):
|
439 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
440 |
+
"""
|
441 |
+
|
442 |
+
|
443 |
+
@add_start_docstrings(
|
444 |
+
"The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
|
445 |
+
LLAMA_START_DOCSTRING,
|
446 |
+
)
|
447 |
+
class LlamaModel(LlamaPreTrainedModel):
|
448 |
+
"""
|
449 |
+
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
|
450 |
+
|
451 |
+
Args:
|
452 |
+
config: LlamaConfig
|
453 |
+
"""
|
454 |
+
|
455 |
+
def __init__(self, config: LlamaConfig):
|
456 |
+
super().__init__(config)
|
457 |
+
self.padding_idx = config.pad_token_id
|
458 |
+
self.vocab_size = config.vocab_size
|
459 |
+
|
460 |
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
461 |
+
self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])
|
462 |
+
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
463 |
+
|
464 |
+
self.gradient_checkpointing = False
|
465 |
+
# Initialize weights and apply final processing
|
466 |
+
self.post_init()
|
467 |
+
|
468 |
+
def get_input_embeddings(self):
|
469 |
+
return self.embed_tokens
|
470 |
+
|
471 |
+
def set_input_embeddings(self, value):
|
472 |
+
self.embed_tokens = value
|
473 |
+
|
474 |
+
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
|
475 |
+
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
|
476 |
+
# create causal mask
|
477 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
478 |
+
combined_attention_mask = None
|
479 |
+
if input_shape[-1] > 1:
|
480 |
+
combined_attention_mask = _make_causal_mask(
|
481 |
+
input_shape,
|
482 |
+
inputs_embeds.dtype,
|
483 |
+
device=inputs_embeds.device,
|
484 |
+
past_key_values_length=past_key_values_length,
|
485 |
+
)
|
486 |
+
|
487 |
+
if attention_mask is not None:
|
488 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
489 |
+
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype,
|
490 |
+
tgt_len=input_shape[-1]).to(inputs_embeds.device)
|
491 |
+
combined_attention_mask = expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
|
492 |
+
|
493 |
+
return combined_attention_mask
|
494 |
+
|
495 |
+
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
|
496 |
+
def forward(
|
497 |
+
self,
|
498 |
+
input_ids: torch.LongTensor = None,
|
499 |
+
attention_mask: Optional[torch.Tensor] = None,
|
500 |
+
position_ids: Optional[torch.LongTensor] = None,
|
501 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
502 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
503 |
+
use_cache: Optional[bool] = None,
|
504 |
+
output_attentions: Optional[bool] = None,
|
505 |
+
output_hidden_states: Optional[bool] = None,
|
506 |
+
return_dict: Optional[bool] = None,
|
507 |
+
) -> Union[Tuple, BaseModelOutputWithPast]:
|
508 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
509 |
+
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
510 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
511 |
+
|
512 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
513 |
+
|
514 |
+
# retrieve input_ids and inputs_embeds
|
515 |
+
if input_ids is not None and inputs_embeds is not None:
|
516 |
+
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
|
517 |
+
elif input_ids is not None:
|
518 |
+
batch_size, seq_length = input_ids.shape
|
519 |
+
elif inputs_embeds is not None:
|
520 |
+
batch_size, seq_length, _ = inputs_embeds.shape
|
521 |
+
else:
|
522 |
+
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
523 |
+
|
524 |
+
seq_length_with_past = seq_length
|
525 |
+
past_key_values_length = 0
|
526 |
+
|
527 |
+
if past_key_values is not None:
|
528 |
+
past_key_values_length = past_key_values[0][0].shape[2]
|
529 |
+
seq_length_with_past = seq_length_with_past + past_key_values_length
|
530 |
+
|
531 |
+
if position_ids is None:
|
532 |
+
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
533 |
+
position_ids = torch.arange(
|
534 |
+
past_key_values_length,
|
535 |
+
seq_length + past_key_values_length,
|
536 |
+
dtype=torch.long,
|
537 |
+
device=device,
|
538 |
+
)
|
539 |
+
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
540 |
+
else:
|
541 |
+
position_ids = position_ids.view(-1, seq_length).long()
|
542 |
+
|
543 |
+
if inputs_embeds is None:
|
544 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
545 |
+
# embed positions
|
546 |
+
if attention_mask is None:
|
547 |
+
attention_mask = torch.ones(
|
548 |
+
(batch_size, seq_length_with_past),
|
549 |
+
dtype=torch.bool,
|
550 |
+
device=inputs_embeds.device,
|
551 |
+
)
|
552 |
+
attention_mask = self._prepare_decoder_attention_mask(
|
553 |
+
attention_mask,
|
554 |
+
(batch_size, seq_length),
|
555 |
+
inputs_embeds,
|
556 |
+
past_key_values_length,
|
557 |
+
)
|
558 |
+
|
559 |
+
hidden_states = inputs_embeds
|
560 |
+
|
561 |
+
if self.gradient_checkpointing and self.training:
|
562 |
+
if use_cache:
|
563 |
+
logger.warning_once(
|
564 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
|
565 |
+
use_cache = False
|
566 |
+
|
567 |
+
# decoder layers
|
568 |
+
all_hidden_states = () if output_hidden_states else None
|
569 |
+
all_self_attns = () if output_attentions else None
|
570 |
+
next_decoder_cache = () if use_cache else None
|
571 |
+
|
572 |
+
for idx, decoder_layer in enumerate(self.layers):
|
573 |
+
if output_hidden_states:
|
574 |
+
all_hidden_states += (hidden_states, )
|
575 |
+
|
576 |
+
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
577 |
+
|
578 |
+
if self.gradient_checkpointing and self.training:
|
579 |
+
|
580 |
+
def create_custom_forward(module):
|
581 |
+
|
582 |
+
def custom_forward(*inputs):
|
583 |
+
# None for past_key_value
|
584 |
+
return module(*inputs, output_attentions, None)
|
585 |
+
|
586 |
+
return custom_forward
|
587 |
+
|
588 |
+
layer_outputs = torch.utils.checkpoint.checkpoint(
|
589 |
+
create_custom_forward(decoder_layer),
|
590 |
+
hidden_states,
|
591 |
+
attention_mask,
|
592 |
+
position_ids,
|
593 |
+
None,
|
594 |
+
)
|
595 |
+
else:
|
596 |
+
layer_outputs = decoder_layer(
|
597 |
+
hidden_states,
|
598 |
+
attention_mask=attention_mask,
|
599 |
+
position_ids=position_ids,
|
600 |
+
past_key_value=past_key_value,
|
601 |
+
output_attentions=output_attentions,
|
602 |
+
use_cache=use_cache,
|
603 |
+
)
|
604 |
+
|
605 |
+
hidden_states = layer_outputs[0]
|
606 |
+
|
607 |
+
if use_cache:
|
608 |
+
next_decoder_cache += (layer_outputs[2 if output_attentions else 1], )
|
609 |
+
|
610 |
+
if output_attentions:
|
611 |
+
all_self_attns += (layer_outputs[1], )
|
612 |
+
|
613 |
+
hidden_states = self.norm(hidden_states)
|
614 |
+
|
615 |
+
# add hidden states from the last decoder layer
|
616 |
+
if output_hidden_states:
|
617 |
+
all_hidden_states += (hidden_states, )
|
618 |
+
|
619 |
+
next_cache = next_decoder_cache if use_cache else None
|
620 |
+
if not return_dict:
|
621 |
+
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
622 |
+
return BaseModelOutputWithPast(
|
623 |
+
last_hidden_state=hidden_states,
|
624 |
+
past_key_values=next_cache,
|
625 |
+
hidden_states=all_hidden_states,
|
626 |
+
attentions=all_self_attns,
|
627 |
+
)
|
628 |
+
|
629 |
+
|
630 |
+
class LlamaForCausalLM(LlamaPreTrainedModel):
|
631 |
+
|
632 |
+
def __init__(self, config):
|
633 |
+
super().__init__(config)
|
634 |
+
self.model = LlamaModel(config)
|
635 |
+
|
636 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
637 |
+
|
638 |
+
# Initialize weights and apply final processing
|
639 |
+
self.post_init()
|
640 |
+
|
641 |
+
def get_input_embeddings(self):
|
642 |
+
return self.model.embed_tokens
|
643 |
+
|
644 |
+
def set_input_embeddings(self, value):
|
645 |
+
self.model.embed_tokens = value
|
646 |
+
|
647 |
+
def get_output_embeddings(self):
|
648 |
+
return self.lm_head
|
649 |
+
|
650 |
+
def set_output_embeddings(self, new_embeddings):
|
651 |
+
self.lm_head = new_embeddings
|
652 |
+
|
653 |
+
def set_decoder(self, decoder):
|
654 |
+
self.model = decoder
|
655 |
+
|
656 |
+
def get_decoder(self):
|
657 |
+
return self.model
|
658 |
+
|
659 |
+
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
|
660 |
+
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
661 |
+
def forward(
|
662 |
+
self,
|
663 |
+
input_ids: torch.LongTensor = None,
|
664 |
+
attention_mask: Optional[torch.Tensor] = None,
|
665 |
+
position_ids: Optional[torch.LongTensor] = None,
|
666 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
667 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
668 |
+
labels: Optional[torch.LongTensor] = None,
|
669 |
+
use_cache: Optional[bool] = None,
|
670 |
+
output_attentions: Optional[bool] = None,
|
671 |
+
output_hidden_states: Optional[bool] = None,
|
672 |
+
return_dict: Optional[bool] = None,
|
673 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
674 |
+
r"""
|
675 |
+
Args:
|
676 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
677 |
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
678 |
+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
679 |
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
680 |
+
|
681 |
+
Returns:
|
682 |
+
|
683 |
+
Example:
|
684 |
+
|
685 |
+
```python
|
686 |
+
>>> from transformers import AutoTokenizer, LlamaForCausalLM
|
687 |
+
|
688 |
+
>>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
|
689 |
+
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
|
690 |
+
|
691 |
+
>>> prompt = "Hey, are you consciours? Can you talk to me?"
|
692 |
+
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
693 |
+
|
694 |
+
>>> # Generate
|
695 |
+
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
696 |
+
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
697 |
+
"Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
|
698 |
+
```"""
|
699 |
+
|
700 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
701 |
+
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
702 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
703 |
+
|
704 |
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
705 |
+
outputs = self.model(
|
706 |
+
input_ids=input_ids,
|
707 |
+
attention_mask=attention_mask,
|
708 |
+
position_ids=position_ids,
|
709 |
+
past_key_values=past_key_values,
|
710 |
+
inputs_embeds=inputs_embeds,
|
711 |
+
use_cache=use_cache,
|
712 |
+
output_attentions=output_attentions,
|
713 |
+
output_hidden_states=output_hidden_states,
|
714 |
+
return_dict=return_dict,
|
715 |
+
)
|
716 |
+
|
717 |
+
hidden_states = outputs[0]
|
718 |
+
logits = self.lm_head(hidden_states)
|
719 |
+
|
720 |
+
loss = None
|
721 |
+
if labels is not None:
|
722 |
+
# Shift so that tokens < n predict n
|
723 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
724 |
+
shift_labels = labels[..., 1:].contiguous()
|
725 |
+
# Flatten the tokens
|
726 |
+
loss_fct = CrossEntropyLoss()
|
727 |
+
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
728 |
+
shift_labels = shift_labels.view(-1)
|
729 |
+
# Enable model parallelism
|
730 |
+
shift_labels = shift_labels.to(shift_logits.device)
|
731 |
+
loss = loss_fct(shift_logits, shift_labels)
|
732 |
+
|
733 |
+
if not return_dict:
|
734 |
+
output = (logits, ) + outputs[1:]
|
735 |
+
return (loss, ) + output if loss is not None else output
|
736 |
+
|
737 |
+
return CausalLMOutputWithPast(
|
738 |
+
loss=loss,
|
739 |
+
logits=logits,
|
740 |
+
past_key_values=outputs.past_key_values,
|
741 |
+
hidden_states=outputs.hidden_states,
|
742 |
+
attentions=outputs.attentions,
|
743 |
+
)
|
744 |
+
|
745 |
+
def prepare_inputs_for_generation(
|
746 |
+
self,
|
747 |
+
input_ids,
|
748 |
+
past_key_values=None,
|
749 |
+
attention_mask=None,
|
750 |
+
inputs_embeds=None,
|
751 |
+
**kwargs,
|
752 |
+
):
|
753 |
+
if past_key_values:
|
754 |
+
input_ids = input_ids[:, -1:]
|
755 |
+
|
756 |
+
position_ids = kwargs.get("position_ids", None)
|
757 |
+
if attention_mask is not None and position_ids is None:
|
758 |
+
# create position_ids on the fly for batch generation
|
759 |
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
760 |
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
761 |
+
if past_key_values:
|
762 |
+
position_ids = position_ids[:, -1].unsqueeze(-1)
|
763 |
+
|
764 |
+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
765 |
+
if inputs_embeds is not None and past_key_values is None:
|
766 |
+
model_inputs = {"inputs_embeds": inputs_embeds}
|
767 |
+
else:
|
768 |
+
model_inputs = {"input_ids": input_ids}
|
769 |
+
|
770 |
+
model_inputs.update({
|
771 |
+
"position_ids": position_ids,
|
772 |
+
"past_key_values": past_key_values,
|
773 |
+
"use_cache": kwargs.get("use_cache"),
|
774 |
+
"attention_mask": attention_mask,
|
775 |
+
})
|
776 |
+
return model_inputs
|
777 |
+
|
778 |
+
@staticmethod
|
779 |
+
def _reorder_cache(past_key_values, beam_idx):
|
780 |
+
reordered_past = ()
|
781 |
+
for layer_past in past_key_values:
|
782 |
+
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past), )
|
783 |
+
return reordered_past
|
784 |
+
|
785 |
+
|
786 |
+
@add_start_docstrings(
|
787 |
+
"""
|
788 |
+
The LLaMa Model transformer with a sequence classification head on top (linear layer).
|
789 |
+
|
790 |
+
[`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models
|
791 |
+
(e.g. GPT-2) do.
|
792 |
+
|
793 |
+
Since it does classification on the last token, it requires to know the position of the last token. If a
|
794 |
+
`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
|
795 |
+
no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
|
796 |
+
padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
|
797 |
+
each row of the batch).
|
798 |
+
""",
|
799 |
+
LLAMA_START_DOCSTRING,
|
800 |
+
)
|
801 |
+
class LlamaForSequenceClassification(LlamaPreTrainedModel):
|
802 |
+
_keys_to_ignore_on_load_missing = [r"lm_head.weight"]
|
803 |
+
|
804 |
+
def __init__(self, config):
|
805 |
+
super().__init__(config)
|
806 |
+
self.num_labels = config.num_labels
|
807 |
+
self.model = LlamaModel(config)
|
808 |
+
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
|
809 |
+
|
810 |
+
# Initialize weights and apply final processing
|
811 |
+
self.post_init()
|
812 |
+
|
813 |
+
def get_input_embeddings(self):
|
814 |
+
return self.model.embed_tokens
|
815 |
+
|
816 |
+
def set_input_embeddings(self, value):
|
817 |
+
self.model.embed_tokens = value
|
818 |
+
|
819 |
+
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
|
820 |
+
def forward(
|
821 |
+
self,
|
822 |
+
input_ids: torch.LongTensor = None,
|
823 |
+
attention_mask: Optional[torch.Tensor] = None,
|
824 |
+
position_ids: Optional[torch.LongTensor] = None,
|
825 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
826 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
827 |
+
labels: Optional[torch.LongTensor] = None,
|
828 |
+
use_cache: Optional[bool] = None,
|
829 |
+
output_attentions: Optional[bool] = None,
|
830 |
+
output_hidden_states: Optional[bool] = None,
|
831 |
+
return_dict: Optional[bool] = None,
|
832 |
+
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
|
833 |
+
r"""
|
834 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
835 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
836 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
837 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
838 |
+
"""
|
839 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
840 |
+
|
841 |
+
transformer_outputs = self.model(
|
842 |
+
input_ids,
|
843 |
+
attention_mask=attention_mask,
|
844 |
+
position_ids=position_ids,
|
845 |
+
past_key_values=past_key_values,
|
846 |
+
inputs_embeds=inputs_embeds,
|
847 |
+
use_cache=use_cache,
|
848 |
+
output_attentions=output_attentions,
|
849 |
+
output_hidden_states=output_hidden_states,
|
850 |
+
return_dict=return_dict,
|
851 |
+
)
|
852 |
+
hidden_states = transformer_outputs[0]
|
853 |
+
logits = self.score(hidden_states)
|
854 |
+
|
855 |
+
if input_ids is not None:
|
856 |
+
batch_size = input_ids.shape[0]
|
857 |
+
else:
|
858 |
+
batch_size = inputs_embeds.shape[0]
|
859 |
+
|
860 |
+
if self.config.pad_token_id is None and batch_size != 1:
|
861 |
+
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
|
862 |
+
if self.config.pad_token_id is None:
|
863 |
+
sequence_lengths = -1
|
864 |
+
else:
|
865 |
+
if input_ids is not None:
|
866 |
+
sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
|
867 |
+
else:
|
868 |
+
sequence_lengths = -1
|
869 |
+
|
870 |
+
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
|
871 |
+
|
872 |
+
loss = None
|
873 |
+
if labels is not None:
|
874 |
+
labels = labels.to(logits.device)
|
875 |
+
if self.config.problem_type is None:
|
876 |
+
if self.num_labels == 1:
|
877 |
+
self.config.problem_type = "regression"
|
878 |
+
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
879 |
+
self.config.problem_type = "single_label_classification"
|
880 |
+
else:
|
881 |
+
self.config.problem_type = "multi_label_classification"
|
882 |
+
|
883 |
+
if self.config.problem_type == "regression":
|
884 |
+
loss_fct = MSELoss()
|
885 |
+
if self.num_labels == 1:
|
886 |
+
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
|
887 |
+
else:
|
888 |
+
loss = loss_fct(pooled_logits, labels)
|
889 |
+
elif self.config.problem_type == "single_label_classification":
|
890 |
+
loss_fct = CrossEntropyLoss()
|
891 |
+
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
|
892 |
+
elif self.config.problem_type == "multi_label_classification":
|
893 |
+
loss_fct = BCEWithLogitsLoss()
|
894 |
+
loss = loss_fct(pooled_logits, labels)
|
895 |
+
if not return_dict:
|
896 |
+
output = (pooled_logits, ) + transformer_outputs[1:]
|
897 |
+
return ((loss, ) + output) if loss is not None else output
|
898 |
+
|
899 |
+
return SequenceClassifierOutputWithPast(
|
900 |
+
loss=loss,
|
901 |
+
logits=pooled_logits,
|
902 |
+
past_key_values=transformer_outputs.past_key_values,
|
903 |
+
hidden_states=transformer_outputs.hidden_states,
|
904 |
+
attentions=transformer_outputs.attentions,
|
905 |
+
)
|
906 |
+
|
models/model_tools.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from .llama_xformer import LlamaForCausalLM
|
3 |
+
|
4 |
+
|
5 |
+
def get_pretrained_llama_causal_model(pretrained_model_name_or_path=None, torch_dtype='fp16', **kwargs):
|
6 |
+
if torch_dtype == 'fp16' or torch_dtype == 'float16':
|
7 |
+
torch_dtype = torch.float16
|
8 |
+
elif torch_dtype == 'bf16' or torch_dtype == 'bfloat16':
|
9 |
+
torch_dtype = torch.bfloat16
|
10 |
+
else:
|
11 |
+
torch_dtype == torch.float32
|
12 |
+
model = LlamaForCausalLM.from_pretrained(
|
13 |
+
pretrained_model_name_or_path=pretrained_model_name_or_path,
|
14 |
+
torch_dtype=torch_dtype,
|
15 |
+
**kwargs,
|
16 |
+
)
|
17 |
+
|
18 |
+
return model
|
models/pipeline_stable_unclip_img2img.py
ADDED
@@ -0,0 +1,794 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import inspect
|
16 |
+
import warnings
|
17 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
18 |
+
|
19 |
+
import PIL
|
20 |
+
import torch
|
21 |
+
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
|
22 |
+
|
23 |
+
from diffusers.utils.import_utils import is_accelerate_available
|
24 |
+
|
25 |
+
from diffusers.image_processor import VaeImageProcessor
|
26 |
+
|
27 |
+
from diffusers.image_processor import VaeImageProcessor
|
28 |
+
from diffusers.loaders import LoraLoaderMixin, TextualInversionLoaderMixin
|
29 |
+
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
30 |
+
from diffusers.models.embeddings import get_timestep_embedding
|
31 |
+
from diffusers.schedulers import KarrasDiffusionSchedulers
|
32 |
+
from diffusers.utils import is_accelerate_version, logging, randn_tensor, replace_example_docstring
|
33 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
34 |
+
from diffusers.pipelines.stable_diffusion.stable_unclip_image_normalizer import StableUnCLIPImageNormalizer
|
35 |
+
|
36 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
37 |
+
|
38 |
+
EXAMPLE_DOC_STRING = """
|
39 |
+
Examples:
|
40 |
+
```py
|
41 |
+
>>> import requests
|
42 |
+
>>> import torch
|
43 |
+
>>> from PIL import Image
|
44 |
+
>>> from io import BytesIO
|
45 |
+
|
46 |
+
>>> from diffusers import StableUnCLIPImg2ImgPipeline
|
47 |
+
|
48 |
+
>>> pipe = StableUnCLIPImg2ImgPipeline.from_pretrained(
|
49 |
+
... "fusing/stable-unclip-2-1-l-img2img", torch_dtype=torch.float16
|
50 |
+
... ) # TODO update model path
|
51 |
+
>>> pipe = pipe.to("cuda")
|
52 |
+
|
53 |
+
>>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
|
54 |
+
|
55 |
+
>>> response = requests.get(url)
|
56 |
+
>>> init_image = Image.open(BytesIO(response.content)).convert("RGB")
|
57 |
+
>>> init_image = init_image.resize((768, 512))
|
58 |
+
|
59 |
+
>>> prompt = "A fantasy landscape, trending on artstation"
|
60 |
+
|
61 |
+
>>> images = pipe(prompt, init_image).images
|
62 |
+
>>> images[0].save("fantasy_landscape.png")
|
63 |
+
```
|
64 |
+
"""
|
65 |
+
|
66 |
+
|
67 |
+
class StableUnCLIPImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
|
68 |
+
"""
|
69 |
+
Pipeline for text-guided image-to-image generation using stable unCLIP.
|
70 |
+
|
71 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
72 |
+
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
73 |
+
|
74 |
+
Args:
|
75 |
+
feature_extractor ([`CLIPImageProcessor`]):
|
76 |
+
Feature extractor for image pre-processing before being encoded.
|
77 |
+
image_encoder ([`CLIPVisionModelWithProjection`]):
|
78 |
+
CLIP vision model for encoding images.
|
79 |
+
image_normalizer ([`StableUnCLIPImageNormalizer`]):
|
80 |
+
Used to normalize the predicted image embeddings before the noise is applied and un-normalize the image
|
81 |
+
embeddings after the noise has been applied.
|
82 |
+
image_noising_scheduler ([`KarrasDiffusionSchedulers`]):
|
83 |
+
Noise schedule for adding noise to the predicted image embeddings. The amount of noise to add is determined
|
84 |
+
by the `noise_level`.
|
85 |
+
tokenizer (`~transformers.CLIPTokenizer`):
|
86 |
+
A [`~transformers.CLIPTokenizer`)].
|
87 |
+
text_encoder ([`~transformers.CLIPTextModel`]):
|
88 |
+
Frozen [`~transformers.CLIPTextModel`] text-encoder.
|
89 |
+
unet ([`UNet2DConditionModel`]):
|
90 |
+
A [`UNet2DConditionModel`] to denoise the encoded image latents.
|
91 |
+
scheduler ([`KarrasDiffusionSchedulers`]):
|
92 |
+
A scheduler to be used in combination with `unet` to denoise the encoded image latents.
|
93 |
+
vae ([`AutoencoderKL`]):
|
94 |
+
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
95 |
+
"""
|
96 |
+
|
97 |
+
_exclude_from_cpu_offload = ["image_normalizer"]
|
98 |
+
|
99 |
+
# image encoding components
|
100 |
+
feature_extractor: CLIPImageProcessor
|
101 |
+
image_encoder: CLIPVisionModelWithProjection
|
102 |
+
|
103 |
+
# image noising components
|
104 |
+
image_normalizer: StableUnCLIPImageNormalizer
|
105 |
+
image_noising_scheduler: KarrasDiffusionSchedulers
|
106 |
+
|
107 |
+
# regular denoising components
|
108 |
+
tokenizer: CLIPTokenizer
|
109 |
+
text_encoder: CLIPTextModel
|
110 |
+
unet: UNet2DConditionModel
|
111 |
+
scheduler: KarrasDiffusionSchedulers
|
112 |
+
|
113 |
+
vae: AutoencoderKL
|
114 |
+
|
115 |
+
def __init__(
|
116 |
+
self,
|
117 |
+
# image encoding components
|
118 |
+
feature_extractor: CLIPImageProcessor,
|
119 |
+
image_encoder: CLIPVisionModelWithProjection,
|
120 |
+
# image noising components
|
121 |
+
image_normalizer: StableUnCLIPImageNormalizer,
|
122 |
+
image_noising_scheduler: KarrasDiffusionSchedulers,
|
123 |
+
# regular denoising components
|
124 |
+
tokenizer: CLIPTokenizer,
|
125 |
+
text_encoder: CLIPTextModel,
|
126 |
+
unet: UNet2DConditionModel,
|
127 |
+
scheduler: KarrasDiffusionSchedulers,
|
128 |
+
# vae
|
129 |
+
vae: AutoencoderKL,
|
130 |
+
):
|
131 |
+
super().__init__()
|
132 |
+
|
133 |
+
self.register_modules(
|
134 |
+
feature_extractor=feature_extractor,
|
135 |
+
image_encoder=image_encoder,
|
136 |
+
image_normalizer=image_normalizer,
|
137 |
+
image_noising_scheduler=image_noising_scheduler,
|
138 |
+
tokenizer=tokenizer,
|
139 |
+
text_encoder=text_encoder,
|
140 |
+
unet=unet,
|
141 |
+
scheduler=scheduler,
|
142 |
+
vae=vae,
|
143 |
+
)
|
144 |
+
|
145 |
+
self.vae_scale_factor = 2**(len(self.vae.config.block_out_channels) - 1)
|
146 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
147 |
+
|
148 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
|
149 |
+
def enable_vae_slicing(self):
|
150 |
+
r"""
|
151 |
+
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
152 |
+
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
153 |
+
"""
|
154 |
+
self.vae.enable_slicing()
|
155 |
+
|
156 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
|
157 |
+
def disable_vae_slicing(self):
|
158 |
+
r"""
|
159 |
+
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
|
160 |
+
computing decoding in one step.
|
161 |
+
"""
|
162 |
+
self.vae.disable_slicing()
|
163 |
+
|
164 |
+
def enable_model_cpu_offload(self, gpu_id=0):
|
165 |
+
r"""
|
166 |
+
Offload all models to CPU to reduce memory usage with a low impact on performance. Moves one whole model at a
|
167 |
+
time to the GPU when its `forward` method is called, and the model remains in GPU until the next model runs.
|
168 |
+
Memory savings are lower than using `enable_sequential_cpu_offload`, but performance is much better due to the
|
169 |
+
iterative execution of the `unet`.
|
170 |
+
"""
|
171 |
+
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
|
172 |
+
from accelerate import cpu_offload_with_hook
|
173 |
+
else:
|
174 |
+
raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
|
175 |
+
|
176 |
+
device = torch.device(f"cuda:{gpu_id}")
|
177 |
+
|
178 |
+
if self.device.type != "cpu":
|
179 |
+
self.to("cpu", silence_dtype_warnings=True)
|
180 |
+
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
|
181 |
+
|
182 |
+
hook = None
|
183 |
+
for cpu_offloaded_model in [self.text_encoder, self.image_encoder, self.unet, self.vae]:
|
184 |
+
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
|
185 |
+
|
186 |
+
# We'll offload the last model manually.
|
187 |
+
self.final_offload_hook = hook
|
188 |
+
|
189 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
|
190 |
+
def _encode_prompt(
|
191 |
+
self,
|
192 |
+
prompt,
|
193 |
+
device,
|
194 |
+
num_images_per_prompt,
|
195 |
+
do_classifier_free_guidance,
|
196 |
+
negative_prompt=None,
|
197 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
198 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
199 |
+
lora_scale: Optional[float] = None,
|
200 |
+
):
|
201 |
+
r"""
|
202 |
+
Encodes the prompt into text encoder hidden states.
|
203 |
+
|
204 |
+
Args:
|
205 |
+
prompt (`str` or `List[str]`, *optional*):
|
206 |
+
prompt to be encoded
|
207 |
+
device: (`torch.device`):
|
208 |
+
torch device
|
209 |
+
num_images_per_prompt (`int`):
|
210 |
+
number of images that should be generated per prompt
|
211 |
+
do_classifier_free_guidance (`bool`):
|
212 |
+
whether to use classifier free guidance or not
|
213 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
214 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
215 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
216 |
+
less than `1`).
|
217 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
218 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
219 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
220 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
221 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
222 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
223 |
+
argument.
|
224 |
+
lora_scale (`float`, *optional*):
|
225 |
+
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
|
226 |
+
"""
|
227 |
+
# set lora scale so that monkey patched LoRA
|
228 |
+
# function of text encoder can correctly access it
|
229 |
+
if lora_scale is not None and isinstance(self, LoraLoaderMixin):
|
230 |
+
self._lora_scale = lora_scale
|
231 |
+
|
232 |
+
if prompt is not None and isinstance(prompt, str):
|
233 |
+
batch_size = 1
|
234 |
+
elif prompt is not None and isinstance(prompt, list):
|
235 |
+
batch_size = len(prompt)
|
236 |
+
else:
|
237 |
+
batch_size = prompt_embeds.shape[0]
|
238 |
+
|
239 |
+
if prompt_embeds is None:
|
240 |
+
# textual inversion: procecss multi-vector tokens if necessary
|
241 |
+
if isinstance(self, TextualInversionLoaderMixin):
|
242 |
+
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
|
243 |
+
|
244 |
+
text_inputs = self.tokenizer(
|
245 |
+
prompt,
|
246 |
+
padding="max_length",
|
247 |
+
max_length=self.tokenizer.model_max_length,
|
248 |
+
truncation=True,
|
249 |
+
return_tensors="pt",
|
250 |
+
)
|
251 |
+
text_input_ids = text_inputs.input_ids
|
252 |
+
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
253 |
+
|
254 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
255 |
+
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1:-1])
|
256 |
+
logger.warning("The following part of your input was truncated because CLIP can only handle sequences up to"
|
257 |
+
f" {self.tokenizer.model_max_length} tokens: {removed_text}")
|
258 |
+
|
259 |
+
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
260 |
+
attention_mask = text_inputs.attention_mask.to(device)
|
261 |
+
else:
|
262 |
+
attention_mask = None
|
263 |
+
|
264 |
+
prompt_embeds = self.text_encoder(
|
265 |
+
text_input_ids.to(device),
|
266 |
+
attention_mask=attention_mask,
|
267 |
+
)
|
268 |
+
prompt_embeds = prompt_embeds[0]
|
269 |
+
|
270 |
+
if self.text_encoder is not None:
|
271 |
+
prompt_embeds_dtype = self.text_encoder.dtype
|
272 |
+
elif self.unet is not None:
|
273 |
+
prompt_embeds_dtype = self.unet.dtype
|
274 |
+
else:
|
275 |
+
prompt_embeds_dtype = prompt_embeds.dtype
|
276 |
+
|
277 |
+
prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
|
278 |
+
|
279 |
+
bs_embed, seq_len, _ = prompt_embeds.shape
|
280 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
281 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
282 |
+
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
283 |
+
|
284 |
+
# get unconditional embeddings for classifier free guidance
|
285 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
286 |
+
uncond_tokens: List[str]
|
287 |
+
if negative_prompt is None:
|
288 |
+
uncond_tokens = [""] * batch_size
|
289 |
+
elif prompt is not None and type(prompt) is not type(negative_prompt):
|
290 |
+
raise TypeError(f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
291 |
+
f" {type(prompt)}.")
|
292 |
+
elif isinstance(negative_prompt, str):
|
293 |
+
uncond_tokens = [negative_prompt]
|
294 |
+
elif batch_size != len(negative_prompt):
|
295 |
+
raise ValueError(
|
296 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
297 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
298 |
+
" the batch size of `prompt`.")
|
299 |
+
else:
|
300 |
+
uncond_tokens = negative_prompt
|
301 |
+
|
302 |
+
# textual inversion: procecss multi-vector tokens if necessary
|
303 |
+
if isinstance(self, TextualInversionLoaderMixin):
|
304 |
+
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
|
305 |
+
|
306 |
+
max_length = prompt_embeds.shape[1]
|
307 |
+
uncond_input = self.tokenizer(
|
308 |
+
uncond_tokens,
|
309 |
+
padding="max_length",
|
310 |
+
max_length=max_length,
|
311 |
+
truncation=True,
|
312 |
+
return_tensors="pt",
|
313 |
+
)
|
314 |
+
|
315 |
+
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
316 |
+
attention_mask = uncond_input.attention_mask.to(device)
|
317 |
+
else:
|
318 |
+
attention_mask = None
|
319 |
+
|
320 |
+
negative_prompt_embeds = self.text_encoder(
|
321 |
+
uncond_input.input_ids.to(device),
|
322 |
+
attention_mask=attention_mask,
|
323 |
+
)
|
324 |
+
negative_prompt_embeds = negative_prompt_embeds[0]
|
325 |
+
|
326 |
+
if do_classifier_free_guidance:
|
327 |
+
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
328 |
+
seq_len = negative_prompt_embeds.shape[1]
|
329 |
+
|
330 |
+
negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
|
331 |
+
|
332 |
+
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
333 |
+
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
334 |
+
|
335 |
+
# For classifier free guidance, we need to do two forward passes.
|
336 |
+
# Here we concatenate the unconditional and text embeddings into a single batch
|
337 |
+
# to avoid doing two forward passes
|
338 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
339 |
+
|
340 |
+
return prompt_embeds
|
341 |
+
|
342 |
+
def _encode_image(
|
343 |
+
self,
|
344 |
+
image,
|
345 |
+
device,
|
346 |
+
batch_size,
|
347 |
+
num_images_per_prompt,
|
348 |
+
do_classifier_free_guidance,
|
349 |
+
noise_level,
|
350 |
+
generator,
|
351 |
+
image_embeds,
|
352 |
+
negative_image_embeds,
|
353 |
+
):
|
354 |
+
dtype = next(self.image_encoder.parameters()).dtype
|
355 |
+
|
356 |
+
if isinstance(image, PIL.Image.Image):
|
357 |
+
# the image embedding should repeated so it matches the total batch size of the prompt
|
358 |
+
repeat_by = batch_size
|
359 |
+
else:
|
360 |
+
# assume the image input is already properly batched and just needs to be repeated so
|
361 |
+
# it matches the num_images_per_prompt.
|
362 |
+
#
|
363 |
+
# NOTE(will) this is probably missing a few number of side cases. I.e. batched/non-batched
|
364 |
+
# `image_embeds`. If those happen to be common use cases, let's think harder about
|
365 |
+
# what the expected dimensions of inputs should be and how we handle the encoding.
|
366 |
+
repeat_by = num_images_per_prompt
|
367 |
+
|
368 |
+
if image_embeds is None:
|
369 |
+
if not isinstance(image, torch.Tensor):
|
370 |
+
image = self.feature_extractor(images=image, return_tensors="pt").pixel_values
|
371 |
+
|
372 |
+
image = image.to(device=device, dtype=dtype)
|
373 |
+
image_embeds = self.image_encoder(image).image_embeds
|
374 |
+
|
375 |
+
image_embeds = self.noise_image_embeddings(
|
376 |
+
image_embeds=image_embeds,
|
377 |
+
noise_level=noise_level,
|
378 |
+
generator=generator,
|
379 |
+
)
|
380 |
+
|
381 |
+
# duplicate image embeddings for each generation per prompt, using mps friendly method
|
382 |
+
image_embeds = image_embeds.unsqueeze(1)
|
383 |
+
bs_embed, seq_len, _ = image_embeds.shape
|
384 |
+
image_embeds = image_embeds.repeat(1, repeat_by, 1)
|
385 |
+
image_embeds = image_embeds.view(bs_embed * repeat_by, seq_len, -1)
|
386 |
+
image_embeds = image_embeds.squeeze(1)
|
387 |
+
|
388 |
+
if negative_image_embeds is not None:
|
389 |
+
negative_image_embeds = self.noise_image_embeddings(
|
390 |
+
image_embeds=negative_image_embeds,
|
391 |
+
noise_level=0,
|
392 |
+
generator=generator,
|
393 |
+
)
|
394 |
+
# duplicate negative image embeddings for each generation per prompt, using mps friendly method
|
395 |
+
negative_image_embeds = negative_image_embeds.unsqueeze(1)
|
396 |
+
bs_embed, seq_len, _ = negative_image_embeds.shape
|
397 |
+
negative_image_embeds = negative_image_embeds.repeat(1, repeat_by, 1)
|
398 |
+
negative_image_embeds = negative_image_embeds.view(bs_embed * repeat_by, seq_len, -1)
|
399 |
+
negative_image_embeds = negative_image_embeds.squeeze(1)
|
400 |
+
|
401 |
+
if do_classifier_free_guidance:
|
402 |
+
if negative_image_embeds is None:
|
403 |
+
negative_image_embeds = torch.zeros_like(image_embeds)
|
404 |
+
|
405 |
+
# For classifier free guidance, we need to do two forward passes.
|
406 |
+
# Here we concatenate the unconditional and text embeddings into a single batch
|
407 |
+
# to avoid doing two forward passes
|
408 |
+
image_embeds = torch.cat([negative_image_embeds, image_embeds])
|
409 |
+
|
410 |
+
return image_embeds
|
411 |
+
|
412 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
|
413 |
+
def decode_latents(self, latents):
|
414 |
+
warnings.warn(
|
415 |
+
"The decode_latents method is deprecated and will be removed in a future version. Please"
|
416 |
+
" use VaeImageProcessor instead",
|
417 |
+
FutureWarning,
|
418 |
+
)
|
419 |
+
latents = 1 / self.vae.config.scaling_factor * latents
|
420 |
+
image = self.vae.decode(latents, return_dict=False)[0]
|
421 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
422 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
423 |
+
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
424 |
+
return image
|
425 |
+
|
426 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
427 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
428 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
429 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
430 |
+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
431 |
+
# and should be between [0, 1]
|
432 |
+
|
433 |
+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
434 |
+
extra_step_kwargs = {}
|
435 |
+
if accepts_eta:
|
436 |
+
extra_step_kwargs["eta"] = eta
|
437 |
+
|
438 |
+
# check if the scheduler accepts generator
|
439 |
+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
440 |
+
if accepts_generator:
|
441 |
+
extra_step_kwargs["generator"] = generator
|
442 |
+
return extra_step_kwargs
|
443 |
+
|
444 |
+
def check_inputs(
|
445 |
+
self,
|
446 |
+
prompt,
|
447 |
+
image,
|
448 |
+
height,
|
449 |
+
width,
|
450 |
+
callback_steps,
|
451 |
+
noise_level,
|
452 |
+
negative_prompt=None,
|
453 |
+
prompt_embeds=None,
|
454 |
+
negative_prompt_embeds=None,
|
455 |
+
image_embeds=None,
|
456 |
+
):
|
457 |
+
if height % 8 != 0 or width % 8 != 0:
|
458 |
+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
459 |
+
|
460 |
+
if (callback_steps is None) or (callback_steps is not None and
|
461 |
+
(not isinstance(callback_steps, int) or callback_steps <= 0)):
|
462 |
+
raise ValueError(f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
463 |
+
f" {type(callback_steps)}.")
|
464 |
+
|
465 |
+
if prompt is not None and prompt_embeds is not None:
|
466 |
+
raise ValueError("Provide either `prompt` or `prompt_embeds`. Please make sure to define only one of the two.")
|
467 |
+
|
468 |
+
if prompt is None and prompt_embeds is None:
|
469 |
+
raise ValueError(
|
470 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined.")
|
471 |
+
|
472 |
+
if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
473 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
474 |
+
|
475 |
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
476 |
+
raise ValueError(
|
477 |
+
"Provide either `negative_prompt` or `negative_prompt_embeds`. Cannot leave both `negative_prompt` and `negative_prompt_embeds` undefined."
|
478 |
+
)
|
479 |
+
|
480 |
+
if prompt is not None and negative_prompt is not None:
|
481 |
+
if type(prompt) is not type(negative_prompt):
|
482 |
+
raise TypeError(f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
483 |
+
f" {type(prompt)}.")
|
484 |
+
|
485 |
+
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
486 |
+
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
487 |
+
raise ValueError(
|
488 |
+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
489 |
+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
490 |
+
f" {negative_prompt_embeds.shape}.")
|
491 |
+
|
492 |
+
if noise_level < 0 or noise_level >= self.image_noising_scheduler.config.num_train_timesteps:
|
493 |
+
raise ValueError(
|
494 |
+
f"`noise_level` must be between 0 and {self.image_noising_scheduler.config.num_train_timesteps - 1}, inclusive."
|
495 |
+
)
|
496 |
+
|
497 |
+
if image is not None and image_embeds is not None:
|
498 |
+
raise ValueError("Provide either `image` or `image_embeds`. Please make sure to define only one of the two.")
|
499 |
+
|
500 |
+
if image is None and image_embeds is None:
|
501 |
+
raise ValueError(
|
502 |
+
"Provide either `image` or `image_embeds`. Cannot leave both `image` and `image_embeds` undefined.")
|
503 |
+
|
504 |
+
if image is not None:
|
505 |
+
if (not isinstance(image, torch.Tensor) and not isinstance(image, PIL.Image.Image) and not isinstance(image, list)):
|
506 |
+
raise ValueError(
|
507 |
+
"`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is"
|
508 |
+
f" {type(image)}")
|
509 |
+
|
510 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
|
511 |
+
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
|
512 |
+
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
513 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
514 |
+
raise ValueError(
|
515 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
516 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators.")
|
517 |
+
|
518 |
+
if latents is None:
|
519 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
520 |
+
else:
|
521 |
+
latents = latents.to(device)
|
522 |
+
|
523 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
524 |
+
latents = latents * self.scheduler.init_noise_sigma
|
525 |
+
return latents
|
526 |
+
|
527 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_unclip.StableUnCLIPPipeline.noise_image_embeddings
|
528 |
+
def noise_image_embeddings(
|
529 |
+
self,
|
530 |
+
image_embeds: torch.Tensor,
|
531 |
+
noise_level: int,
|
532 |
+
noise: Optional[torch.FloatTensor] = None,
|
533 |
+
generator: Optional[torch.Generator] = None,
|
534 |
+
):
|
535 |
+
"""
|
536 |
+
Add noise to the image embeddings. The amount of noise is controlled by a `noise_level` input. A higher
|
537 |
+
`noise_level` increases the variance in the final un-noised images.
|
538 |
+
|
539 |
+
The noise is applied in two ways:
|
540 |
+
1. A noise schedule is applied directly to the embeddings.
|
541 |
+
2. A vector of sinusoidal time embeddings are appended to the output.
|
542 |
+
|
543 |
+
In both cases, the amount of noise is controlled by the same `noise_level`.
|
544 |
+
|
545 |
+
The embeddings are normalized before the noise is applied and un-normalized after the noise is applied.
|
546 |
+
"""
|
547 |
+
if noise is None:
|
548 |
+
noise = randn_tensor(image_embeds.shape, generator=generator, device=image_embeds.device, dtype=image_embeds.dtype)
|
549 |
+
|
550 |
+
noise_level = torch.tensor([noise_level] * image_embeds.shape[0], device=image_embeds.device)
|
551 |
+
|
552 |
+
self.image_normalizer.to(image_embeds.device)
|
553 |
+
image_embeds = self.image_normalizer.scale(image_embeds)
|
554 |
+
|
555 |
+
image_embeds = self.image_noising_scheduler.add_noise(image_embeds, timesteps=noise_level, noise=noise)
|
556 |
+
|
557 |
+
image_embeds = self.image_normalizer.unscale(image_embeds)
|
558 |
+
|
559 |
+
noise_level = get_timestep_embedding(timesteps=noise_level,
|
560 |
+
embedding_dim=image_embeds.shape[-1],
|
561 |
+
flip_sin_to_cos=True,
|
562 |
+
downscale_freq_shift=0)
|
563 |
+
|
564 |
+
# `get_timestep_embeddings` does not contain any weights and will always return f32 tensors,
|
565 |
+
# but we might actually be running in fp16. so we need to cast here.
|
566 |
+
# there might be better ways to encapsulate this.
|
567 |
+
noise_level = noise_level.to(image_embeds.dtype)
|
568 |
+
|
569 |
+
image_embeds = torch.cat((image_embeds, noise_level), 1)
|
570 |
+
|
571 |
+
return image_embeds
|
572 |
+
|
573 |
+
@torch.no_grad()
|
574 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
575 |
+
def __call__(
|
576 |
+
self,
|
577 |
+
image: Union[torch.FloatTensor, PIL.Image.Image] = None,
|
578 |
+
prompt: Union[str, List[str]] = None,
|
579 |
+
height: Optional[int] = None,
|
580 |
+
width: Optional[int] = None,
|
581 |
+
num_inference_steps: int = 20,
|
582 |
+
guidance_scale: float = 10,
|
583 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
584 |
+
num_images_per_prompt: Optional[int] = 1,
|
585 |
+
eta: float = 0.0,
|
586 |
+
generator: Optional[torch.Generator] = None,
|
587 |
+
latents: Optional[torch.FloatTensor] = None,
|
588 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
589 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
590 |
+
output_type: Optional[str] = "pil",
|
591 |
+
return_dict: bool = True,
|
592 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
593 |
+
callback_steps: int = 1,
|
594 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
595 |
+
noise_level: int = 0,
|
596 |
+
image_embeds: Optional[torch.FloatTensor] = None,
|
597 |
+
negative_image_embeds: Optional[torch.FloatTensor] = None,
|
598 |
+
):
|
599 |
+
r"""
|
600 |
+
The call function to the pipeline for generation.
|
601 |
+
|
602 |
+
Args:
|
603 |
+
prompt (`str` or `List[str]`, *optional*):
|
604 |
+
The prompt or prompts to guide the image generation. If not defined, either `prompt_embeds` will be
|
605 |
+
used or prompt is initialized to `""`.
|
606 |
+
image (`torch.FloatTensor` or `PIL.Image.Image`):
|
607 |
+
`Image` or tensor representing an image batch. The image is encoded to its CLIP embedding which the
|
608 |
+
`unet` is conditioned on. The image is _not_ encoded by the `vae` and then used as the latents in the
|
609 |
+
denoising process like it is in the standard Stable Diffusion text-guided image variation process.
|
610 |
+
height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
611 |
+
The height in pixels of the generated image.
|
612 |
+
width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
613 |
+
The width in pixels of the generated image.
|
614 |
+
num_inference_steps (`int`, *optional*, defaults to 20):
|
615 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
616 |
+
expense of slower inference.
|
617 |
+
guidance_scale (`float`, *optional*, defaults to 10.0):
|
618 |
+
A higher guidance scale value encourages the model to generate images closely linked to the text
|
619 |
+
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
|
620 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
621 |
+
The prompt or prompts to guide what to not include in image generation. If not defined, you need to
|
622 |
+
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
|
623 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
624 |
+
The number of images to generate per prompt.
|
625 |
+
eta (`float`, *optional*, defaults to 0.0):
|
626 |
+
Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
|
627 |
+
to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
|
628 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
629 |
+
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
630 |
+
generation deterministic.
|
631 |
+
latents (`torch.FloatTensor`, *optional*):
|
632 |
+
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
|
633 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
634 |
+
tensor is generated by sampling using the supplied random `generator`.
|
635 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
636 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
|
637 |
+
provided, text embeddings are generated from the `prompt` input argument.
|
638 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
639 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
|
640 |
+
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
|
641 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
642 |
+
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
643 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
644 |
+
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
|
645 |
+
callback (`Callable`, *optional*):
|
646 |
+
A function that calls every `callback_steps` steps during inference. The function is called with the
|
647 |
+
following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
648 |
+
callback_steps (`int`, *optional*, defaults to 1):
|
649 |
+
The frequency at which the `callback` function is called. If not specified, the callback is called at
|
650 |
+
every step.
|
651 |
+
cross_attention_kwargs (`dict`, *optional*):
|
652 |
+
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
|
653 |
+
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
654 |
+
noise_level (`int`, *optional*, defaults to `0`):
|
655 |
+
The amount of noise to add to the image embeddings. A higher `noise_level` increases the variance in
|
656 |
+
the final un-noised images. See [`StableUnCLIPPipeline.noise_image_embeddings`] for more details.
|
657 |
+
image_embeds (`torch.FloatTensor`, *optional*):
|
658 |
+
Pre-generated CLIP embeddings to condition the `unet` on. These latents are not used in the denoising
|
659 |
+
process. If you want to provide pre-generated latents, pass them to `__call__` as `latents`.
|
660 |
+
|
661 |
+
Examples:
|
662 |
+
|
663 |
+
Returns:
|
664 |
+
[`~pipelines.ImagePipelineOutput`] or `tuple`:
|
665 |
+
[`~ pipeline_utils.ImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When returning
|
666 |
+
a tuple, the first element is a list with the generated images.
|
667 |
+
"""
|
668 |
+
# 0. Default height and width to unet
|
669 |
+
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
670 |
+
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
671 |
+
|
672 |
+
if prompt is None and prompt_embeds is None:
|
673 |
+
prompt = len(image) * [""] if isinstance(image, list) else ""
|
674 |
+
|
675 |
+
# 1. Check inputs. Raise error if not correct
|
676 |
+
self.check_inputs(
|
677 |
+
prompt=prompt,
|
678 |
+
image=image,
|
679 |
+
height=height,
|
680 |
+
width=width,
|
681 |
+
callback_steps=callback_steps,
|
682 |
+
noise_level=noise_level,
|
683 |
+
negative_prompt=negative_prompt,
|
684 |
+
prompt_embeds=prompt_embeds,
|
685 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
686 |
+
image_embeds=image_embeds,
|
687 |
+
)
|
688 |
+
|
689 |
+
# 2. Define call parameters
|
690 |
+
if prompt is not None and isinstance(prompt, str):
|
691 |
+
batch_size = 1
|
692 |
+
elif prompt is not None and isinstance(prompt, list):
|
693 |
+
batch_size = len(prompt)
|
694 |
+
else:
|
695 |
+
batch_size = prompt_embeds.shape[0]
|
696 |
+
|
697 |
+
batch_size = batch_size * num_images_per_prompt
|
698 |
+
|
699 |
+
device = self._execution_device
|
700 |
+
|
701 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
702 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
703 |
+
# corresponds to doing no classifier free guidance.
|
704 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
705 |
+
|
706 |
+
# 3. Encode input prompt
|
707 |
+
text_encoder_lora_scale = (cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None)
|
708 |
+
prompt_embeds = self._encode_prompt(
|
709 |
+
prompt=prompt,
|
710 |
+
device=device,
|
711 |
+
num_images_per_prompt=num_images_per_prompt,
|
712 |
+
do_classifier_free_guidance=do_classifier_free_guidance,
|
713 |
+
negative_prompt=negative_prompt,
|
714 |
+
prompt_embeds=prompt_embeds,
|
715 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
716 |
+
lora_scale=text_encoder_lora_scale,
|
717 |
+
)
|
718 |
+
|
719 |
+
# 4. Encoder input image
|
720 |
+
noise_level = torch.tensor([noise_level], device=device)
|
721 |
+
image_embeds = self._encode_image(
|
722 |
+
image=image,
|
723 |
+
device=device,
|
724 |
+
batch_size=batch_size,
|
725 |
+
num_images_per_prompt=num_images_per_prompt,
|
726 |
+
do_classifier_free_guidance=do_classifier_free_guidance,
|
727 |
+
noise_level=noise_level,
|
728 |
+
generator=generator,
|
729 |
+
image_embeds=image_embeds,
|
730 |
+
negative_image_embeds=negative_image_embeds,
|
731 |
+
)
|
732 |
+
|
733 |
+
# 5. Prepare timesteps
|
734 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
735 |
+
timesteps = self.scheduler.timesteps
|
736 |
+
|
737 |
+
# 6. Prepare latent variables
|
738 |
+
num_channels_latents = self.unet.config.in_channels
|
739 |
+
latents = self.prepare_latents(
|
740 |
+
batch_size=batch_size,
|
741 |
+
num_channels_latents=num_channels_latents,
|
742 |
+
height=height,
|
743 |
+
width=width,
|
744 |
+
dtype=prompt_embeds.dtype,
|
745 |
+
device=device,
|
746 |
+
generator=generator,
|
747 |
+
latents=latents,
|
748 |
+
)
|
749 |
+
|
750 |
+
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
751 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
752 |
+
|
753 |
+
# 8. Denoising loop
|
754 |
+
for i, t in enumerate(self.progress_bar(timesteps)):
|
755 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
756 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
757 |
+
|
758 |
+
# predict the noise residual
|
759 |
+
noise_pred = self.unet(
|
760 |
+
latent_model_input,
|
761 |
+
t,
|
762 |
+
encoder_hidden_states=prompt_embeds,
|
763 |
+
class_labels=image_embeds,
|
764 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
765 |
+
return_dict=False,
|
766 |
+
)[0]
|
767 |
+
|
768 |
+
# perform guidance
|
769 |
+
if do_classifier_free_guidance:
|
770 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
771 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
772 |
+
|
773 |
+
# compute the previous noisy sample x_t -> x_t-1
|
774 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
775 |
+
|
776 |
+
if callback is not None and i % callback_steps == 0:
|
777 |
+
callback(i, t, latents)
|
778 |
+
|
779 |
+
# 9. Post-processing
|
780 |
+
if not output_type == "latent":
|
781 |
+
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
782 |
+
else:
|
783 |
+
image = latents
|
784 |
+
|
785 |
+
image = self.image_processor.postprocess(image, output_type=output_type)
|
786 |
+
|
787 |
+
# Offload last model to CPU
|
788 |
+
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
789 |
+
self.final_offload_hook.offload()
|
790 |
+
|
791 |
+
if not return_dict:
|
792 |
+
return (image, )
|
793 |
+
|
794 |
+
return ImagePipelineOutput(images=image)
|
models/seed_llama_tokenizer.py
ADDED
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
import torch
|
3 |
+
# import math
|
4 |
+
# from torchvision import transforms
|
5 |
+
import os
|
6 |
+
# from timm.models import create_model
|
7 |
+
from typing import Any, Dict, List, Optional, Union
|
8 |
+
from transformers import LlamaTokenizer
|
9 |
+
from diffusers import DiffusionPipeline
|
10 |
+
# from torchvision.transforms.functional import pil_to_tensor
|
11 |
+
|
12 |
+
# import torch
|
13 |
+
from PIL import Image
|
14 |
+
from torchvision import transforms
|
15 |
+
|
16 |
+
# from qformer.qformer_quantizer import Blip2QformerQuantizer
|
17 |
+
# from diffusers import StableUnCLIPImg2ImgPipeline
|
18 |
+
from .pipeline_stable_unclip_img2img import StableUnCLIPImg2ImgPipeline
|
19 |
+
|
20 |
+
WEIGHTS_NAME = 'seed_quantizer.pt'
|
21 |
+
DIFFUSION_NAME = 'diffusion_model'
|
22 |
+
|
23 |
+
|
24 |
+
class ImageTokenizer(nn.Module):
|
25 |
+
def __init__(self,
|
26 |
+
model_path,
|
27 |
+
diffusion_model_path=None,
|
28 |
+
load_diffusion=False,
|
29 |
+
image_size=224,
|
30 |
+
device='cuda',
|
31 |
+
fp16=True,
|
32 |
+
**kwargs):
|
33 |
+
super().__init__()
|
34 |
+
from .seed_qformer.qformer_quantizer import Blip2QformerQuantizer
|
35 |
+
|
36 |
+
model = Blip2QformerQuantizer.from_pretrained(pretrained_model_path=model_path,
|
37 |
+
vit_precision='fp16' if fp16 else 'fp32',
|
38 |
+
**kwargs).eval()
|
39 |
+
if diffusion_model_path is not None and load_diffusion:
|
40 |
+
# diffusion_model = DiffusionPipeline.from_pretrained(diffusion_model_path,
|
41 |
+
# torch_dtype=torch.float16 if fp16 else torch.float32)
|
42 |
+
diffusion_model = StableUnCLIPImg2ImgPipeline.from_pretrained(diffusion_model_path,
|
43 |
+
torch_dtype=torch.float16 if fp16 else torch.float32)
|
44 |
+
self.diffusion_model = diffusion_model.to(device)
|
45 |
+
else:
|
46 |
+
self.diffusion_model = None
|
47 |
+
|
48 |
+
model = model.to(device)
|
49 |
+
|
50 |
+
processor = transforms.Compose([
|
51 |
+
transforms.Resize((image_size, image_size), interpolation=3),
|
52 |
+
# transforms.Resize(image_size, interpolation=3),
|
53 |
+
# transforms.CenterCrop(image_size),
|
54 |
+
transforms.ToTensor(),
|
55 |
+
transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
|
56 |
+
])
|
57 |
+
|
58 |
+
if fp16:
|
59 |
+
model = model.half()
|
60 |
+
|
61 |
+
shape_latents = torch.Size([1, 4, 96, 96])
|
62 |
+
self.latents = torch.randn(shape_latents, generator=None, device=device, dtype=torch.float16, layout=torch.strided)
|
63 |
+
|
64 |
+
shape_noise = torch.Size([1, 1024])
|
65 |
+
self.noise = torch.randn(shape_noise, generator=None, device=device, dtype=torch.float16, layout=torch.strided)
|
66 |
+
|
67 |
+
self.model = model
|
68 |
+
self.processor = processor
|
69 |
+
self.device = device
|
70 |
+
self.fp16 = fp16
|
71 |
+
|
72 |
+
def __len__(self):
|
73 |
+
return self.model.n_embed
|
74 |
+
|
75 |
+
def encode(self, image_torch):
|
76 |
+
'''Convert a batch of img to code
|
77 |
+
Args:
|
78 |
+
model: The tokenizer model.
|
79 |
+
img: [b, c, h, w]
|
80 |
+
'''
|
81 |
+
if len(image_torch.shape) == 3:
|
82 |
+
image_torch = image_torch.unsqueeze(0)
|
83 |
+
|
84 |
+
# img = image_torch.to(self.device)
|
85 |
+
img = image_torch
|
86 |
+
if self.fp16:
|
87 |
+
img = img.half()
|
88 |
+
with torch.no_grad():
|
89 |
+
id, _ = self.model.get_codebook_indices(img)
|
90 |
+
return id.view(img.shape[0], -1)
|
91 |
+
|
92 |
+
def decode(self, indices, negative_indices=None, guidance_scale=10, num_inference_steps=20):
|
93 |
+
image_embeds = self.model.get_codebook_entry(indices)
|
94 |
+
# image = self.diffusion_model(image_embeds=image_embed,
|
95 |
+
# noise_level=0,
|
96 |
+
# num_inference_steps=20,
|
97 |
+
# latents=self.latents,
|
98 |
+
# noise=self.noise).images
|
99 |
+
if negative_indices is not None:
|
100 |
+
assert indices.shape == negative_indices.shape, 'Negative indices must have the same shape with indices'
|
101 |
+
negative_image_embeds = self.model.get_codebook_entry(negative_indices)
|
102 |
+
else:
|
103 |
+
negative_image_embeds = None
|
104 |
+
|
105 |
+
image = self.diffusion_model(
|
106 |
+
image_embeds=image_embeds,
|
107 |
+
negative_image_embeds=negative_image_embeds,
|
108 |
+
guidance_scale=guidance_scale,
|
109 |
+
noise_level=0,
|
110 |
+
num_inference_steps=num_inference_steps,
|
111 |
+
latents=self.latents,
|
112 |
+
).images
|
113 |
+
return image
|
114 |
+
|
115 |
+
|
116 |
+
class SeedLlamaTokenizer(LlamaTokenizer):
|
117 |
+
def __init__(self,
|
118 |
+
vocab_file,
|
119 |
+
unk_token="<unk>",
|
120 |
+
bos_token="<s>",
|
121 |
+
eos_token="</s>",
|
122 |
+
pad_token=None,
|
123 |
+
sp_model_kwargs: Optional[Dict[str, Any]] = None,
|
124 |
+
add_bos_token=True,
|
125 |
+
add_eos_token=False,
|
126 |
+
clean_up_tokenization_spaces=False,
|
127 |
+
device='cuda',
|
128 |
+
fp16=True,
|
129 |
+
load_diffusion=False,
|
130 |
+
encoder_url=None,
|
131 |
+
diffusion_path=None,
|
132 |
+
**kwargs):
|
133 |
+
super().__init__(vocab_file, unk_token, bos_token, eos_token, pad_token, sp_model_kwargs, add_bos_token, add_eos_token,
|
134 |
+
clean_up_tokenization_spaces, **kwargs)
|
135 |
+
self.device = device
|
136 |
+
self.fp16 = fp16
|
137 |
+
self.pad_token = self.unk_token
|
138 |
+
self.load_diffusion = load_diffusion
|
139 |
+
self.encoder_url = encoder_url
|
140 |
+
self.diffusion_path = diffusion_path
|
141 |
+
|
142 |
+
self.load_image_tokenizer()
|
143 |
+
|
144 |
+
def load_image_tokenizer(self):
|
145 |
+
if not hasattr(self, '_image_tokenizer'):
|
146 |
+
if self.encoder_url is not None:
|
147 |
+
model_path = self.encoder_url
|
148 |
+
else:
|
149 |
+
assert hasattr(self, 'name_or_path') and os.path.exists(self.name_or_path)
|
150 |
+
model_path = os.path.join(self.name_or_path, WEIGHTS_NAME)
|
151 |
+
# diffusion_model_path = os.path.join(self.name_or_path, DIFFUSION_NAME)
|
152 |
+
# diffusion_model_path = 'stabilityai/stable-diffusion-2-1-unclip'
|
153 |
+
self._image_tokenizer = ImageTokenizer(model_path=model_path,
|
154 |
+
diffusion_model_path=self.diffusion_path,
|
155 |
+
load_diffusion=self.load_diffusion,
|
156 |
+
device=self.device,
|
157 |
+
fp16=self.fp16)
|
158 |
+
|
159 |
+
@property
|
160 |
+
def image_tokenizer(self):
|
161 |
+
if not hasattr(self, '_image_tokenizer'):
|
162 |
+
if self.encoder_url is not None:
|
163 |
+
model_path = self.encoder_url
|
164 |
+
else:
|
165 |
+
assert hasattr(self, 'name_or_path') and os.path.exists(self.name_or_path)
|
166 |
+
model_path = os.path.join(self.name_or_path, WEIGHTS_NAME)
|
167 |
+
# diffusion_model_path = os.path.join(self.name_or_path, DIFFUSION_NAME)
|
168 |
+
# diffusion_model_path = 'stabilityai/stable-diffusion-2-1-unclip'
|
169 |
+
self._image_tokenizer = ImageTokenizer(model_path=model_path,
|
170 |
+
diffusion_model_path=self.diffusion_path,
|
171 |
+
load_diffusion=self.load_diffusion,
|
172 |
+
device=self.device,
|
173 |
+
fp16=self.fp16)
|
174 |
+
return self._image_tokenizer
|
175 |
+
|
176 |
+
@property
|
177 |
+
def num_image_tokens(self):
|
178 |
+
return 8192 # self.image_tokenizer.num_tokens # allow not load
|
179 |
+
|
180 |
+
def to(self, device):
|
181 |
+
self.device = device
|
182 |
+
if hasattr(self, '_image_tokenizer'):
|
183 |
+
self._image_tokenizer.to(device=device)
|
184 |
+
|
185 |
+
def encode_image(
|
186 |
+
self,
|
187 |
+
image_path=None,
|
188 |
+
image_pil=None,
|
189 |
+
image_torch=None,
|
190 |
+
image_size: int = 224,
|
191 |
+
):
|
192 |
+
assert (image_path is None) + (image_pil is None) + (image_torch is None) == 2
|
193 |
+
|
194 |
+
# need_norm_to_1 = False
|
195 |
+
if image_path is not None:
|
196 |
+
image_pil = Image.open(image_path).convert('RGB')
|
197 |
+
|
198 |
+
if image_pil is not None:
|
199 |
+
image_torch = self.image_tokenizer.processor(image_pil)
|
200 |
+
|
201 |
+
image_torch = image_torch.to(self.device)
|
202 |
+
return self.image_tokenizer.encode(image_torch)
|
203 |
+
|
204 |
+
def decode_image(self, indices, negative_indices=None, guidance_scale=10):
|
205 |
+
indices = indices.to(self.device)
|
206 |
+
if negative_indices is not None:
|
207 |
+
negative_indices = negative_indices.to(self.device)
|
208 |
+
image = self.image_tokenizer.decode(
|
209 |
+
indices,
|
210 |
+
negative_indices=negative_indices,
|
211 |
+
guidance_scale=guidance_scale,
|
212 |
+
)
|
213 |
+
return image
|
models/seed_qformer/blip2.py
ADDED
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) 2023, salesforce.com, inc.
|
3 |
+
All rights reserved.
|
4 |
+
SPDX-License-Identifier: BSD-3-Clause
|
5 |
+
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
+
"""
|
7 |
+
import contextlib
|
8 |
+
import logging
|
9 |
+
import os
|
10 |
+
import time
|
11 |
+
import datetime
|
12 |
+
|
13 |
+
import torch
|
14 |
+
import torch.nn as nn
|
15 |
+
import torch.distributed as dist
|
16 |
+
import torch.nn.functional as F
|
17 |
+
|
18 |
+
|
19 |
+
from .qformer_causual import BertConfig, BertLMHeadModel
|
20 |
+
|
21 |
+
from .utils import download_cached_file, get_rank, get_dist_info, get_world_size, main_process, is_dist_avail_and_initialized, is_url
|
22 |
+
from .eva_vit import create_eva_vit_g
|
23 |
+
from .clip_vit import create_clip_vit_L
|
24 |
+
from transformers import BertTokenizer
|
25 |
+
|
26 |
+
|
27 |
+
# class Blip2Base(BaseModel):
|
28 |
+
class Blip2Base(nn.Module):
|
29 |
+
def __init__(self):
|
30 |
+
super().__init__()
|
31 |
+
|
32 |
+
@property
|
33 |
+
def device(self):
|
34 |
+
return list(self.parameters())[0].device
|
35 |
+
|
36 |
+
@classmethod
|
37 |
+
def init_tokenizer(cls, truncation_side="right"):
|
38 |
+
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", truncation_side=truncation_side)
|
39 |
+
tokenizer.add_special_tokens({"bos_token": "[DEC]"})
|
40 |
+
return tokenizer
|
41 |
+
|
42 |
+
def maybe_autocast(self, dtype=torch.float16):
|
43 |
+
# if on cpu, don't use autocast
|
44 |
+
# if on gpu, use autocast with dtype if provided, otherwise use torch.float16
|
45 |
+
enable_autocast = self.device != torch.device("cpu")
|
46 |
+
|
47 |
+
if enable_autocast:
|
48 |
+
return torch.cuda.amp.autocast(dtype=dtype)
|
49 |
+
else:
|
50 |
+
return contextlib.nullcontext()
|
51 |
+
|
52 |
+
@classmethod
|
53 |
+
def init_Qformer(cls, num_query_token, vision_width, cross_attention_freq=2):
|
54 |
+
encoder_config = BertConfig.from_pretrained("bert-base-uncased")
|
55 |
+
encoder_config.encoder_width = vision_width
|
56 |
+
# insert cross-attention layer every other block
|
57 |
+
encoder_config.add_cross_attention = True
|
58 |
+
encoder_config.cross_attention_freq = cross_attention_freq
|
59 |
+
encoder_config.query_length = num_query_token
|
60 |
+
Qformer = BertLMHeadModel.from_pretrained("bert-base-uncased", config=encoder_config)
|
61 |
+
query_tokens = nn.Parameter(torch.zeros(1, num_query_token, encoder_config.hidden_size))
|
62 |
+
query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range)
|
63 |
+
return Qformer, query_tokens
|
64 |
+
|
65 |
+
def init_vision_encoder(self, model_name, img_size, drop_path_rate, use_grad_checkpoint, precision):
|
66 |
+
assert model_name in [
|
67 |
+
"eva_clip_g",
|
68 |
+
"eva2_clip_L",
|
69 |
+
"clip_L",
|
70 |
+
], "vit model must be eva_clip_g, eva2_clip_L or clip_L"
|
71 |
+
if model_name == "eva_clip_g":
|
72 |
+
visual_encoder = create_eva_vit_g(img_size, drop_path_rate, use_grad_checkpoint, precision)
|
73 |
+
|
74 |
+
elif model_name == "clip_L":
|
75 |
+
visual_encoder = create_clip_vit_L(img_size, use_grad_checkpoint, precision)
|
76 |
+
ln_vision = LayerNorm(visual_encoder.num_features)
|
77 |
+
self.vit_name = model_name
|
78 |
+
return visual_encoder, ln_vision
|
79 |
+
|
80 |
+
def load_from_pretrained(self, url_or_filename):
|
81 |
+
if is_url(url_or_filename):
|
82 |
+
cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True)
|
83 |
+
checkpoint = torch.load(cached_file, map_location="cpu")
|
84 |
+
elif os.path.isfile(url_or_filename):
|
85 |
+
checkpoint = torch.load(url_or_filename, map_location="cpu")
|
86 |
+
else:
|
87 |
+
raise RuntimeError("checkpoint url or path is invalid")
|
88 |
+
|
89 |
+
state_dict = checkpoint["model"]
|
90 |
+
|
91 |
+
msg = self.load_state_dict(state_dict, strict=False)
|
92 |
+
|
93 |
+
# logging.info("Missing keys {}".format(msg.missing_keys))
|
94 |
+
logging.info("load checkpoint from %s" % url_or_filename)
|
95 |
+
|
96 |
+
return msg
|
97 |
+
|
98 |
+
def get_optimizer_params(self, weight_decay, lr_scale=1):
|
99 |
+
if self.vit_name == "eva_clip_g":
|
100 |
+
vit_num_layers = self.visual_encoder.get_num_layer()
|
101 |
+
lr_scales = list(lr_scale**(vit_num_layers + 1 - i) for i in range(vit_num_layers + 2))
|
102 |
+
|
103 |
+
parameter_group_names = {}
|
104 |
+
parameter_group_vars = {}
|
105 |
+
|
106 |
+
for name, param in self.named_parameters():
|
107 |
+
if not param.requires_grad:
|
108 |
+
continue # frozen weights
|
109 |
+
if len(param.shape) == 1 or name.endswith(".bias"):
|
110 |
+
group_name = "no_decay"
|
111 |
+
this_weight_decay = 0.
|
112 |
+
else:
|
113 |
+
group_name = "decay"
|
114 |
+
this_weight_decay = weight_decay
|
115 |
+
if 'visual_encoder' in name:
|
116 |
+
layer_id = self.visual_encoder.get_num_layer(name.replace('visual_encoder.', ''))
|
117 |
+
group_name = "vit_layer_%d_%s" % (layer_id, group_name)
|
118 |
+
else:
|
119 |
+
layer_id = None
|
120 |
+
|
121 |
+
if group_name not in parameter_group_names:
|
122 |
+
if layer_id is not None:
|
123 |
+
scale = lr_scales[layer_id]
|
124 |
+
else:
|
125 |
+
scale = 1
|
126 |
+
parameter_group_names[group_name] = {"weight_decay": this_weight_decay, "params": [], "lr_scale": scale}
|
127 |
+
parameter_group_vars[group_name] = {"weight_decay": this_weight_decay, "params": [], "lr_scale": scale}
|
128 |
+
parameter_group_vars[group_name]["params"].append(param)
|
129 |
+
parameter_group_names[group_name]["params"].append(name)
|
130 |
+
# import json
|
131 |
+
# print("Param groups = %s" % json.dumps(parameter_group_names, indent=2))
|
132 |
+
optim_params = list(parameter_group_vars.values())
|
133 |
+
return optim_params
|
134 |
+
else:
|
135 |
+
return super().get_optimizer_params(weight_decay, lr_scale)
|
136 |
+
|
137 |
+
def _lemmatize(self, answers):
|
138 |
+
def apply(answer):
|
139 |
+
doc = self.lemmatizer(answer)
|
140 |
+
|
141 |
+
words = []
|
142 |
+
for token in doc:
|
143 |
+
if token.pos_ in ["NOUN", "VERB"]:
|
144 |
+
words.append(token.lemma_)
|
145 |
+
else:
|
146 |
+
words.append(token.text)
|
147 |
+
answer = " ".join(words)
|
148 |
+
|
149 |
+
return answer
|
150 |
+
|
151 |
+
return [apply(answer) for answer in answers]
|
152 |
+
|
153 |
+
@property
|
154 |
+
def lemmatizer(self):
|
155 |
+
if self._lemmatizer is None:
|
156 |
+
try:
|
157 |
+
import spacy
|
158 |
+
|
159 |
+
self._lemmatizer = spacy.load("en_core_web_sm")
|
160 |
+
except ImportError:
|
161 |
+
logging.error("""
|
162 |
+
Please install spacy and en_core_web_sm model to apply lemmatization.
|
163 |
+
python -m spacy download en_core_web_sm
|
164 |
+
OR
|
165 |
+
import spacy.cli
|
166 |
+
spacy.cli.download("en_core_web_sm")
|
167 |
+
""")
|
168 |
+
exit(1)
|
169 |
+
|
170 |
+
return self._lemmatizer
|
171 |
+
|
172 |
+
|
173 |
+
def disabled_train(self, mode=True):
|
174 |
+
"""Overwrite model.train with this function to make sure train/eval mode
|
175 |
+
does not change anymore."""
|
176 |
+
return self
|
177 |
+
|
178 |
+
|
179 |
+
class LayerNorm(nn.LayerNorm):
|
180 |
+
"""Subclass torch's LayerNorm to handle fp16."""
|
181 |
+
def forward(self, x: torch.Tensor):
|
182 |
+
orig_type = x.dtype
|
183 |
+
ret = super().forward(x.type(torch.float32))
|
184 |
+
return ret.type(orig_type)
|
185 |
+
|
186 |
+
|
models/seed_qformer/clip_vit.py
ADDED
@@ -0,0 +1,257 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import OrderedDict
|
2 |
+
from itertools import repeat
|
3 |
+
import collections.abc
|
4 |
+
import math
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from torch import nn
|
9 |
+
|
10 |
+
|
11 |
+
from .eva_vit import convert_weights_to_fp16
|
12 |
+
from .utils import download_cached_file
|
13 |
+
|
14 |
+
|
15 |
+
class Bottleneck(nn.Module):
|
16 |
+
expansion = 4
|
17 |
+
|
18 |
+
def __init__(self, inplanes, planes, stride=1):
|
19 |
+
super().__init__()
|
20 |
+
|
21 |
+
# all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
|
22 |
+
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
|
23 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
24 |
+
self.relu1 = nn.ReLU(inplace=True)
|
25 |
+
|
26 |
+
self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
|
27 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
28 |
+
self.relu2 = nn.ReLU(inplace=True)
|
29 |
+
|
30 |
+
self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
|
31 |
+
|
32 |
+
self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
|
33 |
+
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
34 |
+
self.relu3 = nn.ReLU(inplace=True)
|
35 |
+
|
36 |
+
self.downsample = None
|
37 |
+
self.stride = stride
|
38 |
+
|
39 |
+
if stride > 1 or inplanes != planes * Bottleneck.expansion:
|
40 |
+
# downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
|
41 |
+
self.downsample = nn.Sequential(
|
42 |
+
OrderedDict([("-1", nn.AvgPool2d(stride)),
|
43 |
+
("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
|
44 |
+
("1", nn.BatchNorm2d(planes * self.expansion))]))
|
45 |
+
|
46 |
+
def forward(self, x: torch.Tensor):
|
47 |
+
identity = x
|
48 |
+
|
49 |
+
out = self.relu1(self.bn1(self.conv1(x)))
|
50 |
+
out = self.relu2(self.bn2(self.conv2(out)))
|
51 |
+
out = self.avgpool(out)
|
52 |
+
out = self.bn3(self.conv3(out))
|
53 |
+
|
54 |
+
if self.downsample is not None:
|
55 |
+
identity = self.downsample(x)
|
56 |
+
|
57 |
+
out += identity
|
58 |
+
out = self.relu3(out)
|
59 |
+
return out
|
60 |
+
|
61 |
+
|
62 |
+
class AttentionPool2d(nn.Module):
|
63 |
+
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
|
64 |
+
super().__init__()
|
65 |
+
self.positional_embedding = nn.Parameter(torch.randn(spacial_dim**2 + 1, embed_dim) / embed_dim**0.5)
|
66 |
+
self.k_proj = nn.Linear(embed_dim, embed_dim)
|
67 |
+
self.q_proj = nn.Linear(embed_dim, embed_dim)
|
68 |
+
self.v_proj = nn.Linear(embed_dim, embed_dim)
|
69 |
+
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
|
70 |
+
self.num_heads = num_heads
|
71 |
+
|
72 |
+
def forward(self, x):
|
73 |
+
x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
|
74 |
+
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
|
75 |
+
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
|
76 |
+
x, _ = F.multi_head_attention_forward(query=x,
|
77 |
+
key=x,
|
78 |
+
value=x,
|
79 |
+
embed_dim_to_check=x.shape[-1],
|
80 |
+
num_heads=self.num_heads,
|
81 |
+
q_proj_weight=self.q_proj.weight,
|
82 |
+
k_proj_weight=self.k_proj.weight,
|
83 |
+
v_proj_weight=self.v_proj.weight,
|
84 |
+
in_proj_weight=None,
|
85 |
+
in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
|
86 |
+
bias_k=None,
|
87 |
+
bias_v=None,
|
88 |
+
add_zero_attn=False,
|
89 |
+
dropout_p=0,
|
90 |
+
out_proj_weight=self.c_proj.weight,
|
91 |
+
out_proj_bias=self.c_proj.bias,
|
92 |
+
use_separate_proj_weight=True,
|
93 |
+
training=self.training,
|
94 |
+
need_weights=False)
|
95 |
+
|
96 |
+
return x[0]
|
97 |
+
|
98 |
+
|
99 |
+
class LayerNorm(nn.LayerNorm):
|
100 |
+
"""Subclass torch's LayerNorm to handle fp16."""
|
101 |
+
def forward(self, x: torch.Tensor):
|
102 |
+
orig_type = x.dtype
|
103 |
+
ret = super().forward(x.type(torch.float32))
|
104 |
+
return ret.type(orig_type)
|
105 |
+
|
106 |
+
|
107 |
+
class QuickGELU(nn.Module):
|
108 |
+
def forward(self, x: torch.Tensor):
|
109 |
+
return x * torch.sigmoid(1.702 * x)
|
110 |
+
|
111 |
+
|
112 |
+
class ResidualAttentionBlock(nn.Module):
|
113 |
+
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None, use_grad_checkpointing=False):
|
114 |
+
super().__init__()
|
115 |
+
|
116 |
+
self.attn = nn.MultiheadAttention(d_model, n_head)
|
117 |
+
self.ln_1 = LayerNorm(d_model)
|
118 |
+
self.mlp = nn.Sequential(
|
119 |
+
OrderedDict([("c_fc", nn.Linear(d_model, d_model * 4)), ("gelu", QuickGELU()),
|
120 |
+
("c_proj", nn.Linear(d_model * 4, d_model))]))
|
121 |
+
self.ln_2 = LayerNorm(d_model)
|
122 |
+
self.attn_mask = attn_mask
|
123 |
+
|
124 |
+
# if use_grad_checkpointing:
|
125 |
+
# self.attn = checkpoint_wrapper(self.attn)
|
126 |
+
# self.mlp = checkpoint_wrapper(self.mlp)
|
127 |
+
# raise NotImplementedError
|
128 |
+
|
129 |
+
def attention(self, x: torch.Tensor):
|
130 |
+
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
|
131 |
+
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
|
132 |
+
|
133 |
+
def forward(self, x: torch.Tensor):
|
134 |
+
x = x + self.attention(self.ln_1(x))
|
135 |
+
x = x + self.mlp(self.ln_2(x))
|
136 |
+
return x
|
137 |
+
|
138 |
+
|
139 |
+
class Transformer(nn.Module):
|
140 |
+
def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None, use_grad_checkpointing=False):
|
141 |
+
super().__init__()
|
142 |
+
self.width = width
|
143 |
+
self.layers = layers
|
144 |
+
self.resblocks = nn.Sequential(
|
145 |
+
*[ResidualAttentionBlock(width, heads, attn_mask, use_grad_checkpointing and i > 12) for i in range(layers)])
|
146 |
+
|
147 |
+
def forward(self, x: torch.Tensor):
|
148 |
+
return self.resblocks(x)
|
149 |
+
|
150 |
+
|
151 |
+
class VisionTransformer(nn.Module):
|
152 |
+
def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int,
|
153 |
+
use_grad_checkpointing: bool):
|
154 |
+
super().__init__()
|
155 |
+
self.input_resolution = input_resolution
|
156 |
+
self.num_features = width
|
157 |
+
self.num_heads = heads
|
158 |
+
self.num_patches = (input_resolution // patch_size)**2
|
159 |
+
self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
|
160 |
+
|
161 |
+
scale = width**-0.5
|
162 |
+
self.class_embedding = nn.Parameter(scale * torch.randn(width))
|
163 |
+
self.positional_embedding = nn.Parameter(scale * torch.randn(self.num_patches + 1, width))
|
164 |
+
self.ln_pre = LayerNorm(width)
|
165 |
+
|
166 |
+
self.transformer = Transformer(width, layers, heads, use_grad_checkpointing=use_grad_checkpointing)
|
167 |
+
|
168 |
+
# self.ln_final = LayerNorm(width)
|
169 |
+
|
170 |
+
def forward(self, x: torch.Tensor):
|
171 |
+
|
172 |
+
x = self.conv1(x) # shape = [*, width, grid, grid]
|
173 |
+
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
|
174 |
+
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
|
175 |
+
x = torch.cat(
|
176 |
+
[self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x],
|
177 |
+
dim=1) # shape = [*, grid ** 2 + 1, width]
|
178 |
+
x = x + self.positional_embedding.to(x.dtype)
|
179 |
+
x = self.ln_pre(x)
|
180 |
+
|
181 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
182 |
+
x = self.transformer(x)
|
183 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
184 |
+
|
185 |
+
# x = self.ln_final(x)
|
186 |
+
return x
|
187 |
+
|
188 |
+
|
189 |
+
# From PyTorch internals
|
190 |
+
def _ntuple(n):
|
191 |
+
def parse(x):
|
192 |
+
if isinstance(x, collections.abc.Iterable):
|
193 |
+
return x
|
194 |
+
return tuple(repeat(x, n))
|
195 |
+
|
196 |
+
return parse
|
197 |
+
|
198 |
+
|
199 |
+
to_2tuple = _ntuple(2)
|
200 |
+
|
201 |
+
|
202 |
+
def interpolate_pos_embed(model, state_dict, interpolation: str = 'bicubic', seq_dim=1):
|
203 |
+
# Rescale the grid of position embeddings when loading from state_dict
|
204 |
+
old_pos_embed = state_dict.get('positional_embedding', None)
|
205 |
+
|
206 |
+
grid_size = round((model.positional_embedding.shape[0] - 1)**0.5)
|
207 |
+
if old_pos_embed is None:
|
208 |
+
return
|
209 |
+
grid_size = to_2tuple(grid_size)
|
210 |
+
extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more)
|
211 |
+
new_seq_len = grid_size[0] * grid_size[1] + extra_tokens
|
212 |
+
if new_seq_len == old_pos_embed.shape[0]:
|
213 |
+
return
|
214 |
+
|
215 |
+
if extra_tokens:
|
216 |
+
pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:]
|
217 |
+
else:
|
218 |
+
pos_emb_tok, pos_emb_img = None, old_pos_embed
|
219 |
+
|
220 |
+
old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img))))
|
221 |
+
|
222 |
+
print('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size)
|
223 |
+
pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2)
|
224 |
+
pos_emb_img = F.interpolate(
|
225 |
+
pos_emb_img,
|
226 |
+
size=grid_size,
|
227 |
+
mode=interpolation,
|
228 |
+
align_corners=True,
|
229 |
+
)
|
230 |
+
pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0]
|
231 |
+
if pos_emb_tok is not None:
|
232 |
+
new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0)
|
233 |
+
else:
|
234 |
+
new_pos_embed = pos_emb_img
|
235 |
+
state_dict['positional_embedding'] = new_pos_embed
|
236 |
+
|
237 |
+
|
238 |
+
def create_clip_vit_L(img_size=224, use_checkpoint=False, precision="fp16"):
|
239 |
+
model = VisionTransformer(
|
240 |
+
input_resolution=img_size,
|
241 |
+
patch_size=14,
|
242 |
+
width=1024,
|
243 |
+
layers=23,
|
244 |
+
heads=16,
|
245 |
+
use_grad_checkpointing=use_checkpoint,
|
246 |
+
)
|
247 |
+
url = "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/clip_vit_L.pth"
|
248 |
+
cached_file = download_cached_file(url, check_hash=False, progress=True)
|
249 |
+
state_dict = torch.load(cached_file, map_location="cpu")
|
250 |
+
interpolate_pos_embed(model, state_dict)
|
251 |
+
|
252 |
+
incompatible_keys = model.load_state_dict(state_dict, strict=False)
|
253 |
+
# print(incompatible_keys)
|
254 |
+
|
255 |
+
if precision == "fp16":
|
256 |
+
convert_weights_to_fp16(model)
|
257 |
+
return model
|
models/seed_qformer/eva_vit.py
ADDED
@@ -0,0 +1,486 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Based on EVA, BEIT, timm and DeiT code bases
|
2 |
+
# https://github.com/baaivision/EVA
|
3 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm
|
4 |
+
# https://github.com/microsoft/unilm/tree/master/beit
|
5 |
+
# https://github.com/facebookresearch/deit/
|
6 |
+
# https://github.com/facebookresearch/dino
|
7 |
+
# --------------------------------------------------------'
|
8 |
+
import math
|
9 |
+
from functools import partial
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
import torch.nn.functional as F
|
14 |
+
import torch.utils.checkpoint as checkpoint
|
15 |
+
from timm.models.layers import drop_path, to_2tuple, trunc_normal_
|
16 |
+
|
17 |
+
|
18 |
+
from .utils import download_cached_file
|
19 |
+
|
20 |
+
|
21 |
+
def _cfg(url='', **kwargs):
|
22 |
+
return {
|
23 |
+
'url': url,
|
24 |
+
'num_classes': 1000,
|
25 |
+
'input_size': (3, 224, 224),
|
26 |
+
'pool_size': None,
|
27 |
+
'crop_pct': .9,
|
28 |
+
'interpolation': 'bicubic',
|
29 |
+
'mean': (0.5, 0.5, 0.5),
|
30 |
+
'std': (0.5, 0.5, 0.5),
|
31 |
+
**kwargs
|
32 |
+
}
|
33 |
+
|
34 |
+
|
35 |
+
class DropPath(nn.Module):
|
36 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
37 |
+
"""
|
38 |
+
def __init__(self, drop_prob=None):
|
39 |
+
super(DropPath, self).__init__()
|
40 |
+
self.drop_prob = drop_prob
|
41 |
+
|
42 |
+
def forward(self, x):
|
43 |
+
return drop_path(x, self.drop_prob, self.training)
|
44 |
+
|
45 |
+
def extra_repr(self) -> str:
|
46 |
+
return 'p={}'.format(self.drop_prob)
|
47 |
+
|
48 |
+
|
49 |
+
class Mlp(nn.Module):
|
50 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
51 |
+
super().__init__()
|
52 |
+
out_features = out_features or in_features
|
53 |
+
hidden_features = hidden_features or in_features
|
54 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
55 |
+
self.act = act_layer()
|
56 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
57 |
+
self.drop = nn.Dropout(drop)
|
58 |
+
|
59 |
+
def forward(self, x):
|
60 |
+
x = self.fc1(x)
|
61 |
+
x = self.act(x)
|
62 |
+
# x = self.drop(x)
|
63 |
+
# commit this for the orignal BERT implement
|
64 |
+
x = self.fc2(x)
|
65 |
+
x = self.drop(x)
|
66 |
+
return x
|
67 |
+
|
68 |
+
|
69 |
+
class Attention(nn.Module):
|
70 |
+
def __init__(self,
|
71 |
+
dim,
|
72 |
+
num_heads=8,
|
73 |
+
qkv_bias=False,
|
74 |
+
qk_scale=None,
|
75 |
+
attn_drop=0.,
|
76 |
+
proj_drop=0.,
|
77 |
+
window_size=None,
|
78 |
+
attn_head_dim=None):
|
79 |
+
super().__init__()
|
80 |
+
self.num_heads = num_heads
|
81 |
+
head_dim = dim // num_heads
|
82 |
+
if attn_head_dim is not None:
|
83 |
+
head_dim = attn_head_dim
|
84 |
+
all_head_dim = head_dim * self.num_heads
|
85 |
+
self.scale = qk_scale or head_dim**-0.5
|
86 |
+
|
87 |
+
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
|
88 |
+
if qkv_bias:
|
89 |
+
self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
|
90 |
+
self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
|
91 |
+
else:
|
92 |
+
self.q_bias = None
|
93 |
+
self.v_bias = None
|
94 |
+
|
95 |
+
if window_size:
|
96 |
+
self.window_size = window_size
|
97 |
+
self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
|
98 |
+
self.relative_position_bias_table = nn.Parameter(torch.zeros(self.num_relative_distance,
|
99 |
+
num_heads)) # 2*Wh-1 * 2*Ww-1, nH
|
100 |
+
# cls to token & token 2 cls & cls to cls
|
101 |
+
|
102 |
+
# get pair-wise relative position index for each token inside the window
|
103 |
+
coords_h = torch.arange(window_size[0])
|
104 |
+
coords_w = torch.arange(window_size[1])
|
105 |
+
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
106 |
+
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
107 |
+
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
108 |
+
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
109 |
+
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
|
110 |
+
relative_coords[:, :, 1] += window_size[1] - 1
|
111 |
+
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
|
112 |
+
relative_position_index = \
|
113 |
+
torch.zeros(size=(window_size[0] * window_size[1] + 1, ) * 2, dtype=relative_coords.dtype)
|
114 |
+
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
115 |
+
relative_position_index[0, 0:] = self.num_relative_distance - 3
|
116 |
+
relative_position_index[0:, 0] = self.num_relative_distance - 2
|
117 |
+
relative_position_index[0, 0] = self.num_relative_distance - 1
|
118 |
+
|
119 |
+
self.register_buffer("relative_position_index", relative_position_index)
|
120 |
+
else:
|
121 |
+
self.window_size = None
|
122 |
+
self.relative_position_bias_table = None
|
123 |
+
self.relative_position_index = None
|
124 |
+
|
125 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
126 |
+
self.proj = nn.Linear(all_head_dim, dim)
|
127 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
128 |
+
|
129 |
+
def forward(self, x, rel_pos_bias=None):
|
130 |
+
B, N, C = x.shape
|
131 |
+
qkv_bias = None
|
132 |
+
if self.q_bias is not None:
|
133 |
+
qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
|
134 |
+
# qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
135 |
+
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
|
136 |
+
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
137 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
138 |
+
|
139 |
+
q = q * self.scale
|
140 |
+
attn = (q @ k.transpose(-2, -1))
|
141 |
+
|
142 |
+
if self.relative_position_bias_table is not None:
|
143 |
+
relative_position_bias = \
|
144 |
+
self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
|
145 |
+
self.window_size[0] * self.window_size[1] + 1,
|
146 |
+
self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
|
147 |
+
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
148 |
+
attn = attn + relative_position_bias.unsqueeze(0)
|
149 |
+
|
150 |
+
if rel_pos_bias is not None:
|
151 |
+
attn = attn + rel_pos_bias
|
152 |
+
|
153 |
+
attn = attn.softmax(dim=-1)
|
154 |
+
attn = self.attn_drop(attn)
|
155 |
+
|
156 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
|
157 |
+
x = self.proj(x)
|
158 |
+
x = self.proj_drop(x)
|
159 |
+
return x
|
160 |
+
|
161 |
+
|
162 |
+
class Block(nn.Module):
|
163 |
+
def __init__(self,
|
164 |
+
dim,
|
165 |
+
num_heads,
|
166 |
+
mlp_ratio=4.,
|
167 |
+
qkv_bias=False,
|
168 |
+
qk_scale=None,
|
169 |
+
drop=0.,
|
170 |
+
attn_drop=0.,
|
171 |
+
drop_path=0.,
|
172 |
+
init_values=None,
|
173 |
+
act_layer=nn.GELU,
|
174 |
+
norm_layer=nn.LayerNorm,
|
175 |
+
window_size=None,
|
176 |
+
attn_head_dim=None):
|
177 |
+
super().__init__()
|
178 |
+
self.norm1 = norm_layer(dim)
|
179 |
+
self.attn = Attention(dim,
|
180 |
+
num_heads=num_heads,
|
181 |
+
qkv_bias=qkv_bias,
|
182 |
+
qk_scale=qk_scale,
|
183 |
+
attn_drop=attn_drop,
|
184 |
+
proj_drop=drop,
|
185 |
+
window_size=window_size,
|
186 |
+
attn_head_dim=attn_head_dim)
|
187 |
+
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
188 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
189 |
+
self.norm2 = norm_layer(dim)
|
190 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
191 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
192 |
+
|
193 |
+
if init_values is not None and init_values > 0:
|
194 |
+
self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
|
195 |
+
self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
|
196 |
+
else:
|
197 |
+
self.gamma_1, self.gamma_2 = None, None
|
198 |
+
|
199 |
+
def forward(self, x, rel_pos_bias=None):
|
200 |
+
if self.gamma_1 is None:
|
201 |
+
x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))
|
202 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
203 |
+
else:
|
204 |
+
x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))
|
205 |
+
x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
|
206 |
+
return x
|
207 |
+
|
208 |
+
|
209 |
+
class PatchEmbed(nn.Module):
|
210 |
+
""" Image to Patch Embedding
|
211 |
+
"""
|
212 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
|
213 |
+
super().__init__()
|
214 |
+
img_size = to_2tuple(img_size)
|
215 |
+
patch_size = to_2tuple(patch_size)
|
216 |
+
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
|
217 |
+
self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
|
218 |
+
self.img_size = img_size
|
219 |
+
self.patch_size = patch_size
|
220 |
+
self.num_patches = num_patches
|
221 |
+
|
222 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
223 |
+
|
224 |
+
def forward(self, x, **kwargs):
|
225 |
+
B, C, H, W = x.shape
|
226 |
+
# FIXME look at relaxing size constraints
|
227 |
+
assert H == self.img_size[0] and W == self.img_size[1], \
|
228 |
+
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
229 |
+
x = self.proj(x).flatten(2).transpose(1, 2)
|
230 |
+
return x
|
231 |
+
|
232 |
+
|
233 |
+
class RelativePositionBias(nn.Module):
|
234 |
+
def __init__(self, window_size, num_heads):
|
235 |
+
super().__init__()
|
236 |
+
self.window_size = window_size
|
237 |
+
self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
|
238 |
+
self.relative_position_bias_table = nn.Parameter(torch.zeros(self.num_relative_distance,
|
239 |
+
num_heads)) # 2*Wh-1 * 2*Ww-1, nH
|
240 |
+
# cls to token & token 2 cls & cls to cls
|
241 |
+
|
242 |
+
# get pair-wise relative position index for each token inside the window
|
243 |
+
coords_h = torch.arange(window_size[0])
|
244 |
+
coords_w = torch.arange(window_size[1])
|
245 |
+
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
246 |
+
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
247 |
+
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
248 |
+
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
249 |
+
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
|
250 |
+
relative_coords[:, :, 1] += window_size[1] - 1
|
251 |
+
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
|
252 |
+
relative_position_index = \
|
253 |
+
torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
|
254 |
+
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
255 |
+
relative_position_index[0, 0:] = self.num_relative_distance - 3
|
256 |
+
relative_position_index[0:, 0] = self.num_relative_distance - 2
|
257 |
+
relative_position_index[0, 0] = self.num_relative_distance - 1
|
258 |
+
|
259 |
+
self.register_buffer("relative_position_index", relative_position_index)
|
260 |
+
|
261 |
+
# trunc_normal_(self.relative_position_bias_table, std=.02)
|
262 |
+
|
263 |
+
def forward(self):
|
264 |
+
relative_position_bias = \
|
265 |
+
self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
|
266 |
+
self.window_size[0] * self.window_size[1] + 1,
|
267 |
+
self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
|
268 |
+
return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
269 |
+
|
270 |
+
|
271 |
+
class VisionTransformer(nn.Module):
|
272 |
+
""" Vision Transformer with support for patch or hybrid CNN input stage
|
273 |
+
"""
|
274 |
+
def __init__(self,
|
275 |
+
img_size=224,
|
276 |
+
patch_size=16,
|
277 |
+
in_chans=3,
|
278 |
+
num_classes=1000,
|
279 |
+
embed_dim=768,
|
280 |
+
depth=12,
|
281 |
+
num_heads=12,
|
282 |
+
mlp_ratio=4.,
|
283 |
+
qkv_bias=False,
|
284 |
+
qk_scale=None,
|
285 |
+
drop_rate=0.,
|
286 |
+
attn_drop_rate=0.,
|
287 |
+
drop_path_rate=0.,
|
288 |
+
norm_layer=nn.LayerNorm,
|
289 |
+
init_values=None,
|
290 |
+
use_abs_pos_emb=True,
|
291 |
+
use_rel_pos_bias=False,
|
292 |
+
use_shared_rel_pos_bias=False,
|
293 |
+
use_mean_pooling=True,
|
294 |
+
init_scale=0.001,
|
295 |
+
use_checkpoint=False):
|
296 |
+
super().__init__()
|
297 |
+
self.image_size = img_size
|
298 |
+
self.num_classes = num_classes
|
299 |
+
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
300 |
+
|
301 |
+
self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
302 |
+
num_patches = self.patch_embed.num_patches
|
303 |
+
|
304 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
305 |
+
if use_abs_pos_emb:
|
306 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
|
307 |
+
else:
|
308 |
+
self.pos_embed = None
|
309 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
310 |
+
|
311 |
+
if use_shared_rel_pos_bias:
|
312 |
+
self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.patch_shape, num_heads=num_heads)
|
313 |
+
else:
|
314 |
+
self.rel_pos_bias = None
|
315 |
+
self.use_checkpoint = use_checkpoint
|
316 |
+
|
317 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
318 |
+
self.use_rel_pos_bias = use_rel_pos_bias
|
319 |
+
self.blocks = nn.ModuleList([
|
320 |
+
Block(dim=embed_dim,
|
321 |
+
num_heads=num_heads,
|
322 |
+
mlp_ratio=mlp_ratio,
|
323 |
+
qkv_bias=qkv_bias,
|
324 |
+
qk_scale=qk_scale,
|
325 |
+
drop=drop_rate,
|
326 |
+
attn_drop=attn_drop_rate,
|
327 |
+
drop_path=dpr[i],
|
328 |
+
norm_layer=norm_layer,
|
329 |
+
init_values=init_values,
|
330 |
+
window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None) for i in range(depth)
|
331 |
+
])
|
332 |
+
# self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim)
|
333 |
+
# self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None
|
334 |
+
# self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
335 |
+
|
336 |
+
if self.pos_embed is not None:
|
337 |
+
trunc_normal_(self.pos_embed, std=.02)
|
338 |
+
trunc_normal_(self.cls_token, std=.02)
|
339 |
+
# trunc_normal_(self.mask_token, std=.02)
|
340 |
+
# if isinstance(self.head, nn.Linear):
|
341 |
+
# trunc_normal_(self.head.weight, std=.02)
|
342 |
+
self.apply(self._init_weights)
|
343 |
+
self.fix_init_weight()
|
344 |
+
|
345 |
+
def fix_init_weight(self):
|
346 |
+
def rescale(param, layer_id):
|
347 |
+
param.div_(math.sqrt(2.0 * layer_id))
|
348 |
+
|
349 |
+
for layer_id, layer in enumerate(self.blocks):
|
350 |
+
rescale(layer.attn.proj.weight.data, layer_id + 1)
|
351 |
+
rescale(layer.mlp.fc2.weight.data, layer_id + 1)
|
352 |
+
|
353 |
+
def _init_weights(self, m):
|
354 |
+
if isinstance(m, nn.Linear):
|
355 |
+
trunc_normal_(m.weight, std=.02)
|
356 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
357 |
+
nn.init.constant_(m.bias, 0)
|
358 |
+
elif isinstance(m, nn.LayerNorm):
|
359 |
+
nn.init.constant_(m.bias, 0)
|
360 |
+
nn.init.constant_(m.weight, 1.0)
|
361 |
+
|
362 |
+
def get_classifier(self):
|
363 |
+
return self.head
|
364 |
+
|
365 |
+
def reset_classifier(self, num_classes, global_pool=''):
|
366 |
+
self.num_classes = num_classes
|
367 |
+
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
368 |
+
|
369 |
+
def forward_features(self, x):
|
370 |
+
x = self.patch_embed(x)
|
371 |
+
batch_size, seq_len, _ = x.size()
|
372 |
+
|
373 |
+
cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
|
374 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
375 |
+
if self.pos_embed is not None:
|
376 |
+
x = x + self.pos_embed
|
377 |
+
x = self.pos_drop(x)
|
378 |
+
|
379 |
+
rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
|
380 |
+
for blk in self.blocks:
|
381 |
+
if self.use_checkpoint:
|
382 |
+
x = checkpoint.checkpoint(blk, x, rel_pos_bias)
|
383 |
+
else:
|
384 |
+
x = blk(x, rel_pos_bias)
|
385 |
+
return x
|
386 |
+
|
387 |
+
def forward(self, x):
|
388 |
+
x = self.forward_features(x)
|
389 |
+
# x = self.head(x)
|
390 |
+
return x
|
391 |
+
|
392 |
+
def get_intermediate_layers(self, x):
|
393 |
+
x = self.patch_embed(x)
|
394 |
+
batch_size, seq_len, _ = x.size()
|
395 |
+
|
396 |
+
cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
|
397 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
398 |
+
if self.pos_embed is not None:
|
399 |
+
x = x + self.pos_embed
|
400 |
+
x = self.pos_drop(x)
|
401 |
+
|
402 |
+
features = []
|
403 |
+
rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
|
404 |
+
for blk in self.blocks:
|
405 |
+
x = blk(x, rel_pos_bias)
|
406 |
+
features.append(x)
|
407 |
+
|
408 |
+
return features
|
409 |
+
|
410 |
+
def get_num_layer(self, var_name=""):
|
411 |
+
if var_name in ("cls_token", "mask_token", "pos_embed"):
|
412 |
+
return 0
|
413 |
+
elif var_name.startswith("patch_embed"):
|
414 |
+
return 0
|
415 |
+
elif var_name.startswith("rel_pos_bias"):
|
416 |
+
return len(self.blocks) - 1
|
417 |
+
elif var_name.startswith("blocks"):
|
418 |
+
layer_id = int(var_name.split('.')[1])
|
419 |
+
return layer_id + 1
|
420 |
+
else:
|
421 |
+
return len(self.blocks)
|
422 |
+
|
423 |
+
|
424 |
+
def interpolate_pos_embed(model, checkpoint_model):
|
425 |
+
if 'pos_embed' in checkpoint_model:
|
426 |
+
pos_embed_checkpoint = checkpoint_model['pos_embed'].float()
|
427 |
+
embedding_size = pos_embed_checkpoint.shape[-1]
|
428 |
+
num_patches = model.patch_embed.num_patches
|
429 |
+
num_extra_tokens = model.pos_embed.shape[-2] - num_patches
|
430 |
+
# height (== width) for the checkpoint position embedding
|
431 |
+
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens)**0.5)
|
432 |
+
# height (== width) for the new position embedding
|
433 |
+
new_size = int(num_patches**0.5)
|
434 |
+
# class_token and dist_token are kept unchanged
|
435 |
+
if orig_size != new_size:
|
436 |
+
print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
|
437 |
+
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
|
438 |
+
# only the position tokens are interpolated
|
439 |
+
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
|
440 |
+
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
|
441 |
+
pos_tokens = torch.nn.functional.interpolate(pos_tokens,
|
442 |
+
size=(new_size, new_size),
|
443 |
+
mode='bicubic',
|
444 |
+
align_corners=False)
|
445 |
+
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
|
446 |
+
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
|
447 |
+
checkpoint_model['pos_embed'] = new_pos_embed
|
448 |
+
|
449 |
+
|
450 |
+
def convert_weights_to_fp16(model: nn.Module):
|
451 |
+
"""Convert applicable model parameters to fp16"""
|
452 |
+
def _convert_weights_to_fp16(l):
|
453 |
+
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
|
454 |
+
l.weight.data = l.weight.data.half()
|
455 |
+
if l.bias is not None:
|
456 |
+
l.bias.data = l.bias.data.half()
|
457 |
+
|
458 |
+
model.apply(_convert_weights_to_fp16)
|
459 |
+
|
460 |
+
|
461 |
+
def create_eva_vit_g(img_size=224, drop_path_rate=0.4, use_checkpoint=False, precision="fp16"):
|
462 |
+
model = VisionTransformer(
|
463 |
+
img_size=img_size,
|
464 |
+
patch_size=14,
|
465 |
+
use_mean_pooling=False,
|
466 |
+
embed_dim=1408,
|
467 |
+
depth=39,
|
468 |
+
num_heads=1408 // 88,
|
469 |
+
mlp_ratio=4.3637,
|
470 |
+
qkv_bias=True,
|
471 |
+
drop_path_rate=drop_path_rate,
|
472 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
473 |
+
use_checkpoint=use_checkpoint,
|
474 |
+
)
|
475 |
+
url = "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/eva_vit_g.pth"
|
476 |
+
cached_file = download_cached_file(url, check_hash=False, progress=True)
|
477 |
+
state_dict = torch.load(cached_file, map_location="cpu")
|
478 |
+
interpolate_pos_embed(model, state_dict)
|
479 |
+
|
480 |
+
incompatible_keys = model.load_state_dict(state_dict, strict=False)
|
481 |
+
# print(incompatible_keys)
|
482 |
+
|
483 |
+
if precision == "fp16":
|
484 |
+
# model.to("cuda")
|
485 |
+
convert_weights_to_fp16(model)
|
486 |
+
return model
|
models/seed_qformer/qformer_causual.py
ADDED
@@ -0,0 +1,1169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
* Copyright (c) 2023, salesforce.com, inc.
|
3 |
+
* All rights reserved.
|
4 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
5 |
+
* For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
+
* By Junnan Li
|
7 |
+
* Based on huggingface code base
|
8 |
+
* https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert
|
9 |
+
"""
|
10 |
+
|
11 |
+
import math
|
12 |
+
import os
|
13 |
+
import warnings
|
14 |
+
from dataclasses import dataclass
|
15 |
+
from typing import Optional, Tuple, Dict, Any
|
16 |
+
|
17 |
+
import torch
|
18 |
+
from torch import Tensor, device, dtype, nn
|
19 |
+
import torch.utils.checkpoint
|
20 |
+
from torch.nn import CrossEntropyLoss
|
21 |
+
import torch.nn.functional as F
|
22 |
+
import numpy as np
|
23 |
+
|
24 |
+
from transformers.activations import ACT2FN
|
25 |
+
from transformers.file_utils import (
|
26 |
+
ModelOutput, )
|
27 |
+
from transformers.modeling_outputs import (
|
28 |
+
BaseModelOutputWithPastAndCrossAttentions,
|
29 |
+
BaseModelOutputWithPoolingAndCrossAttentions,
|
30 |
+
CausalLMOutputWithCrossAttentions,
|
31 |
+
MaskedLMOutput,
|
32 |
+
MultipleChoiceModelOutput,
|
33 |
+
NextSentencePredictorOutput,
|
34 |
+
QuestionAnsweringModelOutput,
|
35 |
+
SequenceClassifierOutput,
|
36 |
+
TokenClassifierOutput,
|
37 |
+
)
|
38 |
+
from transformers.modeling_utils import (
|
39 |
+
PreTrainedModel,
|
40 |
+
apply_chunking_to_forward,
|
41 |
+
find_pruneable_heads_and_indices,
|
42 |
+
prune_linear_layer,
|
43 |
+
)
|
44 |
+
from transformers.utils import logging
|
45 |
+
from transformers.models.bert.configuration_bert import BertConfig
|
46 |
+
|
47 |
+
#torch.set_printoptions(profile="full")
|
48 |
+
logger = logging.get_logger(__name__)
|
49 |
+
|
50 |
+
|
51 |
+
class BertEmbeddings(nn.Module):
|
52 |
+
"""Construct the embeddings from word and position embeddings."""
|
53 |
+
def __init__(self, config):
|
54 |
+
super().__init__()
|
55 |
+
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
|
56 |
+
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
|
57 |
+
|
58 |
+
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
|
59 |
+
# any TensorFlow checkpoint file
|
60 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
61 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
62 |
+
|
63 |
+
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
64 |
+
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
|
65 |
+
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
66 |
+
|
67 |
+
self.config = config
|
68 |
+
|
69 |
+
def forward(
|
70 |
+
self,
|
71 |
+
input_ids=None,
|
72 |
+
position_ids=None,
|
73 |
+
query_embeds=None,
|
74 |
+
past_key_values_length=0,
|
75 |
+
):
|
76 |
+
if input_ids is not None:
|
77 |
+
seq_length = input_ids.size()[1]
|
78 |
+
else:
|
79 |
+
seq_length = 0
|
80 |
+
|
81 |
+
if position_ids is None:
|
82 |
+
position_ids = self.position_ids[:, past_key_values_length:seq_length + past_key_values_length].clone()
|
83 |
+
|
84 |
+
if input_ids is not None:
|
85 |
+
embeddings = self.word_embeddings(input_ids)
|
86 |
+
if self.position_embedding_type == "absolute":
|
87 |
+
position_embeddings = self.position_embeddings(position_ids)
|
88 |
+
embeddings = embeddings + position_embeddings
|
89 |
+
|
90 |
+
if query_embeds is not None:
|
91 |
+
embeddings = torch.cat((query_embeds, embeddings), dim=1)
|
92 |
+
#print(query_embeds.shape, embeddings.shape)
|
93 |
+
else:
|
94 |
+
embeddings = query_embeds
|
95 |
+
|
96 |
+
embeddings = self.LayerNorm(embeddings)
|
97 |
+
embeddings = self.dropout(embeddings)
|
98 |
+
return embeddings
|
99 |
+
|
100 |
+
|
101 |
+
class BertSelfAttention(nn.Module):
|
102 |
+
def __init__(self, config, is_cross_attention):
|
103 |
+
super().__init__()
|
104 |
+
self.config = config
|
105 |
+
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
106 |
+
raise ValueError("The hidden size (%d) is not a multiple of the number of attention "
|
107 |
+
"heads (%d)" % (config.hidden_size, config.num_attention_heads))
|
108 |
+
|
109 |
+
self.num_attention_heads = config.num_attention_heads
|
110 |
+
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
111 |
+
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
112 |
+
|
113 |
+
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
114 |
+
if is_cross_attention:
|
115 |
+
self.key = nn.Linear(config.encoder_width, self.all_head_size)
|
116 |
+
self.value = nn.Linear(config.encoder_width, self.all_head_size)
|
117 |
+
else:
|
118 |
+
self.key = nn.Linear(config.hidden_size, self.all_head_size)
|
119 |
+
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
120 |
+
|
121 |
+
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
122 |
+
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
123 |
+
if (self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query"):
|
124 |
+
self.max_position_embeddings = config.max_position_embeddings
|
125 |
+
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
|
126 |
+
self.save_attention = False
|
127 |
+
|
128 |
+
def save_attn_gradients(self, attn_gradients):
|
129 |
+
self.attn_gradients = attn_gradients
|
130 |
+
|
131 |
+
def get_attn_gradients(self):
|
132 |
+
return self.attn_gradients
|
133 |
+
|
134 |
+
def save_attention_map(self, attention_map):
|
135 |
+
self.attention_map = attention_map
|
136 |
+
|
137 |
+
def get_attention_map(self):
|
138 |
+
return self.attention_map
|
139 |
+
|
140 |
+
def transpose_for_scores(self, x):
|
141 |
+
new_x_shape = x.size()[:-1] + (
|
142 |
+
self.num_attention_heads,
|
143 |
+
self.attention_head_size,
|
144 |
+
)
|
145 |
+
x = x.view(*new_x_shape)
|
146 |
+
return x.permute(0, 2, 1, 3)
|
147 |
+
|
148 |
+
def forward(
|
149 |
+
self,
|
150 |
+
hidden_states,
|
151 |
+
attention_mask=None,
|
152 |
+
head_mask=None,
|
153 |
+
encoder_hidden_states=None,
|
154 |
+
encoder_attention_mask=None,
|
155 |
+
past_key_value=None,
|
156 |
+
output_attentions=False,
|
157 |
+
):
|
158 |
+
|
159 |
+
# If this is instantiated as a cross-attention module, the keys
|
160 |
+
# and values come from an encoder; the attention mask needs to be
|
161 |
+
# such that the encoder's padding tokens are not attended to.
|
162 |
+
is_cross_attention = encoder_hidden_states is not None
|
163 |
+
|
164 |
+
if is_cross_attention:
|
165 |
+
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
|
166 |
+
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
|
167 |
+
#print(key_layer.shape, value_layer.shape)
|
168 |
+
attention_mask = encoder_attention_mask
|
169 |
+
elif past_key_value is not None:
|
170 |
+
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
171 |
+
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
172 |
+
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
|
173 |
+
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
|
174 |
+
#print(past_key_value[0].shape, key_layer.shape)
|
175 |
+
else:
|
176 |
+
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
177 |
+
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
178 |
+
|
179 |
+
mixed_query_layer = self.query(hidden_states)
|
180 |
+
|
181 |
+
query_layer = self.transpose_for_scores(mixed_query_layer)
|
182 |
+
# if past_key_value is not None:
|
183 |
+
# print(query_layer.shape)
|
184 |
+
|
185 |
+
past_key_value = (key_layer, value_layer)
|
186 |
+
#print(key_layer.shape, value_layer.shape)
|
187 |
+
|
188 |
+
# Take the dot product between "query" and "key" to get the raw attention scores.
|
189 |
+
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
190 |
+
#if is_cross_attention:
|
191 |
+
# if attention_scores.shape[2] == 32:
|
192 |
+
# attention_scores_save = attention_scores[0].detach().cpu().numpy()
|
193 |
+
# print(attention_scores_save.shape)
|
194 |
+
# np.save('attention_scores_causal_text_child.npy', attention_scores_save)
|
195 |
+
|
196 |
+
if (self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query"):
|
197 |
+
seq_length = hidden_states.size()[1]
|
198 |
+
position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
|
199 |
+
position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
|
200 |
+
distance = position_ids_l - position_ids_r
|
201 |
+
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
|
202 |
+
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
|
203 |
+
|
204 |
+
if self.position_embedding_type == "relative_key":
|
205 |
+
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
206 |
+
attention_scores = attention_scores + relative_position_scores
|
207 |
+
elif self.position_embedding_type == "relative_key_query":
|
208 |
+
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
209 |
+
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
|
210 |
+
attention_scores = (attention_scores + relative_position_scores_query + relative_position_scores_key)
|
211 |
+
|
212 |
+
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
213 |
+
if attention_mask is not None:
|
214 |
+
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
|
215 |
+
attention_scores = attention_scores + attention_mask
|
216 |
+
|
217 |
+
# Normalize the attention scores to probabilities.
|
218 |
+
attention_probs = nn.Softmax(dim=-1)(attention_scores)
|
219 |
+
|
220 |
+
if is_cross_attention and self.save_attention:
|
221 |
+
self.save_attention_map(attention_probs)
|
222 |
+
attention_probs.register_hook(self.save_attn_gradients)
|
223 |
+
|
224 |
+
# This is actually dropping out entire tokens to attend to, which might
|
225 |
+
# seem a bit unusual, but is taken from the original Transformer paper.
|
226 |
+
attention_probs_dropped = self.dropout(attention_probs)
|
227 |
+
|
228 |
+
# Mask heads if we want to
|
229 |
+
if head_mask is not None:
|
230 |
+
attention_probs_dropped = attention_probs_dropped * head_mask
|
231 |
+
|
232 |
+
context_layer = torch.matmul(attention_probs_dropped, value_layer)
|
233 |
+
|
234 |
+
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
235 |
+
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size, )
|
236 |
+
context_layer = context_layer.view(*new_context_layer_shape)
|
237 |
+
|
238 |
+
outputs = ((context_layer, attention_probs) if output_attentions else (context_layer, ))
|
239 |
+
|
240 |
+
outputs = outputs + (past_key_value, )
|
241 |
+
return outputs
|
242 |
+
|
243 |
+
|
244 |
+
class BertSelfOutput(nn.Module):
|
245 |
+
def __init__(self, config):
|
246 |
+
super().__init__()
|
247 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
248 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
249 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
250 |
+
|
251 |
+
def forward(self, hidden_states, input_tensor):
|
252 |
+
hidden_states = self.dense(hidden_states)
|
253 |
+
hidden_states = self.dropout(hidden_states)
|
254 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
255 |
+
return hidden_states
|
256 |
+
|
257 |
+
|
258 |
+
class BertAttention(nn.Module):
|
259 |
+
def __init__(self, config, is_cross_attention=False):
|
260 |
+
super().__init__()
|
261 |
+
self.self = BertSelfAttention(config, is_cross_attention)
|
262 |
+
self.output = BertSelfOutput(config)
|
263 |
+
self.pruned_heads = set()
|
264 |
+
|
265 |
+
def prune_heads(self, heads):
|
266 |
+
if len(heads) == 0:
|
267 |
+
return
|
268 |
+
heads, index = find_pruneable_heads_and_indices(
|
269 |
+
heads,
|
270 |
+
self.self.num_attention_heads,
|
271 |
+
self.self.attention_head_size,
|
272 |
+
self.pruned_heads,
|
273 |
+
)
|
274 |
+
|
275 |
+
# Prune linear layers
|
276 |
+
self.self.query = prune_linear_layer(self.self.query, index)
|
277 |
+
self.self.key = prune_linear_layer(self.self.key, index)
|
278 |
+
self.self.value = prune_linear_layer(self.self.value, index)
|
279 |
+
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
|
280 |
+
|
281 |
+
# Update hyper params and store pruned heads
|
282 |
+
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
|
283 |
+
self.self.all_head_size = (self.self.attention_head_size * self.self.num_attention_heads)
|
284 |
+
self.pruned_heads = self.pruned_heads.union(heads)
|
285 |
+
|
286 |
+
def forward(
|
287 |
+
self,
|
288 |
+
hidden_states,
|
289 |
+
attention_mask=None,
|
290 |
+
head_mask=None,
|
291 |
+
encoder_hidden_states=None,
|
292 |
+
encoder_attention_mask=None,
|
293 |
+
past_key_value=None,
|
294 |
+
output_attentions=False,
|
295 |
+
):
|
296 |
+
self_outputs = self.self(
|
297 |
+
hidden_states,
|
298 |
+
attention_mask,
|
299 |
+
head_mask,
|
300 |
+
encoder_hidden_states,
|
301 |
+
encoder_attention_mask,
|
302 |
+
past_key_value,
|
303 |
+
output_attentions,
|
304 |
+
)
|
305 |
+
attention_output = self.output(self_outputs[0], hidden_states)
|
306 |
+
|
307 |
+
outputs = (attention_output, ) + self_outputs[1:] # add attentions if we output them
|
308 |
+
return outputs
|
309 |
+
|
310 |
+
|
311 |
+
class BertIntermediate(nn.Module):
|
312 |
+
def __init__(self, config):
|
313 |
+
super().__init__()
|
314 |
+
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
315 |
+
if isinstance(config.hidden_act, str):
|
316 |
+
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
317 |
+
else:
|
318 |
+
self.intermediate_act_fn = config.hidden_act
|
319 |
+
|
320 |
+
def forward(self, hidden_states):
|
321 |
+
hidden_states = self.dense(hidden_states)
|
322 |
+
hidden_states = self.intermediate_act_fn(hidden_states)
|
323 |
+
return hidden_states
|
324 |
+
|
325 |
+
|
326 |
+
class BertOutput(nn.Module):
|
327 |
+
def __init__(self, config):
|
328 |
+
super().__init__()
|
329 |
+
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
330 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
331 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
332 |
+
|
333 |
+
def forward(self, hidden_states, input_tensor):
|
334 |
+
hidden_states = self.dense(hidden_states)
|
335 |
+
hidden_states = self.dropout(hidden_states)
|
336 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
337 |
+
return hidden_states
|
338 |
+
|
339 |
+
|
340 |
+
class BertLayer(nn.Module):
|
341 |
+
def __init__(self, config, layer_num):
|
342 |
+
super().__init__()
|
343 |
+
self.config = config
|
344 |
+
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
345 |
+
self.seq_len_dim = 1
|
346 |
+
self.attention = BertAttention(config)
|
347 |
+
self.layer_num = layer_num
|
348 |
+
if (self.config.add_cross_attention and layer_num % self.config.cross_attention_freq == 0):
|
349 |
+
self.crossattention = BertAttention(config, is_cross_attention=self.config.add_cross_attention)
|
350 |
+
self.has_cross_attention = True
|
351 |
+
else:
|
352 |
+
self.has_cross_attention = False
|
353 |
+
self.intermediate = BertIntermediate(config)
|
354 |
+
self.output = BertOutput(config)
|
355 |
+
|
356 |
+
self.intermediate_query = BertIntermediate(config)
|
357 |
+
self.output_query = BertOutput(config)
|
358 |
+
|
359 |
+
def forward(
|
360 |
+
self,
|
361 |
+
hidden_states,
|
362 |
+
attention_mask=None,
|
363 |
+
head_mask=None,
|
364 |
+
encoder_hidden_states=None,
|
365 |
+
encoder_attention_mask=None,
|
366 |
+
past_key_value=None,
|
367 |
+
output_attentions=False,
|
368 |
+
query_length=0,
|
369 |
+
):
|
370 |
+
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
371 |
+
self_attn_past_key_value = (past_key_value[:2] if past_key_value is not None else None)
|
372 |
+
# if past_key_value is not None:
|
373 |
+
# print(hidden_states.shape, attention_mask.shape)
|
374 |
+
#print(hidden_states.shape, attention_mask.shape)
|
375 |
+
# casual attention for query embeds with self attention
|
376 |
+
self_attention_outputs = self.attention(
|
377 |
+
hidden_states,
|
378 |
+
attention_mask,
|
379 |
+
head_mask,
|
380 |
+
output_attentions=output_attentions,
|
381 |
+
past_key_value=self_attn_past_key_value,
|
382 |
+
)
|
383 |
+
#print('attention_mask', attention_mask.shape)
|
384 |
+
# if attention_mask.shape[-1] == 77:
|
385 |
+
# print('attention_mask', attention_mask[0])
|
386 |
+
attention_output = self_attention_outputs[0]
|
387 |
+
outputs = self_attention_outputs[1:-1]
|
388 |
+
|
389 |
+
present_key_value = self_attention_outputs[-1]
|
390 |
+
#print(present_key_value[0].shape)
|
391 |
+
|
392 |
+
if query_length > 0:
|
393 |
+
query_attention_output = attention_output[:, :query_length, :]
|
394 |
+
|
395 |
+
if self.has_cross_attention:
|
396 |
+
assert (encoder_hidden_states is not None), "encoder_hidden_states must be given for cross-attention layers"
|
397 |
+
#print(attention_mask.shape)
|
398 |
+
cross_attention_outputs = self.crossattention(
|
399 |
+
query_attention_output,
|
400 |
+
attention_mask,
|
401 |
+
head_mask,
|
402 |
+
encoder_hidden_states,
|
403 |
+
encoder_attention_mask,
|
404 |
+
output_attentions=output_attentions,
|
405 |
+
)
|
406 |
+
query_attention_output = cross_attention_outputs[0]
|
407 |
+
outputs = (outputs + cross_attention_outputs[1:-1]) # add cross attentions if we output attention weights
|
408 |
+
|
409 |
+
layer_output = apply_chunking_to_forward(
|
410 |
+
self.feed_forward_chunk_query,
|
411 |
+
self.chunk_size_feed_forward,
|
412 |
+
self.seq_len_dim,
|
413 |
+
query_attention_output,
|
414 |
+
)
|
415 |
+
if attention_output.shape[1] > query_length:
|
416 |
+
layer_output_text = apply_chunking_to_forward(
|
417 |
+
self.feed_forward_chunk,
|
418 |
+
self.chunk_size_feed_forward,
|
419 |
+
self.seq_len_dim,
|
420 |
+
attention_output[:, query_length:, :],
|
421 |
+
)
|
422 |
+
layer_output = torch.cat([layer_output, layer_output_text], dim=1)
|
423 |
+
else:
|
424 |
+
layer_output = apply_chunking_to_forward(
|
425 |
+
self.feed_forward_chunk,
|
426 |
+
self.chunk_size_feed_forward,
|
427 |
+
self.seq_len_dim,
|
428 |
+
attention_output,
|
429 |
+
)
|
430 |
+
outputs = (layer_output, ) + outputs
|
431 |
+
|
432 |
+
outputs = outputs + (present_key_value, )
|
433 |
+
|
434 |
+
return outputs
|
435 |
+
|
436 |
+
def feed_forward_chunk(self, attention_output):
|
437 |
+
intermediate_output = self.intermediate(attention_output)
|
438 |
+
layer_output = self.output(intermediate_output, attention_output)
|
439 |
+
return layer_output
|
440 |
+
|
441 |
+
def feed_forward_chunk_query(self, attention_output):
|
442 |
+
intermediate_output = self.intermediate_query(attention_output)
|
443 |
+
layer_output = self.output_query(intermediate_output, attention_output)
|
444 |
+
return layer_output
|
445 |
+
|
446 |
+
|
447 |
+
class BertEncoder(nn.Module):
|
448 |
+
def __init__(self, config):
|
449 |
+
super().__init__()
|
450 |
+
self.config = config
|
451 |
+
self.layer = nn.ModuleList([BertLayer(config, i) for i in range(config.num_hidden_layers)])
|
452 |
+
|
453 |
+
def forward(
|
454 |
+
self,
|
455 |
+
hidden_states,
|
456 |
+
attention_mask=None,
|
457 |
+
head_mask=None,
|
458 |
+
encoder_hidden_states=None,
|
459 |
+
encoder_attention_mask=None,
|
460 |
+
past_key_values=None,
|
461 |
+
use_cache=None,
|
462 |
+
output_attentions=False,
|
463 |
+
output_hidden_states=False,
|
464 |
+
return_dict=True,
|
465 |
+
query_length=0,
|
466 |
+
):
|
467 |
+
all_hidden_states = () if output_hidden_states else None
|
468 |
+
all_self_attentions = () if output_attentions else None
|
469 |
+
all_cross_attentions = (() if output_attentions and self.config.add_cross_attention else None)
|
470 |
+
|
471 |
+
next_decoder_cache = () if use_cache else None
|
472 |
+
|
473 |
+
for i in range(self.config.num_hidden_layers):
|
474 |
+
layer_module = self.layer[i]
|
475 |
+
if output_hidden_states:
|
476 |
+
all_hidden_states = all_hidden_states + (hidden_states, )
|
477 |
+
|
478 |
+
layer_head_mask = head_mask[i] if head_mask is not None else None
|
479 |
+
past_key_value = past_key_values[i] if past_key_values is not None else None
|
480 |
+
# if past_key_value is not None:
|
481 |
+
# print(past_key_value[0].shape, past_key_value[1].shape)
|
482 |
+
|
483 |
+
if getattr(self.config, "gradient_checkpointing", False) and self.training:
|
484 |
+
|
485 |
+
if use_cache:
|
486 |
+
logger.warn("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
|
487 |
+
use_cache = False
|
488 |
+
|
489 |
+
def create_custom_forward(module):
|
490 |
+
def custom_forward(*inputs):
|
491 |
+
return module(*inputs, past_key_value, output_attentions, query_length)
|
492 |
+
|
493 |
+
return custom_forward
|
494 |
+
|
495 |
+
layer_outputs = torch.utils.checkpoint.checkpoint(
|
496 |
+
create_custom_forward(layer_module),
|
497 |
+
hidden_states,
|
498 |
+
attention_mask,
|
499 |
+
layer_head_mask,
|
500 |
+
encoder_hidden_states,
|
501 |
+
encoder_attention_mask,
|
502 |
+
)
|
503 |
+
else:
|
504 |
+
layer_outputs = layer_module(
|
505 |
+
hidden_states,
|
506 |
+
attention_mask,
|
507 |
+
layer_head_mask,
|
508 |
+
encoder_hidden_states,
|
509 |
+
encoder_attention_mask,
|
510 |
+
past_key_value,
|
511 |
+
output_attentions,
|
512 |
+
query_length,
|
513 |
+
)
|
514 |
+
# if past_key_value is not None:
|
515 |
+
# print(hidden_states.shape, attention_mask.shape)
|
516 |
+
# print(len(past_key_value))
|
517 |
+
|
518 |
+
hidden_states = layer_outputs[0]
|
519 |
+
if use_cache:
|
520 |
+
next_decoder_cache += (layer_outputs[-1], )
|
521 |
+
#print(layer_outputs[-1][0].shape)
|
522 |
+
if output_attentions:
|
523 |
+
all_self_attentions = all_self_attentions + (layer_outputs[1], )
|
524 |
+
all_cross_attentions = all_cross_attentions + (layer_outputs[2], )
|
525 |
+
|
526 |
+
if output_hidden_states:
|
527 |
+
all_hidden_states = all_hidden_states + (hidden_states, )
|
528 |
+
|
529 |
+
if not return_dict:
|
530 |
+
return tuple(v for v in [
|
531 |
+
hidden_states,
|
532 |
+
next_decoder_cache,
|
533 |
+
all_hidden_states,
|
534 |
+
all_self_attentions,
|
535 |
+
all_cross_attentions,
|
536 |
+
] if v is not None)
|
537 |
+
return BaseModelOutputWithPastAndCrossAttentions(
|
538 |
+
last_hidden_state=hidden_states,
|
539 |
+
past_key_values=next_decoder_cache,
|
540 |
+
hidden_states=all_hidden_states,
|
541 |
+
attentions=all_self_attentions,
|
542 |
+
cross_attentions=all_cross_attentions,
|
543 |
+
)
|
544 |
+
|
545 |
+
|
546 |
+
class BertPooler(nn.Module):
|
547 |
+
def __init__(self, config):
|
548 |
+
super().__init__()
|
549 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
550 |
+
self.activation = nn.Tanh()
|
551 |
+
|
552 |
+
def forward(self, hidden_states):
|
553 |
+
# We "pool" the model by simply taking the hidden state corresponding
|
554 |
+
# to the first token.
|
555 |
+
first_token_tensor = hidden_states[:, 0]
|
556 |
+
pooled_output = self.dense(first_token_tensor)
|
557 |
+
pooled_output = self.activation(pooled_output)
|
558 |
+
return pooled_output
|
559 |
+
|
560 |
+
|
561 |
+
class BertPredictionHeadTransform(nn.Module):
|
562 |
+
def __init__(self, config):
|
563 |
+
super().__init__()
|
564 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
565 |
+
if isinstance(config.hidden_act, str):
|
566 |
+
self.transform_act_fn = ACT2FN[config.hidden_act]
|
567 |
+
else:
|
568 |
+
self.transform_act_fn = config.hidden_act
|
569 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
570 |
+
|
571 |
+
def forward(self, hidden_states):
|
572 |
+
hidden_states = self.dense(hidden_states)
|
573 |
+
hidden_states = self.transform_act_fn(hidden_states)
|
574 |
+
hidden_states = self.LayerNorm(hidden_states)
|
575 |
+
return hidden_states
|
576 |
+
|
577 |
+
|
578 |
+
class BertLMPredictionHead(nn.Module):
|
579 |
+
def __init__(self, config):
|
580 |
+
super().__init__()
|
581 |
+
self.transform = BertPredictionHeadTransform(config)
|
582 |
+
|
583 |
+
# The output weights are the same as the input embeddings, but there is
|
584 |
+
# an output-only bias for each token.
|
585 |
+
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
586 |
+
|
587 |
+
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
588 |
+
|
589 |
+
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
|
590 |
+
self.decoder.bias = self.bias
|
591 |
+
|
592 |
+
def forward(self, hidden_states):
|
593 |
+
hidden_states = self.transform(hidden_states)
|
594 |
+
hidden_states = self.decoder(hidden_states)
|
595 |
+
return hidden_states
|
596 |
+
|
597 |
+
|
598 |
+
class BertOnlyMLMHead(nn.Module):
|
599 |
+
def __init__(self, config):
|
600 |
+
super().__init__()
|
601 |
+
self.predictions = BertLMPredictionHead(config)
|
602 |
+
|
603 |
+
def forward(self, sequence_output):
|
604 |
+
prediction_scores = self.predictions(sequence_output)
|
605 |
+
return prediction_scores
|
606 |
+
|
607 |
+
|
608 |
+
class BertPreTrainedModel(PreTrainedModel):
|
609 |
+
"""
|
610 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
611 |
+
models.
|
612 |
+
"""
|
613 |
+
|
614 |
+
config_class = BertConfig
|
615 |
+
base_model_prefix = "bert"
|
616 |
+
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
617 |
+
|
618 |
+
def _init_weights(self, module):
|
619 |
+
"""Initialize the weights"""
|
620 |
+
if isinstance(module, (nn.Linear, nn.Embedding)):
|
621 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
622 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
623 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
624 |
+
elif isinstance(module, nn.LayerNorm):
|
625 |
+
module.bias.data.zero_()
|
626 |
+
module.weight.data.fill_(1.0)
|
627 |
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
628 |
+
module.bias.data.zero_()
|
629 |
+
|
630 |
+
|
631 |
+
class BertModel(BertPreTrainedModel):
|
632 |
+
"""
|
633 |
+
The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
|
634 |
+
cross-attention is added between the self-attention layers, following the architecture described in `Attention is
|
635 |
+
all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
|
636 |
+
Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
|
637 |
+
argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
|
638 |
+
input to the forward pass.
|
639 |
+
"""
|
640 |
+
def __init__(self, config, add_pooling_layer=False):
|
641 |
+
super().__init__(config)
|
642 |
+
self.config = config
|
643 |
+
|
644 |
+
self.embeddings = BertEmbeddings(config)
|
645 |
+
|
646 |
+
self.encoder = BertEncoder(config)
|
647 |
+
|
648 |
+
self.pooler = BertPooler(config) if add_pooling_layer else None
|
649 |
+
|
650 |
+
self.init_weights()
|
651 |
+
|
652 |
+
def get_input_embeddings(self):
|
653 |
+
return self.embeddings.word_embeddings
|
654 |
+
|
655 |
+
def set_input_embeddings(self, value):
|
656 |
+
self.embeddings.word_embeddings = value
|
657 |
+
|
658 |
+
def _prune_heads(self, heads_to_prune):
|
659 |
+
"""
|
660 |
+
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
|
661 |
+
class PreTrainedModel
|
662 |
+
"""
|
663 |
+
for layer, heads in heads_to_prune.items():
|
664 |
+
self.encoder.layer[layer].attention.prune_heads(heads)
|
665 |
+
|
666 |
+
def get_extended_attention_mask(
|
667 |
+
self,
|
668 |
+
attention_mask: Tensor,
|
669 |
+
input_shape: Tuple[int],
|
670 |
+
device: device,
|
671 |
+
is_decoder: bool,
|
672 |
+
is_casual: bool,
|
673 |
+
has_query: bool = False,
|
674 |
+
) -> Tensor:
|
675 |
+
"""
|
676 |
+
Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
|
677 |
+
|
678 |
+
Arguments:
|
679 |
+
attention_mask (:obj:`torch.Tensor`):
|
680 |
+
Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
|
681 |
+
input_shape (:obj:`Tuple[int]`):
|
682 |
+
The shape of the input to the model.
|
683 |
+
device: (:obj:`torch.device`):
|
684 |
+
The device of the input to the model.
|
685 |
+
|
686 |
+
Returns:
|
687 |
+
:obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
|
688 |
+
"""
|
689 |
+
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
690 |
+
# ourselves in which case we just need to make it broadcastable to all heads.
|
691 |
+
#print(attention_mask.dim())
|
692 |
+
if attention_mask.dim() == 3:
|
693 |
+
extended_attention_mask = attention_mask[:, None, :, :]
|
694 |
+
elif attention_mask.dim() == 2:
|
695 |
+
# Provided a padding mask of dimensions [batch_size, seq_length]
|
696 |
+
# - if the model is a decoder, apply a causal mask in addition to the padding mask
|
697 |
+
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
698 |
+
if is_decoder or is_casual:
|
699 |
+
batch_size, seq_length = input_shape
|
700 |
+
#print(input_shape)
|
701 |
+
if not is_decoder and seq_length > 32:
|
702 |
+
query_length = 32
|
703 |
+
text_length = seq_length - query_length
|
704 |
+
query_ids = torch.arange(query_length, device=device)
|
705 |
+
query_causal_mask = (query_ids[None, None, :].repeat(batch_size, query_length, 1) <= query_ids[None, :,
|
706 |
+
None])
|
707 |
+
causal_mask = torch.ones((batch_size, seq_length, seq_length), device=device)
|
708 |
+
causal_mask[:, :query_length, :query_length] = query_causal_mask
|
709 |
+
# print(query_causal_mask.shape, causal_mask.shape)
|
710 |
+
#print(causal_mask[0])
|
711 |
+
|
712 |
+
else:
|
713 |
+
seq_ids = torch.arange(seq_length, device=device)
|
714 |
+
causal_mask = (seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None])
|
715 |
+
|
716 |
+
# add a prefix ones mask to the causal mask
|
717 |
+
# causal and attention masks must have same type with pytorch version < 1.3
|
718 |
+
causal_mask = causal_mask.to(attention_mask.dtype)
|
719 |
+
# if is_decoder:
|
720 |
+
# print(causal_mask.shape, attention_mask.shape)
|
721 |
+
#print(causal_mask.shape, attention_mask.shape)
|
722 |
+
|
723 |
+
if causal_mask.shape[1] < attention_mask.shape[1]:
|
724 |
+
prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
|
725 |
+
if has_query: # UniLM style attention mask
|
726 |
+
causal_mask = torch.cat(
|
727 |
+
[
|
728 |
+
torch.zeros(
|
729 |
+
(batch_size, prefix_seq_len, seq_length),
|
730 |
+
device=device,
|
731 |
+
dtype=causal_mask.dtype,
|
732 |
+
),
|
733 |
+
causal_mask,
|
734 |
+
],
|
735 |
+
axis=1,
|
736 |
+
)
|
737 |
+
causal_mask = torch.cat(
|
738 |
+
[
|
739 |
+
torch.ones(
|
740 |
+
(batch_size, causal_mask.shape[1], prefix_seq_len),
|
741 |
+
device=device,
|
742 |
+
dtype=causal_mask.dtype,
|
743 |
+
),
|
744 |
+
causal_mask,
|
745 |
+
],
|
746 |
+
axis=-1,
|
747 |
+
)
|
748 |
+
#print(has_query, causal_mask.shape)
|
749 |
+
#print(causal_mask[0])
|
750 |
+
extended_attention_mask = (causal_mask[:, None, :, :] * attention_mask[:, None, None, :])
|
751 |
+
#print(extended_attention_mask[0])
|
752 |
+
#print('extended_attention_mask', extended_attention_mask.shape)
|
753 |
+
else:
|
754 |
+
extended_attention_mask = attention_mask[:, None, None, :]
|
755 |
+
#print(attention_mask.shape, extended_attention_mask.shape)
|
756 |
+
else:
|
757 |
+
raise ValueError("Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
|
758 |
+
input_shape, attention_mask.shape))
|
759 |
+
|
760 |
+
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
761 |
+
# masked positions, this operation will create a tensor which is 0.0 for
|
762 |
+
# positions we want to attend and -10000.0 for masked positions.
|
763 |
+
# Since we are adding it to the raw scores before the softmax, this is
|
764 |
+
# effectively the same as removing these entirely.
|
765 |
+
extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
766 |
+
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
767 |
+
return extended_attention_mask
|
768 |
+
|
769 |
+
def forward(
|
770 |
+
self,
|
771 |
+
input_ids=None,
|
772 |
+
attention_mask=None,
|
773 |
+
position_ids=None,
|
774 |
+
head_mask=None,
|
775 |
+
query_embeds=None,
|
776 |
+
encoder_hidden_states=None,
|
777 |
+
encoder_attention_mask=None,
|
778 |
+
past_key_values=None,
|
779 |
+
use_cache=None,
|
780 |
+
output_attentions=None,
|
781 |
+
output_hidden_states=None,
|
782 |
+
return_dict=None,
|
783 |
+
is_decoder=False,
|
784 |
+
):
|
785 |
+
r"""
|
786 |
+
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
787 |
+
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
|
788 |
+
the model is configured as a decoder.
|
789 |
+
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
790 |
+
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
|
791 |
+
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
|
792 |
+
- 1 for tokens that are **not masked**,
|
793 |
+
- 0 for tokens that are **masked**.
|
794 |
+
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
795 |
+
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
796 |
+
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
|
797 |
+
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
|
798 |
+
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
|
799 |
+
use_cache (:obj:`bool`, `optional`):
|
800 |
+
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
|
801 |
+
decoding (see :obj:`past_key_values`).
|
802 |
+
"""
|
803 |
+
output_attentions = (output_attentions if output_attentions is not None else self.config.output_attentions)
|
804 |
+
output_hidden_states = (output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states)
|
805 |
+
return_dict = (return_dict if return_dict is not None else self.config.use_return_dict)
|
806 |
+
|
807 |
+
# use_cache = use_cache if use_cache is not None else self.config.use_cache
|
808 |
+
|
809 |
+
if input_ids is None:
|
810 |
+
assert (query_embeds is not None), "You have to specify query_embeds when input_ids is None"
|
811 |
+
|
812 |
+
#if query_embeds is not None:
|
813 |
+
if query_embeds is not None and query_embeds.shape[1] == 32:
|
814 |
+
is_casual = True
|
815 |
+
else:
|
816 |
+
is_casual = False
|
817 |
+
past_key_values_length = (past_key_values[0][0].shape[2] -
|
818 |
+
self.config.query_length if past_key_values is not None else 0)
|
819 |
+
|
820 |
+
query_length = query_embeds.shape[1] if query_embeds is not None else 0
|
821 |
+
|
822 |
+
embedding_output = self.embeddings(
|
823 |
+
input_ids=input_ids,
|
824 |
+
position_ids=position_ids,
|
825 |
+
query_embeds=query_embeds,
|
826 |
+
past_key_values_length=past_key_values_length,
|
827 |
+
)
|
828 |
+
|
829 |
+
input_shape = embedding_output.size()[:-1]
|
830 |
+
batch_size, seq_length = input_shape
|
831 |
+
device = embedding_output.device
|
832 |
+
|
833 |
+
#print('attention_mask', attention_mask)
|
834 |
+
if attention_mask is None:
|
835 |
+
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
|
836 |
+
#print(seq_length, past_key_values_length)
|
837 |
+
|
838 |
+
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
839 |
+
# ourselves in which case we just need to make it broadcastable to all heads.
|
840 |
+
if is_decoder:
|
841 |
+
#print(attention_mask.shape, input_ids.shape)
|
842 |
+
extended_attention_mask = self.get_extended_attention_mask(
|
843 |
+
attention_mask,
|
844 |
+
input_ids.shape,
|
845 |
+
device,
|
846 |
+
is_decoder,
|
847 |
+
is_casual,
|
848 |
+
has_query=(query_embeds is not None),
|
849 |
+
)
|
850 |
+
else:
|
851 |
+
extended_attention_mask = self.get_extended_attention_mask(
|
852 |
+
attention_mask,
|
853 |
+
input_shape,
|
854 |
+
device,
|
855 |
+
is_decoder,
|
856 |
+
is_casual,
|
857 |
+
)
|
858 |
+
#print(is_decoder, extended_attention_mask.shape)
|
859 |
+
# if is_decoder:
|
860 |
+
# print(extended_attention_mask[0,0,:,32:])
|
861 |
+
# if attention_mask is not None:
|
862 |
+
# print(input_ids, embedding_output.shape, extended_attention_mask.shape)
|
863 |
+
|
864 |
+
# If a 2D or 3D attention mask is provided for the cross-attention
|
865 |
+
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
866 |
+
if encoder_hidden_states is not None:
|
867 |
+
if type(encoder_hidden_states) == list:
|
868 |
+
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size()
|
869 |
+
else:
|
870 |
+
(
|
871 |
+
encoder_batch_size,
|
872 |
+
encoder_sequence_length,
|
873 |
+
_,
|
874 |
+
) = encoder_hidden_states.size()
|
875 |
+
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
876 |
+
|
877 |
+
if type(encoder_attention_mask) == list:
|
878 |
+
encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask]
|
879 |
+
elif encoder_attention_mask is None:
|
880 |
+
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
881 |
+
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
882 |
+
else:
|
883 |
+
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
884 |
+
#print(is_casual, extended_attention_mask.shape, encoder_attention_mask.shape, encoder_extended_attention_mask.shape)
|
885 |
+
else:
|
886 |
+
encoder_extended_attention_mask = None
|
887 |
+
|
888 |
+
# if input_ids is not None and query_embeds is not None:
|
889 |
+
# print(extended_attention_mask.shape, encoder_extended_attention_mask.shape)
|
890 |
+
# Prepare head mask if needed
|
891 |
+
# 1.0 in head_mask indicate we keep the head
|
892 |
+
# attention_probs has shape bsz x n_heads x N x N
|
893 |
+
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
894 |
+
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
895 |
+
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
896 |
+
#print(head_mask)
|
897 |
+
|
898 |
+
encoder_outputs = self.encoder(
|
899 |
+
embedding_output,
|
900 |
+
attention_mask=extended_attention_mask,
|
901 |
+
head_mask=head_mask,
|
902 |
+
encoder_hidden_states=encoder_hidden_states,
|
903 |
+
encoder_attention_mask=encoder_extended_attention_mask,
|
904 |
+
past_key_values=past_key_values,
|
905 |
+
use_cache=use_cache,
|
906 |
+
output_attentions=output_attentions,
|
907 |
+
output_hidden_states=output_hidden_states,
|
908 |
+
return_dict=return_dict,
|
909 |
+
query_length=query_length,
|
910 |
+
)
|
911 |
+
# if is_decoder:
|
912 |
+
# print(embedding_output.shape, attention_mask.shape, len(past_key_values))
|
913 |
+
#print(embedding_output.shape, extended_attention_mask.shape, encoder_hidden_states.shape, encoder_extended_attention_mask.shape)
|
914 |
+
#print(extended_attention_mask[0], encoder_extended_attention_mask[0])
|
915 |
+
|
916 |
+
#print(query_embeds.shape, encoder_hidden_states.shape)
|
917 |
+
|
918 |
+
sequence_output = encoder_outputs[0]
|
919 |
+
pooled_output = (self.pooler(sequence_output) if self.pooler is not None else None)
|
920 |
+
|
921 |
+
if not return_dict:
|
922 |
+
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
923 |
+
|
924 |
+
return BaseModelOutputWithPoolingAndCrossAttentions(
|
925 |
+
last_hidden_state=sequence_output,
|
926 |
+
pooler_output=pooled_output,
|
927 |
+
past_key_values=encoder_outputs.past_key_values,
|
928 |
+
hidden_states=encoder_outputs.hidden_states,
|
929 |
+
attentions=encoder_outputs.attentions,
|
930 |
+
cross_attentions=encoder_outputs.cross_attentions,
|
931 |
+
)
|
932 |
+
|
933 |
+
|
934 |
+
class BertLMHeadModel(BertPreTrainedModel):
|
935 |
+
|
936 |
+
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
937 |
+
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
|
938 |
+
|
939 |
+
def __init__(self, config):
|
940 |
+
super().__init__(config)
|
941 |
+
|
942 |
+
self.bert = BertModel(config, add_pooling_layer=False)
|
943 |
+
self.cls = BertOnlyMLMHead(config)
|
944 |
+
|
945 |
+
self.init_weights()
|
946 |
+
|
947 |
+
def get_output_embeddings(self):
|
948 |
+
return self.cls.predictions.decoder
|
949 |
+
|
950 |
+
def set_output_embeddings(self, new_embeddings):
|
951 |
+
self.cls.predictions.decoder = new_embeddings
|
952 |
+
|
953 |
+
def forward(
|
954 |
+
self,
|
955 |
+
input_ids=None,
|
956 |
+
attention_mask=None,
|
957 |
+
position_ids=None,
|
958 |
+
head_mask=None,
|
959 |
+
query_embeds=None,
|
960 |
+
encoder_hidden_states=None,
|
961 |
+
encoder_attention_mask=None,
|
962 |
+
labels=None,
|
963 |
+
past_key_values=None,
|
964 |
+
use_cache=True,
|
965 |
+
output_attentions=None,
|
966 |
+
output_hidden_states=None,
|
967 |
+
return_dict=None,
|
968 |
+
return_logits=False,
|
969 |
+
is_decoder=True,
|
970 |
+
reduction="mean",
|
971 |
+
):
|
972 |
+
r"""
|
973 |
+
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
974 |
+
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
|
975 |
+
the model is configured as a decoder.
|
976 |
+
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
977 |
+
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
|
978 |
+
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
|
979 |
+
- 1 for tokens that are **not masked**,
|
980 |
+
- 0 for tokens that are **masked**.
|
981 |
+
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
982 |
+
Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
|
983 |
+
``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
|
984 |
+
ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``
|
985 |
+
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
986 |
+
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
987 |
+
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
|
988 |
+
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
|
989 |
+
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
|
990 |
+
use_cache (:obj:`bool`, `optional`):
|
991 |
+
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
|
992 |
+
decoding (see :obj:`past_key_values`).
|
993 |
+
Returns:
|
994 |
+
Example::
|
995 |
+
>>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
|
996 |
+
>>> import torch
|
997 |
+
>>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
|
998 |
+
>>> config = BertConfig.from_pretrained("bert-base-cased")
|
999 |
+
>>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
|
1000 |
+
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
|
1001 |
+
>>> outputs = model(**inputs)
|
1002 |
+
>>> prediction_logits = outputs.logits
|
1003 |
+
"""
|
1004 |
+
return_dict = (return_dict if return_dict is not None else self.config.use_return_dict)
|
1005 |
+
if labels is not None:
|
1006 |
+
use_cache = False
|
1007 |
+
if past_key_values is not None:
|
1008 |
+
query_embeds = None
|
1009 |
+
#print(len(past_key_values))
|
1010 |
+
#print('attention_mask', attention_mask)
|
1011 |
+
outputs = self.bert(
|
1012 |
+
input_ids,
|
1013 |
+
attention_mask=attention_mask,
|
1014 |
+
position_ids=position_ids,
|
1015 |
+
head_mask=head_mask,
|
1016 |
+
query_embeds=query_embeds,
|
1017 |
+
encoder_hidden_states=encoder_hidden_states,
|
1018 |
+
encoder_attention_mask=encoder_attention_mask,
|
1019 |
+
past_key_values=past_key_values,
|
1020 |
+
use_cache=use_cache,
|
1021 |
+
output_attentions=output_attentions,
|
1022 |
+
output_hidden_states=output_hidden_states,
|
1023 |
+
return_dict=return_dict,
|
1024 |
+
is_decoder=is_decoder,
|
1025 |
+
)
|
1026 |
+
|
1027 |
+
sequence_output = outputs[0]
|
1028 |
+
if query_embeds is not None:
|
1029 |
+
sequence_output = outputs[0][:, query_embeds.shape[1]:, :]
|
1030 |
+
|
1031 |
+
prediction_scores = self.cls(sequence_output)
|
1032 |
+
|
1033 |
+
if return_logits:
|
1034 |
+
return prediction_scores[:, :-1, :].contiguous()
|
1035 |
+
|
1036 |
+
lm_loss = None
|
1037 |
+
if labels is not None:
|
1038 |
+
# we are doing next-token prediction; shift prediction scores and input ids by one
|
1039 |
+
shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
|
1040 |
+
labels = labels[:, 1:].contiguous()
|
1041 |
+
loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1)
|
1042 |
+
lm_loss = loss_fct(
|
1043 |
+
shifted_prediction_scores.view(-1, self.config.vocab_size),
|
1044 |
+
labels.view(-1),
|
1045 |
+
)
|
1046 |
+
if reduction == "none":
|
1047 |
+
lm_loss = lm_loss.view(prediction_scores.size(0), -1).sum(1)
|
1048 |
+
|
1049 |
+
if not return_dict:
|
1050 |
+
output = (prediction_scores, ) + outputs[2:]
|
1051 |
+
return ((lm_loss, ) + output) if lm_loss is not None else output
|
1052 |
+
|
1053 |
+
return CausalLMOutputWithCrossAttentions(
|
1054 |
+
loss=lm_loss,
|
1055 |
+
logits=prediction_scores,
|
1056 |
+
past_key_values=outputs.past_key_values,
|
1057 |
+
hidden_states=outputs.hidden_states,
|
1058 |
+
attentions=outputs.attentions,
|
1059 |
+
cross_attentions=outputs.cross_attentions,
|
1060 |
+
)
|
1061 |
+
|
1062 |
+
def prepare_inputs_for_generation(self, input_ids, query_embeds, past=None, attention_mask=None, **model_kwargs):
|
1063 |
+
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
1064 |
+
if attention_mask is None:
|
1065 |
+
attention_mask = input_ids.new_ones(input_ids.shape)
|
1066 |
+
query_mask = input_ids.new_ones(query_embeds.shape[:-1])
|
1067 |
+
attention_mask = torch.cat([query_mask, attention_mask], dim=-1)
|
1068 |
+
|
1069 |
+
# cut decoder_input_ids if past is used
|
1070 |
+
if past is not None:
|
1071 |
+
input_ids = input_ids[:, -1:]
|
1072 |
+
|
1073 |
+
return {
|
1074 |
+
"input_ids": input_ids,
|
1075 |
+
"query_embeds": query_embeds,
|
1076 |
+
"attention_mask": attention_mask,
|
1077 |
+
"past_key_values": past,
|
1078 |
+
"encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None),
|
1079 |
+
"encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None),
|
1080 |
+
"is_decoder": True,
|
1081 |
+
}
|
1082 |
+
|
1083 |
+
def _reorder_cache(self, past, beam_idx):
|
1084 |
+
reordered_past = ()
|
1085 |
+
for layer_past in past:
|
1086 |
+
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past), )
|
1087 |
+
return reordered_past
|
1088 |
+
|
1089 |
+
|
1090 |
+
class BertForMaskedLM(BertPreTrainedModel):
|
1091 |
+
|
1092 |
+
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
1093 |
+
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
|
1094 |
+
|
1095 |
+
def __init__(self, config):
|
1096 |
+
super().__init__(config)
|
1097 |
+
|
1098 |
+
self.bert = BertModel(config, add_pooling_layer=False)
|
1099 |
+
self.cls = BertOnlyMLMHead(config)
|
1100 |
+
|
1101 |
+
self.init_weights()
|
1102 |
+
|
1103 |
+
def get_output_embeddings(self):
|
1104 |
+
return self.cls.predictions.decoder
|
1105 |
+
|
1106 |
+
def set_output_embeddings(self, new_embeddings):
|
1107 |
+
self.cls.predictions.decoder = new_embeddings
|
1108 |
+
|
1109 |
+
def forward(
|
1110 |
+
self,
|
1111 |
+
input_ids=None,
|
1112 |
+
attention_mask=None,
|
1113 |
+
position_ids=None,
|
1114 |
+
head_mask=None,
|
1115 |
+
query_embeds=None,
|
1116 |
+
encoder_hidden_states=None,
|
1117 |
+
encoder_attention_mask=None,
|
1118 |
+
labels=None,
|
1119 |
+
output_attentions=None,
|
1120 |
+
output_hidden_states=None,
|
1121 |
+
return_dict=None,
|
1122 |
+
return_logits=False,
|
1123 |
+
is_decoder=False,
|
1124 |
+
):
|
1125 |
+
r"""
|
1126 |
+
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
1127 |
+
Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
|
1128 |
+
config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
|
1129 |
+
(masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
|
1130 |
+
"""
|
1131 |
+
|
1132 |
+
return_dict = (return_dict if return_dict is not None else self.config.use_return_dict)
|
1133 |
+
|
1134 |
+
outputs = self.bert(
|
1135 |
+
input_ids,
|
1136 |
+
attention_mask=attention_mask,
|
1137 |
+
position_ids=position_ids,
|
1138 |
+
head_mask=head_mask,
|
1139 |
+
query_embeds=query_embeds,
|
1140 |
+
encoder_hidden_states=encoder_hidden_states,
|
1141 |
+
encoder_attention_mask=encoder_attention_mask,
|
1142 |
+
output_attentions=output_attentions,
|
1143 |
+
output_hidden_states=output_hidden_states,
|
1144 |
+
return_dict=return_dict,
|
1145 |
+
is_decoder=is_decoder,
|
1146 |
+
)
|
1147 |
+
|
1148 |
+
if query_embeds is not None:
|
1149 |
+
sequence_output = outputs[0][:, query_embeds.shape[1]:, :]
|
1150 |
+
prediction_scores = self.cls(sequence_output)
|
1151 |
+
|
1152 |
+
if return_logits:
|
1153 |
+
return prediction_scores
|
1154 |
+
|
1155 |
+
masked_lm_loss = None
|
1156 |
+
if labels is not None:
|
1157 |
+
loss_fct = CrossEntropyLoss() # -100 index = padding token
|
1158 |
+
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
|
1159 |
+
|
1160 |
+
if not return_dict:
|
1161 |
+
output = (prediction_scores, ) + outputs[2:]
|
1162 |
+
return (((masked_lm_loss, ) + output) if masked_lm_loss is not None else output)
|
1163 |
+
|
1164 |
+
return MaskedLMOutput(
|
1165 |
+
loss=masked_lm_loss,
|
1166 |
+
logits=prediction_scores,
|
1167 |
+
hidden_states=outputs.hidden_states,
|
1168 |
+
attentions=outputs.attentions,
|
1169 |
+
)
|
models/seed_qformer/qformer_quantizer.py
ADDED
@@ -0,0 +1,375 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) 2023, salesforce.com, inc.
|
3 |
+
All rights reserved.
|
4 |
+
SPDX-License-Identifier: BSD-3-Clause
|
5 |
+
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
+
"""
|
7 |
+
import logging
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.distributed as dist
|
11 |
+
import torch.nn as nn
|
12 |
+
from torch.cuda.amp import autocast as autocast
|
13 |
+
from torch.nn import functional as F
|
14 |
+
import numpy as np
|
15 |
+
from functools import partial
|
16 |
+
from einops import rearrange
|
17 |
+
|
18 |
+
from .blip2 import Blip2Base, disabled_train
|
19 |
+
from .vit import Block
|
20 |
+
from .utils import download_cached_file, is_url
|
21 |
+
|
22 |
+
class VectorQuantizer2(nn.Module):
|
23 |
+
"""
|
24 |
+
Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly
|
25 |
+
avoids costly matrix multiplications and allows for post-hoc remapping of indices.
|
26 |
+
"""
|
27 |
+
|
28 |
+
# NOTE: due to a bug the beta term was applied to the wrong term. for
|
29 |
+
# backwards compatibility we use the buggy version by default, but you can
|
30 |
+
# specify legacy=False to fix it.
|
31 |
+
def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random", sane_index_shape=False, legacy=True):
|
32 |
+
super().__init__()
|
33 |
+
self.n_e = n_e
|
34 |
+
self.e_dim = e_dim
|
35 |
+
self.beta = beta
|
36 |
+
self.legacy = legacy
|
37 |
+
|
38 |
+
self.embedding = nn.Embedding(self.n_e, self.e_dim)
|
39 |
+
self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
|
40 |
+
|
41 |
+
self.remap = remap
|
42 |
+
if self.remap is not None:
|
43 |
+
self.register_buffer("used", torch.tensor(np.load(self.remap)))
|
44 |
+
self.re_embed = self.used.shape[0]
|
45 |
+
self.unknown_index = unknown_index # "random" or "extra" or integer
|
46 |
+
if self.unknown_index == "extra":
|
47 |
+
self.unknown_index = self.re_embed
|
48 |
+
self.re_embed = self.re_embed + 1
|
49 |
+
print(f"Remapping {self.n_e} indices to {self.re_embed} indices. "
|
50 |
+
f"Using {self.unknown_index} for unknown indices.")
|
51 |
+
else:
|
52 |
+
self.re_embed = n_e
|
53 |
+
|
54 |
+
self.sane_index_shape = sane_index_shape
|
55 |
+
|
56 |
+
def remap_to_used(self, inds):
|
57 |
+
ishape = inds.shape
|
58 |
+
assert len(ishape) > 1
|
59 |
+
inds = inds.reshape(ishape[0], -1)
|
60 |
+
used = self.used.to(inds)
|
61 |
+
match = (inds[:, :, None] == used[None, None, ...]).long()
|
62 |
+
new = match.argmax(-1)
|
63 |
+
unknown = match.sum(2) < 1
|
64 |
+
if self.unknown_index == "random":
|
65 |
+
new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device)
|
66 |
+
else:
|
67 |
+
new[unknown] = self.unknown_index
|
68 |
+
return new.reshape(ishape)
|
69 |
+
|
70 |
+
def unmap_to_all(self, inds):
|
71 |
+
ishape = inds.shape
|
72 |
+
assert len(ishape) > 1
|
73 |
+
inds = inds.reshape(ishape[0], -1)
|
74 |
+
used = self.used.to(inds)
|
75 |
+
if self.re_embed > self.used.shape[0]: # extra token
|
76 |
+
inds[inds >= self.used.shape[0]] = 0 # simply set to zero
|
77 |
+
back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
|
78 |
+
return back.reshape(ishape)
|
79 |
+
|
80 |
+
# def l2norm(self, t):
|
81 |
+
# return F.normalize(t, p = 2, dim = -1)
|
82 |
+
|
83 |
+
def forward(self, z, temp=None, rescale_logits=False, return_logits=False):
|
84 |
+
assert temp is None or temp == 1.0, "Only for interface compatible with Gumbel"
|
85 |
+
assert rescale_logits is False, "Only for interface compatible with Gumbel"
|
86 |
+
assert return_logits is False, "Only for interface compatible with Gumbel"
|
87 |
+
# reshape z -> (batch, height, width, channel) and flatten
|
88 |
+
#z = rearrange(z, 'b c h w -> b h w c').contiguous()
|
89 |
+
bz = z.shape[0]
|
90 |
+
z_flattened = z.view(-1, self.e_dim)
|
91 |
+
#print('z_flattened', z_flattened.shape)
|
92 |
+
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
|
93 |
+
|
94 |
+
d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
|
95 |
+
torch.sum(self.embedding.weight**2, dim=1) - 2 * \
|
96 |
+
torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n'))
|
97 |
+
|
98 |
+
min_encoding_indices = torch.argmin(d, dim=1)
|
99 |
+
z_q = self.embedding(min_encoding_indices).view(z.shape)
|
100 |
+
perplexity = None
|
101 |
+
min_encodings = None
|
102 |
+
|
103 |
+
# compute loss for embedding
|
104 |
+
if not self.legacy:
|
105 |
+
loss = self.beta * torch.mean((z_q.detach() - z)**2) + torch.mean((z_q - z.detach())**2)
|
106 |
+
else:
|
107 |
+
loss = torch.mean((z_q.detach() - z)**2) + self.beta * torch.mean((z_q - z.detach())**2)
|
108 |
+
|
109 |
+
# preserve gradients
|
110 |
+
z_q = z + (z_q - z).detach()
|
111 |
+
|
112 |
+
# reshape back to match original input shape
|
113 |
+
#z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous()
|
114 |
+
z_q = z_q.reshape(bz, -1, z_q.shape[-1])
|
115 |
+
if self.remap is not None:
|
116 |
+
min_encoding_indices = min_encoding_indices.reshape(z.shape[0], -1) # add batch axis
|
117 |
+
min_encoding_indices = self.remap_to_used(min_encoding_indices)
|
118 |
+
min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten
|
119 |
+
|
120 |
+
if self.sane_index_shape:
|
121 |
+
min_encoding_indices = min_encoding_indices.reshape(z_q.shape[0], z_q.shape[2], z_q.shape[3])
|
122 |
+
|
123 |
+
return z_q, loss, min_encoding_indices
|
124 |
+
|
125 |
+
def get_codebook_entry(self, indices, shape=None):
|
126 |
+
# shape specifying (batch, height, width, channel)
|
127 |
+
if self.remap is not None:
|
128 |
+
indices = indices.reshape(shape[0], -1) # add batch axis
|
129 |
+
indices = self.unmap_to_all(indices)
|
130 |
+
indices = indices.reshape(-1) # flatten again
|
131 |
+
|
132 |
+
# get quantized latent vectors
|
133 |
+
z_q = self.embedding(indices)
|
134 |
+
|
135 |
+
if shape is not None:
|
136 |
+
z_q = z_q.view(shape)
|
137 |
+
# reshape back to match original input shape
|
138 |
+
z_q = z_q.permute(0, 3, 1, 2).contiguous()
|
139 |
+
|
140 |
+
return z_q
|
141 |
+
|
142 |
+
|
143 |
+
class Blip2QformerQuantizer(Blip2Base):
|
144 |
+
"""
|
145 |
+
BLIP2 first-stage model with Q-former and ViT.
|
146 |
+
Supported model types:
|
147 |
+
- pretrained: pretrained model with vit-g
|
148 |
+
- pretrain_vitL: pretrained model with vit-large
|
149 |
+
- coco: fintuned model on coco
|
150 |
+
Usage:
|
151 |
+
>>> from lavis.models import load_model
|
152 |
+
>>> model = load_model("blip2", "pretrain")
|
153 |
+
"""
|
154 |
+
|
155 |
+
PRETRAINED_MODEL_CONFIG_DICT = {
|
156 |
+
"pretrain": "configs/models/blip2/blip2_pretrain.yaml",
|
157 |
+
"pretrain_vitL": "configs/models/blip2/blip2_pretrain_vitL.yaml",
|
158 |
+
"coco": "configs/models/blip2/blip2_coco.yaml",
|
159 |
+
}
|
160 |
+
|
161 |
+
def __init__(self,
|
162 |
+
vit_model="eva_clip_g",
|
163 |
+
img_size=224,
|
164 |
+
drop_path_rate=0,
|
165 |
+
use_grad_checkpoint=False,
|
166 |
+
vit_precision="fp16",
|
167 |
+
freeze_vit=True,
|
168 |
+
num_query_token=32,
|
169 |
+
cross_attention_freq=2,
|
170 |
+
embed_dim=256,
|
171 |
+
max_txt_len=32,
|
172 |
+
codebook_embed_dim=32,
|
173 |
+
n_embed=8192,
|
174 |
+
recon_s=True,
|
175 |
+
blocks_for_image=True,
|
176 |
+
decode_depth=4,
|
177 |
+
use_recon_s_for_image=False,
|
178 |
+
use_qformer_image=False,
|
179 |
+
image_features_dim=1024):
|
180 |
+
super().__init__()
|
181 |
+
|
182 |
+
self.tokenizer = self.init_tokenizer()
|
183 |
+
|
184 |
+
self.visual_encoder, self.ln_vision = self.init_vision_encoder(vit_model, img_size, drop_path_rate, use_grad_checkpoint,
|
185 |
+
vit_precision)
|
186 |
+
if freeze_vit:
|
187 |
+
for name, param in self.visual_encoder.named_parameters():
|
188 |
+
param.requires_grad = False
|
189 |
+
self.visual_encoder = self.visual_encoder.eval()
|
190 |
+
self.visual_encoder.train = disabled_train
|
191 |
+
logging.info("freeze vision encoder")
|
192 |
+
self.ln_vision.weight.requires_grad = False
|
193 |
+
self.ln_vision.bias.requires_grad = False
|
194 |
+
|
195 |
+
self.codebook_embed_dim = codebook_embed_dim
|
196 |
+
self.n_embed = n_embed
|
197 |
+
self.recon_s = recon_s
|
198 |
+
self.blocks_for_image = blocks_for_image
|
199 |
+
self.use_recon_s_for_image = use_recon_s_for_image
|
200 |
+
self.depth = decode_depth
|
201 |
+
self.image_features_dim = image_features_dim
|
202 |
+
self.use_qformer_image = use_qformer_image
|
203 |
+
|
204 |
+
self.Qformer, self.query_tokens = self.init_Qformer(num_query_token, self.visual_encoder.num_features)
|
205 |
+
|
206 |
+
self.Qformer.cls = None
|
207 |
+
self.Qformer.bert.embeddings.word_embeddings = None
|
208 |
+
self.Qformer.bert.embeddings.position_embeddings = None
|
209 |
+
for layer in self.Qformer.bert.encoder.layer:
|
210 |
+
layer.output = None
|
211 |
+
layer.intermediate = None
|
212 |
+
|
213 |
+
for name, param in self.Qformer.named_parameters():
|
214 |
+
param.requires_grad = False
|
215 |
+
self.query_tokens.requires_grad = False
|
216 |
+
|
217 |
+
self.quantize = VectorQuantizer2(n_embed, codebook_embed_dim, beta=0.25, remap=None, sane_index_shape=False)
|
218 |
+
|
219 |
+
self.encode_task_layer = nn.Sequential(
|
220 |
+
nn.Linear(self.Qformer.config.hidden_size, self.Qformer.config.hidden_size),
|
221 |
+
nn.Tanh(),
|
222 |
+
nn.Linear(self.Qformer.config.hidden_size, codebook_embed_dim) # for quantize
|
223 |
+
)
|
224 |
+
|
225 |
+
self.decode_task_layer = nn.Sequential(
|
226 |
+
nn.Linear(codebook_embed_dim, codebook_embed_dim),
|
227 |
+
nn.Tanh(),
|
228 |
+
nn.Linear(codebook_embed_dim, self.Qformer.config.hidden_size) # for quantize
|
229 |
+
)
|
230 |
+
|
231 |
+
self.quantize = self.quantize.eval()
|
232 |
+
self.quantize.training = False
|
233 |
+
for name, param in self.named_parameters():
|
234 |
+
if 'quantize' in name or 'encode_task_layer' in name or 'decode_task_layer' in name:
|
235 |
+
#print('freeze params', name)
|
236 |
+
param.requires_grad = False
|
237 |
+
|
238 |
+
if self.recon_s:
|
239 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_query_token, self.Qformer.config.hidden_size))
|
240 |
+
self.blocks = nn.ModuleList([
|
241 |
+
Block(dim=self.Qformer.config.hidden_size,
|
242 |
+
num_heads=12,
|
243 |
+
mlp_ratio=4.0,
|
244 |
+
qkv_bias=True,
|
245 |
+
qk_scale=None,
|
246 |
+
drop=0.0,
|
247 |
+
attn_drop=0.0,
|
248 |
+
drop_path=0.0,
|
249 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6)) for i in range(self.depth)
|
250 |
+
])
|
251 |
+
|
252 |
+
if self.blocks_for_image:
|
253 |
+
self.pos_embed_image = nn.Parameter(torch.zeros(1, num_query_token, self.Qformer.config.hidden_size))
|
254 |
+
self.blocks_image = nn.ModuleList([
|
255 |
+
Block(dim=self.Qformer.config.hidden_size,
|
256 |
+
num_heads=12,
|
257 |
+
mlp_ratio=4.0,
|
258 |
+
qkv_bias=True,
|
259 |
+
qk_scale=None,
|
260 |
+
drop=0.0,
|
261 |
+
attn_drop=0.0,
|
262 |
+
drop_path=0.0,
|
263 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6)) for i in range(self.depth)
|
264 |
+
])
|
265 |
+
|
266 |
+
if self.use_qformer_image:
|
267 |
+
num_reverse_token = 1
|
268 |
+
self.Reverse_Qformer, self.reverse_tokens = self.init_Qformer(num_reverse_token, self.Qformer.config.hidden_size)
|
269 |
+
|
270 |
+
self.Reverse_Qformer.cls = None
|
271 |
+
self.Reverse_Qformer.bert.embeddings.word_embeddings = None
|
272 |
+
self.Reverse_Qformer.bert.embeddings.position_embeddings = None
|
273 |
+
for layer in self.Reverse_Qformer.bert.encoder.layer:
|
274 |
+
layer.output = None
|
275 |
+
layer.intermediate = None
|
276 |
+
self.distill_image_proj = nn.Linear(self.Qformer.config.hidden_size, image_features_dim)
|
277 |
+
|
278 |
+
else:
|
279 |
+
self.image_down = nn.Sequential(
|
280 |
+
nn.Linear(self.Qformer.config.hidden_size, 256, bias=False),
|
281 |
+
nn.ReLU(),
|
282 |
+
nn.Linear(256, 128, bias=False),
|
283 |
+
nn.ReLU(),
|
284 |
+
nn.Linear(128, 32, bias=False),
|
285 |
+
)
|
286 |
+
self.distill_image_proj = nn.Linear(num_query_token * 32, image_features_dim)
|
287 |
+
|
288 |
+
def get_codebook_indices(self, image):
|
289 |
+
with torch.no_grad():
|
290 |
+
with self.maybe_autocast():
|
291 |
+
image_embeds = self.ln_vision(self.visual_encoder(image))
|
292 |
+
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image.device)
|
293 |
+
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
|
294 |
+
query_output = self.Qformer.bert(
|
295 |
+
query_embeds=query_tokens,
|
296 |
+
encoder_hidden_states=image_embeds,
|
297 |
+
encoder_attention_mask=image_atts,
|
298 |
+
return_dict=True,
|
299 |
+
)
|
300 |
+
|
301 |
+
query_output_down = self.encode_task_layer(query_output.last_hidden_state)
|
302 |
+
quant, loss_embed, embed_ind = self.quantize(query_output_down)
|
303 |
+
embed_ind = embed_ind.reshape(quant.shape[0], -1)
|
304 |
+
|
305 |
+
query_output_up = self.decode_task_layer(quant)
|
306 |
+
|
307 |
+
return embed_ind, query_output_up
|
308 |
+
|
309 |
+
def get_codebook_entry(self, indices):
|
310 |
+
quant_embedding = self.quantize.get_codebook_entry(indices)
|
311 |
+
# print('quant_embedding_shape: ', quant_embedding.shape)
|
312 |
+
# print(self.decode_task_layer)
|
313 |
+
# exit()
|
314 |
+
query_output_up = self.decode_task_layer(quant_embedding)
|
315 |
+
|
316 |
+
pos_embed_image = self.pos_embed_image.repeat(query_output_up.shape[0], 1, 1)
|
317 |
+
query_output_up_pos_image = query_output_up + pos_embed_image
|
318 |
+
for blk in self.blocks_image:
|
319 |
+
query_output_up_pos_image = blk(query_output_up_pos_image)
|
320 |
+
query_output_up = query_output_up_pos_image
|
321 |
+
|
322 |
+
if self.use_qformer_image:
|
323 |
+
query_atts = torch.ones(query_output_up.size()[:-1], dtype=torch.long).to(query_output_up.device)
|
324 |
+
reverse_tokens = self.reverse_tokens.expand(query_output_up.shape[0], -1, -1)
|
325 |
+
reverse_output = self.Reverse_Qformer.bert(
|
326 |
+
query_embeds=reverse_tokens,
|
327 |
+
encoder_hidden_states=query_output_up,
|
328 |
+
encoder_attention_mask=query_atts,
|
329 |
+
return_dict=True,
|
330 |
+
)
|
331 |
+
reverse_output = reverse_output.last_hidden_state
|
332 |
+
reverse_output_proj = self.distill_image_proj(reverse_output).squeeze(1)
|
333 |
+
else:
|
334 |
+
reverse_output = self.image_down(query_output_up)
|
335 |
+
reverse_output = reverse_output.reshape(reverse_output.shape[0], -1)
|
336 |
+
reverse_output_proj = self.distill_image_proj(reverse_output)
|
337 |
+
|
338 |
+
return reverse_output_proj
|
339 |
+
|
340 |
+
@classmethod
|
341 |
+
def from_pretrained(cls, pretrained_model_path, **kwargs):
|
342 |
+
vit_model = kwargs.get("vit_model", "eva_clip_g")
|
343 |
+
img_size = kwargs.get("image_size", 224)
|
344 |
+
num_query_token = kwargs.get("num_query_token", 32)
|
345 |
+
cross_attention_freq = kwargs.get("cross_attention_freq", 2)
|
346 |
+
|
347 |
+
drop_path_rate = kwargs.get("drop_path_rate", 0)
|
348 |
+
use_grad_checkpoint = kwargs.get("use_grad_checkpoint", False)
|
349 |
+
vit_precision = kwargs.get("vit_precision", "fp16")
|
350 |
+
freeze_vit = kwargs.get("freeze_vit", True)
|
351 |
+
|
352 |
+
max_txt_len = kwargs.get("max_txt_len", 32)
|
353 |
+
|
354 |
+
model = cls(
|
355 |
+
vit_model=vit_model,
|
356 |
+
img_size=img_size,
|
357 |
+
drop_path_rate=drop_path_rate,
|
358 |
+
use_grad_checkpoint=use_grad_checkpoint,
|
359 |
+
vit_precision=vit_precision,
|
360 |
+
freeze_vit=freeze_vit,
|
361 |
+
num_query_token=num_query_token,
|
362 |
+
cross_attention_freq=cross_attention_freq,
|
363 |
+
max_txt_len=max_txt_len,
|
364 |
+
)
|
365 |
+
|
366 |
+
if pretrained_model_path.startswith('http'):
|
367 |
+
print('start download seed model...')
|
368 |
+
cached_file = download_cached_file(pretrained_model_path, check_hash=False, progress=True)
|
369 |
+
print(cached_file)
|
370 |
+
ckpt = torch.load(cached_file, map_location="cpu")
|
371 |
+
else:
|
372 |
+
ckpt = torch.load(pretrained_model_path, map_location="cpu")
|
373 |
+
missing, unexcepted = model.load_state_dict(ckpt, strict=False)
|
374 |
+
print('missing keys: ', len(missing), 'unexpected keys:', len(unexcepted))
|
375 |
+
return model
|
models/seed_qformer/utils.py
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) 2022, salesforce.com, inc.
|
3 |
+
All rights reserved.
|
4 |
+
SPDX-License-Identifier: BSD-3-Clause
|
5 |
+
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
+
"""
|
7 |
+
|
8 |
+
import datetime
|
9 |
+
import functools
|
10 |
+
import os
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import torch.distributed as dist
|
14 |
+
import timm.models.hub as timm_hub
|
15 |
+
from urllib.parse import urlparse
|
16 |
+
|
17 |
+
|
18 |
+
def setup_for_distributed(is_master):
|
19 |
+
"""
|
20 |
+
This function disables printing when not in master process
|
21 |
+
"""
|
22 |
+
import builtins as __builtin__
|
23 |
+
|
24 |
+
builtin_print = __builtin__.print
|
25 |
+
|
26 |
+
def print(*args, **kwargs):
|
27 |
+
force = kwargs.pop("force", False)
|
28 |
+
if is_master or force:
|
29 |
+
builtin_print(*args, **kwargs)
|
30 |
+
|
31 |
+
__builtin__.print = print
|
32 |
+
|
33 |
+
|
34 |
+
def is_dist_avail_and_initialized():
|
35 |
+
if not dist.is_available():
|
36 |
+
return False
|
37 |
+
if not dist.is_initialized():
|
38 |
+
return False
|
39 |
+
return True
|
40 |
+
|
41 |
+
|
42 |
+
def get_world_size():
|
43 |
+
if not is_dist_avail_and_initialized():
|
44 |
+
return 1
|
45 |
+
return dist.get_world_size()
|
46 |
+
|
47 |
+
|
48 |
+
def get_rank():
|
49 |
+
if not is_dist_avail_and_initialized():
|
50 |
+
return 0
|
51 |
+
return dist.get_rank()
|
52 |
+
|
53 |
+
|
54 |
+
def is_main_process():
|
55 |
+
return get_rank() == 0
|
56 |
+
|
57 |
+
|
58 |
+
def init_distributed_mode(args):
|
59 |
+
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
|
60 |
+
args.rank = int(os.environ["RANK"])
|
61 |
+
args.world_size = int(os.environ["WORLD_SIZE"])
|
62 |
+
args.gpu = int(os.environ["LOCAL_RANK"])
|
63 |
+
elif "SLURM_PROCID" in os.environ:
|
64 |
+
args.rank = int(os.environ["SLURM_PROCID"])
|
65 |
+
args.gpu = args.rank % torch.cuda.device_count()
|
66 |
+
else:
|
67 |
+
print("Not using distributed mode")
|
68 |
+
args.distributed = False
|
69 |
+
return
|
70 |
+
|
71 |
+
args.distributed = True
|
72 |
+
|
73 |
+
torch.cuda.set_device(args.gpu)
|
74 |
+
args.dist_backend = "nccl"
|
75 |
+
print(
|
76 |
+
"| distributed init (rank {}, world {}): {}".format(args.rank, args.world_size, args.dist_url),
|
77 |
+
flush=True,
|
78 |
+
)
|
79 |
+
torch.distributed.init_process_group(
|
80 |
+
backend=args.dist_backend,
|
81 |
+
init_method=args.dist_url,
|
82 |
+
world_size=args.world_size,
|
83 |
+
rank=args.rank,
|
84 |
+
timeout=datetime.timedelta(days=365), # allow auto-downloading and de-compressing
|
85 |
+
)
|
86 |
+
torch.distributed.barrier()
|
87 |
+
setup_for_distributed(args.rank == 0)
|
88 |
+
|
89 |
+
|
90 |
+
def get_dist_info():
|
91 |
+
if torch.__version__ < "1.0":
|
92 |
+
initialized = dist._initialized
|
93 |
+
else:
|
94 |
+
initialized = dist.is_initialized()
|
95 |
+
if initialized:
|
96 |
+
rank = dist.get_rank()
|
97 |
+
world_size = dist.get_world_size()
|
98 |
+
else: # non-distributed training
|
99 |
+
rank = 0
|
100 |
+
world_size = 1
|
101 |
+
return rank, world_size
|
102 |
+
|
103 |
+
|
104 |
+
def main_process(func):
|
105 |
+
@functools.wraps(func)
|
106 |
+
def wrapper(*args, **kwargs):
|
107 |
+
rank, _ = get_dist_info()
|
108 |
+
if rank == 0:
|
109 |
+
return func(*args, **kwargs)
|
110 |
+
|
111 |
+
return wrapper
|
112 |
+
|
113 |
+
|
114 |
+
def download_cached_file(url, check_hash=True, progress=False):
|
115 |
+
"""
|
116 |
+
Download a file from a URL and cache it locally. If the file already exists, it is not downloaded again.
|
117 |
+
If distributed, only the main process downloads the file, and the other processes wait for the file to be downloaded.
|
118 |
+
"""
|
119 |
+
def get_cached_file_path():
|
120 |
+
# a hack to sync the file path across processes
|
121 |
+
parts = torch.hub.urlparse(url)
|
122 |
+
filename = os.path.basename(parts.path)
|
123 |
+
cached_file = os.path.join(timm_hub.get_cache_dir(), filename)
|
124 |
+
|
125 |
+
return cached_file
|
126 |
+
|
127 |
+
if is_main_process():
|
128 |
+
timm_hub.download_cached_file(url, check_hash, progress)
|
129 |
+
|
130 |
+
if is_dist_avail_and_initialized():
|
131 |
+
dist.barrier()
|
132 |
+
|
133 |
+
return get_cached_file_path()
|
134 |
+
|
135 |
+
|
136 |
+
def is_url(url_or_filename):
|
137 |
+
parsed = urlparse(url_or_filename)
|
138 |
+
return parsed.scheme in ("http", "https")
|
models/seed_qformer/vit.py
ADDED
@@ -0,0 +1,395 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) 2022, salesforce.com, inc.
|
3 |
+
All rights reserved.
|
4 |
+
SPDX-License-Identifier: BSD-3-Clause
|
5 |
+
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
+
|
7 |
+
Based on timm code base
|
8 |
+
https://github.com/rwightman/pytorch-image-models/tree/master/timm
|
9 |
+
"""
|
10 |
+
|
11 |
+
import math
|
12 |
+
import torch
|
13 |
+
import torch.nn as nn
|
14 |
+
import torch.nn.functional as F
|
15 |
+
from functools import partial
|
16 |
+
|
17 |
+
from timm.models.vision_transformer import _cfg, PatchEmbed
|
18 |
+
from timm.models.registry import register_model
|
19 |
+
from timm.models.layers import trunc_normal_, DropPath
|
20 |
+
from timm.models.helpers import named_apply, adapt_input_conv
|
21 |
+
|
22 |
+
|
23 |
+
class Mlp(nn.Module):
|
24 |
+
"""MLP as used in Vision Transformer, MLP-Mixer and related networks"""
|
25 |
+
def __init__(
|
26 |
+
self,
|
27 |
+
in_features,
|
28 |
+
hidden_features=None,
|
29 |
+
out_features=None,
|
30 |
+
act_layer=nn.GELU,
|
31 |
+
drop=0.0,
|
32 |
+
):
|
33 |
+
super().__init__()
|
34 |
+
out_features = out_features or in_features
|
35 |
+
hidden_features = hidden_features or in_features
|
36 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
37 |
+
self.act = act_layer()
|
38 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
39 |
+
self.drop = nn.Dropout(drop)
|
40 |
+
|
41 |
+
def forward(self, x):
|
42 |
+
x = self.fc1(x)
|
43 |
+
x = self.act(x)
|
44 |
+
x = self.drop(x)
|
45 |
+
x = self.fc2(x)
|
46 |
+
x = self.drop(x)
|
47 |
+
return x
|
48 |
+
|
49 |
+
|
50 |
+
class Attention(nn.Module):
|
51 |
+
def __init__(
|
52 |
+
self,
|
53 |
+
dim,
|
54 |
+
num_heads=8,
|
55 |
+
qkv_bias=False,
|
56 |
+
qk_scale=None,
|
57 |
+
attn_drop=0.0,
|
58 |
+
proj_drop=0.0,
|
59 |
+
):
|
60 |
+
super().__init__()
|
61 |
+
self.num_heads = num_heads
|
62 |
+
head_dim = dim // num_heads
|
63 |
+
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
|
64 |
+
self.scale = qk_scale or head_dim**-0.5
|
65 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
66 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
67 |
+
self.proj = nn.Linear(dim, dim)
|
68 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
69 |
+
self.attn_gradients = None
|
70 |
+
self.attention_map = None
|
71 |
+
|
72 |
+
def save_attn_gradients(self, attn_gradients):
|
73 |
+
self.attn_gradients = attn_gradients
|
74 |
+
|
75 |
+
def get_attn_gradients(self):
|
76 |
+
return self.attn_gradients
|
77 |
+
|
78 |
+
def save_attention_map(self, attention_map):
|
79 |
+
self.attention_map = attention_map
|
80 |
+
|
81 |
+
def get_attention_map(self):
|
82 |
+
return self.attention_map
|
83 |
+
|
84 |
+
def forward(self, x, register_hook=False):
|
85 |
+
B, N, C = x.shape
|
86 |
+
qkv = (self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4))
|
87 |
+
q, k, v = (
|
88 |
+
qkv[0],
|
89 |
+
qkv[1],
|
90 |
+
qkv[2],
|
91 |
+
) # make torchscript happy (cannot use tensor as tuple)
|
92 |
+
|
93 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
94 |
+
attn = attn.softmax(dim=-1)
|
95 |
+
attn = self.attn_drop(attn)
|
96 |
+
|
97 |
+
if register_hook:
|
98 |
+
self.save_attention_map(attn)
|
99 |
+
attn.register_hook(self.save_attn_gradients)
|
100 |
+
|
101 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
102 |
+
x = self.proj(x)
|
103 |
+
x = self.proj_drop(x)
|
104 |
+
return x
|
105 |
+
|
106 |
+
|
107 |
+
class Block(nn.Module):
|
108 |
+
def __init__(
|
109 |
+
self,
|
110 |
+
dim,
|
111 |
+
num_heads,
|
112 |
+
mlp_ratio=4.0,
|
113 |
+
qkv_bias=False,
|
114 |
+
qk_scale=None,
|
115 |
+
drop=0.0,
|
116 |
+
attn_drop=0.0,
|
117 |
+
drop_path=0.0,
|
118 |
+
act_layer=nn.GELU,
|
119 |
+
norm_layer=nn.LayerNorm,
|
120 |
+
use_grad_checkpointing=False,
|
121 |
+
):
|
122 |
+
super().__init__()
|
123 |
+
self.norm1 = norm_layer(dim)
|
124 |
+
self.attn = Attention(
|
125 |
+
dim,
|
126 |
+
num_heads=num_heads,
|
127 |
+
qkv_bias=qkv_bias,
|
128 |
+
qk_scale=qk_scale,
|
129 |
+
attn_drop=attn_drop,
|
130 |
+
proj_drop=drop,
|
131 |
+
)
|
132 |
+
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
133 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
134 |
+
self.norm2 = norm_layer(dim)
|
135 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
136 |
+
self.mlp = Mlp(
|
137 |
+
in_features=dim,
|
138 |
+
hidden_features=mlp_hidden_dim,
|
139 |
+
act_layer=act_layer,
|
140 |
+
drop=drop,
|
141 |
+
)
|
142 |
+
|
143 |
+
# if use_grad_checkpointing:
|
144 |
+
# self.attn = checkpoint_wrapper(self.attn)
|
145 |
+
# self.mlp = checkpoint_wrapper(self.mlp)
|
146 |
+
|
147 |
+
def forward(self, x, register_hook=False):
|
148 |
+
x = x + self.drop_path(self.attn(self.norm1(x), register_hook=register_hook))
|
149 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
150 |
+
return x
|
151 |
+
|
152 |
+
|
153 |
+
class VisionTransformer(nn.Module):
|
154 |
+
"""Vision Transformer
|
155 |
+
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` -
|
156 |
+
https://arxiv.org/abs/2010.11929
|
157 |
+
"""
|
158 |
+
def __init__(
|
159 |
+
self,
|
160 |
+
img_size=224,
|
161 |
+
patch_size=16,
|
162 |
+
in_chans=3,
|
163 |
+
num_classes=1000,
|
164 |
+
embed_dim=768,
|
165 |
+
depth=12,
|
166 |
+
num_heads=12,
|
167 |
+
mlp_ratio=4.0,
|
168 |
+
qkv_bias=True,
|
169 |
+
qk_scale=None,
|
170 |
+
representation_size=None,
|
171 |
+
drop_rate=0.0,
|
172 |
+
attn_drop_rate=0.0,
|
173 |
+
drop_path_rate=0.0,
|
174 |
+
norm_layer=None,
|
175 |
+
use_grad_checkpointing=False,
|
176 |
+
ckpt_layer=0,
|
177 |
+
):
|
178 |
+
"""
|
179 |
+
Args:
|
180 |
+
img_size (int, tuple): input image size
|
181 |
+
patch_size (int, tuple): patch size
|
182 |
+
in_chans (int): number of input channels
|
183 |
+
num_classes (int): number of classes for classification head
|
184 |
+
embed_dim (int): embedding dimension
|
185 |
+
depth (int): depth of transformer
|
186 |
+
num_heads (int): number of attention heads
|
187 |
+
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
|
188 |
+
qkv_bias (bool): enable bias for qkv if True
|
189 |
+
qk_scale (float): override default qk scale of head_dim ** -0.5 if set
|
190 |
+
representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
|
191 |
+
drop_rate (float): dropout rate
|
192 |
+
attn_drop_rate (float): attention dropout rate
|
193 |
+
drop_path_rate (float): stochastic depth rate
|
194 |
+
norm_layer: (nn.Module): normalization layer
|
195 |
+
"""
|
196 |
+
super().__init__()
|
197 |
+
self.num_features = (self.embed_dim) = embed_dim # num_features for consistency with other models
|
198 |
+
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
|
199 |
+
|
200 |
+
self.patch_embed = PatchEmbed(
|
201 |
+
img_size=img_size,
|
202 |
+
patch_size=patch_size,
|
203 |
+
in_chans=in_chans,
|
204 |
+
embed_dim=embed_dim,
|
205 |
+
)
|
206 |
+
|
207 |
+
num_patches = self.patch_embed.num_patches
|
208 |
+
|
209 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
210 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
|
211 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
212 |
+
|
213 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
214 |
+
self.blocks = nn.ModuleList([
|
215 |
+
Block(
|
216 |
+
dim=embed_dim,
|
217 |
+
num_heads=num_heads,
|
218 |
+
mlp_ratio=mlp_ratio,
|
219 |
+
qkv_bias=qkv_bias,
|
220 |
+
qk_scale=qk_scale,
|
221 |
+
drop=drop_rate,
|
222 |
+
attn_drop=attn_drop_rate,
|
223 |
+
drop_path=dpr[i],
|
224 |
+
norm_layer=norm_layer,
|
225 |
+
use_grad_checkpointing=(use_grad_checkpointing and i >= depth - ckpt_layer),
|
226 |
+
) for i in range(depth)
|
227 |
+
])
|
228 |
+
self.norm = norm_layer(embed_dim)
|
229 |
+
|
230 |
+
trunc_normal_(self.pos_embed, std=0.02)
|
231 |
+
trunc_normal_(self.cls_token, std=0.02)
|
232 |
+
self.apply(self._init_weights)
|
233 |
+
|
234 |
+
def _init_weights(self, m):
|
235 |
+
if isinstance(m, nn.Linear):
|
236 |
+
trunc_normal_(m.weight, std=0.02)
|
237 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
238 |
+
nn.init.constant_(m.bias, 0)
|
239 |
+
elif isinstance(m, nn.LayerNorm):
|
240 |
+
nn.init.constant_(m.bias, 0)
|
241 |
+
nn.init.constant_(m.weight, 1.0)
|
242 |
+
|
243 |
+
@torch.jit.ignore
|
244 |
+
def no_weight_decay(self):
|
245 |
+
return {"pos_embed", "cls_token"}
|
246 |
+
|
247 |
+
def forward(self, x, register_blk=-1):
|
248 |
+
B = x.shape[0]
|
249 |
+
x = self.patch_embed(x)
|
250 |
+
|
251 |
+
cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
|
252 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
253 |
+
|
254 |
+
x = x + self.pos_embed[:, :x.size(1), :]
|
255 |
+
x = self.pos_drop(x)
|
256 |
+
|
257 |
+
for i, blk in enumerate(self.blocks):
|
258 |
+
x = blk(x, register_blk == i)
|
259 |
+
x = self.norm(x)
|
260 |
+
|
261 |
+
return x
|
262 |
+
|
263 |
+
@torch.jit.ignore()
|
264 |
+
def load_pretrained(self, checkpoint_path, prefix=""):
|
265 |
+
_load_weights(self, checkpoint_path, prefix)
|
266 |
+
|
267 |
+
|
268 |
+
@torch.no_grad()
|
269 |
+
def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ""):
|
270 |
+
"""Load weights from .npz checkpoints for official Google Brain Flax implementation"""
|
271 |
+
import numpy as np
|
272 |
+
|
273 |
+
def _n2p(w, t=True):
|
274 |
+
if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1:
|
275 |
+
w = w.flatten()
|
276 |
+
if t:
|
277 |
+
if w.ndim == 4:
|
278 |
+
w = w.transpose([3, 2, 0, 1])
|
279 |
+
elif w.ndim == 3:
|
280 |
+
w = w.transpose([2, 0, 1])
|
281 |
+
elif w.ndim == 2:
|
282 |
+
w = w.transpose([1, 0])
|
283 |
+
return torch.from_numpy(w)
|
284 |
+
|
285 |
+
w = np.load(checkpoint_path)
|
286 |
+
if not prefix and "opt/target/embedding/kernel" in w:
|
287 |
+
prefix = "opt/target/"
|
288 |
+
|
289 |
+
if hasattr(model.patch_embed, "backbone"):
|
290 |
+
# hybrid
|
291 |
+
backbone = model.patch_embed.backbone
|
292 |
+
stem_only = not hasattr(backbone, "stem")
|
293 |
+
stem = backbone if stem_only else backbone.stem
|
294 |
+
stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f"{prefix}conv_root/kernel"])))
|
295 |
+
stem.norm.weight.copy_(_n2p(w[f"{prefix}gn_root/scale"]))
|
296 |
+
stem.norm.bias.copy_(_n2p(w[f"{prefix}gn_root/bias"]))
|
297 |
+
if not stem_only:
|
298 |
+
for i, stage in enumerate(backbone.stages):
|
299 |
+
for j, block in enumerate(stage.blocks):
|
300 |
+
bp = f"{prefix}block{i + 1}/unit{j + 1}/"
|
301 |
+
for r in range(3):
|
302 |
+
getattr(block, f"conv{r + 1}").weight.copy_(_n2p(w[f"{bp}conv{r + 1}/kernel"]))
|
303 |
+
getattr(block, f"norm{r + 1}").weight.copy_(_n2p(w[f"{bp}gn{r + 1}/scale"]))
|
304 |
+
getattr(block, f"norm{r + 1}").bias.copy_(_n2p(w[f"{bp}gn{r + 1}/bias"]))
|
305 |
+
if block.downsample is not None:
|
306 |
+
block.downsample.conv.weight.copy_(_n2p(w[f"{bp}conv_proj/kernel"]))
|
307 |
+
block.downsample.norm.weight.copy_(_n2p(w[f"{bp}gn_proj/scale"]))
|
308 |
+
block.downsample.norm.bias.copy_(_n2p(w[f"{bp}gn_proj/bias"]))
|
309 |
+
embed_conv_w = _n2p(w[f"{prefix}embedding/kernel"])
|
310 |
+
else:
|
311 |
+
embed_conv_w = adapt_input_conv(model.patch_embed.proj.weight.shape[1], _n2p(w[f"{prefix}embedding/kernel"]))
|
312 |
+
model.patch_embed.proj.weight.copy_(embed_conv_w)
|
313 |
+
model.patch_embed.proj.bias.copy_(_n2p(w[f"{prefix}embedding/bias"]))
|
314 |
+
model.cls_token.copy_(_n2p(w[f"{prefix}cls"], t=False))
|
315 |
+
pos_embed_w = _n2p(w[f"{prefix}Transformer/posembed_input/pos_embedding"], t=False)
|
316 |
+
if pos_embed_w.shape != model.pos_embed.shape:
|
317 |
+
pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights
|
318 |
+
pos_embed_w,
|
319 |
+
model.pos_embed,
|
320 |
+
getattr(model, "num_tokens", 1),
|
321 |
+
model.patch_embed.grid_size,
|
322 |
+
)
|
323 |
+
model.pos_embed.copy_(pos_embed_w)
|
324 |
+
model.norm.weight.copy_(_n2p(w[f"{prefix}Transformer/encoder_norm/scale"]))
|
325 |
+
model.norm.bias.copy_(_n2p(w[f"{prefix}Transformer/encoder_norm/bias"]))
|
326 |
+
# if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]:
|
327 |
+
# model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
|
328 |
+
# model.head.bias.copy_(_n2p(w[f'{prefix}head/bias']))
|
329 |
+
# if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w:
|
330 |
+
# model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel']))
|
331 |
+
# model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias']))
|
332 |
+
for i, block in enumerate(model.blocks.children()):
|
333 |
+
block_prefix = f"{prefix}Transformer/encoderblock_{i}/"
|
334 |
+
mha_prefix = block_prefix + "MultiHeadDotProductAttention_1/"
|
335 |
+
block.norm1.weight.copy_(_n2p(w[f"{block_prefix}LayerNorm_0/scale"]))
|
336 |
+
block.norm1.bias.copy_(_n2p(w[f"{block_prefix}LayerNorm_0/bias"]))
|
337 |
+
block.attn.qkv.weight.copy_(
|
338 |
+
torch.cat([_n2p(w[f"{mha_prefix}{n}/kernel"], t=False).flatten(1).T for n in ("query", "key", "value")]))
|
339 |
+
block.attn.qkv.bias.copy_(
|
340 |
+
torch.cat([_n2p(w[f"{mha_prefix}{n}/bias"], t=False).reshape(-1) for n in ("query", "key", "value")]))
|
341 |
+
block.attn.proj.weight.copy_(_n2p(w[f"{mha_prefix}out/kernel"]).flatten(1))
|
342 |
+
block.attn.proj.bias.copy_(_n2p(w[f"{mha_prefix}out/bias"]))
|
343 |
+
for r in range(2):
|
344 |
+
getattr(block.mlp, f"fc{r + 1}").weight.copy_(_n2p(w[f"{block_prefix}MlpBlock_3/Dense_{r}/kernel"]))
|
345 |
+
getattr(block.mlp, f"fc{r + 1}").bias.copy_(_n2p(w[f"{block_prefix}MlpBlock_3/Dense_{r}/bias"]))
|
346 |
+
block.norm2.weight.copy_(_n2p(w[f"{block_prefix}LayerNorm_2/scale"]))
|
347 |
+
block.norm2.bias.copy_(_n2p(w[f"{block_prefix}LayerNorm_2/bias"]))
|
348 |
+
|
349 |
+
|
350 |
+
def resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=()):
|
351 |
+
# Rescale the grid of position embeddings when loading from state_dict. Adapted from
|
352 |
+
# https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
|
353 |
+
print("Resized position embedding: %s to %s", posemb.shape, posemb_new.shape)
|
354 |
+
ntok_new = posemb_new.shape[1]
|
355 |
+
if num_tokens:
|
356 |
+
posemb_tok, posemb_grid = posemb[:, :num_tokens], posemb[0, num_tokens:]
|
357 |
+
ntok_new -= num_tokens
|
358 |
+
else:
|
359 |
+
posemb_tok, posemb_grid = posemb[:, :0], posemb[0]
|
360 |
+
gs_old = int(math.sqrt(len(posemb_grid)))
|
361 |
+
if not len(gs_new): # backwards compatibility
|
362 |
+
gs_new = [int(math.sqrt(ntok_new))] * 2
|
363 |
+
assert len(gs_new) >= 2
|
364 |
+
print("Position embedding grid-size from %s to %s", [gs_old, gs_old], gs_new)
|
365 |
+
posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
|
366 |
+
posemb_grid = F.interpolate(posemb_grid, size=gs_new, mode="bicubic", align_corners=False)
|
367 |
+
posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new[0] * gs_new[1], -1)
|
368 |
+
posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
|
369 |
+
return
|
370 |
+
|
371 |
+
|
372 |
+
def interpolate_pos_embed(pos_embed_checkpoint, visual_encoder):
|
373 |
+
# interpolate position embedding
|
374 |
+
embedding_size = pos_embed_checkpoint.shape[-1]
|
375 |
+
num_patches = visual_encoder.patch_embed.num_patches
|
376 |
+
num_extra_tokens = visual_encoder.pos_embed.shape[-2] - num_patches
|
377 |
+
# height (== width) for the checkpoint position embedding
|
378 |
+
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens)**0.5)
|
379 |
+
# height (== width) for the new position embedding
|
380 |
+
new_size = int(num_patches**0.5)
|
381 |
+
|
382 |
+
if orig_size != new_size:
|
383 |
+
# class_token and dist_token are kept unchanged
|
384 |
+
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
|
385 |
+
# only the position tokens are interpolated
|
386 |
+
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
|
387 |
+
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
|
388 |
+
pos_tokens = torch.nn.functional.interpolate(pos_tokens, size=(new_size, new_size), mode="bicubic", align_corners=False)
|
389 |
+
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
|
390 |
+
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
|
391 |
+
print("reshape position embedding from %d to %d" % (orig_size**2, new_size**2))
|
392 |
+
|
393 |
+
return new_pos_embed
|
394 |
+
else:
|
395 |
+
return pos_embed_checkpoint
|
models/transforms.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torchvision import transforms
|
2 |
+
|
3 |
+
|
4 |
+
def get_transform(type='clip', keep_ratio=True, image_size=224):
|
5 |
+
if type == 'clip':
|
6 |
+
transform = []
|
7 |
+
if keep_ratio:
|
8 |
+
transform.extend([
|
9 |
+
transforms.Resize(image_size),
|
10 |
+
transforms.CenterCrop(image_size),
|
11 |
+
])
|
12 |
+
else:
|
13 |
+
transform.append(transforms.Resize((image_size, image_size)))
|
14 |
+
transform.extend([
|
15 |
+
transforms.ToTensor(),
|
16 |
+
transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
|
17 |
+
])
|
18 |
+
|
19 |
+
return transforms.Compose(transform)
|
20 |
+
else:
|
21 |
+
raise NotImplementedError
|