yuyingge commited on
Commit
590af54
1 Parent(s): 1264376

Add application file

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. License.txt +335 -0
  2. configs/.DS_Store +0 -0
  3. configs/clm_models/.DS_Store +0 -0
  4. configs/clm_models/agent_seed_x_i.yaml +23 -0
  5. configs/clm_models/llm_seed_x_i.yaml +3 -0
  6. configs/discrete_model/.DS_Store +0 -0
  7. configs/discrete_model/discrete_identity.yaml +1 -0
  8. configs/processer/.DS_Store +0 -0
  9. configs/processer/qwen_448_transform.yaml +4 -0
  10. configs/sdxl_adapter/.DS_Store +0 -0
  11. configs/sdxl_adapter/sdxl_qwen_vit_resampler_l4_q64_full_with_latent_image_pretrain_no_normalize.yaml +20 -0
  12. configs/sdxl_adapter/sdxl_qwen_vit_resampler_l4_q64_pretrain_no_normalize.yaml +18 -0
  13. configs/tokenizer/.DS_Store +0 -0
  14. configs/tokenizer/clm_llama_tokenizer_224loc_anyres.yaml +2 -0
  15. configs/visual_encoder/.DS_Store +0 -0
  16. configs/visual_encoder/qwen_vitg_448.yaml +11 -0
  17. pretrained/QwenViT/qwen_vit_G.pt +3 -0
  18. requirements.txt +11 -0
  19. seed_x/arrow.jpg +0 -0
  20. seed_x/bank.png +0 -0
  21. src/.DS_Store +0 -0
  22. src/demo/__pycache__/conversation.cpython-311.pyc +0 -0
  23. src/demo/__pycache__/conversation.cpython-38.pyc +0 -0
  24. src/demo/__pycache__/utils.cpython-311.pyc +0 -0
  25. src/demo/__pycache__/utils.cpython-38.pyc +0 -0
  26. src/demo/configs/agent_13b_anyres_out_64_pretrain_merged.yaml +29 -0
  27. src/demo/configs/agent_13b_in100_out64_rs5_merged_pretrain.yaml +22 -0
  28. src/demo/configs/llama2chat13b_merged_100imgtokens.yaml +12 -0
  29. src/demo/conversation.py +182 -0
  30. src/demo/seed_llama_flask.py +379 -0
  31. src/demo/seed_llama_gradio.py +465 -0
  32. src/demo/utils.py +83 -0
  33. src/inference/.DS_Store +0 -0
  34. src/inference/__pycache__/any_res.cpython-311.pyc +0 -0
  35. src/inference/__pycache__/any_res.cpython-38.pyc +0 -0
  36. src/inference/any_res.py +257 -0
  37. src/inference/eval_img2edit_seed_x.py +155 -0
  38. src/inference/eval_img2text_seed_x.py +235 -0
  39. src/inference/eval_text2img_seed_x.py +94 -0
  40. src/models/detokenizer/__init__.py +1 -0
  41. src/models/detokenizer/__pycache__/__init__.cpython-311.pyc +0 -0
  42. src/models/detokenizer/__pycache__/__init__.cpython-38.pyc +0 -0
  43. src/models/detokenizer/__pycache__/adapter_modules.cpython-311.pyc +0 -0
  44. src/models/detokenizer/__pycache__/adapter_modules.cpython-38.pyc +0 -0
  45. src/models/detokenizer/__pycache__/attention_processor.cpython-38.pyc +0 -0
  46. src/models/detokenizer/__pycache__/ipa_utils.cpython-38.pyc +0 -0
  47. src/models/detokenizer/__pycache__/pipeline_stable_diffusion_t2i_edit.cpython-38.pyc +0 -0
  48. src/models/detokenizer/__pycache__/pipeline_stable_diffusion_xl_t2i_edit.cpython-311.pyc +0 -0
  49. src/models/detokenizer/__pycache__/pipeline_stable_diffusion_xl_t2i_edit.cpython-38.pyc +0 -0
  50. src/models/detokenizer/__pycache__/resampler.cpython-311.pyc +0 -0
License.txt ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Tencent is pleased to support the open source community by making Seed-X available.
2
+
3
+ Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
4
+
5
+ Seed-X is licensed under the Apache License Version 2.0 except for the third-party components listed below.
6
+
7
+
8
+ Terms of the Apache License Version 2.0:
9
+ --------------------------------------------------------------------
10
+ Apache License
11
+
12
+ Version 2.0, January 2004
13
+
14
+ http://www.apache.org/licenses/
15
+
16
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
17
+ 1. Definitions.
18
+
19
+ "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document.
20
+
21
+ "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License.
22
+
23
+ "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity.
24
+
25
+ "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License.
26
+
27
+ "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files.
28
+
29
+ "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types.
30
+
31
+ "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below).
32
+
33
+ "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof.
34
+
35
+ "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution."
36
+
37
+ "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work.
38
+
39
+ 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form.
40
+
41
+ 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed.
42
+
43
+ 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions:
44
+
45
+ You must give any other recipients of the Work or Derivative Works a copy of this License; and
46
+
47
+ You must cause any modified files to carry prominent notices stating that You changed the files; and
48
+
49
+ You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and
50
+
51
+ If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License.
52
+
53
+ You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License.
54
+
55
+ 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions.
56
+
57
+ 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file.
58
+
59
+ 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License.
60
+
61
+ 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages.
62
+
63
+ 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability.
64
+
65
+ END OF TERMS AND CONDITIONS
66
+
67
+
68
+
69
+ Other dependencies and licenses:
70
+
71
+
72
+ Open Source Software Licensed under the Apache License Version 2.0:
73
+ --------------------------------------------------------------------
74
+ 1. transformers
75
+ Copyright 2018- The Hugging Face team. All rights reserved.
76
+ Source code of this software can be obtained from: https://github.com/huggingface/transformers/blob/v4.30.2/
77
+
78
+ 2. diffusers
79
+ Copyright 2023 The HuggingFace Team. All rights reserved.
80
+ Source code of this software can be obtained from: https://github.com/huggingface/diffusers/blob/v0.25.0/
81
+
82
+ A copy of Apache 2.0 has been included in this file.
83
+
84
+
85
+
86
+ Open Source Software Licensed under the BSD 3-Clause License:
87
+ --------------------------------------------------------------------
88
+ 1. torchvision
89
+ Copyright (c) Soumith Chintala 2016,
90
+ All rights reserved.
91
+
92
+ Terms of the BSD 3-Clause License:
93
+ --------------------------------------------------------------------
94
+ Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
95
+
96
+ 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
97
+
98
+ 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
99
+
100
+ 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
101
+
102
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
103
+
104
+
105
+
106
+ Open Source Software Licensed under the BSD 3-Clause License and Other Licenses of the Third-Party Components therein:
107
+ --------------------------------------------------------------------
108
+ 1. numpy
109
+ Copyright (c) 2005-2021, NumPy Developers.
110
+ All rights reserved.
111
+
112
+ A copy of the BSD 3-Clause License is included in this file.
113
+
114
+ For the license of other third party components, please refer to the following URL:
115
+ https://github.com/numpy/numpy/blob/v1.20.1/LICENSES_bundled.txt
116
+
117
+
118
+
119
+ Open Source Software Licensed under the BSD 3-Clause License and Other Licenses of the Third-Party Components therein:
120
+ --------------------------------------------------------------------
121
+ 1. torch
122
+ Copyright (c) 2016- Facebook, Inc (Adam Paszke)
123
+ Copyright (c) 2014- Facebook, Inc (Soumith Chintala)
124
+ Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
125
+ Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu)
126
+ Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
127
+ Copyright (c) 2011-2013 NYU (Clement Farabet)
128
+ Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
129
+ Copyright (c) 2006 Idiap Research Institute (Samy Bengio)
130
+ Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
131
+
132
+ A copy of the BSD 3-Clause License is included in this file.
133
+
134
+ For the license of other third party components, please refer to the following URL:
135
+ https://github.com/pytorch/pytorch/blob/v2.0.1/NOTICE
136
+
137
+
138
+
139
+ Open Source Software Licensed under the LLAMA 2 Community License:
140
+ --------------------------------------------------------------------
141
+ 1. Llama 2
142
+ Copyright (c) Meta Platforms, Inc. All Rights Reserved.
143
+
144
+
145
+ Terms of the LLAMA 2 COMMUNITY LICENSE AGREEMENT:
146
+ --------------------------------------------------------------------
147
+ LLAMA 2 COMMUNITY LICENSE AGREEMENT
148
+ Llama 2 Version Release Date: July 18, 2023
149
+
150
+ "Agreement" means the terms and conditions for use, reproduction, distribution and
151
+ modification of the Llama Materials set forth herein.
152
+
153
+ "Documentation" means the specifications, manuals and documentation
154
+ accompanying Llama 2 distributed by Meta at ai.meta.com/resources/models-and-
155
+ libraries/llama-downloads/.
156
+
157
+ "Licensee" or "you" means you, or your employer or any other person or entity (if
158
+ you are entering into this Agreement on such person or entity's behalf), of the age
159
+ required under applicable laws, rules or regulations to provide legal consent and that
160
+ has legal authority to bind your employer or such other person or entity if you are
161
+ entering in this Agreement on their behalf.
162
+
163
+ "Llama 2" means the foundational large language models and software and
164
+ algorithms, including machine-learning model code, trained model weights,
165
+ inference-enabling code, training-enabling code, fine-tuning enabling code and other
166
+ elements of the foregoing distributed by Meta at ai.meta.com/resources/models-and-
167
+ libraries/llama-downloads/.
168
+
169
+ "Llama Materials" means, collectively, Meta's proprietary Llama 2 and
170
+ Documentation (and any portion thereof) made available under this Agreement.
171
+
172
+ "Meta" or "we" means Meta Platforms Ireland Limited (if you are located in or, if you
173
+ are an entity, your principal place of business is in the EEA or Switzerland) and Meta
174
+ Platforms, Inc. (if you are located outside of the EEA or Switzerland).
175
+
176
+ By clicking "I Accept" below or by using or distributing any portion or element of the
177
+ Llama Materials, you agree to be bound by this Agreement.
178
+
179
+ 1. License Rights and Redistribution.
180
+
181
+ a. Grant of Rights. You are granted a non-exclusive, worldwide, non-
182
+ transferable and royalty-free limited license under Meta's intellectual property or
183
+ other rights owned by Meta embodied in the Llama Materials to use, reproduce,
184
+ distribute, copy, create derivative works of, and make modifications to the Llama
185
+ Materials.
186
+
187
+ b. Redistribution and Use.
188
+
189
+ i. If you distribute or make the Llama Materials, or any derivative works
190
+ thereof, available to a third party, you shall provide a copy of this Agreement to such
191
+ third party.
192
+ ii. If you receive Llama Materials, or any derivative works thereof, from
193
+ a Licensee as part of an integrated end user product, then Section 2 of this
194
+ Agreement will not apply to you.
195
+
196
+ iii. You must retain in all copies of the Llama Materials that you
197
+ distribute the following attribution notice within a "Notice" text file distributed as a
198
+ part of such copies: "Llama 2 is licensed under the LLAMA 2 Community License,
199
+ Copyright (c) Meta Platforms, Inc. All Rights Reserved."
200
+
201
+ iv. Your use of the Llama Materials must comply with applicable laws
202
+ and regulations (including trade compliance laws and regulations) and adhere to the
203
+ Acceptable Use Policy for the Llama Materials (available at
204
+ https://ai.meta.com/llama/use-policy), which is hereby incorporated by reference into
205
+ this Agreement.
206
+
207
+ v. You will not use the Llama Materials or any output or results of the
208
+ Llama Materials to improve any other large language model (excluding Llama 2 or
209
+ derivative works thereof).
210
+
211
+ 2. Additional Commercial Terms. If, on the Llama 2 version release date, the
212
+ monthly active users of the products or services made available by or for Licensee,
213
+ or Licensee's affiliates, is greater than 700 million monthly active users in the
214
+ preceding calendar month, you must request a license from Meta, which Meta may
215
+ grant to you in its sole discretion, and you are not authorized to exercise any of the
216
+ rights under this Agreement unless or until Meta otherwise expressly grants you
217
+ such rights.
218
+
219
+ 3. Disclaimer of Warranty. UNLESS REQUIRED BY APPLICABLE LAW, THE
220
+ LLAMA MATERIALS AND ANY OUTPUT AND RESULTS THEREFROM ARE
221
+ PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
222
+ EITHER EXPRESS OR IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY
223
+ WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR
224
+ FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE
225
+ FOR DETERMINING THE APPROPRIATENESS OF USING OR REDISTRIBUTING
226
+ THE LLAMA MATERIALS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR
227
+ USE OF THE LLAMA MATERIALS AND ANY OUTPUT AND RESULTS.
228
+
229
+ 4. Limitation of Liability. IN NO EVENT WILL META OR ITS AFFILIATES BE
230
+ LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT,
231
+ NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS
232
+ AGREEMENT, FOR ANY LOST PROFITS OR ANY INDIRECT, SPECIAL,
233
+ CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN
234
+ IF META OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF
235
+ ANY OF THE FOREGOING.
236
+
237
+ 5. Intellectual Property.
238
+
239
+ a. No trademark licenses are granted under this Agreement, and in
240
+ connection with the Llama Materials, neither Meta nor Licensee may use any name
241
+ or mark owned by or associated with the other or any of its affiliates, except as
242
+ required for reasonable and customary use in describing and redistributing the
243
+ Llama Materials.
244
+
245
+ b. Subject to Meta's ownership of Llama Materials and derivatives made by or
246
+ for Meta, with respect to any derivative works and modifications of the Llama
247
+ Materials that are made by you, as between you and Meta, you are and will be the
248
+ owner of such derivative works and modifications.
249
+
250
+ c. If you institute litigation or other proceedings against Meta or any entity
251
+ (including a cross-claim or counterclaim in a lawsuit) alleging that the Llama
252
+ Materials or Llama 2 outputs or results, or any portion of any of the foregoing,
253
+ constitutes infringement of intellectual property or other rights owned or licensable
254
+ by you, then any licenses granted to you under this Agreement shall terminate as of
255
+ the date such litigation or claim is filed or instituted. You will indemnify and hold
256
+ harmless Meta from and against any claim by any third party arising out of or related
257
+ to your use or distribution of the Llama Materials.
258
+
259
+ 6. Term and Termination. The term of this Agreement will commence upon your
260
+ acceptance of this Agreement or access to the Llama Materials and will continue in
261
+ full force and effect until terminated in accordance with the terms and conditions
262
+ herein. Meta may terminate this Agreement if you are in breach of any term or
263
+ condition of this Agreement. Upon termination of this Agreement, you shall delete
264
+ and cease use of the Llama Materials. Sections 3, 4 and 7 shall survive the
265
+ termination of this Agreement.
266
+
267
+ 7. Governing Law and Jurisdiction. This Agreement will be governed and
268
+ construed under the laws of the State of California without regard to choice of law
269
+ principles, and the UN Convention on Contracts for the International Sale of Goods
270
+ does not apply to this Agreement. The courts of California shall have exclusive
271
+ jurisdiction of any dispute arising out of this Agreement.
272
+
273
+
274
+
275
+ Open Source Software Licensed under the Tongyi Qianwen LICENSE AGREEMENT:
276
+ --------------------------------------------------------------------
277
+ 1. Qwen-VL
278
+ Copyright (c) Alibaba Cloud. All Rights Reserved.
279
+
280
+
281
+ Terms of the Tongyi Qianwen LICENSE AGREEMENT:
282
+ --------------------------------------------------------------------
283
+ Tongyi Qianwen LICENSE AGREEMENT
284
+
285
+ Tongyi Qianwen Release Date: August 23, 2023
286
+
287
+ By clicking to agree or by using or distributing any portion or element of the Tongyi Qianwen Materials, you will be deemed to have recognized and accepted the content of this Agreement, which is effective immediately.
288
+
289
+ 1. Definitions
290
+ a. This Tongyi Qianwen LICENSE AGREEMENT (this "Agreement") shall mean the terms and conditions for use, reproduction, distribution and modification of the Materials as defined by this Agreement.
291
+ b. "We"(or "Us") shall mean Alibaba Cloud.
292
+ c. "You" (or "Your") shall mean a natural person or legal entity exercising the rights granted by this Agreement and/or using the Materials for any purpose and in any field of use.
293
+ d. "Third Parties" shall mean individuals or legal entities that are not under common control with Us or You.
294
+ e. "Tongyi Qianwen" shall mean the large language models (including Qwen-VL model and Qwen-VL-Chat model), and software and algorithms, consisting of trained model weights, parameters (including optimizer states), machine-learning model code, inference-enabling code, training-enabling code, fine-tuning enabling code and other elements of the foregoing distributed by Us.
295
+ f. "Materials" shall mean, collectively, Alibaba Cloud's proprietary Tongyi Qianwen and Documentation (and any portion thereof) made available under this Agreement.
296
+ g. "Source" form shall mean the preferred form for making modifications, including but not limited to model source code, documentation source, and configuration files.
297
+ h. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation,
298
+ and conversions to other media types.
299
+
300
+ 2. Grant of Rights
301
+ You are granted a non-exclusive, worldwide, non-transferable and royalty-free limited license under Alibaba Cloud's intellectual property or other rights owned by Us embodied in the Materials to use, reproduce, distribute, copy, create derivative works of, and make modifications to the Materials.
302
+
303
+ 3. Redistribution
304
+ You may reproduce and distribute copies of the Materials or derivative works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions:
305
+ a. You shall give any other recipients of the Materials or derivative works a copy of this Agreement;
306
+ b. You shall cause any modified files to carry prominent notices stating that You changed the files;
307
+ c. You shall retain in all copies of the Materials that You distribute the following attribution notices within a "Notice" text file distributed as a part of such copies: "Tongyi Qianwen is licensed under the Tongyi Qianwen LICENSE AGREEMENT, Copyright (c) Alibaba Cloud. All Rights Reserved."; and
308
+ d. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such derivative works as a whole, provided Your use, reproduction, and distribution of the work otherwise complies with the terms and conditions of this Agreement.
309
+
310
+ 4. Restrictions
311
+ If you are commercially using the Materials, and your product or service has more than 100 million monthly active users, You shall request a license from Us. You cannot exercise your rights under this Agreement without our express authorization.
312
+
313
+ 5. Rules of use
314
+ a. The Materials may be subject to export controls or restrictions in China, the United States or other countries or regions. You shall comply with applicable laws and regulations in your use of the Materials.
315
+ b. You can not use the Materials or any output therefrom to improve any other large language model (excluding Tongyi Qianwen or derivative works thereof).
316
+
317
+ 6. Intellectual Property
318
+ a. We retain ownership of all intellectual property rights in and to the Materials and derivatives made by or for Us. Conditioned upon compliance with the terms and conditions of this Agreement, with respect to any derivative works and modifications of the Materials that are made by you, you are and will be the owner of such derivative works and modifications.
319
+ b. No trademark license is granted to use the trade names, trademarks, service marks, or product names of Us, except as required to fulfill notice requirements under this Agreement or as required for reasonable and customary use in describing and redistributing the Materials.
320
+ c. If you commence a lawsuit or other proceedings (including a cross-claim or counterclaim in a lawsuit) against Us or any entity alleging that the Materials or any output therefrom, or any part of the foregoing, infringe any intellectual property or other right owned or licensable by you, then all licences granted to you under this Agreement shall terminate as of the date such lawsuit or other proceeding is commenced or brought.
321
+
322
+ 7. Disclaimer of Warranty and Limitation of Liability
323
+
324
+ a. We are not obligated to support, update, provide training for, or develop any further version of the Tongyi Qianwen Materials or to grant any license thereto.
325
+ b. THE MATERIALS ARE PROVIDED "AS IS" WITHOUT ANY EXPRESS OR IMPLIED WARRANTY OF ANY KIND INCLUDING WARRANTIES OF MERCHANTABILITY, NONINFRINGEMENT, OR FITNESS FOR A PARTICULAR PURPOSE. WE MAKE NO WARRANTY AND ASSUME NO RESPONSIBILITY FOR THE SAFETY OR STABILITY OF THE MATERIALS AND ANY OUTPUT THEREFROM.
326
+ c. IN NO EVENT SHALL WE BE LIABLE TO YOU FOR ANY DAMAGES, INCLUDING, BUT NOT LIMITED TO ANY DIRECT, OR INDIRECT, SPECIAL OR CONSEQUENTIAL DAMAGES ARISING FROM YOUR USE OR INABILITY TO USE THE MATERIALS OR ANY OUTPUT OF IT, NO MATTER HOW IT’S CAUSED.
327
+ d. You will defend, indemnify and hold harmless Us from and against any claim by any third party arising out of or related to your use or distribution of the Materials.
328
+
329
+ 8. Survival and Termination.
330
+ a. The term of this Agreement shall commence upon your acceptance of this Agreement or access to the Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein.
331
+ b. We may terminate this Agreement if you breach any of the terms or conditions of this Agreement. Upon termination of this Agreement, you must delete and cease use of the Materials. Sections 7 and 9 shall survive the termination of this Agreement.
332
+
333
+ 9. Governing Law and Jurisdiction.
334
+ a. This Agreement and any dispute arising out of or relating to it will be governed by the laws of China, without regard to conflict of law principles, and the UN Convention on Contracts for the International Sale of Goods does not apply to this Agreement.
335
+ b. The People's Courts in Hangzhou City shall have exclusive jurisdiction over any dispute arising out of this Agreement.
configs/.DS_Store ADDED
Binary file (8.2 kB). View file
 
configs/clm_models/.DS_Store ADDED
Binary file (6.15 kB). View file
 
configs/clm_models/agent_seed_x_i.yaml ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: src.models.mllm.seed_x.ContinuousLVLM.from_pretrained
2
+ input_resampler:
3
+ _target_: src.models.tokenizer.qwen_visual.Resampler
4
+ grid_size: 8
5
+ embed_dim: 5120
6
+ num_heads: 32
7
+ kv_dim: 4096
8
+
9
+ output_resampler:
10
+ _target_: src.models.tokenizer.qwen_visual.Resampler
11
+ grid_size: 8
12
+ embed_dim: 4096
13
+ num_heads: 32
14
+ kv_dim: 5120
15
+
16
+ add_patch_pos: True
17
+ vit_down: True
18
+ mse: True
19
+
20
+ lm_loss_scale: 1.0
21
+ rec_loss_scale: 6.0
22
+
23
+ pretrained_model_path: https://huggingface.co/AILab-CVC/SEED-X-17B/blob/main/seed_x_i/agent/pytorch_model.bin
configs/clm_models/llm_seed_x_i.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ _target_: src.models.mllm.modeling_llama_xformer.LlamaForCausalLM.from_pretrained
2
+ pretrained_model_name_or_path: https://huggingface.co/AILab-CVC/SEED-X-17B/tree/main/seed_x_i/llm
3
+ low_cpu_mem_usage: True
configs/discrete_model/.DS_Store ADDED
Binary file (6.15 kB). View file
 
configs/discrete_model/discrete_identity.yaml ADDED
@@ -0,0 +1 @@
 
 
1
+ _target_: src.models.tokenizer.discrete_models.DiscreteModleIdentity
configs/processer/.DS_Store ADDED
Binary file (6.15 kB). View file
 
configs/processer/qwen_448_transform.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ _target_: src.processer.transforms.get_transform
2
+ type: clip
3
+ image_size: 448
4
+ keep_ratio: False
configs/sdxl_adapter/.DS_Store ADDED
Binary file (6.15 kB). View file
 
configs/sdxl_adapter/sdxl_qwen_vit_resampler_l4_q64_full_with_latent_image_pretrain_no_normalize.yaml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: src.models.detokenizer.adapter_modules.SDXLAdapterWithLatentImage.from_pretrained
2
+
3
+ resampler:
4
+ _target_: src.models.detokenizer.resampler.ResamplerXLV2
5
+ dim: 1024
6
+ depth: 4
7
+ dim_head: 64
8
+ heads: 16
9
+ num_queries: 64
10
+ embedding_dim: 4096
11
+ output1_dim: 768
12
+ output2_dim: 1280
13
+ ff_mult: 4
14
+ normalize: False
15
+
16
+ full_ft: True
17
+ set_trainable_late: False
18
+
19
+ vit_down: True
20
+ pretrained_model_path: pretrained/seed_detokenizer/second_stage/pytorch_model.bin
configs/sdxl_adapter/sdxl_qwen_vit_resampler_l4_q64_pretrain_no_normalize.yaml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: src.models.detokenizer.adapter_modules.SDXLAdapter.from_pretrained
2
+
3
+ resampler:
4
+ _target_: src.models.detokenizer.resampler.ResamplerXLV2
5
+ dim: 1024
6
+ depth: 4
7
+ dim_head: 64
8
+ heads: 16
9
+ num_queries: 64
10
+ embedding_dim: 4096
11
+ output1_dim: 768
12
+ output2_dim: 1280
13
+ ff_mult: 4
14
+ normalize: False
15
+
16
+ vit_down: True
17
+
18
+ pretrained_model_path: https://huggingface.co/AILab-CVC/SEED-X-17B/blob/main/seed_detokenizer/first_stage/pytorch_model.bin
configs/tokenizer/.DS_Store ADDED
Binary file (6.15 kB). View file
 
configs/tokenizer/clm_llama_tokenizer_224loc_anyres.yaml ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ _target_: transformers.LlamaTokenizer.from_pretrained
2
+ pretrained_model_name_or_path: https://huggingface.co/AILab-CVC/SEED-X-17B/tree/main/cvlm_llama2_tokenizer_100img_and_224loc_addpatch
configs/visual_encoder/.DS_Store ADDED
Binary file (6.15 kB). View file
 
configs/visual_encoder/qwen_vitg_448.yaml ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: src.models.tokenizer.qwen_visual.VisionTransformerWithAttnPool.from_pretrained
2
+ heads: 16
3
+ image_size: 448
4
+ image_start_id": 151857
5
+ layers: 48
6
+ mlp_ratio: 4.9231
7
+ output_dim: 4096
8
+ patch_size: 14
9
+ width: 1664
10
+
11
+ pretrained_model_path: pretrained/QwenViT/qwen_vit_G.pt
pretrained/QwenViT/qwen_vit_G.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d951083fc79b07bdb84be61944eb263b8e14572fe2dc4fa80b0447f83064463c
3
+ size 3871440281
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.0.1
2
+ hydra-core
3
+ transformers==4.30.2
4
+ diffusers==0.25.0
5
+ sentencepiece
6
+ opencv-python
7
+ deepspeed
8
+ pyrootutils
9
+ xformers>=0.0.20
10
+ accelerate
11
+ transformers_stream_generator
seed_x/arrow.jpg ADDED
seed_x/bank.png ADDED
src/.DS_Store ADDED
Binary file (10.2 kB). View file
 
src/demo/__pycache__/conversation.cpython-311.pyc ADDED
Binary file (8.21 kB). View file
 
src/demo/__pycache__/conversation.cpython-38.pyc ADDED
Binary file (4.43 kB). View file
 
src/demo/__pycache__/utils.cpython-311.pyc ADDED
Binary file (4.39 kB). View file
 
src/demo/__pycache__/utils.cpython-38.pyc ADDED
Binary file (2.32 kB). View file
 
src/demo/configs/agent_13b_anyres_out_64_pretrain_merged.yaml ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: src.models_clm.models.ContinuousLVLM.from_pretrained
2
+ input_resampler:
3
+ _target_: src.models.qwen_visual.Resampler
4
+ grid_size: 8
5
+ embed_dim: 5120
6
+ num_heads: 32
7
+ kv_dim: 4096
8
+
9
+ output_resampler:
10
+ _target_: src.models.qwen_visual.Resampler
11
+ grid_size: 8
12
+ embed_dim: 4096
13
+ num_heads: 32
14
+ kv_dim: 5120
15
+
16
+ add_patch_pos: True
17
+ vit_down: True
18
+ mse: True
19
+
20
+ lm_loss_scale: 1.0
21
+ rec_loss_scale: 6.0
22
+
23
+ #pretrained_model_path: /chat_sh/share_300719895/user/jinguozhu/codes/work_dirs/sft_exp_new_acc4/checkpoint-2000/pytorch_model.bin
24
+ #pretrained_model_path: /chat_sh/share_300719895/user/yuyingge/jinguo_code/DiscreteLearning_debug/train_output/03_27_any_res_sft_from_merged_10k/checkpoint-9000/pytorch_model.bin
25
+ #pretrained_model_path: /chat_sh/share_300719895/user/yuyingge/jinguo_code/DiscreteLearning_debug/train_output/03_27_any_res_sft_from_merged_10k/checkpoint-8000-merged/agent/pytorch_model.bin
26
+ #pretrained_model_path: /chat_sh/share_300719895/user/yuyingge/jinguo_code/DiscreteLearning_debug/train_output/04_09_any_res_sft_editing_from_merged_10k_32a100_new_data/checkpoint-6000-merged/agent/pytorch_model.bin
27
+ #pretrained_model_path: /chat_sh/share_300719895/user/yuyingge/jinguo_code/DiscreteLearning_debug/train_output/04_16_any_res_sft_editing_from_merged_H800_23k_16_gpu_2_new/checkpoint-6000-merged/agent/pytorch_model.bin
28
+ #pretrained_model_path: /chat_sh/share_300719895/user/yuyingge/jinguo_code/DiscreteLearning_debug/train_output/04_16_any_res_sft_com_gen_from_merged_H800_23k/checkpoint-15000-merged/agent/pytorch_model.bin
29
+ pretrained_model_path: /group/40034/yuyingge/SEED_X_inference/pretrained/seed_x_i/agent/pytorch_model.bin
src/demo/configs/agent_13b_in100_out64_rs5_merged_pretrain.yaml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: src.models_clm.models.ContinuousLVLM.from_pretrained
2
+ input_resampler:
3
+ _target_: src.models.qwen_visual.Resampler
4
+ grid_size: 10
5
+ embed_dim: 5120
6
+ num_heads: 32
7
+ kv_dim: 4096
8
+
9
+ output_resampler:
10
+ _target_: src.models.qwen_visual.Resampler
11
+ grid_size: 16
12
+ embed_dim: 4096
13
+ num_heads: 32
14
+ kv_dim: 5120
15
+
16
+ lm_loss_scale: 1.0
17
+ rec_loss_scale: 5.0
18
+
19
+ # pretrained_model_path: /chat_sh/share_300719895/user/sijiezhao/Program/2023/DiscreteLearning/train_output_clm_sh/1208_llama2chat13b_lora_clm_qwen-vit-448_pretrain_rs5_64a100pro/checkpoint-27000/pytorch_model.bin
20
+ # pretrained_model_path: /apdcephfs_cq4/share_2942043/Multimodal/sijiezhao/DiscreteLearning/train_output/sft_from_1208_llama2chat13b_lora_clm_qwen-vit-448_pretrain_rs5_64a100pro_40k/ckpt-10000-merged/agent/pytorch_model.bin
21
+ # pretrained_model_path: /apdcephfs_cq4/share_2942043/Multimodal/sijiezhao/DiscreteLearning/train_output/sft_from_1208_llama2chat13b_lora_clm_qwen-vit-448_pretrain_rs5_64a100pro_40k/ckpt-5000-merged/agent/pytorch_model.bin
22
+ pretrained_model_path: /apdcephfs_cq4/share_2942043/Multimodal/sijiezhao/DiscreteLearning/train_output/sft_from_1211_llama2chat13b_lora_clm_qwen-vit-448_pretrain_rs5_64a100pro_grounding_27k/ckpt-4000-merged/agent/pytorch_model.bin
src/demo/configs/llama2chat13b_merged_100imgtokens.yaml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ _target_: src.models_clm.modeling_llama_xformer.LlamaForCausalLM.from_pretrained
3
+ # _target_: transformers.LlamaForCausalLM.from_pretrained
4
+ # pretrained_model_name_or_path: /apdcephfs_cq4/share_2942043/Multimodal/sijiezhao/DiscreteLearning/train_output/sft_from_1208_llama2chat13b_lora_clm_qwen-vit-448_pretrain_rs5_64a100pro_40k/ckpt-10000-merged/llm
5
+ # pretrained_model_name_or_path: /apdcephfs_cq4/share_2942043/Multimodal/sijiezhao/DiscreteLearning/train_output/sft_from_1208_llama2chat13b_lora_clm_qwen-vit-448_pretrain_rs5_64a100pro_40k/ckpt-5000-merged/llm
6
+ #pretrained_model_name_or_path: /chat_sh/share_300719895/user/jinguozhu/codes/work_dirs/pretraining_anyres_newexp_v2/checkpoint-10000-merged/llm
7
+ #pretrained_model_name_or_path: /chat_sh/share_300719895/user/yuyingge/jinguo_code/DiscreteLearning_debug/train_output/03_27_any_res_sft_from_merged_10k/checkpoint-8000-merged/llm
8
+ #pretrained_model_name_or_path: /chat_sh/share_300719895/user/yuyingge/jinguo_code/DiscreteLearning_debug/train_output/04_09_any_res_sft_editing_from_merged_10k_32a100_new_data/checkpoint-6000-merged/llm
9
+ #pretrained_model_name_or_path: /chat_sh/share_300719895/user/yuyingge/jinguo_code/DiscreteLearning_debug/train_output/04_16_any_res_sft_editing_from_merged_H800_23k_16_gpu_2_new/checkpoint-6000-merged/llm
10
+ #pretrained_model_name_or_path: /chat_sh/share_300719895/user/yuyingge/jinguo_code/DiscreteLearning_debug/train_output/04_16_any_res_sft_com_gen_from_merged_H800_23k/checkpoint-15000-merged/llm
11
+ pretrained_model_name_or_path: /group/40034/yuyingge/SEED_X_inference/pretrained/seed_x_i/llm
12
+ low_cpu_mem_usage: True
src/demo/conversation.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ from enum import auto, Enum
3
+ from typing import List, Tuple
4
+
5
+ import io
6
+ import base64
7
+ import os
8
+ from PIL import Image
9
+ import copy
10
+
11
+ IMG_FLAG = '<image>'
12
+
13
+
14
+ class SeparatorStyle(Enum):
15
+ """Different separator style."""
16
+ SINGLE = auto()
17
+ TWO = auto()
18
+ MPT = auto()
19
+ PLAIN = auto()
20
+ LLAMA_2 = auto()
21
+
22
+
23
+ def decode_image(encoded_image: str) -> Image:
24
+ decoded_bytes = base64.b64decode(encoded_image.encode('utf-8'))
25
+ buffer = io.BytesIO(decoded_bytes)
26
+ image = Image.open(buffer)
27
+ return image
28
+
29
+
30
+ def encode_image(image: Image.Image, format: str = 'PNG') -> str:
31
+ with io.BytesIO() as buffer:
32
+ image.save(buffer, format=format)
33
+ encoded_image = base64.b64encode(buffer.getvalue()).decode('utf-8')
34
+ return encoded_image
35
+
36
+
37
+ @dataclasses.dataclass
38
+ class Conversation:
39
+ """A class that keeps all conversation history."""
40
+ system: str
41
+ roles: List[str]
42
+ messages: List[dict] # multi-turn -> user & assistant -> {'images': [PIL.Image,], 'text': str}
43
+ offset: int
44
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
45
+ sep: str = "###"
46
+ sep2: str = None
47
+ version: str = "Unknown"
48
+
49
+ skip_next: bool = False
50
+
51
+ def get_prompt(self):
52
+ messages = copy.deepcopy(self.messages)
53
+ if self.sep_style == SeparatorStyle.SINGLE:
54
+ if self.system is None or self.system == '':
55
+ text = ''
56
+ else:
57
+ text = self.system + self.sep
58
+ images = []
59
+ for message in messages:
60
+ text += message['role'] + ": " + message['message']['text'] + self.sep
61
+ for image_path in message['message']['images']:
62
+ image = Image.open(image_path).resize((256, 256))
63
+ image_base64 = encode_image(image)
64
+ images.append(image_base64)
65
+
66
+ text += self.roles[1] + ":"
67
+ elif self.sep_style == SeparatorStyle.LLAMA_2:
68
+ b_token = "[INST] "
69
+ e_token = " [/INST]"
70
+ if self.system is None or self.system == '':
71
+ text = ''
72
+ else:
73
+ text = f"<<SYS>>\n{self.system}\n<</SYS>>\n\n"
74
+ images = []
75
+ for idx, message in enumerate(messages):
76
+ # text += message['role'] + ": " + message['message']['text'] + self.sep
77
+ if idx % 2 == 0:
78
+ text += b_token + message['message']['text'] + e_token + self.sep
79
+ else:
80
+ text += message['message']['text'] + self.sep
81
+
82
+ for image_path in message['message']['images']:
83
+ image = Image.open(image_path)
84
+ image_base64 = encode_image(image)
85
+ images.append(image_base64)
86
+ else:
87
+ raise NotImplementedError
88
+
89
+ return {'text': text, 'images': images}
90
+
91
+ # def update_image_ids(self, images_ids):
92
+ # image_count = 0
93
+ # for message in self.messages:
94
+ # for idx in range(len(message['message']['images_ids'])):
95
+ # if message['message']["images_ids"][idx] is None:
96
+ # message['message']["images_ids"][idx] = images_ids[image_count]
97
+ # image_count += 1
98
+
99
+ # assert len(images_ids) == image_count, print(len(images_ids), image_count)
100
+
101
+ def append_message(self, role, message):
102
+ self.messages.append([role, message])
103
+
104
+ def to_gradio_chatbot(self):
105
+ dialog = []
106
+ for i, single_turn in enumerate(self.messages[self.offset:]):
107
+ single_turn = single_turn['message']
108
+ text_list = single_turn['text'].split(IMG_FLAG)
109
+ assert len(text_list) == len(single_turn['images']) + 1, print(text_list, len(single_turn['images']))
110
+ message = ''
111
+ for image_idx in range(len(single_turn['images'])):
112
+ # image = single_turn['images'][image_idx]
113
+ # image_base64 = encode_image(image)
114
+ # image_str = f'<img src="data:image/png;base64,{image_base64}" alt="user upload image" />'
115
+ image_path = single_turn['images'][image_idx]
116
+ if image_path == '':
117
+ message += text_list[image_idx] + '<corrupt_image>'
118
+ else:
119
+ message += text_list[image_idx] + f'![](file={image_path})'
120
+ message += text_list[-1]
121
+
122
+ if i % 2 == 0:
123
+ dialog.append([message, None])
124
+ else:
125
+ dialog[-1][-1] = message
126
+
127
+ return dialog
128
+
129
+ def copy(self):
130
+ return Conversation(system=self.system,
131
+ roles=self.roles,
132
+ messages=copy.deepcopy(self.messages),
133
+ offset=self.offset,
134
+ sep_style=self.sep_style,
135
+ sep=self.sep,
136
+ sep2=self.sep2,
137
+ version=self.version)
138
+
139
+ def dict(self):
140
+ messages = copy.deepcopy(self.messages)
141
+ for message in messages:
142
+ for i in range(len(message['message']['images'])):
143
+ message['message']['images'][i] = os.path.basename(message['message']['images'][i])
144
+ return {
145
+ "system": self.system,
146
+ "roles": self.roles,
147
+ "messages": messages,
148
+ "offset": self.offset,
149
+ "sep": self.sep,
150
+ "sep2": self.sep2,
151
+ }
152
+
153
+
154
+ conv_seed_vicuna = Conversation(
155
+ system="",
156
+ roles=("USER", "ASSISTANT"),
157
+ version="v2",
158
+ messages=[],
159
+ offset=0,
160
+ sep_style=SeparatorStyle.SINGLE,
161
+ sep='\n',
162
+ )
163
+
164
+ conv_seed_vicuna_system = Conversation(
165
+ system="A chat between a curious user and an artificial intelligence assistant. ",
166
+ roles=("USER", "ASSISTANT"),
167
+ version="v2",
168
+ messages=[],
169
+ offset=0,
170
+ sep_style=SeparatorStyle.SINGLE,
171
+ sep='\n',
172
+ )
173
+
174
+ conv_seed_llama2 = Conversation(
175
+ system="",
176
+ roles=("[INST]", "[/INST]"),
177
+ version="v2",
178
+ messages=[],
179
+ offset=0,
180
+ sep_style=SeparatorStyle.LLAMA_2,
181
+ sep='\n',
182
+ )
src/demo/seed_llama_flask.py ADDED
@@ -0,0 +1,379 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hydra
2
+ import pyrootutils
3
+ import torch
4
+ import re
5
+ import time
6
+ from omegaconf import OmegaConf
7
+ from flask import Flask, request
8
+ from typing import Optional
9
+ import transformers
10
+ from dataclasses import dataclass, field
11
+ import io
12
+ import base64
13
+ from PIL import Image
14
+ import numpy as np
15
+ import cv2
16
+ from diffusers import AutoencoderKL, UNet2DConditionModel, EulerDiscreteScheduler
17
+
18
+
19
+ pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
20
+
21
+ from src.data.any_res import process_anyres_image
22
+
23
+ BOI_TOKEN = '<img>'
24
+ BOP_TOKEN = '<patch>'
25
+ EOI_TOKEN = '</img>'
26
+ EOP_TOKEN = '</patch>'
27
+ IMG_TOKEN = '<img_{:05d}>'
28
+
29
+ IMG_FLAG = '<image>'
30
+ num_img_in_tokens = 64
31
+ num_img_out_tokens = 64
32
+
33
+ resolution_grids = ['1x1', '1x2', '1x3', '1x4', '1x5', '1x6', '1x10', '2x1', '3x1', '4x1', '5x1', '6x1', '10x1', '2x2', '2x3', '3x2', '2x4', '4x2']
34
+ base_resolution = 448
35
+
36
+ app = Flask(__name__)
37
+
38
+
39
+ def decode_image(encoded_image: str) -> Image:
40
+ decoded_bytes = base64.b64decode(encoded_image.encode('utf-8'))
41
+ buffer = io.BytesIO(decoded_bytes)
42
+ image = Image.open(buffer)
43
+ return image
44
+
45
+
46
+ def encode_image(image: Image.Image, format: str = 'PNG') -> str:
47
+ with io.BytesIO() as buffer:
48
+ image.save(buffer, format=format)
49
+ encoded_image = base64.b64encode(buffer.getvalue()).decode('utf-8')
50
+ return encoded_image
51
+
52
+
53
+ @dataclass
54
+ class Arguments:
55
+ image_transform: Optional[str] = field(default=None, metadata={"help": "config path of image transform"})
56
+ tokenizer: Optional[str] = field(default=None, metadata={"help": "config path of tokenizer used to initialize tokenizer"})
57
+ llm: Optional[str] = field(default=None, metadata={"help": "config path of llm"})
58
+ visual_encoder: Optional[str] = field(default=None, metadata={"help": "config path of visual encoder"})
59
+ sd_adapter: Optional[str] = field(default=None, metadata={"help": "config path of sd adapter"})
60
+ agent: Optional[str] = field(default=None, metadata={"help": "config path of agent model"})
61
+ diffusion_path: Optional[str] = field(default=None, metadata={"help": "diffusion model path"})
62
+ has_bbox: Optional[bool] = field(default=False, metadata={"help": "visualize the box"})
63
+
64
+ port: Optional[str] = field(default=80, metadata={"help": "network port"})
65
+ llm_device: Optional[str] = field(default='cuda:0', metadata={"help": "llm device"})
66
+ vit_sd_device: Optional[str] = field(default='cuda:0', metadata={"help": "sd and vit device"})
67
+ dtype: Optional[str] = field(default='fp16', metadata={"help": "mix percision"})
68
+
69
+ multi_resolution: Optional[bool] = field(default=False, metadata={"help": "multi resolution"})
70
+
71
+
72
+ parser = transformers.HfArgumentParser(Arguments)
73
+ args, = parser.parse_args_into_dataclasses()
74
+
75
+ def extract_box(output_str):
76
+ boxes = re.findall('(.*?)<box_end>', output_str)
77
+ if len(boxes) >0:
78
+ bboxes = [[int(num) for num in re.findall('<loc-(\d+)>', box)] for box in boxes]
79
+ else:
80
+ bboxes = None
81
+
82
+ return bboxes
83
+
84
+
85
+ def visualize_bbox(image, bboxes):
86
+ img_width, img_height = image.size
87
+ image = np.array(image)
88
+ image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
89
+ for bbox in bboxes:
90
+ x_center, y_center, box_width, box_height = bbox
91
+
92
+ x_center = x_center / 224 * img_width
93
+ y_center = y_center / 224 * img_height
94
+
95
+ box_width = box_width /224 * img_width
96
+ box_height = box_height / 224 * img_height
97
+
98
+ x1 = int(x_center - box_width / 2)
99
+ y1 = int(y_center - box_height / 2)
100
+ x2 = int(x_center + box_width / 2)
101
+ y2 = int(y_center + box_height / 2)
102
+
103
+ cv2.rectangle(image, (x1, y1), (x2, y2), (0, 255, 0), 4)
104
+
105
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
106
+ image = Image.fromarray(image)
107
+
108
+
109
+ return image
110
+
111
+
112
+
113
+
114
+ class LLMService:
115
+
116
+ def __init__(self, args) -> None:
117
+
118
+ self.llm_device = args.llm_device
119
+ self.vit_sd_device = args.vit_sd_device
120
+
121
+ dtype = args.dtype
122
+ if dtype == 'fp16':
123
+ self.dtype = torch.float16
124
+ elif dtype == 'bf16':
125
+ self.dtype = torch.bfloat16
126
+ else:
127
+ raise ValueError
128
+
129
+ image_transform_cfg = OmegaConf.load(args.image_transform)
130
+ self.image_transform = hydra.utils.instantiate(image_transform_cfg)
131
+
132
+ tokenizer_cfg = OmegaConf.load(args.tokenizer)
133
+ self.tokenizer = hydra.utils.instantiate(tokenizer_cfg)
134
+
135
+ visual_encoder_cfg = OmegaConf.load(args.visual_encoder)
136
+ self.visual_encoder = hydra.utils.instantiate(visual_encoder_cfg)
137
+ self.visual_encoder.eval().to(self.vit_sd_device, dtype=self.dtype)
138
+ print('Init visual encoder done')
139
+
140
+ llm_cfg = OmegaConf.load(args.llm)
141
+ llm = hydra.utils.instantiate(llm_cfg, torch_dtype=self.dtype)
142
+ print('Init llm done.')
143
+
144
+ agent_cfg = OmegaConf.load(args.agent)
145
+ self.agent = hydra.utils.instantiate(agent_cfg, llm=llm)
146
+
147
+ self.agent.eval().to(self.llm_device, dtype=self.dtype)
148
+ print('Init agent mdoel Done')
149
+
150
+ noise_scheduler = EulerDiscreteScheduler.from_pretrained(args.diffusion_path, subfolder="scheduler")
151
+
152
+ vae = AutoencoderKL.from_pretrained(args.diffusion_path, subfolder="vae").to(self.vit_sd_device, dtype=self.dtype)
153
+
154
+ unet = UNet2DConditionModel.from_pretrained(args.diffusion_path, subfolder="unet").to(dtype=self.dtype)
155
+
156
+ sd_adapter_cfg = OmegaConf.load(args.sd_adapter)
157
+
158
+ self.sd_adapter = hydra.utils.instantiate(sd_adapter_cfg, unet=unet).eval().to(dtype=self.dtype)
159
+
160
+ self.sd_adapter.init_pipe(vae=vae,
161
+ scheduler=noise_scheduler,
162
+ visual_encoder=self.visual_encoder.to("cpu"),
163
+ image_transform=self.image_transform,
164
+ discrete_model=None,
165
+ dtype=self.dtype,
166
+ device="cpu")
167
+
168
+ print('Init sd adapter pipe done.')
169
+
170
+ self.visual_encoder.to(self.vit_sd_device, dtype=self.dtype)
171
+
172
+ self.boi_token_id = self.tokenizer.encode(BOI_TOKEN, add_special_tokens=False)[0]
173
+ self.eoi_token_id = self.tokenizer.encode(EOI_TOKEN, add_special_tokens=False)[0]
174
+
175
+ self.bop_token_id = self.tokenizer.encode(BOP_TOKEN, add_special_tokens=False)[0]
176
+ self.eop_token_id = self.tokenizer.encode(EOP_TOKEN, add_special_tokens=False)[0]
177
+
178
+ self.multi_resolution = args.multi_resolution
179
+ if self.multi_resolution:
180
+ self.base_resolution = base_resolution
181
+ grid_pinpoints = []
182
+ for scale in resolution_grids:
183
+ s1, s2 = scale.split('x')
184
+ grid_pinpoints.append([int(s1)*base_resolution, int(s2)*base_resolution])
185
+ self.grid_pinpoints = grid_pinpoints
186
+
187
+
188
+ service = LLMService(args)
189
+
190
+
191
+ @app.route('/generate', methods=['GET', 'POST'])
192
+ def generate():
193
+ with torch.no_grad():
194
+ request_info = request.get_json()
195
+
196
+ text_list = request_info['text'].split(IMG_FLAG)
197
+ image_list = request_info['images']
198
+ max_new_tokens = request_info.get('max_new_tokens', 256)
199
+ top_p = 0.5
200
+ force_boi = request_info.get('force_boi', False)
201
+ force_bbox = request_info.get('force_bbox', False)
202
+
203
+ assert len(text_list) == len(image_list) + 1
204
+
205
+ image_tokens = BOI_TOKEN + ''.join([IMG_TOKEN.format(int(item)) for item in range(num_img_in_tokens)]) + EOI_TOKEN
206
+
207
+ input_images = []
208
+ if len(image_list) > 0:
209
+ image_tensor_list = []
210
+ embeds_cmp_mask = []
211
+ embeds_gen_mask = []
212
+
213
+ if service.multi_resolution:
214
+ patch_pos = []
215
+ image_patch_length = []
216
+ image_size_list = []
217
+
218
+ for idx, image_item in enumerate(image_list):
219
+ if isinstance(image_item, str):
220
+ image = decode_image(image_item)
221
+ print('after decode image size:', image.size)
222
+ input_images.append(image)
223
+
224
+ if service.multi_resolution:
225
+ image_size_list.append(image.size)
226
+ print('image size:', image.size)
227
+ image_tensor, patch_pos_tensor = process_anyres_image(image, service.image_transform, service.grid_pinpoints, service.base_resolution)
228
+ image_tensor_list.append(image_tensor)
229
+ patch_pos.append(patch_pos_tensor)
230
+ image_patch_length.append(image_tensor.shape[0])
231
+ print('image_patch_length', image_patch_length)
232
+ embeds_cmp_mask.extend([True]*image_tensor.shape[0])
233
+ embeds_gen_mask.extend([False]*image_tensor.shape[0])
234
+
235
+ else:
236
+ image_tensor = service.image_transform(image)
237
+
238
+ image_tensor_list.append(image_tensor)
239
+ embeds_cmp_mask.append(True)
240
+ embeds_gen_mask.append(False)
241
+ else:
242
+ raise ValueError
243
+
244
+ if service.multi_resolution:
245
+ pixel_values = torch.cat(image_tensor_list).to(service.vit_sd_device, dtype=service.dtype)
246
+ patch_position = torch.cat(patch_pos, dim=0)
247
+
248
+ image_tokens_list = []
249
+ for patch_length in image_patch_length:
250
+ image_tokens = ''
251
+ for _ in range(patch_length-1):
252
+ image_tokens += BOP_TOKEN + ''.join(IMG_TOKEN.format(int(item)) for item in range(num_img_in_tokens)) + EOP_TOKEN
253
+ image_tokens += BOI_TOKEN + ''.join(IMG_TOKEN.format(int(item)) for item in range(num_img_in_tokens)) + EOI_TOKEN
254
+ image_tokens_list.append(image_tokens)
255
+ else:
256
+ pixel_values = torch.stack(image_tensor_list).to(service.vit_sd_device, dtype=service.dtype)
257
+
258
+ image_embeds = service.visual_encoder(pixel_values)
259
+ image_embeds = image_embeds.to(service.llm_device)
260
+
261
+ embeds_cmp_mask = torch.tensor(embeds_cmp_mask, dtype=torch.bool).to(service.llm_device)
262
+ embeds_gen_mask = torch.tensor(embeds_gen_mask, dtype=torch.bool).to(service.llm_device)
263
+
264
+ else:
265
+ image_embeds = None
266
+ patch_position = 0
267
+ embeds_cmp_mask = None
268
+ embeds_gen_mask = None
269
+
270
+ if service.multi_resolution:
271
+ input_text = ''
272
+ for i, c in enumerate(text_list[:-1]):
273
+ input_text += c + image_tokens_list[i]
274
+ input_text += text_list[-1]
275
+
276
+ else:
277
+ input_text = image_tokens.join(text_list)
278
+
279
+ if force_boi:
280
+ input_text = input_text + BOI_TOKEN
281
+
282
+ if force_bbox:
283
+ input_text = input_text + '[[ <box_start>'
284
+ print('input_text:', input_text)
285
+ input_ids = service.tokenizer.encode(input_text, add_special_tokens=False)
286
+ input_ids = [service.tokenizer.bos_token_id] + input_ids
287
+
288
+ input_ids = torch.tensor(input_ids).to(service.llm_device, dtype=torch.long)
289
+ ids_cmp_mask = torch.zeros_like(input_ids, dtype=torch.bool).to(service.llm_device)
290
+ ids_gen_mask = torch.zeros_like(input_ids, dtype=torch.bool).to(service.llm_device)
291
+
292
+ if service.multi_resolution:
293
+ boi_indices = torch.where(torch.logical_or(input_ids == service.boi_token_id, input_ids == service.bop_token_id))[0].tolist()
294
+ eoi_indices = torch.where(torch.logical_or(input_ids == service.eoi_token_id, input_ids == service.eop_token_id))[0].tolist()
295
+
296
+ else:
297
+
298
+ boi_indices = torch.where(input_ids == service.boi_token_id)[0].tolist()
299
+ eoi_indices = torch.where(input_ids == service.eoi_token_id)[0].tolist()
300
+
301
+ for boi_idx, eoi_idx in zip(boi_indices, eoi_indices):
302
+ ids_cmp_mask[boi_idx + 1:eoi_idx] = True
303
+
304
+ input_ids = input_ids.unsqueeze(0)
305
+ ids_cmp_mask = ids_cmp_mask.unsqueeze(0)
306
+ ids_gen_mask = ids_gen_mask.unsqueeze(0)
307
+
308
+ error_msg = []
309
+
310
+ if service.multi_resolution:
311
+ output = service.agent.generate(
312
+ tokenizer=service.tokenizer,
313
+ input_ids=input_ids,
314
+ image_embeds=image_embeds,
315
+ patch_positions=patch_position,
316
+ embeds_cmp_mask=embeds_cmp_mask,
317
+ ids_cmp_mask=ids_cmp_mask,
318
+ num_img_gen_tokens=num_img_out_tokens,
319
+ max_new_tokens=max_new_tokens,
320
+ dtype=service.dtype,
321
+ device=service.llm_device,
322
+ top_p=top_p,
323
+ )
324
+ else:
325
+ output = service.agent.generate(
326
+ tokenizer=service.tokenizer,
327
+ input_ids=input_ids,
328
+ image_embeds=image_embeds,
329
+ embeds_cmp_mask=embeds_cmp_mask,
330
+ ids_cmp_mask=ids_cmp_mask,
331
+ num_img_gen_tokens=num_img_out_tokens,
332
+ max_new_tokens=max_new_tokens,
333
+ dtype=service.dtype,
334
+ device=service.llm_device,
335
+ top_p=top_p,
336
+ )
337
+
338
+ gen_imgs_base64_list = []
339
+ generated_text = output['text']
340
+ generated_text = generated_text.replace(EOI_TOKEN, IMG_FLAG).replace(service.tokenizer.eos_token, '')
341
+
342
+ if output['has_img_output']:
343
+ print('loading visual encoder and llm to CPU, and sd to GPU')
344
+ a = time.time()
345
+ service.agent = service.agent.to("cpu")
346
+ service.sd_adapter = service.sd_adapter.to(service.vit_sd_device, dtype=service.dtype)
347
+ print("Loading finished: ", time.time() - a)
348
+
349
+ img_gen_feat = output['img_gen_feat'].to(service.vit_sd_device, dtype=service.dtype)
350
+
351
+ for img_idx in range(output['num_gen_imgs']):
352
+ img_feat = img_gen_feat[img_idx:img_idx + 1]
353
+ generated_image = service.sd_adapter.generate(image_embeds=img_feat, num_inference_steps=50)[0]
354
+ image_base64 = encode_image(generated_image)
355
+ gen_imgs_base64_list.append(image_base64)
356
+
357
+ print('loading visual encoder and llm to GPU, and sd to CPU')
358
+ a = time.time()
359
+ service.sd_adapter = service.sd_adapter.to("cpu")
360
+ service.visual_encoder = service.visual_encoder.to(service.vit_sd_device, dtype=service.dtype)
361
+ service.agent = service.agent.to(service.vit_sd_device, dtype=service.dtype)
362
+ print("Loading finished: ", time.time() - a)
363
+
364
+ if args.has_bbox:
365
+ bboxes = extract_box(generated_text)
366
+
367
+ if bboxes is not None and len(input_images) > 0:
368
+ image_viz = visualize_bbox(input_images[0], bboxes)
369
+ image_base64 = encode_image(image_viz)
370
+ gen_imgs_base64_list.append(image_base64)
371
+ generated_text = re.sub(r'\[\[ <box_start>.*?<box_end>.*?\]\]', 'the green bounding box', generated_text)
372
+ generated_text += IMG_FLAG
373
+ print(input_text + generated_text)
374
+
375
+ return {'text': generated_text, 'images': gen_imgs_base64_list, 'error_msg': error_msg}
376
+
377
+
378
+ if __name__ == '__main__':
379
+ app.run(host='0.0.0.0', port=args.port)
src/demo/seed_llama_gradio.py ADDED
@@ -0,0 +1,465 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import datetime
4
+ import json
5
+ from typing import Optional
6
+ import transformers
7
+ from dataclasses import dataclass, field
8
+ import io
9
+ import base64
10
+ from PIL import Image
11
+ import gradio as gr
12
+ import time
13
+ import hashlib
14
+ import requests
15
+
16
+ from utils import build_logger
17
+ from conversation import conv_seed_llama2
18
+
19
+ IMG_FLAG = '<image>'
20
+ LOGDIR = 'log'
21
+
22
+ logger = build_logger("gradio_seed_x", LOGDIR)
23
+ headers = {"User-Agent": "SEED-X Client"}
24
+
25
+ no_change_btn = gr.Button.update()
26
+ enable_btn = gr.Button.update(interactive=True)
27
+ disable_btn = gr.Button.update(interactive=False)
28
+
29
+
30
+ @dataclass
31
+ class Arguments:
32
+ server_port: Optional[int] = field(default=7860, metadata={"help": "network port"})
33
+ server_name: Optional[str] = field(default='0.0.0.0', metadata={"help": "network address"})
34
+ request_address: Optional[str] = field(default='http://127.0.0.1:7890/generate',
35
+ metadata={"help": "request address"})
36
+
37
+
38
+ parser = transformers.HfArgumentParser(Arguments)
39
+ args, = parser.parse_args_into_dataclasses()
40
+ conv_seed_llama = conv_seed_llama2
41
+
42
+
43
+ def decode_image(encoded_image: str) -> Image:
44
+ decoded_bytes = base64.b64decode(encoded_image.encode('utf-8'))
45
+ buffer = io.BytesIO(decoded_bytes)
46
+ image = Image.open(buffer)
47
+ return image
48
+
49
+
50
+ def encode_image(image: Image.Image, format: str = 'PNG') -> str:
51
+ with io.BytesIO() as buffer:
52
+ image.save(buffer, format=format)
53
+ encoded_image = base64.b64encode(buffer.getvalue()).decode('utf-8')
54
+ return encoded_image
55
+
56
+
57
+ def get_conv_log_filename():
58
+ t = datetime.datetime.now()
59
+ name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
60
+ return name
61
+
62
+
63
+ def get_conv_image_dir():
64
+ name = os.path.join(LOGDIR, 'images')
65
+ os.makedirs(name, exist_ok=True)
66
+ return name
67
+
68
+
69
+ def get_image_name(image, image_dir=None):
70
+ buffer = io.BytesIO()
71
+ image.save(buffer, format='PNG')
72
+ image_bytes = buffer.getvalue()
73
+ md5 = hashlib.md5(image_bytes).hexdigest()
74
+
75
+ if image_dir is not None:
76
+ image_name = os.path.join(image_dir, md5 + '.png')
77
+ else:
78
+ image_name = md5 + '.png'
79
+
80
+ return image_name
81
+
82
+
83
+ def resize_image_square(image, target_size=448):
84
+ resized_image = image.resize((target_size, target_size))
85
+ return resized_image
86
+
87
+
88
+ def resize_image(image, max_size=512):
89
+ width, height = image.size
90
+ aspect_ratio = float(width) / float(height)
91
+
92
+ if width > height:
93
+ new_width = max_size
94
+ new_height = int(new_width / aspect_ratio)
95
+ else:
96
+ new_height = max_size
97
+ new_width = int(new_height * aspect_ratio)
98
+
99
+ resized_image = image.resize((new_width, new_height))
100
+ return resized_image
101
+
102
+
103
+ def center_crop_image(image, max_aspect_ratio=1.5):
104
+ width, height = image.size
105
+ aspect_ratio = max(width, height) / min(width, height)
106
+
107
+ if aspect_ratio >= max_aspect_ratio:
108
+ if width > height:
109
+ new_width = int(height * max_aspect_ratio)
110
+ left = (width - new_width) // 2
111
+ right = (width + new_width) // 2
112
+ top = 0
113
+ bottom = height
114
+ else:
115
+ new_height = int(width * max_aspect_ratio)
116
+ left = 0
117
+ right = width
118
+ top = (height - new_height) // 2
119
+ bottom = (height + new_height) // 2
120
+
121
+ cropped_image = image.crop((left, top, right, bottom))
122
+ return cropped_image
123
+ else:
124
+ return image
125
+
126
+
127
+ def vote_last_response(state, vote_type, request: gr.Request):
128
+ with open(get_conv_log_filename(), "a") as fout:
129
+ data = {
130
+ "tstamp": round(time.time(), 4),
131
+ "type": vote_type,
132
+ "state": state.dict(),
133
+ "ip": request.client.host,
134
+ }
135
+ fout.write(json.dumps(data) + "\n")
136
+
137
+
138
+ def upvote_last_response(state, request: gr.Request):
139
+ logger.info(f"upvote. ip: {request.client.host}")
140
+ vote_last_response(state, "upvote", request)
141
+ return (disable_btn,) * 2
142
+
143
+
144
+ def downvote_last_response(state, request: gr.Request):
145
+ logger.info(f"downvote. ip: {request.client.host}")
146
+ vote_last_response(state, "downvote", request)
147
+ return (disable_btn,) * 2
148
+
149
+
150
+ def regenerate(dialog_state, request: gr.Request):
151
+ logger.info(f"regenerate. ip: {request.client.host}")
152
+ if dialog_state.messages[-1]['role'] == dialog_state.roles[1]:
153
+ dialog_state.messages.pop()
154
+ return (
155
+ dialog_state,
156
+ dialog_state.to_gradio_chatbot(),
157
+ ) + (disable_btn,) * 4
158
+
159
+
160
+ def clear_history(request: gr.Request):
161
+ logger.info(f"clear_history. ip: {request.client.host}")
162
+ dialog_state = conv_seed_llama.copy()
163
+ input_state = init_input_state()
164
+ return (dialog_state, input_state, dialog_state.to_gradio_chatbot()) + (disable_btn,) * 4
165
+
166
+
167
+ def init_input_state():
168
+ return {'images': [], 'text': ''}
169
+
170
+
171
+ def add_text(dialog_state, input_state, text, request: gr.Request):
172
+ logger.info(f"add_text. ip: {request.client.host}.")
173
+ # if len(input_state['text']) == 0:
174
+ if text is None or len(text) == 0:
175
+ # dialog_state.skip_next = True
176
+ return (dialog_state, input_state, "", dialog_state.to_gradio_chatbot()) + (no_change_btn,) * 4
177
+ input_state['text'] += text
178
+
179
+
180
+ if len(dialog_state.messages) > 0 and dialog_state.messages[-1]['role'] == dialog_state.roles[0]:
181
+ dialog_state.messages[-1]['message'] = input_state
182
+ else:
183
+ dialog_state.messages.append({'role': dialog_state.roles[0], 'message': input_state})
184
+ print('add_text: ', dialog_state.to_gradio_chatbot())
185
+
186
+ return (dialog_state, input_state, "", dialog_state.to_gradio_chatbot()) + (disable_btn,) * 4
187
+
188
+
189
+ def is_blank(image):
190
+ image_array = np.array(image)
191
+ unique_colors = np.unique(image_array)
192
+ print('unique_colors', len(unique_colors))
193
+ return len(unique_colors) == 1
194
+
195
+
196
+ def add_image(dialog_state, input_state, image, request: gr.Request):
197
+ logger.info(f"add_image. ip: {request.client.host}.")
198
+ if image is None:
199
+ return (dialog_state, input_state, None, dialog_state.to_gradio_chatbot()) + (no_change_btn,) * 4
200
+
201
+ image = image.convert('RGB')
202
+
203
+ print('image size:', image.size)
204
+
205
+ image = center_crop_image(image, max_aspect_ratio=10)
206
+
207
+ image_dir = get_conv_image_dir()
208
+ image_path = get_image_name(image=image, image_dir=image_dir)
209
+ if not os.path.exists(image_path):
210
+ image.save(image_path)
211
+ input_state['images'].append(image_path)
212
+ input_state['text'] += IMG_FLAG
213
+
214
+ if len(dialog_state.messages) > 0 and dialog_state.messages[-1]['role'] == dialog_state.roles[0]:
215
+ dialog_state.messages[-1]['message'] = input_state
216
+ else:
217
+ dialog_state.messages.append({'role': dialog_state.roles[0], 'message': input_state})
218
+
219
+ print('add_image:', dialog_state)
220
+
221
+ return (dialog_state, input_state, None, dialog_state.to_gradio_chatbot()) + (disable_btn,) * 4
222
+
223
+
224
+ def http_bot(dialog_state, input_state, max_new_tokens, max_turns, force_image_gen, force_bbox,
225
+ request: gr.Request):
226
+ logger.info(f"http_bot. ip: {request.client.host}")
227
+ print('input_state:', input_state)
228
+
229
+ if len(dialog_state.messages) == 0 or dialog_state.messages[-1]['role'] != dialog_state.roles[0] or len(
230
+ dialog_state.messages[-1]['message']['text'].strip(' ?.;!/')) == 0:
231
+ return (dialog_state, input_state, dialog_state.to_gradio_chatbot()) + (no_change_btn,) * 4
232
+
233
+ if len(dialog_state.messages) > max_turns * 2:
234
+ output_state = init_input_state()
235
+ output_state['text'] = 'Error: History exceeds maximum rounds, please clear history and restart.'
236
+ dialog_state.messages.append({'role': dialog_state.roles[1], 'message': output_state})
237
+ input_state = init_input_state()
238
+ return (dialog_state, input_state, dialog_state.to_gradio_chatbot()) + (disable_btn,) * 3 + (enable_btn,)
239
+
240
+ prompt = dialog_state.get_prompt()
241
+ payload = {
242
+ 'text': prompt['text'],
243
+ 'max_new_tokens': int(max_new_tokens),
244
+ 'images': prompt['images'],
245
+ 'force_boi': force_image_gen,
246
+ 'force_bbox': force_bbox,
247
+ }
248
+
249
+ print(
250
+ 'request: ', {
251
+ 'text': prompt['text'],
252
+ 'max_new_tokens': int(max_new_tokens),
253
+ })
254
+ print('request_address', args.request_address)
255
+ response = requests.request(method="POST", url=args.request_address, headers=headers, json=payload)
256
+ results = response.json()
257
+ print('response: ', {'text': results['text'], 'error_msg': results['error_msg']})
258
+
259
+ output_state = init_input_state()
260
+ image_dir = get_conv_image_dir()
261
+ output_state['text'] = results['text']
262
+
263
+ for image_base64 in results['images']:
264
+ if image_base64 == '':
265
+ image_path = ''
266
+ else:
267
+ image = decode_image(image_base64)
268
+ image = image.convert('RGB')
269
+ image_path = get_image_name(image=image, image_dir=image_dir)
270
+ if not os.path.exists(image_path):
271
+ image.save(image_path)
272
+ output_state['images'].append(image_path)
273
+
274
+ dialog_state.messages.append({'role': dialog_state.roles[1], 'message': output_state})
275
+
276
+ vote_last_response(dialog_state, 'common', request)
277
+ input_state = init_input_state()
278
+ chatbot = update_error_msg(dialog_state.to_gradio_chatbot(), results['error_msg'])
279
+ return (dialog_state, input_state, chatbot) + (enable_btn,) * 4
280
+
281
+
282
+ def update_error_msg(chatbot, error_msg):
283
+ if len(error_msg) > 0:
284
+ info = '\n-------------\nSome errors occurred during response, please clear history and restart.\n' + '\n'.join(
285
+ error_msg)
286
+ chatbot[-1][-1] = chatbot[-1][-1] + info
287
+
288
+ return chatbot
289
+
290
+
291
+ def load_demo(request: gr.Request):
292
+ logger.info(f"load_demo. ip: {request.client.host}")
293
+ dialog_state = conv_seed_llama.copy()
294
+ input_state = init_input_state()
295
+ return dialog_state, input_state
296
+
297
+
298
+ title = ("""
299
+ # SEED-X-I
300
+ [[Paper]](https://arxiv.org/abs/2404.14396) [[Code]](https://github.com/AILab-CVC/SEED-X)
301
+
302
+ Demo of a general instruction-tuned model SEED-X-I (17B) from the foundation model SEED-X.
303
+
304
+ SEED-X-I can follow multimodal instruction (including images with **dynamic resolutions**) and make responses with **images, texts and bounding boxes** in multi-turn conversation.
305
+
306
+ SEED-X-I **does not support image manipulation**. If you want to experience **SEED-X-Edit** for high-precision image editing, please refer to [[Inference Code]](https://github.com/AILab-CVC/SEED-X).
307
+
308
+ Due to insufficient GPU memory, when generating images, we need to offload the LLM to the CPU and move the de-tokenizer to the CPU, which will **result in a long processing time**. If you want to experience the normal model inference speed, you can run [[Inference Code]](https://github.com/AILab-CVC/SEED-X) locally.
309
+
310
+
311
+ ## Tips:
312
+ * Check out the conversation examples (at the bottom) for inspiration.
313
+
314
+ * You can adjust "Max History Rounds" to try a conversation with up to five rounds. For more turns, you can download our checkpoints from GitHub and deploy them locally for inference.
315
+
316
+ * Our demo supports a mix of images and texts as input. You can freely upload an image or enter text, and then click on "Add Image/Text". You can repeat the former step multiple times, and click on "Submit" for model inference at last.
317
+
318
+ * You can click "Force Image Generation" to compel the model to produce images when necessary. For example, our model might struggle to generate images when there is an excessive amount of text-only context.
319
+
320
+ * You can click "Force Bounding Box" to compel the model to produce bounding box for object detection.
321
+
322
+ * SEED-X was trained with English-only data. It may process with other languages due to the inherent capabilities from LLaMA, but might not stable.
323
+
324
+ """)
325
+
326
+ css = """
327
+ img {
328
+ font-family: 'Helvetica';
329
+ font-weight: 300;
330
+ line-height: 2;
331
+ text-align: center;
332
+
333
+ width: auto;
334
+ height: auto;
335
+ display: block;
336
+ position: relative;
337
+ }
338
+
339
+ img:before {
340
+ content: " ";
341
+ display: block;
342
+
343
+ position: absolute;
344
+ top: -10px;
345
+ left: 0;
346
+ height: calc(100% + 10px);
347
+ width: 100%;
348
+ background-color: rgb(230, 230, 230);
349
+ border: 2px dotted rgb(200, 200, 200);
350
+ border-radius: 5px;
351
+ }
352
+
353
+ img:after {
354
+ content: " ";
355
+ display: block;
356
+ font-size: 16px;
357
+ font-style: normal;
358
+ font-family: FontAwesome;
359
+ color: rgb(100, 100, 100);
360
+
361
+ position: absolute;
362
+ top: 5px;
363
+ left: 0;
364
+ width: 100%;
365
+ text-align: center;
366
+ }
367
+
368
+ """
369
+
370
+ if __name__ == '__main__':
371
+
372
+ examples_mix = [
373
+ ['seed_x/bank.png', 'Can I conntect with an advisor on Sunday?'],
374
+ ['seed_x/ground.png',
375
+ 'Is there anything in the image that can protect me from catching the flu virus when I go out? Show me the location.'],
376
+ ['seed_x/arrow.jpg', 'What is the object pointed by the red arrow?'],
377
+ ['seed_x/shanghai.png', 'Where was this image taken? Explain your answer.'],
378
+ ['seed_x/GPT4.png', 'How long does it take to make GPT-4 safer?'],
379
+ ['seed_x/twitter.png',
380
+ 'Please provide a comprehensive description of this image.'],
381
+ ]
382
+
383
+ examples_text = [
384
+ ['I want to build a two story cabin in the woods, with many commanding windows. Can you show me a picture?'],
385
+ ['Use your imagination to design a concept image for Artificial General Intelligence (AGI). Show me an image.'],
386
+ [
387
+ 'Can you design an illustration for “The Three-Body Problem” to depict a scene from the novel? Show me a picture.'],
388
+ [
389
+ 'My four year old son loves toy trains. Can you design a fancy birthday cake for him? Please generate a picture.'],
390
+ [
391
+ 'Generate an image of a portrait of young nordic girl, age 25, freckled skin, neck tatoo, blue eyes 35mm lens, photography, ultra details.'],
392
+ ['Generate an impressionist painting of an astronaut in a jungle.']
393
+ ]
394
+ with gr.Blocks(css=css) as demo:
395
+ gr.Markdown(title)
396
+ dialog_state = gr.State()
397
+ input_state = gr.State()
398
+ with gr.Row():
399
+ with gr.Column(scale=3):
400
+ with gr.Row():
401
+ image = gr.Image(type='pil', label='input_image')
402
+ with gr.Row():
403
+ text = gr.Textbox(lines=5,
404
+ show_label=False,
405
+ label='input_text',
406
+ elem_id='textbox',
407
+ placeholder="Enter text or add image, and press submit,").style(container=False)
408
+ with gr.Row():
409
+ add_image_btn = gr.Button("Add Image")
410
+ add_text_btn = gr.Button("Add Text")
411
+
412
+ submit_btn = gr.Button("Submit")
413
+
414
+ with gr.Row():
415
+ max_new_tokens = gr.Slider(minimum=64,
416
+ maximum=1024,
417
+ value=768,
418
+ step=64,
419
+ interactive=True,
420
+ label="Max Output Tokens")
421
+ max_turns = gr.Slider(minimum=1, maximum=9, value=3, step=1, interactive=True,
422
+ label="Max History Rounds")
423
+ force_img_gen = gr.Radio(choices=[True, False], value=False, label='Force Image Generation')
424
+ force_bbox = gr.Radio(choices=[True, False], value=False, label='Force Bounding Box')
425
+
426
+ with gr.Column(scale=7):
427
+ chatbot = gr.Chatbot(elem_id='chatbot', label="SEED-X-I").style(height=700)
428
+ with gr.Row():
429
+ upvote_btn = gr.Button(value="👍 Upvote", interactive=False)
430
+ downvote_btn = gr.Button(value="👎 Downvote", interactive=False)
431
+ regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
432
+ clear_btn = gr.Button(value="🗑️ Clear history", interactive=False)
433
+
434
+ with gr.Row():
435
+ with gr.Column(scale=0.7):
436
+ gr.Examples(examples=examples_mix, label='Input examples', inputs=[image, text])
437
+ with gr.Column(scale=0.3):
438
+ gr.Examples(examples=examples_text, label='Input examples', inputs=[text])
439
+
440
+ # Register listeners
441
+ btn_list = [upvote_btn, downvote_btn, regenerate_btn, clear_btn]
442
+ upvote_btn.click(upvote_last_response, [dialog_state], [upvote_btn, downvote_btn])
443
+ downvote_btn.click(downvote_last_response, [dialog_state], [upvote_btn, downvote_btn])
444
+
445
+ regenerate_btn.click(regenerate, [dialog_state], [dialog_state, chatbot] + btn_list).then(
446
+ http_bot, [dialog_state, input_state, max_new_tokens, max_turns, force_img_gen, force_bbox],
447
+ [dialog_state, input_state, chatbot] + btn_list)
448
+ add_image_btn.click(add_image, [dialog_state, input_state, image],
449
+ [dialog_state, input_state, image, chatbot] + btn_list)
450
+
451
+ add_text_btn.click(add_text, [dialog_state, input_state, text],
452
+ [dialog_state, input_state, text, chatbot] + btn_list)
453
+
454
+ submit_btn.click(
455
+ add_image, [dialog_state, input_state, image], [dialog_state, input_state, image, chatbot] + btn_list).then(
456
+ add_text, [dialog_state, input_state, text],
457
+ [dialog_state, input_state, text, chatbot, upvote_btn, downvote_btn, regenerate_btn, clear_btn]).then(
458
+ http_bot,
459
+ [dialog_state, input_state, max_new_tokens, max_turns, force_img_gen, force_bbox],
460
+ [dialog_state, input_state, chatbot] + btn_list)
461
+ clear_btn.click(clear_history, None, [dialog_state, input_state, chatbot] + btn_list)
462
+
463
+ demo.load(load_demo, None, [dialog_state, input_state])
464
+
465
+ demo.launch(server_name=args.server_name, server_port=args.server_port, enable_queue=True)
src/demo/utils.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import logging
3
+ import logging.handlers
4
+ import os
5
+ import sys
6
+
7
+ handler = None
8
+
9
+
10
+ def build_logger(logger_name, logger_dir):
11
+ global handler
12
+
13
+ formatter = logging.Formatter(
14
+ fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
15
+ datefmt="%Y-%m-%d %H:%M:%S",
16
+ )
17
+
18
+ # Set the format of root handlers
19
+ if not logging.getLogger().handlers:
20
+ logging.basicConfig(level=logging.INFO)
21
+ logging.getLogger().handlers[0].setFormatter(formatter)
22
+
23
+ # Redirect stdout and stderr to loggers
24
+ stdout_logger = logging.getLogger("stdout")
25
+ stdout_logger.setLevel(logging.INFO)
26
+ sl = StreamToLogger(stdout_logger, logging.INFO)
27
+ sys.stdout = sl
28
+
29
+ stderr_logger = logging.getLogger("stderr")
30
+ stderr_logger.setLevel(logging.ERROR)
31
+ sl = StreamToLogger(stderr_logger, logging.ERROR)
32
+ sys.stderr = sl
33
+
34
+ # Get logger
35
+ logger = logging.getLogger(logger_name)
36
+ logger.setLevel(logging.INFO)
37
+
38
+ # Add a file handler for all loggers
39
+ if handler is None:
40
+ os.makedirs(logger_dir, exist_ok=True)
41
+ filename = os.path.join(logger_dir, logger_name + '.log')
42
+ handler = logging.handlers.TimedRotatingFileHandler(filename, when='D', utc=True)
43
+ handler.setFormatter(formatter)
44
+
45
+ for name, item in logging.root.manager.loggerDict.items():
46
+ if isinstance(item, logging.Logger):
47
+ item.addHandler(handler)
48
+
49
+ return logger
50
+
51
+
52
+ class StreamToLogger(object):
53
+ """
54
+ Fake file-like stream object that redirects writes to a logger instance.
55
+ """
56
+
57
+ def __init__(self, logger, log_level=logging.INFO):
58
+ self.terminal = sys.stdout
59
+ self.logger = logger
60
+ self.log_level = log_level
61
+ self.linebuf = ''
62
+
63
+ def __getattr__(self, attr):
64
+ return getattr(self.terminal, attr)
65
+
66
+ def write(self, buf):
67
+ temp_linebuf = self.linebuf + buf
68
+ self.linebuf = ''
69
+ for line in temp_linebuf.splitlines(True):
70
+ # From the io.TextIOWrapper docs:
71
+ # On output, if newline is None, any '\n' characters written
72
+ # are translated to the system default line separator.
73
+ # By default sys.stdout.write() expects '\n' newlines and then
74
+ # translates them so this is still cross platform.
75
+ if line[-1] == '\n':
76
+ self.logger.log(self.log_level, line.rstrip())
77
+ else:
78
+ self.linebuf += line
79
+
80
+ def flush(self):
81
+ if self.linebuf != '':
82
+ self.logger.log(self.log_level, self.linebuf.rstrip())
83
+ self.linebuf = ''
src/inference/.DS_Store ADDED
Binary file (6.15 kB). View file
 
src/inference/__pycache__/any_res.cpython-311.pyc ADDED
Binary file (12.2 kB). View file
 
src/inference/__pycache__/any_res.cpython-38.pyc ADDED
Binary file (7.47 kB). View file
 
src/inference/any_res.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import torch
3
+ import math
4
+ import ast
5
+ from PIL import Image
6
+ from io import BytesIO
7
+
8
+
9
+ def select_best_resolution(original_size, possible_resolutions):
10
+ """
11
+ Selects the best resolution from a list of possible resolutions based on the original size.
12
+
13
+ Args:
14
+ original_size (tuple): The original size of the image in the format (width, height).
15
+ possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].
16
+
17
+ Returns:
18
+ tuple: The best fit resolution in the format (width, height).
19
+ """
20
+ original_width, original_height = original_size
21
+ best_fit = None
22
+ max_effective_resolution = 0
23
+ min_wasted_resolution = float('inf')
24
+
25
+ for width, height in possible_resolutions:
26
+ scale = min(width / original_width, height / original_height)
27
+ downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
28
+ effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
29
+ wasted_resolution = (width * height) - effective_resolution
30
+
31
+ if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution):
32
+ max_effective_resolution = effective_resolution
33
+ min_wasted_resolution = wasted_resolution
34
+ best_fit = (width, height)
35
+
36
+ return best_fit
37
+
38
+
39
+ def select_best_resolution_v2(original_size, possible_resolutions):
40
+ """
41
+ Selects the best resolution from a list of possible resolutions based on the original size and aspect ratio.
42
+
43
+ Args:
44
+ original_size (tuple): The original size of the image in the format (width, height).
45
+ possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].
46
+
47
+ Returns:
48
+ tuple: The best fit resolution in the format (width, height).
49
+ """
50
+ original_width, original_height = original_size
51
+ original_aspect_ratio = original_height / original_width
52
+ original_area = original_width * original_height
53
+ best_fit = None
54
+ min_aspect_ratio_diff = float('inf')
55
+ min_area_ratio = float('inf')
56
+
57
+ for width, height in possible_resolutions:
58
+ aspect_ratio = height / width
59
+ area = width * height
60
+ aspect_ratio_diff = max(aspect_ratio, original_aspect_ratio) / min(aspect_ratio, original_aspect_ratio)
61
+ area_ratio = max(area, original_area) / min(area, original_area)
62
+
63
+ if aspect_ratio_diff < min_aspect_ratio_diff or (aspect_ratio_diff == min_aspect_ratio_diff and area_ratio < min_area_ratio):
64
+ min_aspect_ratio_diff = aspect_ratio_diff
65
+ min_area_ratio = area_ratio
66
+ best_fit = (width, height)
67
+
68
+ return best_fit
69
+
70
+
71
+ def resize_and_pad_image(image, target_resolution, keep_ratio=False):
72
+ """
73
+ Resize and pad an image to a target resolution
74
+
75
+ Args:
76
+ image (PIL.Image.Image): The input image.
77
+ target_resolution (tuple): The target resolution (width, height) of the image.
78
+
79
+ Returns:
80
+ PIL.Image.Image: The resized and padded image.
81
+ """
82
+ original_width, original_height = image.size
83
+ target_width, target_height = target_resolution
84
+
85
+ if keep_ratio:
86
+ # maintaining aspect ratio
87
+ scale_w = target_width / original_width
88
+ scale_h = target_height / original_height
89
+
90
+ if scale_w < scale_h:
91
+ new_width = target_width
92
+ new_height = min(math.ceil(original_height * scale_w), target_height)
93
+ else:
94
+ new_height = target_height
95
+ new_width = min(math.ceil(original_width * scale_h), target_width)
96
+
97
+ # Resize the image
98
+ resized_image = image.resize((new_width, new_height))
99
+
100
+ new_image = Image.new('RGB', (target_width, target_height), (0, 0, 0))
101
+ paste_x = (target_width - new_width) // 2
102
+ paste_y = (target_height - new_height) // 2
103
+ new_image.paste(resized_image, (paste_x, paste_y))
104
+ else:
105
+ # not maintaining aspect ratio
106
+ new_image = image.resize((target_width, target_height))
107
+
108
+ return new_image
109
+
110
+
111
+ def divide_to_patches(image, patch_size):
112
+ """
113
+ Divides an image into patches of a specified size.
114
+
115
+ Args:
116
+ image (PIL.Image.Image): The input image.
117
+ patch_size (int): The size of each patch.
118
+
119
+ Returns:
120
+ list: A list of PIL.Image.Image objects representing the patches.
121
+ """
122
+ patches = []
123
+ width, height = image.size
124
+ for i in range(0, height, patch_size):
125
+ for j in range(0, width, patch_size):
126
+ box = (j, i, j + patch_size, i + patch_size)
127
+ patch = image.crop(box)
128
+ patches.append(patch)
129
+
130
+ return patches
131
+
132
+
133
+ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
134
+ """
135
+ Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
136
+
137
+ Args:
138
+ image_size (tuple): The size of the input image in the format (width, height).
139
+ grid_pinpoints (str): A string representation of a list of possible resolutions.
140
+ patch_size (int): The size of each image patch.
141
+
142
+ Returns:
143
+ tuple: The shape of the image patch grid in the format (width, height).
144
+ """
145
+ if type(grid_pinpoints) is list:
146
+ possible_resolutions = grid_pinpoints
147
+ else:
148
+ possible_resolutions = ast.literal_eval(grid_pinpoints)
149
+ width1, height1 = select_best_resolution(image_size, possible_resolutions)
150
+ width2, height2 = select_best_resolution_v2(image_size, possible_resolutions)
151
+ if width1*height1 > width2*height2:
152
+ width, height = width2, height2
153
+ else:
154
+ width, height = width1, height1
155
+ return width // patch_size, height // patch_size
156
+
157
+
158
+ def process_anyres_image(image, image_transform, grid_pinpoints, base_image_size):
159
+ """
160
+ Process an image with variable resolutions.
161
+
162
+ Args:
163
+ image (PIL.Image.Image): The input image to be processed.
164
+ image_transform: The image processor object.
165
+ grid_pinpoints (str): A string representation of a list of possible resolutions.
166
+
167
+ Returns:
168
+ torch.Tensor: A tensor containing the processed image patches.
169
+ """
170
+ if type(grid_pinpoints) is list:
171
+ possible_resolutions = grid_pinpoints
172
+ else:
173
+ possible_resolutions = ast.literal_eval(grid_pinpoints)
174
+ # best_resolution = select_best_resolution(image.size, possible_resolutions)
175
+ width1, height1 = select_best_resolution(image.size, possible_resolutions)
176
+ width2, height2 = select_best_resolution_v2(image.size, possible_resolutions)
177
+ if width1*height1 > width2*height2:
178
+ width, height = width2, height2
179
+ else:
180
+ width, height = width1, height1
181
+ best_resolution = [width, height]
182
+
183
+ image_padded = resize_and_pad_image(image, best_resolution)
184
+
185
+ patches = divide_to_patches(image_padded, base_image_size)
186
+
187
+ image_original_resize = image.resize((base_image_size, base_image_size))
188
+
189
+ image_patches = patches + [image_original_resize] # add the original image as the last patch
190
+ image_patches = [image_transform(image_patch)
191
+ for image_patch in image_patches]
192
+
193
+ patch_grid = (best_resolution[0]//base_image_size, best_resolution[1]//base_image_size)
194
+ x_index = (torch.arange(patch_grid[0]).repeat(patch_grid[1], 1) + 0.5)/patch_grid[0]
195
+ y_index = (torch.arange(patch_grid[1]).unsqueeze(1).repeat(1, patch_grid[0]) + 0.5)/patch_grid[1]
196
+ patch_pos = torch.stack([x_index, y_index], dim=-1).flatten(0, 1) # h*w, 2
197
+
198
+ origin_pos = torch.tensor([[0.5, 0.5]])
199
+ patch_pos = torch.cat([patch_pos, origin_pos], dim=0) # h*w+1, 2
200
+
201
+ return torch.stack(image_patches, dim=0), patch_pos
202
+
203
+
204
+ def load_image_from_base64(image):
205
+ return Image.open(BytesIO(base64.b64decode(image)))
206
+
207
+
208
+ def anyres_data_collate(batch, tokenizer, dataset_name=None):
209
+ results = {}
210
+ keys = batch[0].keys()
211
+
212
+ for key in keys:
213
+ cur = [batch[i][key] for i in range(len(batch)) if batch[i][key] is not None]
214
+ if len(cur) == 0:
215
+ results[key] = None
216
+ elif isinstance(cur[0], torch.Tensor):
217
+ if key in ['embeds_gen_mask', 'embeds_cmp_mask', 'images', 'images_patch_length', 'patch_position', 'image_size']:
218
+ results[key] = torch.cat(cur, dim=0)
219
+ else:
220
+ if key in ['input_ids']:
221
+ results[key] = torch.nn.utils.rnn.pad_sequence(cur, batch_first=True, padding_value=tokenizer.pad_token_id)
222
+ elif key in ['attention_mask']:
223
+ results[key] = torch.nn.utils.rnn.pad_sequence(cur, batch_first=True, padding_value=0)
224
+ elif key in ['labels']:
225
+ results[key] = torch.nn.utils.rnn.pad_sequence(cur, batch_first=True, padding_value=-100)
226
+ elif key in ['ids_gen_mask', 'ids_cmp_mask']:
227
+ results[key] = torch.nn.utils.rnn.pad_sequence(cur, batch_first=True, padding_value=False)
228
+
229
+ else:
230
+ results[key] = torch.stack(cur, dim=0)
231
+ else:
232
+ results[key] = cur
233
+
234
+ results['dataset_name'] = dataset_name
235
+
236
+ return results
237
+
238
+
239
+ def anyres_data_collate_old(batch, dataset_name=None):
240
+ results = {}
241
+ keys = batch[0].keys()
242
+
243
+ for key in keys:
244
+ cur = [batch[i][key] for i in range(len(batch)) if batch[i][key] is not None]
245
+ if len(cur) == 0:
246
+ results[key] = None
247
+ elif isinstance(cur[0], torch.Tensor):
248
+ if key in ['embeds_gen_mask', 'embeds_cmp_mask', 'images', 'images_patch_length', 'patch_position', 'image_size']:
249
+ results[key] = torch.cat(cur, dim=0)
250
+ else:
251
+ results[key] = torch.stack(cur, dim=0)
252
+ else:
253
+ results[key] = cur
254
+
255
+ results['dataset_name'] = dataset_name
256
+
257
+ return results
src/inference/eval_img2edit_seed_x.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hydra
2
+ import torch
3
+ import os
4
+ import re
5
+ import pyrootutils
6
+ from PIL import Image
7
+ from omegaconf import OmegaConf
8
+ from diffusers import AutoencoderKL, UNet2DConditionModel, EulerDiscreteScheduler, Transformer2DModel
9
+ from any_res import process_anyres_image
10
+
11
+ pyrootutils.setup_root(__file__, indicator='.project-root', pythonpath=True)
12
+
13
+ BOI_TOKEN = '<img>'
14
+ BOP_TOKEN = '<patch>'
15
+ EOI_TOKEN = '</img>'
16
+ EOP_TOKEN = '</patch>'
17
+ IMG_TOKEN = '<img_{:05d}>'
18
+
19
+ resolution_grids = ['1x1']
20
+ base_resolution = 448
21
+
22
+ device = 'cuda:0'
23
+ device1 = 'cuda:1'
24
+ dtype = torch.float16
25
+ dtype_str = 'fp16'
26
+ num_img_in_tokens = 64
27
+ num_img_out_tokens = 64
28
+ instruction_prompt = '[INST] {instruction} [/INST]\n'
29
+
30
+ save_dir = 'vis'
31
+ os.makedirs(save_dir, exist_ok=True)
32
+
33
+ tokenizer_cfg_path = 'configs/tokenizer/clm_llama_tokenizer_224loc_anyres.yaml'
34
+ image_transform_cfg_path = 'configs/processer/qwen_448_transform.yaml'
35
+ visual_encoder_cfg_path = 'configs/visual_encoder/qwen_vitg_448.yaml'
36
+ llm_cfg_path = 'configs/clm_models/llm_seed_x_edit.yaml'
37
+ agent_cfg_path = 'configs/clm_models/agent_seed_x_edit.yaml'
38
+ adapter_cfg_path = 'configs/sdxl_adapter/sdxl_qwen_vit_resampler_l4_q64_full_with_latent_image_pretrain_no_normalize.yaml'
39
+ discrete_model_cfg_path = 'configs/discrete_model/discrete_identity.yaml'
40
+
41
+ diffusion_model_path = 'pretrained/stable-diffusion-xl-base-1.0'
42
+
43
+ tokenizer_cfg = OmegaConf.load(tokenizer_cfg_path)
44
+ tokenizer = hydra.utils.instantiate(tokenizer_cfg)
45
+
46
+ image_transform_cfg = OmegaConf.load(image_transform_cfg_path)
47
+ image_transform = hydra.utils.instantiate(image_transform_cfg)
48
+
49
+ visual_encoder_cfg = OmegaConf.load(visual_encoder_cfg_path)
50
+ visual_encoder = hydra.utils.instantiate(visual_encoder_cfg)
51
+ visual_encoder.eval().to(device1, dtype=dtype)
52
+ print('Init visual encoder done')
53
+
54
+ llm_cfg = OmegaConf.load(llm_cfg_path)
55
+ llm = hydra.utils.instantiate(llm_cfg, torch_dtype=dtype)
56
+ print('Init llm done.')
57
+
58
+ agent_model_cfg = OmegaConf.load(agent_cfg_path)
59
+ agent_model = hydra.utils.instantiate(agent_model_cfg, llm=llm)
60
+
61
+ agent_model.eval().to(device, dtype=dtype)
62
+ print('Init agent mdoel Done')
63
+
64
+ noise_scheduler = EulerDiscreteScheduler.from_pretrained(diffusion_model_path, subfolder="scheduler")
65
+ print('init vae')
66
+ vae = AutoencoderKL.from_pretrained(diffusion_model_path, subfolder="vae").to(device1, dtype=dtype)
67
+ print('init unet')
68
+ unet = UNet2DConditionModel.from_pretrained(diffusion_model_path, subfolder="unet").to(device1, dtype=dtype)
69
+
70
+ adapter_cfg = OmegaConf.load(adapter_cfg_path)
71
+ adapter = hydra.utils.instantiate(adapter_cfg, unet=unet).to(device1, dtype=dtype).eval()
72
+
73
+ discrete_model_cfg = OmegaConf.load(discrete_model_cfg_path)
74
+ discrete_model = hydra.utils.instantiate(discrete_model_cfg).to(device1).eval()
75
+ print('Init adapter done')
76
+
77
+ adapter.init_pipe(vae=vae,
78
+ scheduler=noise_scheduler,
79
+ visual_encoder=visual_encoder,
80
+ image_transform=image_transform,
81
+ dtype=dtype,
82
+ device=device1)
83
+
84
+ print('Init adapter pipe done')
85
+ boi_token_id = tokenizer.encode(BOI_TOKEN, add_special_tokens=False)[0]
86
+ eoi_token_id = tokenizer.encode(EOI_TOKEN, add_special_tokens=False)[0]
87
+
88
+ bop_token_id = tokenizer.encode(BOP_TOKEN, add_special_tokens=False)[0]
89
+ eop_token_id = tokenizer.encode(EOP_TOKEN, add_special_tokens=False)[0]
90
+
91
+ grid_pinpoints = []
92
+ for scale in resolution_grids:
93
+ s1, s2 = scale.split('x')
94
+ grid_pinpoints.append([int(s1)*base_resolution, int(s2)*base_resolution])
95
+ grid_pinpoints = grid_pinpoints
96
+
97
+
98
+ image_path = 'demo_images/car.jpg'
99
+ instruction = 'Make it under the sunset'
100
+
101
+ image = Image.open(image_path).convert('RGB')
102
+ source_image = image.resize((1024, 1024))
103
+
104
+ image_tensor, patch_pos_tensor = process_anyres_image(image, image_transform, grid_pinpoints, base_resolution)
105
+ embeds_cmp_mask = torch.tensor([True]*image_tensor.shape[0]).to(device, dtype=torch.bool)
106
+
107
+ patch_pos = [patch_pos_tensor]
108
+ patch_position = torch.cat(patch_pos, dim=0)
109
+
110
+ image_tensor = image_tensor.to(device1, dtype=dtype)
111
+
112
+ patch_length = image_tensor.shape[0]
113
+ image_tokens = ''
114
+ for _ in range(patch_length-1):
115
+ image_tokens += BOP_TOKEN + ''.join(IMG_TOKEN.format(int(item)) for item in range(num_img_in_tokens)) + EOP_TOKEN
116
+ image_tokens += BOI_TOKEN + ''.join(IMG_TOKEN.format(int(item)) for item in range(num_img_in_tokens)) + EOI_TOKEN
117
+
118
+ prompt = instruction_prompt.format_map({'instruction': image_tokens + instruction})
119
+
120
+ input_ids = tokenizer.encode(prompt, add_special_tokens=False)
121
+ input_ids = [tokenizer.bos_token_id] + input_ids
122
+
123
+ input_ids = torch.tensor(input_ids).to(device, dtype=torch.long)
124
+
125
+ ids_cmp_mask = torch.zeros_like(input_ids, dtype=torch.bool)
126
+
127
+ boi_indices = torch.where(torch.logical_or(input_ids == boi_token_id, input_ids == bop_token_id))[0].tolist()
128
+ eoi_indices = torch.where(torch.logical_or(input_ids == eoi_token_id, input_ids == eop_token_id))[0].tolist()
129
+
130
+ for boi_idx, eoi_idx in zip(boi_indices, eoi_indices):
131
+ ids_cmp_mask[boi_idx + 1:eoi_idx] = True
132
+
133
+ input_ids = input_ids.unsqueeze(0)
134
+ ids_cmp_mask = ids_cmp_mask.unsqueeze(0)
135
+
136
+ with torch.no_grad():
137
+ image_embeds = visual_encoder(image_tensor)
138
+ image_embeds = image_embeds.to(device)
139
+ output = agent_model.generate(tokenizer=tokenizer,
140
+ input_ids=input_ids,
141
+ image_embeds=image_embeds,
142
+ embeds_cmp_mask=embeds_cmp_mask,
143
+ patch_positions=patch_position,
144
+ ids_cmp_mask=ids_cmp_mask,
145
+ max_new_tokens=512,
146
+ num_img_gen_tokens=num_img_out_tokens)
147
+ text = re.sub('<[^>]*>', '', output['text'])
148
+ print(text)
149
+
150
+ if output['has_img_output']:
151
+ images = adapter.generate(image_embeds=output['img_gen_feat'].to(device1), latent_image=source_image, num_inference_steps=50)
152
+
153
+ save_path = os.path.join(save_dir, str(len(os.listdir(save_dir))) + '_' + instruction + '.jpg')
154
+ images[0].save(save_path)
155
+ torch.cuda.empty_cache()
src/inference/eval_img2text_seed_x.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hydra
2
+ import torch
3
+ import os
4
+ import pyrootutils
5
+ from PIL import Image
6
+ import re
7
+ import cv2
8
+ import numpy as np
9
+ from omegaconf import OmegaConf
10
+ from diffusers import AutoencoderKL, UNet2DConditionModel, EulerDiscreteScheduler
11
+ from any_res import process_anyres_image
12
+
13
+
14
+ pyrootutils.setup_root(__file__, indicator='.project-root', pythonpath=True)
15
+
16
+ def visualize_bbox(image, bboxes, save_path):
17
+ img_width, img_height = image.size
18
+ image = np.array(image)
19
+ image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
20
+ for bbox in bboxes:
21
+ x_center, y_center, box_width, box_height = bbox
22
+
23
+ x_center = x_center / 224 * img_width
24
+ y_center = y_center / 224 * img_height
25
+
26
+ box_width = box_width /224 * img_width
27
+ box_height = box_height / 224 * img_height
28
+
29
+ x1 = int(x_center - box_width / 2)
30
+ y1 = int(y_center - box_height / 2)
31
+ x2 = int(x_center + box_width / 2)
32
+ y2 = int(y_center + box_height / 2)
33
+
34
+ cv2.rectangle(image, (x1, y1), (x2, y2), (0, 255, 0), 2)
35
+
36
+ cv2.imwrite(save_path, image)
37
+
38
+
39
+ def extract_box(output_str):
40
+ boxes = re.findall('<box_start>(.*?)<box_end>', output_str)
41
+ if len(boxes) >0:
42
+ bboxes = [[int(num) for num in re.findall('<loc-(\d+)>', box)] for box in boxes]
43
+ else:
44
+ bboxes = None
45
+
46
+ return bboxes
47
+
48
+
49
+ BOI_TOKEN = '<img>'
50
+ BOP_TOKEN = '<patch>'
51
+ EOI_TOKEN = '</img>'
52
+ EOP_TOKEN = '</patch>'
53
+ IMG_TOKEN = '<img_{:05d}>'
54
+
55
+ instruction_prompt = '[INST] {instruction} [/INST]\n'
56
+
57
+ resolution_grids = ['1x1', '1x2', '1x3', '2x1', '3x1', '1x4', '4x1', '2x2']
58
+ base_resolution = 448
59
+
60
+ device = 'cuda:0'
61
+ device1 = 'cuda:1'
62
+ dtype = torch.float16
63
+ dtype_str = 'fp16'
64
+ num_img_in_tokens = 64
65
+ num_img_out_tokens = 64
66
+
67
+ tokenizer_cfg_path = 'configs/tokenizer/clm_llama_tokenizer_224loc_anyres.yaml'
68
+ image_transform_cfg_path = 'configs/processer/qwen_448_transform.yaml'
69
+ visual_encoder_cfg_path = 'configs/visual_encoder/qwen_vitg_448.yaml'
70
+ llm_cfg_path = 'configs/clm_models/llm_seed_x_i.yaml'
71
+ agent_cfg_path = 'configs/clm_models/agent_seed_x_i.yaml'
72
+ adapter_cfg_path = 'configs/sdxl_adapter/sdxl_qwen_vit_resampler_l4_q64_pretrain_no_normalize.yaml'
73
+ discrete_model_cfg_path = 'configs/discrete_model/discrete_identity.yaml'
74
+
75
+ diffusion_model_path = 'pretrained/stable-diffusion-xl-base-1.0'
76
+
77
+ tokenizer_cfg = OmegaConf.load(tokenizer_cfg_path)
78
+ tokenizer = hydra.utils.instantiate(tokenizer_cfg)
79
+
80
+ image_transform_cfg = OmegaConf.load(image_transform_cfg_path)
81
+ image_transform = hydra.utils.instantiate(image_transform_cfg)
82
+
83
+ visual_encoder_cfg = OmegaConf.load(visual_encoder_cfg_path)
84
+ visual_encoder = hydra.utils.instantiate(visual_encoder_cfg)
85
+ visual_encoder.eval().to(device1, dtype=dtype)
86
+ print('Init visual encoder done')
87
+
88
+ llm_cfg = OmegaConf.load(llm_cfg_path)
89
+ llm = hydra.utils.instantiate(llm_cfg, torch_dtype=dtype)
90
+ print('Init llm done.')
91
+
92
+ agent_model_cfg = OmegaConf.load(agent_cfg_path)
93
+ agent_model = hydra.utils.instantiate(agent_model_cfg, llm=llm)
94
+
95
+ agent_model.eval().to(device, dtype=dtype)
96
+ print('Init agent mdoel Done')
97
+
98
+ noise_scheduler = EulerDiscreteScheduler.from_pretrained(diffusion_model_path, subfolder="scheduler")
99
+ print('init vae')
100
+ vae = AutoencoderKL.from_pretrained(diffusion_model_path, subfolder="vae").to(device1, dtype=dtype)
101
+ print('init unet')
102
+ unet = UNet2DConditionModel.from_pretrained(diffusion_model_path, subfolder="unet").to(device1, dtype=dtype)
103
+
104
+ adapter_cfg = OmegaConf.load(adapter_cfg_path)
105
+ adapter = hydra.utils.instantiate(adapter_cfg, unet=unet).to(device1, dtype=dtype).eval()
106
+
107
+ discrete_model_cfg = OmegaConf.load(discrete_model_cfg_path)
108
+ discrete_model = hydra.utils.instantiate(discrete_model_cfg).to(device1).eval()
109
+ print('Init adapter done')
110
+
111
+ adapter.init_pipe(vae=vae,
112
+ scheduler=noise_scheduler,
113
+ visual_encoder=visual_encoder,
114
+ image_transform=image_transform,
115
+ discrete_model=discrete_model,
116
+ dtype=dtype,
117
+ device=device1)
118
+
119
+ print('Init adapter pipe done')
120
+ boi_token_id = tokenizer.encode(BOI_TOKEN, add_special_tokens=False)[0]
121
+ eoi_token_id = tokenizer.encode(EOI_TOKEN, add_special_tokens=False)[0]
122
+
123
+ bop_token_id = tokenizer.encode(BOP_TOKEN, add_special_tokens=False)[0]
124
+ eop_token_id = tokenizer.encode(EOP_TOKEN, add_special_tokens=False)[0]
125
+
126
+ grid_pinpoints = []
127
+ for scale in resolution_grids:
128
+ s1, s2 = scale.split('x')
129
+ grid_pinpoints.append([int(s1)*base_resolution, int(s2)*base_resolution])
130
+ grid_pinpoints = grid_pinpoints
131
+
132
+ # image comprehension
133
+ image_path = 'demo_images/advisor.png'
134
+ image = Image.open(image_path).convert('RGB')
135
+ image_tensor, patch_pos_tensor = process_anyres_image(image, image_transform, grid_pinpoints, base_resolution)
136
+ embeds_cmp_mask = torch.tensor([True]*image_tensor.shape[0]).to(device, dtype=torch.bool)
137
+
138
+ patch_pos = [patch_pos_tensor]
139
+ patch_position = torch.cat(patch_pos, dim=0)
140
+
141
+ image_tensor = image_tensor.to(device1, dtype=dtype)
142
+
143
+ patch_length = image_tensor.shape[0]
144
+ image_tokens = ''
145
+ for _ in range(patch_length-1):
146
+ image_tokens += BOP_TOKEN + ''.join(IMG_TOKEN.format(int(item)) for item in range(num_img_in_tokens)) + EOP_TOKEN
147
+ image_tokens += BOI_TOKEN + ''.join(IMG_TOKEN.format(int(item)) for item in range(num_img_in_tokens)) + EOI_TOKEN
148
+
149
+ question = 'Can I conntect with an advisor on Sunday?'
150
+ prompt = instruction_prompt.format_map({'instruction': image_tokens + question})
151
+
152
+ input_ids = tokenizer.encode(prompt, add_special_tokens=False)
153
+ input_ids = [tokenizer.bos_token_id] + input_ids
154
+
155
+ input_ids = torch.tensor(input_ids).to(device, dtype=torch.long)
156
+
157
+ ids_cmp_mask = torch.zeros_like(input_ids, dtype=torch.bool)
158
+
159
+ boi_indices = torch.where(torch.logical_or(input_ids == boi_token_id, input_ids == bop_token_id))[0].tolist()
160
+ eoi_indices = torch.where(torch.logical_or(input_ids == eoi_token_id, input_ids == eop_token_id))[0].tolist()
161
+
162
+ for boi_idx, eoi_idx in zip(boi_indices, eoi_indices):
163
+ ids_cmp_mask[boi_idx + 1:eoi_idx] = True
164
+
165
+ input_ids = input_ids.unsqueeze(0)
166
+ ids_cmp_mask = ids_cmp_mask.unsqueeze(0)
167
+
168
+ with torch.no_grad():
169
+ image_embeds = visual_encoder(image_tensor)
170
+ image_embeds = image_embeds.to(device)
171
+ output = agent_model.generate(tokenizer=tokenizer,
172
+ input_ids=input_ids,
173
+ image_embeds=image_embeds,
174
+ embeds_cmp_mask=embeds_cmp_mask,
175
+ patch_positions=patch_position,
176
+ ids_cmp_mask=ids_cmp_mask,
177
+ max_new_tokens=512,
178
+ num_img_gen_tokens=num_img_out_tokens)
179
+
180
+ text = re.sub('<[^>]*>', '', output['text'])
181
+ print(text)
182
+
183
+ # detection
184
+ image_path = 'demo_images/ground.png'
185
+ image = Image.open(image_path).convert('RGB')
186
+ image_tensor, patch_pos_tensor = process_anyres_image(image, image_transform, grid_pinpoints, base_resolution)
187
+ embeds_cmp_mask = torch.tensor([True]*image_tensor.shape[0]).to(device, dtype=torch.bool)
188
+
189
+ patch_pos = [patch_pos_tensor]
190
+ patch_position = torch.cat(patch_pos, dim=0)
191
+
192
+ image_tensor = image_tensor.to(device1, dtype=dtype)
193
+
194
+ patch_length = image_tensor.shape[0]
195
+ image_tokens = ''
196
+ for _ in range(patch_length-1):
197
+ image_tokens += BOP_TOKEN + ''.join(IMG_TOKEN.format(int(item)) for item in range(num_img_in_tokens)) + EOP_TOKEN
198
+ image_tokens += BOI_TOKEN + ''.join(IMG_TOKEN.format(int(item)) for item in range(num_img_in_tokens)) + EOI_TOKEN
199
+
200
+ question = 'Is there anything in the image that can protect me from catching the flu virus when I go out? Show me the location.'
201
+ prompt = instruction_prompt.format_map({'instruction': image_tokens + question})
202
+
203
+ input_ids = tokenizer.encode(prompt, add_special_tokens=False)
204
+ input_ids = [tokenizer.bos_token_id] + input_ids
205
+
206
+ input_ids = torch.tensor(input_ids).to(device, dtype=torch.long)
207
+
208
+ ids_cmp_mask = torch.zeros_like(input_ids, dtype=torch.bool)
209
+
210
+ boi_indices = torch.where(torch.logical_or(input_ids == boi_token_id, input_ids == bop_token_id))[0].tolist()
211
+ eoi_indices = torch.where(torch.logical_or(input_ids == eoi_token_id, input_ids == eop_token_id))[0].tolist()
212
+
213
+ for boi_idx, eoi_idx in zip(boi_indices, eoi_indices):
214
+ ids_cmp_mask[boi_idx + 1:eoi_idx] = True
215
+
216
+ input_ids = input_ids.unsqueeze(0)
217
+ ids_cmp_mask = ids_cmp_mask.unsqueeze(0)
218
+
219
+ with torch.no_grad():
220
+ image_embeds = visual_encoder(image_tensor)
221
+ image_embeds = image_embeds.to(device)
222
+ output = agent_model.generate(tokenizer=tokenizer,
223
+ input_ids=input_ids,
224
+ image_embeds=image_embeds,
225
+ embeds_cmp_mask=embeds_cmp_mask,
226
+ patch_positions=patch_position,
227
+ ids_cmp_mask=ids_cmp_mask,
228
+ max_new_tokens=512,
229
+ num_img_gen_tokens=num_img_out_tokens)
230
+ print(output['text'])
231
+ bbox = extract_box(output['text'])
232
+ if bbox is not None:
233
+ save_path = 'vis/ground.png'
234
+ visualize_bbox(image, bbox, save_path)
235
+
src/inference/eval_text2img_seed_x.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hydra
2
+ import torch
3
+ import os
4
+ import pyrootutils
5
+ from PIL import Image
6
+ from omegaconf import OmegaConf
7
+ from diffusers import AutoencoderKL, UNet2DConditionModel, EulerDiscreteScheduler
8
+
9
+
10
+ pyrootutils.setup_root(__file__, indicator='.project-root', pythonpath=True)
11
+
12
+ BOI_TOKEN = '<img>'
13
+ EOI_TOKEN = '</img>'
14
+ IMG_TOKEN = '<img_{:05d}>'
15
+
16
+ device = 'cuda:0'
17
+ device_2 = 'cuda:1'
18
+ dtype = torch.float16
19
+ dtype_str = 'fp16'
20
+ num_img_in_tokens = 64
21
+ num_img_out_tokens = 64
22
+
23
+ instruction_prompt = '[INST] Generate an image: {caption} [/INST]\n'
24
+
25
+ tokenizer_cfg_path = 'configs/tokenizer/clm_llama_tokenizer_224loc_anyres.yaml'
26
+ image_transform_cfg_path = 'configs/processer/qwen_448_transform.yaml'
27
+ visual_encoder_cfg_path = 'configs/visual_encoder/qwen_vitg_448.yaml'
28
+ llm_cfg_path = 'configs/clm_models/llm_seed_x_i.yaml'
29
+ agent_cfg_path = 'configs/clm_models/agent_seed_x_i.yaml'
30
+ adapter_cfg_path = 'configs/sdxl_adapter/sdxl_qwen_vit_resampler_l4_q64_pretrain_no_normalize.yaml'
31
+ discrete_model_cfg_path = 'configs/discrete_model/discrete_identity.yaml'
32
+
33
+ diffusion_model_path = 'pretrained/stable-diffusion-xl-base-1.0'
34
+
35
+ save_dir = 'vis'
36
+ os.makedirs(save_dir, exist_ok=True)
37
+
38
+ tokenizer_cfg = OmegaConf.load(tokenizer_cfg_path)
39
+ tokenizer = hydra.utils.instantiate(tokenizer_cfg)
40
+
41
+ image_transform_cfg = OmegaConf.load(image_transform_cfg_path)
42
+ image_transform = hydra.utils.instantiate(image_transform_cfg)
43
+
44
+ visual_encoder_cfg = OmegaConf.load(visual_encoder_cfg_path)
45
+ visual_encoder = hydra.utils.instantiate(visual_encoder_cfg)
46
+ visual_encoder.eval().to(device_2, dtype=dtype)
47
+ print('Init visual encoder done')
48
+
49
+ llm_cfg = OmegaConf.load(llm_cfg_path)
50
+ llm = hydra.utils.instantiate(llm_cfg, torch_dtype=dtype)
51
+ print('Init llm done.')
52
+
53
+ agent_model_cfg = OmegaConf.load(agent_cfg_path)
54
+ agent_model = hydra.utils.instantiate(agent_model_cfg, llm=llm)
55
+
56
+ agent_model.eval().to(device, dtype=dtype)
57
+ print('Init agent mdoel Done')
58
+
59
+ noise_scheduler = EulerDiscreteScheduler.from_pretrained(diffusion_model_path, subfolder="scheduler")
60
+ print('init vae')
61
+ vae = AutoencoderKL.from_pretrained(diffusion_model_path, subfolder="vae").to(device_2, dtype=dtype)
62
+ print('init unet')
63
+ unet = UNet2DConditionModel.from_pretrained(diffusion_model_path, subfolder="unet").to(device_2, dtype=dtype)
64
+
65
+ adapter_cfg = OmegaConf.load(adapter_cfg_path)
66
+ adapter = hydra.utils.instantiate(adapter_cfg, unet=unet).to(device_2, dtype=dtype).eval()
67
+
68
+ discrete_model_cfg = OmegaConf.load(discrete_model_cfg_path)
69
+ discrete_model = hydra.utils.instantiate(discrete_model_cfg).to(device_2).eval()
70
+ print('Init adapter done')
71
+
72
+ adapter.init_pipe(vae=vae,
73
+ scheduler=noise_scheduler,
74
+ visual_encoder=visual_encoder,
75
+ image_transform=image_transform,
76
+ discrete_model=discrete_model,
77
+ dtype=dtype,
78
+ device=device_2)
79
+
80
+ print('Init adapter pipe done')
81
+
82
+ caption = 'A cybernetic soldier, enhanced with advanced weapons systems and tactical analysis software, on a mission behind enemy lines.'
83
+ prompt = instruction_prompt.format_map({'caption': caption})
84
+ prompt_ids = tokenizer.encode(prompt, add_special_tokens=False)
85
+ input_ids = torch.tensor([tokenizer.bos_token_id] + prompt_ids).to(device, dtype=torch.long).unsqueeze(0)
86
+ output = agent_model.generate(tokenizer=tokenizer, input_ids=input_ids, num_img_gen_tokens=num_img_out_tokens)
87
+ print(output['has_img_output'])
88
+ print(output['text'])
89
+
90
+ if output['has_img_output']:
91
+ images = adapter.generate(image_embeds=output['img_gen_feat'].to(device_2), num_inference_steps=50)
92
+ save_path = os.path.join(save_dir, caption.replace('.', '') + '.png')
93
+ images[0].save(save_path)
94
+ torch.cuda.empty_cache()
src/models/detokenizer/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
src/models/detokenizer/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (182 Bytes). View file
 
src/models/detokenizer/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (175 Bytes). View file
 
src/models/detokenizer/__pycache__/adapter_modules.cpython-311.pyc ADDED
Binary file (14 kB). View file
 
src/models/detokenizer/__pycache__/adapter_modules.cpython-38.pyc ADDED
Binary file (7.31 kB). View file
 
src/models/detokenizer/__pycache__/attention_processor.cpython-38.pyc ADDED
Binary file (7.4 kB). View file
 
src/models/detokenizer/__pycache__/ipa_utils.cpython-38.pyc ADDED
Binary file (397 Bytes). View file
 
src/models/detokenizer/__pycache__/pipeline_stable_diffusion_t2i_edit.cpython-38.pyc ADDED
Binary file (28.3 kB). View file
 
src/models/detokenizer/__pycache__/pipeline_stable_diffusion_xl_t2i_edit.cpython-311.pyc ADDED
Binary file (53 kB). View file
 
src/models/detokenizer/__pycache__/pipeline_stable_diffusion_xl_t2i_edit.cpython-38.pyc ADDED
Binary file (36.8 kB). View file
 
src/models/detokenizer/__pycache__/resampler.cpython-311.pyc ADDED
Binary file (16.2 kB). View file